Skip to main content

sc_neurocore_engine/simd/
avx2.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore — AVX2
7
8#[cfg(target_arch = "x86_64")]
9use core::arch::x86_64::*;
10
11#[cfg(target_arch = "x86_64")]
12#[target_feature(enable = "avx2")]
13/// Count set bits in 64-bit words using AVX2.
14///
15/// # Safety
16/// Caller must ensure the current CPU supports `avx2`.
17pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
18    let mut total = 0_u64;
19    let mut chunks = data.chunks_exact(4);
20
21    let m1 = _mm256_set1_epi64x(0x5555_5555_5555_5555_u64 as i64);
22    let m2 = _mm256_set1_epi64x(0x3333_3333_3333_3333_u64 as i64);
23    let m4 = _mm256_set1_epi64x(0x0f0f_0f0f_0f0f_0f0f_u64 as i64);
24
25    for chunk in &mut chunks {
26        let mut x = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
27        x = _mm256_sub_epi64(x, _mm256_and_si256(_mm256_srli_epi64::<1>(x), m1));
28        x = _mm256_add_epi64(
29            _mm256_and_si256(x, m2),
30            _mm256_and_si256(_mm256_srli_epi64::<2>(x), m2),
31        );
32        x = _mm256_and_si256(_mm256_add_epi64(x, _mm256_srli_epi64::<4>(x)), m4);
33
34        let mut lanes = [0_u64; 4];
35        _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, x);
36        total += lanes
37            .iter()
38            .copied()
39            .map(|lane| lane.wrapping_mul(0x0101_0101_0101_0101) >> 56)
40            .sum::<u64>();
41    }
42
43    total + crate::bitstream::popcount_words_portable(chunks.remainder())
44}
45
46#[cfg(target_arch = "x86_64")]
47#[target_feature(enable = "avx2")]
48/// Pack u8 bits into u64 words using AVX2 movemask.
49///
50/// Processes 64 bytes into one u64 word by building two 32-bit masks.
51///
52/// # Safety
53/// Caller must ensure the current CPU supports `avx2`.
54pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
55    let length = bits.len();
56    let words = length.div_ceil(64);
57    let mut data = vec![0_u64; words];
58    let full_words = length / 64;
59    let zero = _mm256_setzero_si256();
60
61    for (word_idx, word) in data.iter_mut().take(full_words).enumerate() {
62        let base = word_idx * 64;
63        let lo = _mm256_loadu_si256(bits.as_ptr().add(base) as *const __m256i);
64        let hi = _mm256_loadu_si256(bits.as_ptr().add(base + 32) as *const __m256i);
65
66        let lo_eq_zero = _mm256_cmpeq_epi8(lo, zero);
67        let hi_eq_zero = _mm256_cmpeq_epi8(hi, zero);
68        let lo_mask = !(_mm256_movemask_epi8(lo_eq_zero) as u32);
69        let hi_mask = !(_mm256_movemask_epi8(hi_eq_zero) as u32);
70
71        *word = ((hi_mask as u64) << 32) | (lo_mask as u64);
72    }
73
74    if full_words < words {
75        let tail_start = full_words * 64;
76        let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
77        data[full_words] = tail.data.first().copied().unwrap_or(0);
78    }
79
80    data
81}
82
83#[cfg(target_arch = "x86_64")]
84#[target_feature(enable = "avx2")]
85/// Fused AND+popcount over packed words using AVX2 for the AND stage.
86///
87/// # Safety
88/// Caller must ensure the current CPU supports `avx2`.
89pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
90    let len = a.len().min(b.len());
91    let mut total = 0_u64;
92    let mut chunks_a = a[..len].chunks_exact(4);
93    let mut chunks_b = b[..len].chunks_exact(4);
94
95    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
96        let va = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
97        let vb = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
98        let anded = _mm256_and_si256(va, vb);
99
100        let mut lanes = [0_u64; 4];
101        _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, anded);
102        total += lanes.iter().map(|w| w.count_ones() as u64).sum::<u64>();
103    }
104
105    total
106        + chunks_a
107            .remainder()
108            .iter()
109            .zip(chunks_b.remainder().iter())
110            .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
111            .sum::<u64>()
112}
113
114#[cfg(target_arch = "x86_64")]
115#[target_feature(enable = "avx2")]
116/// Fused XOR+popcount over packed words using AVX2 for the XOR stage.
117///
118/// # Safety
119/// Caller must ensure the current CPU supports `avx2`.
120pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
121    let len = a.len().min(b.len());
122    let mut total = 0_u64;
123    let mut chunks_a = a[..len].chunks_exact(4);
124    let mut chunks_b = b[..len].chunks_exact(4);
125
126    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
127        let va = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
128        let vb = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
129        let xored = _mm256_xor_si256(va, vb);
130
131        let mut lanes = [0_u64; 4];
132        _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, xored);
133        total += lanes.iter().map(|w| w.count_ones() as u64).sum::<u64>();
134    }
135
136    total
137        + chunks_a
138            .remainder()
139            .iter()
140            .zip(chunks_b.remainder().iter())
141            .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
142            .sum::<u64>()
143}
144
145#[cfg(not(target_arch = "x86_64"))]
146/// Fallback fused XOR+popcount when AVX2 is unavailable on this architecture.
147///
148/// # Safety
149/// This function is marked unsafe for API parity with the AVX2 variant.
150pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
151    a.iter()
152        .zip(b.iter())
153        .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
154        .sum()
155}
156
157#[cfg(target_arch = "x86_64")]
158#[target_feature(enable = "avx2")]
159/// Compare 32 random bytes against an unsigned threshold and return bit mask.
160///
161/// Bit `i` in the returned mask is 1 iff `buf[i] < threshold`.
162///
163/// # Safety
164/// Caller must ensure the current CPU supports `avx2`.
165/// `buf` must have at least 32 elements.
166pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
167    assert!(buf.len() >= 32, "buffer must contain at least 32 bytes");
168
169    let data = _mm256_loadu_si256(buf.as_ptr() as *const __m256i);
170    let bias = _mm256_set1_epi8(i8::MIN);
171    let data_biased = _mm256_xor_si256(data, bias);
172    let thresh_biased = _mm256_set1_epi8((threshold ^ 0x80) as i8);
173    let lt = _mm256_cmpgt_epi8(thresh_biased, data_biased);
174    _mm256_movemask_epi8(lt) as u32
175}
176
177#[cfg(not(target_arch = "x86_64"))]
178/// Fallback popcount when AVX2 is unavailable on this architecture.
179///
180/// # Safety
181/// This function is marked unsafe for API parity with the AVX2 variant.
182pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
183    crate::bitstream::popcount_words_portable(data)
184}
185
186#[cfg(not(target_arch = "x86_64"))]
187/// Fallback pack when AVX2 is unavailable on this architecture.
188///
189/// # Safety
190/// This function is marked unsafe for API parity with the AVX2 variant.
191pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
192    crate::bitstream::pack_fast(bits).data
193}
194
195#[cfg(not(target_arch = "x86_64"))]
196/// Fallback fused AND+popcount when AVX2 is unavailable on this architecture.
197///
198/// # Safety
199/// This function is marked unsafe for API parity with the AVX2 variant.
200pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
201    a.iter()
202        .zip(b.iter())
203        .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
204        .sum()
205}
206
207#[cfg(not(target_arch = "x86_64"))]
208/// Fallback Bernoulli compare when AVX2 is unavailable on this architecture.
209///
210/// # Safety
211/// This function is marked unsafe for API parity with the AVX2 variant.
212pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
213    let mut mask = 0_u32;
214    for (bit, &rb) in buf.iter().take(32).enumerate() {
215        if rb < threshold {
216            mask |= 1_u32 << bit;
217        }
218    }
219    mask
220}
221
222// --- f64 SIMD operations (AVX2: 4-wide f64) ---
223
224#[cfg(target_arch = "x86_64")]
225#[target_feature(enable = "avx2,fma")]
226/// Dot product of two f64 slices using AVX2 FMA.
227///
228/// # Safety
229/// Caller must ensure the current CPU supports `avx2` and `fma`.
230pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
231    let len = a.len().min(b.len());
232    let mut acc = _mm256_setzero_pd();
233    let mut chunks_a = a[..len].chunks_exact(4);
234    let mut chunks_b = b[..len].chunks_exact(4);
235
236    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
237        let va = _mm256_loadu_pd(ca.as_ptr());
238        let vb = _mm256_loadu_pd(cb.as_ptr());
239        acc = _mm256_fmadd_pd(va, vb, acc);
240    }
241
242    let mut lanes = [0.0_f64; 4];
243    _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
244    let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
245
246    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
247        sum += ra * rb;
248    }
249    sum
250}
251
252#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx2")]
254/// Maximum of f64 slice using AVX2.
255///
256/// # Safety
257/// Caller must ensure the current CPU supports `avx2`.
258pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
259    if a.is_empty() {
260        return f64::NEG_INFINITY;
261    }
262    let mut vmax = _mm256_set1_pd(f64::NEG_INFINITY);
263    let mut chunks = a.chunks_exact(4);
264
265    for chunk in chunks.by_ref() {
266        let va = _mm256_loadu_pd(chunk.as_ptr());
267        vmax = _mm256_max_pd(vmax, va);
268    }
269
270    let mut lanes = [0.0_f64; 4];
271    _mm256_storeu_pd(lanes.as_mut_ptr(), vmax);
272    let mut m = lanes[0].max(lanes[1]).max(lanes[2].max(lanes[3]));
273    for &v in chunks.remainder() {
274        m = m.max(v);
275    }
276    m
277}
278
279#[cfg(target_arch = "x86_64")]
280#[target_feature(enable = "avx2")]
281/// Sum of f64 slice using AVX2.
282///
283/// # Safety
284/// Caller must ensure the current CPU supports `avx2`.
285pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
286    let mut acc = _mm256_setzero_pd();
287    let mut chunks = a.chunks_exact(4);
288
289    for chunk in chunks.by_ref() {
290        let va = _mm256_loadu_pd(chunk.as_ptr());
291        acc = _mm256_add_pd(acc, va);
292    }
293
294    let mut lanes = [0.0_f64; 4];
295    _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
296    let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
297    for &v in chunks.remainder() {
298        sum += v;
299    }
300    sum
301}
302
303#[cfg(target_arch = "x86_64")]
304#[target_feature(enable = "avx2")]
305/// Scale f64 slice in-place: y[i] *= alpha, using AVX2.
306///
307/// # Safety
308/// Caller must ensure the current CPU supports `avx2`.
309pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
310    let valpha = _mm256_set1_pd(alpha);
311    let mut chunks = y.chunks_exact_mut(4);
312
313    for chunk in chunks.by_ref() {
314        let vy = _mm256_loadu_pd(chunk.as_ptr());
315        let scaled = _mm256_mul_pd(vy, valpha);
316        _mm256_storeu_pd(chunk.as_mut_ptr(), scaled);
317    }
318
319    for v in chunks.into_remainder() {
320        *v *= alpha;
321    }
322}
323
324#[cfg(not(target_arch = "x86_64"))]
325pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
326    let len = a.len().min(b.len());
327    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
328}
329
330#[cfg(not(target_arch = "x86_64"))]
331pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
332    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
333}
334
335#[cfg(not(target_arch = "x86_64"))]
336pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
337    a.iter().sum()
338}
339
340#[cfg(not(target_arch = "x86_64"))]
341pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
342    for v in y.iter_mut() {
343        *v *= alpha;
344    }
345}
346
347/// Hamming distance between two packed bitstream slices using AVX2.
348///
349/// # Safety
350/// Caller must ensure the current CPU supports `avx2`.
351pub unsafe fn hamming_distance_avx2(a: &[u64], b: &[u64]) -> u64 {
352    fused_xor_popcount_avx2(a, b)
353}
354
355#[cfg(target_arch = "x86_64")]
356#[target_feature(enable = "avx2")]
357/// In-place softmax using AVX2 for max, sum, and scale steps.
358///
359/// # Safety
360/// Caller must ensure the current CPU supports `avx2`.
361pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
362    if scores.is_empty() {
363        return;
364    }
365    let max_val = max_f64_avx2(scores);
366    for s in scores.iter_mut() {
367        *s = (*s - max_val).exp();
368    }
369    let exp_sum = sum_f64_avx2(scores);
370    if exp_sum > 0.0 {
371        scale_f64_avx2(1.0 / exp_sum, scores);
372    }
373}
374
375#[cfg(not(target_arch = "x86_64"))]
376pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
377    if scores.is_empty() {
378        return;
379    }
380    let max_val = max_f64_avx2(scores);
381    for s in scores.iter_mut() {
382        *s = (*s - max_val).exp();
383    }
384    let exp_sum = sum_f64_avx2(scores);
385    if exp_sum > 0.0 {
386        scale_f64_avx2(1.0 / exp_sum, scores);
387    }
388}
389
390#[cfg(all(test, target_arch = "x86_64"))]
391mod tests {
392    use crate::bitstream::pack;
393
394    #[test]
395    fn pack_avx2_matches_pack() {
396        if !is_x86_feature_detected!("avx2") {
397            return;
398        }
399
400        let lengths = [
401            1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
402        ];
403        for length in lengths {
404            let bits: Vec<u8> = (0..length)
405                .map(|i| if (i * 17 + 5) % 3 == 0 { 1 } else { 0 })
406                .collect();
407            // SAFETY: Runtime-guarded by feature detection in this test.
408            let got = unsafe { super::pack_avx2(&bits) };
409            let expected = pack(&bits).data;
410            assert_eq!(got, expected, "Mismatch at length={length}");
411        }
412    }
413
414    #[test]
415    fn fused_and_popcount_avx2_matches_scalar() {
416        if !is_x86_feature_detected!("avx2") {
417            return;
418        }
419
420        let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
421        for len in lengths {
422            let a: Vec<u64> = (0..len)
423                .map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
424                .collect();
425            let b: Vec<u64> = (0..len)
426                .map(|i| (i as u64).wrapping_mul(0xC2B2_AE3D_27D4_EB4F) ^ 0x0F0F_F0F0_33CC_CC33)
427                .collect();
428
429            let expected: u64 = a
430                .iter()
431                .zip(b.iter())
432                .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
433                .sum();
434
435            // SAFETY: Runtime-guarded by feature detection in this test.
436            let got = unsafe { super::fused_and_popcount_avx2(&a, &b) };
437            assert_eq!(got, expected, "Mismatch at len={len}");
438        }
439    }
440
441    #[test]
442    fn dot_f64_avx2_matches_scalar() {
443        if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
444            return;
445        }
446        let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.1).collect();
447        let b: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 5.0).collect();
448        let expected: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
449        let got = unsafe { super::dot_f64_avx2(&a, &b) };
450        assert!(
451            (got - expected).abs() < 1e-9,
452            "dot: got {got}, expected {expected}"
453        );
454    }
455
456    #[test]
457    fn max_f64_avx2_matches_scalar() {
458        if !is_x86_feature_detected!("avx2") {
459            return;
460        }
461        let a: Vec<f64> = (0..67).map(|i| (i as f64 * 7.3).sin()).collect();
462        let expected = a.iter().copied().fold(f64::NEG_INFINITY, f64::max);
463        let got = unsafe { super::max_f64_avx2(&a) };
464        assert!(
465            (got - expected).abs() < 1e-12,
466            "max: got {got}, expected {expected}"
467        );
468    }
469
470    #[test]
471    fn sum_f64_avx2_matches_scalar() {
472        if !is_x86_feature_detected!("avx2") {
473            return;
474        }
475        let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.01).collect();
476        let expected: f64 = a.iter().sum();
477        let got = unsafe { super::sum_f64_avx2(&a) };
478        assert!(
479            (got - expected).abs() < 1e-9,
480            "sum: got {got}, expected {expected}"
481        );
482    }
483
484    #[test]
485    fn softmax_avx2_sums_to_one() {
486        if !is_x86_feature_detected!("avx2") {
487            return;
488        }
489        let mut scores: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 10.0).collect();
490        unsafe { super::softmax_inplace_f64_avx2(&mut scores) };
491        let sum: f64 = scores.iter().sum();
492        assert!(
493            (sum - 1.0).abs() < 1e-10,
494            "softmax must sum to 1.0, got {sum}"
495        );
496        assert!(scores.iter().all(|&s| s >= 0.0), "all values must be >= 0");
497    }
498
499    #[test]
500    fn bernoulli_compare_avx2_matches_scalar() {
501        if !is_x86_feature_detected!("avx2") {
502            return;
503        }
504
505        let buf: Vec<u8> = (0..32).map(|i| (i * 73 + 17) as u8).collect();
506        let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
507
508        for threshold in thresholds {
509            let expected = buf.iter().enumerate().fold(0_u32, |acc, (bit, &rb)| {
510                acc | (u32::from(rb < threshold) << bit)
511            });
512
513            // SAFETY: Runtime-guarded by feature detection in this test.
514            let got = unsafe { super::bernoulli_compare_avx2(&buf, threshold) };
515            assert_eq!(
516                got, expected,
517                "Mismatch for threshold={threshold} buf={buf:?}"
518            );
519        }
520    }
521}