Skip to main content

sc_neurocore_engine/simd/
avx512.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — AVX512
8
9#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64::*;
11
12#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx512f,avx512vpopcntdq")]
14/// Count set bits in 64-bit words using AVX-512 VPOPCNTDQ.
15///
16/// # Safety
17/// Caller must ensure the current CPU supports `avx512f` and `avx512vpopcntdq`.
18pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
19    let mut total = 0_u64;
20    let mut chunks = data.chunks_exact(16);
21
22    for chunk in chunks.by_ref() {
23        let v0 = _mm512_loadu_si512(chunk.as_ptr() as *const __m512i);
24        let v1 = _mm512_loadu_si512(chunk.as_ptr().add(8) as *const __m512i);
25
26        total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(v0)) as u64;
27        total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(v1)) as u64;
28    }
29
30    total + crate::bitstream::popcount_words_portable(chunks.remainder())
31}
32
33#[cfg(target_arch = "x86_64")]
34#[target_feature(enable = "avx512f,avx512bw")]
35/// Pack u8 bits into u64 words using AVX-512 k-mask compare.
36///
37/// Processes 64 bytes per iteration where each compare result bit maps
38/// directly to one packed output bit.
39///
40/// # Safety
41/// Caller must ensure the current CPU supports `avx512f` and `avx512bw`.
42pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
43    let length = bits.len();
44    let words = length.div_ceil(64);
45    let mut data = vec![0_u64; words];
46    let full_words = length / 64;
47    let zero = _mm512_setzero_si512();
48
49    let mut chunks = data[..full_words].chunks_exact_mut(4);
50    let mut word_idx = 0;
51    for chunk in chunks.by_ref() {
52        let base = word_idx * 64;
53        for i in 0..4 {
54            let v = _mm512_loadu_si512(bits.as_ptr().add(base + i * 64) as *const __m512i);
55            chunk[i] = _mm512_cmpneq_epi8_mask(v, zero);
56        }
57        word_idx += 4;
58    }
59
60    for i in word_idx..full_words {
61        let v = _mm512_loadu_si512(bits.as_ptr().add(i * 64) as *const __m512i);
62        data[i] = _mm512_cmpneq_epi8_mask(v, zero);
63    }
64
65    if full_words < words {
66        let tail_start = full_words * 64;
67        let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
68        data[full_words] = tail.data.first().copied().unwrap_or(0);
69    }
70
71    data
72}
73
74#[cfg(target_arch = "x86_64")]
75#[target_feature(enable = "avx512f,avx512vpopcntdq")]
76/// Fused AND+popcount over packed words using AVX-512 VPOPCNTDQ.
77///
78/// # Safety
79/// Caller must ensure the current CPU supports `avx512f` and `avx512vpopcntdq`.
80pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
81    let len = a.len().min(b.len());
82    let mut total = 0_u64;
83    let mut chunks_a = a[..len].chunks_exact(16);
84    let mut chunks_b = b[..len].chunks_exact(16);
85
86    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
87        let va0 = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
88        let vb0 = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
89        let va1 = _mm512_loadu_si512(ca.as_ptr().add(8) as *const __m512i);
90        let vb1 = _mm512_loadu_si512(cb.as_ptr().add(8) as *const __m512i);
91
92        let and0 = _mm512_and_si512(va0, vb0);
93        let and1 = _mm512_and_si512(va1, vb1);
94
95        total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(and0)) as u64;
96        total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(and1)) as u64;
97    }
98
99    total
100        + chunks_a
101            .remainder()
102            .iter()
103            .zip(chunks_b.remainder().iter())
104            .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
105            .sum::<u64>()
106}
107
108#[cfg(target_arch = "x86_64")]
109#[target_feature(enable = "avx512f,avx512vpopcntdq")]
110/// Fused XOR+popcount over packed words using AVX-512 VPOPCNTDQ.
111///
112/// # Safety
113/// Caller must ensure the current CPU supports `avx512f` and `avx512vpopcntdq`.
114pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
115    let len = a.len().min(b.len());
116    let mut total = 0_u64;
117    let mut chunks_a = a[..len].chunks_exact(16);
118    let mut chunks_b = b[..len].chunks_exact(16);
119
120    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
121        let va0 = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
122        let vb0 = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
123        let va1 = _mm512_loadu_si512(ca.as_ptr().add(8) as *const __m512i);
124        let vb1 = _mm512_loadu_si512(cb.as_ptr().add(8) as *const __m512i);
125
126        let xor0 = _mm512_xor_si512(va0, vb0);
127        let xor1 = _mm512_xor_si512(va1, vb1);
128
129        total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(xor0)) as u64;
130        total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(xor1)) as u64;
131    }
132
133    total
134        + chunks_a
135            .remainder()
136            .iter()
137            .zip(chunks_b.remainder().iter())
138            .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
139            .sum::<u64>()
140}
141
142#[cfg(not(target_arch = "x86_64"))]
143/// Fallback fused XOR+popcount when AVX-512 is unavailable on this architecture.
144///
145/// # Safety
146/// This function is marked unsafe for API parity with the AVX-512 variant.
147pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
148    a.iter()
149        .zip(b.iter())
150        .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
151        .sum()
152}
153
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx512f,avx512bw")]
156/// Compare 64 random bytes against an unsigned threshold and return bit mask.
157///
158/// Bit `i` in the returned mask is 1 iff `buf[i] < threshold`.
159///
160/// # Safety
161/// Caller must ensure the current CPU supports `avx512f` and `avx512bw`.
162/// `buf` must have at least 64 elements.
163pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
164    assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
165    let data = _mm512_loadu_si512(buf.as_ptr() as *const __m512i);
166    let thresh = _mm512_set1_epi8(threshold as i8);
167    _mm512_cmplt_epu8_mask(data, thresh)
168}
169
170#[cfg(not(target_arch = "x86_64"))]
171/// Fallback popcount when AVX-512 is unavailable on this architecture.
172///
173/// # Safety
174/// This function is marked unsafe for API parity with the AVX-512 variant.
175pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
176    crate::bitstream::popcount_words_portable(data)
177}
178
179#[cfg(not(target_arch = "x86_64"))]
180/// Fallback pack when AVX-512 is unavailable on this architecture.
181///
182/// # Safety
183/// This function is marked unsafe for API parity with the AVX-512 variant.
184pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
185    crate::bitstream::pack_fast(bits).data
186}
187
188#[cfg(not(target_arch = "x86_64"))]
189/// Fallback fused AND+popcount when AVX-512 is unavailable on this architecture.
190///
191/// # Safety
192/// This function is marked unsafe for API parity with the AVX-512 variant.
193pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
194    a.iter()
195        .zip(b.iter())
196        .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
197        .sum()
198}
199
200#[cfg(not(target_arch = "x86_64"))]
201/// Fallback Bernoulli compare when AVX-512 is unavailable on this architecture.
202///
203/// # Safety
204/// This function is marked unsafe for API parity with the AVX-512 variant.
205pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
206    let mut mask = 0_u64;
207    for (bit, &rb) in buf.iter().take(64).enumerate() {
208        if rb < threshold {
209            mask |= 1_u64 << bit;
210        }
211    }
212    mask
213}
214
215// --- f64 SIMD operations (AVX-512: 8-wide f64) ---
216
217#[cfg(target_arch = "x86_64")]
218#[target_feature(enable = "avx512f")]
219/// Dot product of two f64 slices using AVX-512.
220///
221/// # Safety
222/// Caller must ensure the current CPU supports `avx512f`.
223pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
224    let len = a.len().min(b.len());
225    let mut acc = _mm512_setzero_pd();
226    let mut chunks_a = a[..len].chunks_exact(8);
227    let mut chunks_b = b[..len].chunks_exact(8);
228
229    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
230        let va = _mm512_loadu_pd(ca.as_ptr());
231        let vb = _mm512_loadu_pd(cb.as_ptr());
232        acc = _mm512_fmadd_pd(va, vb, acc);
233    }
234
235    let mut sum = _mm512_reduce_add_pd(acc);
236    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
237        sum += ra * rb;
238    }
239    sum
240}
241
242#[cfg(target_arch = "x86_64")]
243#[target_feature(enable = "avx512f")]
244/// Maximum of f64 slice using AVX-512.
245///
246/// # Safety
247/// Caller must ensure the current CPU supports `avx512f`.
248pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
249    if a.is_empty() {
250        return f64::NEG_INFINITY;
251    }
252    let mut vmax0 = _mm512_set1_pd(f64::NEG_INFINITY);
253    let mut vmax1 = _mm512_set1_pd(f64::NEG_INFINITY);
254    let mut chunks = a.chunks_exact(16);
255
256    for chunk in chunks.by_ref() {
257        vmax0 = _mm512_max_pd(vmax0, _mm512_loadu_pd(chunk.as_ptr()));
258        vmax1 = _mm512_max_pd(vmax1, _mm512_loadu_pd(chunk.as_ptr().add(8)));
259    }
260
261    let mut m = _mm512_reduce_max_pd(_mm512_max_pd(vmax0, vmax1));
262    for &v in chunks.remainder() {
263        m = m.max(v);
264    }
265    m
266}
267
268#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "avx512f")]
270/// Sum of f64 slice using AVX-512.
271///
272/// # Safety
273/// Caller must ensure the current CPU supports `avx512f`.
274pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
275    let mut acc0 = _mm512_setzero_pd();
276    let mut acc1 = _mm512_setzero_pd();
277    let mut chunks = a.chunks_exact(16);
278
279    for chunk in chunks.by_ref() {
280        acc0 = _mm512_add_pd(acc0, _mm512_loadu_pd(chunk.as_ptr()));
281        acc1 = _mm512_add_pd(acc1, _mm512_loadu_pd(chunk.as_ptr().add(8)));
282    }
283
284    let mut sum = _mm512_reduce_add_pd(_mm512_add_pd(acc0, acc1));
285    for &v in chunks.remainder() {
286        sum += v;
287    }
288    sum
289}
290
291#[cfg(target_arch = "x86_64")]
292#[target_feature(enable = "avx512f")]
293/// Scale f64 slice in-place: y[i] *= alpha, using AVX-512.
294///
295/// # Safety
296/// Caller must ensure the current CPU supports `avx512f`.
297pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
298    let valpha = _mm512_set1_pd(alpha);
299    let mut chunks = y.chunks_exact_mut(16);
300
301    for chunk in chunks.by_ref() {
302        let v0 = _mm512_loadu_pd(chunk.as_ptr());
303        let v1 = _mm512_loadu_pd(chunk.as_ptr().add(8));
304        _mm512_storeu_pd(chunk.as_mut_ptr(), _mm512_mul_pd(v0, valpha));
305        _mm512_storeu_pd(chunk.as_mut_ptr().add(8), _mm512_mul_pd(v1, valpha));
306    }
307
308    for v in chunks.into_remainder() {
309        *v *= alpha;
310    }
311}
312
313#[cfg(not(target_arch = "x86_64"))]
314pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
315    let len = a.len().min(b.len());
316    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
317}
318
319#[cfg(not(target_arch = "x86_64"))]
320pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
321    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
322}
323
324#[cfg(not(target_arch = "x86_64"))]
325pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
326    a.iter().sum()
327}
328
329#[cfg(not(target_arch = "x86_64"))]
330pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
331    for v in y.iter_mut() {
332        *v *= alpha;
333    }
334}
335
336#[cfg(target_arch = "x86_64")]
337#[target_feature(enable = "avx512bw")]
338/// Compare 1024 random bytes against a threshold and return 16 u64 words.
339///
340/// # Safety
341/// Caller must ensure AVX-512 BW is available on the current CPU.
342pub unsafe fn bernoulli_compare_batch_avx512(buf: &[u8], threshold: u8, out: &mut [u64]) {
343    let v_thresh = _mm512_set1_epi8(threshold as i8);
344    for i in 0..16 {
345        let chunk = &buf[i * 64..(i + 1) * 64];
346        let v = _mm512_loadu_si512(chunk.as_ptr() as *const _);
347        // AVX-512 has direct unsigned comparison
348        out[i] = _mm512_cmplt_epu8_mask(v, v_thresh);
349    }
350}
351
352#[cfg(all(test, target_arch = "x86_64"))]
353mod tests {
354    use crate::bitstream::pack;
355
356    #[test]
357    fn pack_avx512_matches_pack() {
358        if !is_x86_feature_detected!("avx512bw") {
359            return;
360        }
361
362        let lengths = [
363            1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
364        ];
365        for length in lengths {
366            let bits: Vec<u8> = (0..length)
367                .map(|i| if (i * 19 + 11) % 4 == 0 { 1 } else { 0 })
368                .collect();
369            // SAFETY: Runtime-guarded by feature detection in this test.
370            let got = unsafe { super::pack_avx512(&bits) };
371            let expected = pack(&bits).data;
372            assert_eq!(got, expected, "Mismatch at length={length}");
373        }
374    }
375
376    #[test]
377    fn fused_and_popcount_avx512_matches_scalar() {
378        if !is_x86_feature_detected!("avx512vpopcntdq") {
379            return;
380        }
381
382        let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
383        for len in lengths {
384            let a: Vec<u64> = (0..len)
385                .map(|i| (i as u64).wrapping_mul(0xD6E8_FD9D_5A2B_1C47) ^ 0x1357_9BDF_2468_ACE0)
386                .collect();
387            let b: Vec<u64> = (0..len)
388                .map(|i| (i as u64).wrapping_mul(0x94D0_49BB_1331_11EB) ^ 0xF0F0_0F0F_AAAA_5555)
389                .collect();
390
391            let expected: u64 = a
392                .iter()
393                .zip(b.iter())
394                .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
395                .sum();
396
397            // SAFETY: Runtime-guarded by feature detection in this test.
398            let got = unsafe { super::fused_and_popcount_avx512(&a, &b) };
399            assert_eq!(got, expected, "Mismatch at len={len}");
400        }
401    }
402
403    #[test]
404    fn bernoulli_compare_avx512_matches_scalar() {
405        if !is_x86_feature_detected!("avx512bw") {
406            return;
407        }
408
409        let buf: Vec<u8> = (0..64).map(|i| (i * 41 + 23) as u8).collect();
410        let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
411
412        for threshold in thresholds {
413            let expected = buf.iter().enumerate().fold(0_u64, |acc, (bit, &rb)| {
414                acc | (u64::from(rb < threshold) << bit)
415            });
416
417            // SAFETY: Runtime-guarded by feature detection in this test.
418            let got = unsafe { super::bernoulli_compare_avx512(&buf, threshold) };
419            assert_eq!(
420                got, expected,
421                "Mismatch for threshold={threshold} buf={buf:?}"
422            );
423        }
424    }
425}