Skip to main content

sc_neurocore_engine/simd/
avx2.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — AVX2
8
9#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64::*;
11
12#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx2")]
14/// Count set bits in 64-bit words using AVX2.
15///
16/// # Safety
17/// Caller must ensure the current CPU supports `avx2`.
18pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
19    let mut total = 0_u64;
20    let mut chunks = data.chunks_exact(16);
21
22    for chunk in chunks.by_ref() {
23        let v0 = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
24        let v1 = _mm256_loadu_si256(chunk.as_ptr().add(4) as *const __m256i);
25        let v2 = _mm256_loadu_si256(chunk.as_ptr().add(8) as *const __m256i);
26        let v3 = _mm256_loadu_si256(chunk.as_ptr().add(12) as *const __m256i);
27
28        let mut lanes = [0_u64; 16];
29        _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, v0);
30        _mm256_storeu_si256(lanes.as_mut_ptr().add(4) as *mut __m256i, v1);
31        _mm256_storeu_si256(lanes.as_mut_ptr().add(8) as *mut __m256i, v2);
32        _mm256_storeu_si256(lanes.as_mut_ptr().add(12) as *mut __m256i, v3);
33
34        for &w in lanes.iter() {
35            total += w.count_ones() as u64;
36        }
37    }
38
39    total
40        + chunks
41            .remainder()
42            .iter()
43            .map(|&w| w.count_ones() as u64)
44            .sum::<u64>()
45}
46
47#[cfg(target_arch = "x86_64")]
48#[target_feature(enable = "avx2")]
49/// Pack u8 bits into u64 words using AVX2 movemask.
50///
51/// Processes 64 bytes into one u64 word by building two 32-bit masks.
52///
53/// # Safety
54/// Caller must ensure the current CPU supports `avx2`.
55pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
56    let length = bits.len();
57    let words = length.div_ceil(64);
58    let mut data = vec![0_u64; words];
59    let full_words = length / 64;
60    let zero = _mm256_setzero_si256();
61
62    let mut chunks = data[..full_words].chunks_exact_mut(4);
63    let mut word_idx = 0;
64    for chunk in chunks.by_ref() {
65        let base = word_idx * 64;
66        for i in 0..4 {
67            let b = base + i * 64;
68            let lo = _mm256_loadu_si256(bits.as_ptr().add(b) as *const __m256i);
69            let hi = _mm256_loadu_si256(bits.as_ptr().add(b + 32) as *const __m256i);
70            let lo_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(lo, zero)) as u32);
71            let hi_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(hi, zero)) as u32);
72            chunk[i] = ((hi_mask as u64) << 32) | (lo_mask as u64);
73        }
74        word_idx += 4;
75    }
76
77    for i in word_idx..full_words {
78        let base = i * 64;
79        let lo = _mm256_loadu_si256(bits.as_ptr().add(base) as *const __m256i);
80        let hi = _mm256_loadu_si256(bits.as_ptr().add(base + 32) as *const __m256i);
81        let lo_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(lo, zero)) as u32);
82        let hi_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(hi, zero)) as u32);
83        data[i] = ((hi_mask as u64) << 32) | (lo_mask as u64);
84    }
85
86    if full_words < words {
87        let tail_start = full_words * 64;
88        let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
89        data[full_words] = tail.data.first().copied().unwrap_or(0);
90    }
91
92    data
93}
94
95#[cfg(target_arch = "x86_64")]
96#[target_feature(enable = "avx2")]
97/// Fused AND+popcount over packed words using AVX2 for the AND stage.
98///
99/// # Safety
100/// Caller must ensure the current CPU supports `avx2`.
101pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
102    let len = a.len().min(b.len());
103    let mut total = 0_u64;
104    let mut chunks_a = a[..len].chunks_exact(16);
105    let mut chunks_b = b[..len].chunks_exact(16);
106
107    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
108        let va0 = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
109        let vb0 = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
110        let va1 = _mm256_loadu_si256(ca.as_ptr().add(4) as *const __m256i);
111        let vb1 = _mm256_loadu_si256(cb.as_ptr().add(4) as *const __m256i);
112        let va2 = _mm256_loadu_si256(ca.as_ptr().add(8) as *const __m256i);
113        let vb2 = _mm256_loadu_si256(cb.as_ptr().add(8) as *const __m256i);
114        let va3 = _mm256_loadu_si256(ca.as_ptr().add(12) as *const __m256i);
115        let vb3 = _mm256_loadu_si256(cb.as_ptr().add(12) as *const __m256i);
116
117        let and0 = _mm256_and_si256(va0, vb0);
118        let and1 = _mm256_and_si256(va1, vb1);
119        let and2 = _mm256_and_si256(va2, vb2);
120        let and3 = _mm256_and_si256(va3, vb3);
121
122        let mut lanes = [0_u64; 16];
123        _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, and0);
124        _mm256_storeu_si256(lanes.as_mut_ptr().add(4) as *mut __m256i, and1);
125        _mm256_storeu_si256(lanes.as_mut_ptr().add(8) as *mut __m256i, and2);
126        _mm256_storeu_si256(lanes.as_mut_ptr().add(12) as *mut __m256i, and3);
127
128        for &w in lanes.iter() {
129            total += w.count_ones() as u64;
130        }
131    }
132
133    total
134        + chunks_a
135            .remainder()
136            .iter()
137            .zip(chunks_b.remainder().iter())
138            .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
139            .sum::<u64>()
140}
141
142#[cfg(target_arch = "x86_64")]
143#[target_feature(enable = "avx2")]
144/// Fused XOR+popcount over packed words using AVX2 for the XOR stage.
145///
146/// # Safety
147/// Caller must ensure the current CPU supports `avx2`.
148pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
149    let len = a.len().min(b.len());
150    let mut total = 0_u64;
151    let mut chunks_a = a[..len].chunks_exact(16);
152    let mut chunks_b = b[..len].chunks_exact(16);
153
154    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
155        let va0 = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
156        let vb0 = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
157        let va1 = _mm256_loadu_si256(ca.as_ptr().add(4) as *const __m256i);
158        let vb1 = _mm256_loadu_si256(cb.as_ptr().add(4) as *const __m256i);
159        let va2 = _mm256_loadu_si256(ca.as_ptr().add(8) as *const __m256i);
160        let vb2 = _mm256_loadu_si256(cb.as_ptr().add(8) as *const __m256i);
161        let va3 = _mm256_loadu_si256(ca.as_ptr().add(12) as *const __m256i);
162        let vb3 = _mm256_loadu_si256(cb.as_ptr().add(12) as *const __m256i);
163
164        let xor0 = _mm256_xor_si256(va0, vb0);
165        let xor1 = _mm256_xor_si256(va1, vb1);
166        let xor2 = _mm256_xor_si256(va2, vb2);
167        let xor3 = _mm256_xor_si256(va3, vb3);
168
169        let mut lanes = [0_u64; 16];
170        _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, xor0);
171        _mm256_storeu_si256(lanes.as_mut_ptr().add(4) as *mut __m256i, xor1);
172        _mm256_storeu_si256(lanes.as_mut_ptr().add(8) as *mut __m256i, xor2);
173        _mm256_storeu_si256(lanes.as_mut_ptr().add(12) as *mut __m256i, xor3);
174
175        for &w in lanes.iter() {
176            total += w.count_ones() as u64;
177        }
178    }
179
180    total
181        + chunks_a
182            .remainder()
183            .iter()
184            .zip(chunks_b.remainder().iter())
185            .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
186            .sum::<u64>()
187}
188
189#[cfg(not(target_arch = "x86_64"))]
190/// Fallback fused XOR+popcount when AVX2 is unavailable on this architecture.
191///
192/// # Safety
193/// This function is marked unsafe for API parity with the AVX2 variant.
194pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
195    a.iter()
196        .zip(b.iter())
197        .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
198        .sum()
199}
200
201#[cfg(target_arch = "x86_64")]
202#[target_feature(enable = "avx2")]
203/// Compare 32 random bytes against an unsigned threshold and return bit mask.
204///
205/// Bit `i` in the returned mask is 1 iff `buf[i] < threshold`.
206///
207/// # Safety
208/// Caller must ensure the current CPU supports `avx2`.
209/// `buf` must have at least 32 elements.
210pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
211    assert!(buf.len() >= 32, "buffer must contain at least 32 bytes");
212
213    let data = _mm256_loadu_si256(buf.as_ptr() as *const __m256i);
214    let bias = _mm256_set1_epi8(i8::MIN);
215    let data_biased = _mm256_xor_si256(data, bias);
216    let thresh_biased = _mm256_set1_epi8((threshold ^ 0x80) as i8);
217    let lt = _mm256_cmpgt_epi8(thresh_biased, data_biased);
218    _mm256_movemask_epi8(lt) as u32
219}
220
221#[cfg(not(target_arch = "x86_64"))]
222/// Fallback popcount when AVX2 is unavailable on this architecture.
223///
224/// # Safety
225/// This function is marked unsafe for API parity with the AVX2 variant.
226pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
227    crate::bitstream::popcount_words_portable(data)
228}
229
230#[cfg(not(target_arch = "x86_64"))]
231/// Fallback pack when AVX2 is unavailable on this architecture.
232///
233/// # Safety
234/// This function is marked unsafe for API parity with the AVX2 variant.
235pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
236    crate::bitstream::pack_fast(bits).data
237}
238
239#[cfg(not(target_arch = "x86_64"))]
240/// Fallback fused AND+popcount when AVX2 is unavailable on this architecture.
241///
242/// # Safety
243/// This function is marked unsafe for API parity with the AVX2 variant.
244pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
245    a.iter()
246        .zip(b.iter())
247        .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
248        .sum()
249}
250
251#[cfg(not(target_arch = "x86_64"))]
252/// Fallback Bernoulli compare when AVX2 is unavailable on this architecture.
253///
254/// # Safety
255/// This function is marked unsafe for API parity with the AVX2 variant.
256pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
257    let mut mask = 0_u32;
258    for (bit, &rb) in buf.iter().take(32).enumerate() {
259        if rb < threshold {
260            mask |= 1_u32 << bit;
261        }
262    }
263    mask
264}
265
266// --- f64 SIMD operations (AVX2: 4-wide f64) ---
267
268#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "avx2,fma")]
270/// Dot product of two f64 slices using AVX2 FMA.
271///
272/// # Safety
273/// Caller must ensure the current CPU supports `avx2` and `fma`.
274pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
275    let len = a.len().min(b.len());
276    let mut acc = _mm256_setzero_pd();
277    let mut chunks_a = a[..len].chunks_exact(4);
278    let mut chunks_b = b[..len].chunks_exact(4);
279
280    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
281        let va = _mm256_loadu_pd(ca.as_ptr());
282        let vb = _mm256_loadu_pd(cb.as_ptr());
283        acc = _mm256_fmadd_pd(va, vb, acc);
284    }
285
286    let mut lanes = [0.0_f64; 4];
287    _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
288    let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
289
290    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
291        sum += ra * rb;
292    }
293    sum
294}
295
296#[cfg(target_arch = "x86_64")]
297#[target_feature(enable = "avx2")]
298/// Maximum of f64 slice using AVX2.
299///
300/// # Safety
301/// Caller must ensure the current CPU supports `avx2`.
302pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
303    if a.is_empty() {
304        return f64::NEG_INFINITY;
305    }
306    let mut vmax = _mm256_set1_pd(f64::NEG_INFINITY);
307    let mut chunks = a.chunks_exact(4);
308
309    for chunk in chunks.by_ref() {
310        let va = _mm256_loadu_pd(chunk.as_ptr());
311        vmax = _mm256_max_pd(vmax, va);
312    }
313
314    let mut lanes = [0.0_f64; 4];
315    _mm256_storeu_pd(lanes.as_mut_ptr(), vmax);
316    let mut m = lanes[0].max(lanes[1]).max(lanes[2].max(lanes[3]));
317    for &v in chunks.remainder() {
318        m = m.max(v);
319    }
320    m
321}
322
323#[cfg(target_arch = "x86_64")]
324#[target_feature(enable = "avx2")]
325/// Sum of f64 slice using AVX2.
326///
327/// # Safety
328/// Caller must ensure the current CPU supports `avx2`.
329pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
330    let mut acc = _mm256_setzero_pd();
331    let mut chunks = a.chunks_exact(4);
332
333    for chunk in chunks.by_ref() {
334        let va = _mm256_loadu_pd(chunk.as_ptr());
335        acc = _mm256_add_pd(acc, va);
336    }
337
338    let mut lanes = [0.0_f64; 4];
339    _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
340    let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
341    for &v in chunks.remainder() {
342        sum += v;
343    }
344    sum
345}
346
347#[cfg(target_arch = "x86_64")]
348#[target_feature(enable = "avx2")]
349/// Scale f64 slice in-place: y[i] *= alpha, using AVX2.
350///
351/// # Safety
352/// Caller must ensure the current CPU supports `avx2`.
353pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
354    let valpha = _mm256_set1_pd(alpha);
355    let mut chunks = y.chunks_exact_mut(16);
356
357    for chunk in chunks.by_ref() {
358        let v0 = _mm256_loadu_pd(chunk.as_ptr());
359        let v1 = _mm256_loadu_pd(chunk.as_ptr().add(4));
360        let v2 = _mm256_loadu_pd(chunk.as_ptr().add(8));
361        let v3 = _mm256_loadu_pd(chunk.as_ptr().add(12));
362
363        _mm256_storeu_pd(chunk.as_mut_ptr(), _mm256_mul_pd(v0, valpha));
364        _mm256_storeu_pd(chunk.as_mut_ptr().add(4), _mm256_mul_pd(v1, valpha));
365        _mm256_storeu_pd(chunk.as_mut_ptr().add(8), _mm256_mul_pd(v2, valpha));
366        _mm256_storeu_pd(chunk.as_mut_ptr().add(12), _mm256_mul_pd(v3, valpha));
367    }
368
369    for v in chunks.into_remainder() {
370        *v *= alpha;
371    }
372}
373
374#[cfg(not(target_arch = "x86_64"))]
375pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
376    let len = a.len().min(b.len());
377    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
378}
379
380#[cfg(not(target_arch = "x86_64"))]
381pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
382    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
383}
384
385#[cfg(not(target_arch = "x86_64"))]
386pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
387    a.iter().sum()
388}
389
390#[cfg(not(target_arch = "x86_64"))]
391pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
392    for v in y.iter_mut() {
393        *v *= alpha;
394    }
395}
396
397/// Hamming distance between two packed bitstream slices using AVX2.
398///
399/// # Safety
400/// Caller must ensure the current CPU supports `avx2`.
401pub unsafe fn hamming_distance_avx2(a: &[u64], b: &[u64]) -> u64 {
402    fused_xor_popcount_avx2(a, b)
403}
404
405#[cfg(target_arch = "x86_64")]
406#[target_feature(enable = "avx2")]
407/// In-place softmax using AVX2 for max, sum, and scale steps.
408///
409/// # Safety
410/// Caller must ensure the current CPU supports `avx2`.
411pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
412    if scores.is_empty() {
413        return;
414    }
415    let max_val = max_f64_avx2(scores);
416    // let v_max = _mm256_set1_pd(max_val);
417
418    let mut chunks = scores.chunks_exact_mut(16);
419    for chunk in chunks.by_ref() {
420        // We still have to use scalar exp() because AVX2 does not have it in core::arch
421        // but we can unroll the subtractions and stores.
422        for i in 0..16 {
423            chunk[i] = (chunk[i] - max_val).exp();
424        }
425    }
426    for s in chunks.into_remainder() {
427        *s = (*s - max_val).exp();
428    }
429
430    let exp_sum = sum_f64_avx2(scores);
431    if exp_sum > 0.0 {
432        scale_f64_avx2(1.0 / exp_sum, scores);
433    }
434}
435
436#[cfg(not(target_arch = "x86_64"))]
437pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
438    if scores.is_empty() {
439        return;
440    }
441    let max_val = max_f64_avx2(scores);
442    // let v_max = _mm256_set1_pd(max_val);
443
444    let mut chunks = scores.chunks_exact_mut(16);
445    for chunk in chunks.by_ref() {
446        // We still have to use scalar exp() because AVX2 does not have it in core::arch
447        // but we can unroll the subtractions and stores.
448        for i in 0..16 {
449            chunk[i] = (chunk[i] - max_val).exp();
450        }
451    }
452    for s in chunks.into_remainder() {
453        *s = (*s - max_val).exp();
454    }
455
456    let exp_sum = sum_f64_avx2(scores);
457    if exp_sum > 0.0 {
458        scale_f64_avx2(1.0 / exp_sum, scores);
459    }
460}
461
462#[cfg(target_arch = "x86_64")]
463#[target_feature(enable = "avx")]
464/// Dot product of two f64 slices using AVX (no FMA).
465///
466/// # Safety
467/// Caller must ensure the current CPU supports `avx`.
468pub unsafe fn dot_f64_avx(a: &[f64], b: &[f64]) -> f64 {
469    let len = a.len().min(b.len());
470    let mut acc0 = _mm256_setzero_pd();
471    let mut acc1 = _mm256_setzero_pd();
472    let mut acc2 = _mm256_setzero_pd();
473    let mut acc3 = _mm256_setzero_pd();
474
475    let mut chunks_a = a[..len].chunks_exact(16);
476    let mut chunks_b = b[..len].chunks_exact(16);
477
478    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
479        let va0 = _mm256_loadu_pd(ca.as_ptr());
480        let vb0 = _mm256_loadu_pd(cb.as_ptr());
481        acc0 = _mm256_add_pd(acc0, _mm256_mul_pd(va0, vb0));
482
483        let va1 = _mm256_loadu_pd(ca.as_ptr().add(4));
484        let vb1 = _mm256_loadu_pd(cb.as_ptr().add(4));
485        acc1 = _mm256_add_pd(acc1, _mm256_mul_pd(va1, vb1));
486
487        let va2 = _mm256_loadu_pd(ca.as_ptr().add(8));
488        let vb2 = _mm256_loadu_pd(cb.as_ptr().add(8));
489        acc2 = _mm256_add_pd(acc2, _mm256_mul_pd(va2, vb2));
490
491        let va3 = _mm256_loadu_pd(ca.as_ptr().add(12));
492        let vb3 = _mm256_loadu_pd(cb.as_ptr().add(12));
493        acc3 = _mm256_add_pd(acc3, _mm256_mul_pd(va3, vb3));
494    }
495
496    acc0 = _mm256_add_pd(acc0, acc1);
497    acc2 = _mm256_add_pd(acc2, acc3);
498    acc0 = _mm256_add_pd(acc0, acc2);
499
500    let mut lanes = [0.0_f64; 4];
501    _mm256_storeu_pd(lanes.as_mut_ptr(), acc0);
502    let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
503
504    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
505        sum += ra * rb;
506    }
507    sum
508}
509
510#[cfg(target_arch = "x86_64")]
511#[target_feature(enable = "avx")]
512/// Sum of f64 slice using AVX.
513///
514/// # Safety
515/// Caller must ensure AVX is available on the current CPU.
516pub unsafe fn sum_f64_avx(a: &[f64]) -> f64 {
517    let mut acc0 = _mm256_setzero_pd();
518    let mut acc1 = _mm256_setzero_pd();
519    let mut acc2 = _mm256_setzero_pd();
520    let mut acc3 = _mm256_setzero_pd();
521
522    let mut chunks = a.chunks_exact(16);
523    for chunk in chunks.by_ref() {
524        acc0 = _mm256_add_pd(acc0, _mm256_loadu_pd(chunk.as_ptr()));
525        acc1 = _mm256_add_pd(acc1, _mm256_loadu_pd(chunk.as_ptr().add(4)));
526        acc2 = _mm256_add_pd(acc2, _mm256_loadu_pd(chunk.as_ptr().add(8)));
527        acc3 = _mm256_add_pd(acc3, _mm256_loadu_pd(chunk.as_ptr().add(12)));
528    }
529
530    acc0 = _mm256_add_pd(acc0, acc1);
531    acc2 = _mm256_add_pd(acc2, acc3);
532    acc0 = _mm256_add_pd(acc0, acc2);
533
534    let mut lanes = [0.0_f64; 4];
535    _mm256_storeu_pd(lanes.as_mut_ptr(), acc0);
536    let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
537    for &v in chunks.remainder() {
538        sum += v;
539    }
540    sum
541}
542
543#[cfg(target_arch = "x86_64")]
544#[target_feature(enable = "avx2")]
545/// Compare 1024 random bytes against a threshold and return 16 u64 words.
546///
547/// # Safety
548/// Caller must ensure AVX2 is available on the current CPU.
549pub unsafe fn bernoulli_compare_batch_avx2(buf: &[u8], threshold: u8, out: &mut [u64]) {
550    let v_thresh = _mm256_set1_epi8(threshold as i8);
551    // Note: epi8 comparison is signed. Using the xor 0x80 trick for unsigned.
552    let bias = _mm256_set1_epi8(i8::MIN);
553    let v_thresh_biased = _mm256_xor_si256(v_thresh, bias);
554
555    for i in 0..16 {
556        // Each loop iteration processes 64 bytes (2x 256-bit registers)
557        let chunk = &buf[i * 64..(i + 1) * 64];
558        let v0 = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
559        let v1 = _mm256_loadu_si256(chunk.as_ptr().add(32) as *const __m256i);
560
561        let v0_biased = _mm256_xor_si256(v0, bias);
562        let v1_biased = _mm256_xor_si256(v1, bias);
563
564        let m0 = _mm256_cmpgt_epi8(v_thresh_biased, v0_biased);
565        let m1 = _mm256_cmpgt_epi8(v_thresh_biased, v1_biased);
566
567        let mask0 = _mm256_movemask_epi8(m0) as u32;
568        let mask1 = _mm256_movemask_epi8(m1) as u32;
569        out[i] = (mask0 as u64) | ((mask1 as u64) << 32);
570    }
571}
572
573#[cfg(target_arch = "x86_64")]
574#[target_feature(enable = "avx")]
575/// Maximum of f64 slice using AVX (v1).
576///
577/// # Safety
578/// Caller must ensure AVX is available on the current CPU.
579pub unsafe fn max_f64_avx(a: &[f64]) -> f64 {
580    if a.is_empty() {
581        return f64::NEG_INFINITY;
582    }
583    let mut max_vec0 = _mm256_set1_pd(f64::NEG_INFINITY);
584    let mut max_vec1 = _mm256_set1_pd(f64::NEG_INFINITY);
585    let mut max_vec2 = _mm256_set1_pd(f64::NEG_INFINITY);
586    let mut max_vec3 = _mm256_set1_pd(f64::NEG_INFINITY);
587
588    let mut chunks = a.chunks_exact(16);
589    for chunk in chunks.by_ref() {
590        max_vec0 = _mm256_max_pd(max_vec0, _mm256_loadu_pd(chunk.as_ptr()));
591        max_vec1 = _mm256_max_pd(max_vec1, _mm256_loadu_pd(chunk.as_ptr().add(4)));
592        max_vec2 = _mm256_max_pd(max_vec2, _mm256_loadu_pd(chunk.as_ptr().add(8)));
593        max_vec3 = _mm256_max_pd(max_vec3, _mm256_loadu_pd(chunk.as_ptr().add(12)));
594    }
595
596    max_vec0 = _mm256_max_pd(max_vec0, max_vec1);
597    max_vec2 = _mm256_max_pd(max_vec2, max_vec3);
598    max_vec0 = _mm256_max_pd(max_vec0, max_vec2);
599
600    let mut lanes = [0.0_f64; 4];
601    _mm256_storeu_pd(lanes.as_mut_ptr(), max_vec0);
602    let mut m = lanes[0].max(lanes[1]).max(lanes[2].max(lanes[3]));
603    for &v in chunks.remainder() {
604        m = m.max(v);
605    }
606    m
607}
608
609#[cfg(target_arch = "x86_64")]
610#[target_feature(enable = "avx")]
611/// Scale f64 slice using AVX (v1).
612///
613/// # Safety
614/// Caller must ensure AVX is available on the current CPU.
615pub unsafe fn scale_f64_avx(alpha: f64, y: &mut [f64]) {
616    let valpha = _mm256_set1_pd(alpha);
617    let mut chunks = y.chunks_exact_mut(16);
618
619    for chunk in chunks.by_ref() {
620        let v0 = _mm256_loadu_pd(chunk.as_ptr());
621        let v1 = _mm256_loadu_pd(chunk.as_ptr().add(4));
622        let v2 = _mm256_loadu_pd(chunk.as_ptr().add(8));
623        let v3 = _mm256_loadu_pd(chunk.as_ptr().add(12));
624
625        _mm256_storeu_pd(chunk.as_mut_ptr(), _mm256_mul_pd(v0, valpha));
626        _mm256_storeu_pd(chunk.as_mut_ptr().add(4), _mm256_mul_pd(v1, valpha));
627        _mm256_storeu_pd(chunk.as_mut_ptr().add(8), _mm256_mul_pd(v2, valpha));
628        _mm256_storeu_pd(chunk.as_mut_ptr().add(12), _mm256_mul_pd(v3, valpha));
629    }
630
631    for v in chunks.into_remainder() {
632        *v *= alpha;
633    }
634}
635
636#[cfg(all(test, target_arch = "x86_64"))]
637mod tests {
638    use crate::bitstream::pack;
639
640    #[test]
641    fn pack_avx2_matches_pack() {
642        if !is_x86_feature_detected!("avx2") {
643            return;
644        }
645
646        let lengths = [
647            1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
648        ];
649        for length in lengths {
650            let bits: Vec<u8> = (0..length)
651                .map(|i| if (i * 17 + 5) % 3 == 0 { 1 } else { 0 })
652                .collect();
653            // SAFETY: Runtime-guarded by feature detection in this test.
654            let got = unsafe { super::pack_avx2(&bits) };
655            let expected = pack(&bits).data;
656            assert_eq!(got, expected, "Mismatch at length={length}");
657        }
658    }
659
660    #[test]
661    fn fused_and_popcount_avx2_matches_scalar() {
662        if !is_x86_feature_detected!("avx2") {
663            return;
664        }
665
666        let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
667        for len in lengths {
668            let a: Vec<u64> = (0..len)
669                .map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
670                .collect();
671            let b: Vec<u64> = (0..len)
672                .map(|i| (i as u64).wrapping_mul(0xC2B2_AE3D_27D4_EB4F) ^ 0x0F0F_F0F0_33CC_CC33)
673                .collect();
674
675            let expected: u64 = a
676                .iter()
677                .zip(b.iter())
678                .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
679                .sum();
680
681            // SAFETY: Runtime-guarded by feature detection in this test.
682            let got = unsafe { super::fused_and_popcount_avx2(&a, &b) };
683            assert_eq!(got, expected, "Mismatch at len={len}");
684        }
685    }
686
687    #[test]
688    fn dot_f64_avx2_matches_scalar() {
689        if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
690            return;
691        }
692        let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.1).collect();
693        let b: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 5.0).collect();
694        let expected: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
695        let got = unsafe { super::dot_f64_avx2(&a, &b) };
696        assert!(
697            (got - expected).abs() < 1e-9,
698            "dot: got {got}, expected {expected}"
699        );
700    }
701
702    #[test]
703    fn max_f64_avx2_matches_scalar() {
704        if !is_x86_feature_detected!("avx2") {
705            return;
706        }
707        let a: Vec<f64> = (0..67).map(|i| (i as f64 * 7.3).sin()).collect();
708        let expected = a.iter().copied().fold(f64::NEG_INFINITY, f64::max);
709        let got = unsafe { super::max_f64_avx2(&a) };
710        assert!(
711            (got - expected).abs() < 1e-12,
712            "max: got {got}, expected {expected}"
713        );
714    }
715
716    #[test]
717    fn sum_f64_avx2_matches_scalar() {
718        if !is_x86_feature_detected!("avx2") {
719            return;
720        }
721        let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.01).collect();
722        let expected: f64 = a.iter().sum();
723        let got = unsafe { super::sum_f64_avx2(&a) };
724        assert!(
725            (got - expected).abs() < 1e-9,
726            "sum: got {got}, expected {expected}"
727        );
728    }
729
730    #[test]
731    fn softmax_avx2_sums_to_one() {
732        if !is_x86_feature_detected!("avx2") {
733            return;
734        }
735        let mut scores: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 10.0).collect();
736        unsafe { super::softmax_inplace_f64_avx2(&mut scores) };
737        let sum: f64 = scores.iter().sum();
738        assert!(
739            (sum - 1.0).abs() < 1e-10,
740            "softmax must sum to 1.0, got {sum}"
741        );
742        assert!(scores.iter().all(|&s| s >= 0.0), "all values must be >= 0");
743    }
744
745    #[test]
746    fn bernoulli_compare_avx2_matches_scalar() {
747        if !is_x86_feature_detected!("avx2") {
748            return;
749        }
750
751        let buf: Vec<u8> = (0..32).map(|i| (i * 73 + 17) as u8).collect();
752        let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
753
754        for threshold in thresholds {
755            let expected = buf.iter().enumerate().fold(0_u32, |acc, (bit, &rb)| {
756                acc | (u32::from(rb < threshold) << bit)
757            });
758
759            // SAFETY: Runtime-guarded by feature detection in this test.
760            let got = unsafe { super::bernoulli_compare_avx2(&buf, threshold) };
761            assert_eq!(
762                got, expected,
763                "Mismatch for threshold={threshold} buf={buf:?}"
764            );
765        }
766    }
767
768    #[test]
769    fn dot_f64_avx_matches_scalar() {
770        if !is_x86_feature_detected!("avx") {
771            return;
772        }
773        let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.1).collect();
774        let b: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 5.0).collect();
775        let expected: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
776        let got = unsafe { super::dot_f64_avx(&a, &b) };
777        assert!(
778            (got - expected).abs() < 1e-9,
779            "dot_avx: got {got}, expected {expected}"
780        );
781    }
782}