Skip to main content

sc_neurocore_engine/simd/
avx512.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore — AVX512
7
8#[cfg(target_arch = "x86_64")]
9use core::arch::x86_64::*;
10
11#[cfg(target_arch = "x86_64")]
12#[target_feature(enable = "avx512f,avx512vpopcntdq")]
13/// Count set bits in 64-bit words using AVX-512 VPOPCNTDQ.
14///
15/// # Safety
16/// Caller must ensure the current CPU supports `avx512f` and `avx512vpopcntdq`.
17pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
18    let mut total = 0_u64;
19    let mut chunks = data.chunks_exact(8);
20
21    for chunk in &mut chunks {
22        let v = _mm512_loadu_si512(chunk.as_ptr() as *const __m512i);
23        let counts = _mm512_popcnt_epi64(v);
24        let mut lanes = [0_u64; 8];
25        _mm512_storeu_si512(lanes.as_mut_ptr() as *mut __m512i, counts);
26        total += lanes.iter().sum::<u64>();
27    }
28
29    total + crate::bitstream::popcount_words_portable(chunks.remainder())
30}
31
32#[cfg(target_arch = "x86_64")]
33#[target_feature(enable = "avx512f,avx512bw")]
34/// Pack u8 bits into u64 words using AVX-512 k-mask compare.
35///
36/// Processes 64 bytes per iteration where each compare result bit maps
37/// directly to one packed output bit.
38///
39/// # Safety
40/// Caller must ensure the current CPU supports `avx512f` and `avx512bw`.
41pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
42    let length = bits.len();
43    let words = length.div_ceil(64);
44    let mut data = vec![0_u64; words];
45    let full_words = length / 64;
46    let zero = _mm512_setzero_si512();
47
48    for (word_idx, word) in data.iter_mut().take(full_words).enumerate() {
49        let base = word_idx * 64;
50        let v = _mm512_loadu_si512(bits.as_ptr().add(base) as *const __m512i);
51        let mask = _mm512_cmpneq_epi8_mask(v, zero);
52        *word = mask;
53    }
54
55    if full_words < words {
56        let tail_start = full_words * 64;
57        let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
58        data[full_words] = tail.data.first().copied().unwrap_or(0);
59    }
60
61    data
62}
63
64#[cfg(target_arch = "x86_64")]
65#[target_feature(enable = "avx512f,avx512vpopcntdq")]
66/// Fused AND+popcount over packed words using AVX-512 VPOPCNTDQ.
67///
68/// # Safety
69/// Caller must ensure the current CPU supports `avx512f` and `avx512vpopcntdq`.
70pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
71    let len = a.len().min(b.len());
72    let mut total = _mm512_setzero_si512();
73    let mut chunks_a = a[..len].chunks_exact(8);
74    let mut chunks_b = b[..len].chunks_exact(8);
75
76    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
77        let va = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
78        let vb = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
79        let anded = _mm512_and_epi64(va, vb);
80        let counts = _mm512_popcnt_epi64(anded);
81        total = _mm512_add_epi64(total, counts);
82    }
83
84    let mut lanes = [0_u64; 8];
85    _mm512_storeu_si512(lanes.as_mut_ptr() as *mut __m512i, total);
86    let mut sum: u64 = lanes.iter().sum();
87
88    for (&wa, &wb) in chunks_a.remainder().iter().zip(chunks_b.remainder().iter()) {
89        sum += (wa & wb).count_ones() as u64;
90    }
91    sum
92}
93
94#[cfg(target_arch = "x86_64")]
95#[target_feature(enable = "avx512f,avx512vpopcntdq")]
96/// Fused XOR+popcount over packed words using AVX-512 VPOPCNTDQ.
97///
98/// # Safety
99/// Caller must ensure the current CPU supports `avx512f` and `avx512vpopcntdq`.
100pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
101    let len = a.len().min(b.len());
102    let mut total = _mm512_setzero_si512();
103    let mut chunks_a = a[..len].chunks_exact(8);
104    let mut chunks_b = b[..len].chunks_exact(8);
105
106    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
107        let va = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
108        let vb = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
109        let xored = _mm512_xor_epi64(va, vb);
110        let counts = _mm512_popcnt_epi64(xored);
111        total = _mm512_add_epi64(total, counts);
112    }
113
114    let mut lanes = [0_u64; 8];
115    _mm512_storeu_si512(lanes.as_mut_ptr() as *mut __m512i, total);
116    let mut sum: u64 = lanes.iter().sum();
117
118    for (&wa, &wb) in chunks_a.remainder().iter().zip(chunks_b.remainder().iter()) {
119        sum += (wa ^ wb).count_ones() as u64;
120    }
121    sum
122}
123
124#[cfg(not(target_arch = "x86_64"))]
125/// Fallback fused XOR+popcount when AVX-512 is unavailable on this architecture.
126///
127/// # Safety
128/// This function is marked unsafe for API parity with the AVX-512 variant.
129pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
130    a.iter()
131        .zip(b.iter())
132        .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
133        .sum()
134}
135
136#[cfg(target_arch = "x86_64")]
137#[target_feature(enable = "avx512f,avx512bw")]
138/// Compare 64 random bytes against an unsigned threshold and return bit mask.
139///
140/// Bit `i` in the returned mask is 1 iff `buf[i] < threshold`.
141///
142/// # Safety
143/// Caller must ensure the current CPU supports `avx512f` and `avx512bw`.
144/// `buf` must have at least 64 elements.
145pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
146    assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
147    let data = _mm512_loadu_si512(buf.as_ptr() as *const __m512i);
148    let thresh = _mm512_set1_epi8(threshold as i8);
149    _mm512_cmplt_epu8_mask(data, thresh)
150}
151
152#[cfg(not(target_arch = "x86_64"))]
153/// Fallback popcount when AVX-512 is unavailable on this architecture.
154///
155/// # Safety
156/// This function is marked unsafe for API parity with the AVX-512 variant.
157pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
158    crate::bitstream::popcount_words_portable(data)
159}
160
161#[cfg(not(target_arch = "x86_64"))]
162/// Fallback pack when AVX-512 is unavailable on this architecture.
163///
164/// # Safety
165/// This function is marked unsafe for API parity with the AVX-512 variant.
166pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
167    crate::bitstream::pack_fast(bits).data
168}
169
170#[cfg(not(target_arch = "x86_64"))]
171/// Fallback fused AND+popcount when AVX-512 is unavailable on this architecture.
172///
173/// # Safety
174/// This function is marked unsafe for API parity with the AVX-512 variant.
175pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
176    a.iter()
177        .zip(b.iter())
178        .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
179        .sum()
180}
181
182#[cfg(not(target_arch = "x86_64"))]
183/// Fallback Bernoulli compare when AVX-512 is unavailable on this architecture.
184///
185/// # Safety
186/// This function is marked unsafe for API parity with the AVX-512 variant.
187pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
188    let mut mask = 0_u64;
189    for (bit, &rb) in buf.iter().take(64).enumerate() {
190        if rb < threshold {
191            mask |= 1_u64 << bit;
192        }
193    }
194    mask
195}
196
197// --- f64 SIMD operations (AVX-512: 8-wide f64) ---
198
199#[cfg(target_arch = "x86_64")]
200#[target_feature(enable = "avx512f")]
201/// Dot product of two f64 slices using AVX-512.
202///
203/// # Safety
204/// Caller must ensure the current CPU supports `avx512f`.
205pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
206    let len = a.len().min(b.len());
207    let mut acc = _mm512_setzero_pd();
208    let mut chunks_a = a[..len].chunks_exact(8);
209    let mut chunks_b = b[..len].chunks_exact(8);
210
211    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
212        let va = _mm512_loadu_pd(ca.as_ptr());
213        let vb = _mm512_loadu_pd(cb.as_ptr());
214        acc = _mm512_fmadd_pd(va, vb, acc);
215    }
216
217    let mut sum = _mm512_reduce_add_pd(acc);
218    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
219        sum += ra * rb;
220    }
221    sum
222}
223
224#[cfg(target_arch = "x86_64")]
225#[target_feature(enable = "avx512f")]
226/// Maximum of f64 slice using AVX-512.
227///
228/// # Safety
229/// Caller must ensure the current CPU supports `avx512f`.
230pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
231    if a.is_empty() {
232        return f64::NEG_INFINITY;
233    }
234    let mut vmax = _mm512_set1_pd(f64::NEG_INFINITY);
235    let mut chunks = a.chunks_exact(8);
236
237    for chunk in chunks.by_ref() {
238        let va = _mm512_loadu_pd(chunk.as_ptr());
239        vmax = _mm512_max_pd(vmax, va);
240    }
241
242    let mut m = _mm512_reduce_max_pd(vmax);
243    for &v in chunks.remainder() {
244        m = m.max(v);
245    }
246    m
247}
248
249#[cfg(target_arch = "x86_64")]
250#[target_feature(enable = "avx512f")]
251/// Sum of f64 slice using AVX-512.
252///
253/// # Safety
254/// Caller must ensure the current CPU supports `avx512f`.
255pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
256    let mut acc = _mm512_setzero_pd();
257    let mut chunks = a.chunks_exact(8);
258
259    for chunk in chunks.by_ref() {
260        let va = _mm512_loadu_pd(chunk.as_ptr());
261        acc = _mm512_add_pd(acc, va);
262    }
263
264    let mut sum = _mm512_reduce_add_pd(acc);
265    for &v in chunks.remainder() {
266        sum += v;
267    }
268    sum
269}
270
271#[cfg(target_arch = "x86_64")]
272#[target_feature(enable = "avx512f")]
273/// Scale f64 slice in-place: y[i] *= alpha, using AVX-512.
274///
275/// # Safety
276/// Caller must ensure the current CPU supports `avx512f`.
277pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
278    let valpha = _mm512_set1_pd(alpha);
279    let mut chunks = y.chunks_exact_mut(8);
280
281    for chunk in chunks.by_ref() {
282        let vy = _mm512_loadu_pd(chunk.as_ptr());
283        let scaled = _mm512_mul_pd(vy, valpha);
284        _mm512_storeu_pd(chunk.as_mut_ptr(), scaled);
285    }
286
287    for v in chunks.into_remainder() {
288        *v *= alpha;
289    }
290}
291
292#[cfg(not(target_arch = "x86_64"))]
293pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
294    let len = a.len().min(b.len());
295    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
296}
297
298#[cfg(not(target_arch = "x86_64"))]
299pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
300    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
301}
302
303#[cfg(not(target_arch = "x86_64"))]
304pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
305    a.iter().sum()
306}
307
308#[cfg(not(target_arch = "x86_64"))]
309pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
310    for v in y.iter_mut() {
311        *v *= alpha;
312    }
313}
314
315#[cfg(all(test, target_arch = "x86_64"))]
316mod tests {
317    use crate::bitstream::pack;
318
319    #[test]
320    fn pack_avx512_matches_pack() {
321        if !is_x86_feature_detected!("avx512bw") {
322            return;
323        }
324
325        let lengths = [
326            1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
327        ];
328        for length in lengths {
329            let bits: Vec<u8> = (0..length)
330                .map(|i| if (i * 19 + 11) % 4 == 0 { 1 } else { 0 })
331                .collect();
332            // SAFETY: Runtime-guarded by feature detection in this test.
333            let got = unsafe { super::pack_avx512(&bits) };
334            let expected = pack(&bits).data;
335            assert_eq!(got, expected, "Mismatch at length={length}");
336        }
337    }
338
339    #[test]
340    fn fused_and_popcount_avx512_matches_scalar() {
341        if !is_x86_feature_detected!("avx512vpopcntdq") {
342            return;
343        }
344
345        let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
346        for len in lengths {
347            let a: Vec<u64> = (0..len)
348                .map(|i| (i as u64).wrapping_mul(0xD6E8_FD9D_5A2B_1C47) ^ 0x1357_9BDF_2468_ACE0)
349                .collect();
350            let b: Vec<u64> = (0..len)
351                .map(|i| (i as u64).wrapping_mul(0x94D0_49BB_1331_11EB) ^ 0xF0F0_0F0F_AAAA_5555)
352                .collect();
353
354            let expected: u64 = a
355                .iter()
356                .zip(b.iter())
357                .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
358                .sum();
359
360            // SAFETY: Runtime-guarded by feature detection in this test.
361            let got = unsafe { super::fused_and_popcount_avx512(&a, &b) };
362            assert_eq!(got, expected, "Mismatch at len={len}");
363        }
364    }
365
366    #[test]
367    fn bernoulli_compare_avx512_matches_scalar() {
368        if !is_x86_feature_detected!("avx512bw") {
369            return;
370        }
371
372        let buf: Vec<u8> = (0..64).map(|i| (i * 41 + 23) as u8).collect();
373        let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
374
375        for threshold in thresholds {
376            let expected = buf.iter().enumerate().fold(0_u64, |acc, (bit, &rb)| {
377                acc | (u64::from(rb < threshold) << bit)
378            });
379
380            // SAFETY: Runtime-guarded by feature detection in this test.
381            let got = unsafe { super::bernoulli_compare_avx512(&buf, threshold) };
382            assert_eq!(
383                got, expected,
384                "Mismatch for threshold={threshold} buf={buf:?}"
385            );
386        }
387    }
388}