Skip to main content

sc_neurocore_engine/simd/
mod.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — SIMD Popcount Dispatch
8
9//! # SIMD Popcount Dispatch
10//!
11//! Runtime CPU-feature dispatch for packed-bit popcount kernels.
12//! Supported backends: AVX-512, AVX2, ARM NEON, ARM SVE, RISC-V RVV.
13
14use rand::Rng;
15
16pub mod avx2;
17pub mod avx512;
18pub mod neon;
19pub mod rvv;
20pub mod sve;
21
22/// Pack u8 bits into u64 words using the best available SIMD path.
23pub fn pack_dispatch(bits: &[u8]) -> crate::bitstream::BitStreamTensor {
24    let length = bits.len();
25
26    #[cfg(target_arch = "x86_64")]
27    {
28        if is_x86_feature_detected!("avx512bw") {
29            // SAFETY: Guarded by runtime feature detection.
30            let data = unsafe { avx512::pack_avx512(bits) };
31            return crate::bitstream::BitStreamTensor { data, length };
32        }
33        if is_x86_feature_detected!("avx2") {
34            // SAFETY: Guarded by runtime feature detection.
35            let data = unsafe { avx2::pack_avx2(bits) };
36            return crate::bitstream::BitStreamTensor { data, length };
37        }
38    }
39
40    #[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
41    {
42        // SAFETY: SVE target feature is compile-time guaranteed.
43        let data = unsafe { sve::pack_sve(bits) };
44        return crate::bitstream::BitStreamTensor { data, length };
45    }
46
47    crate::bitstream::pack_fast(bits)
48}
49
50/// Count set bits in packed `u64` words using the best available SIMD path.
51pub fn popcount_dispatch(data: &[u64]) -> u64 {
52    #[cfg(target_arch = "x86_64")]
53    {
54        if is_x86_feature_detected!("avx512vpopcntdq") {
55            // SAFETY: Guarded by runtime feature detection.
56            return unsafe { avx512::popcount_avx512(data) };
57        }
58        if is_x86_feature_detected!("avx2") {
59            // SAFETY: Guarded by runtime feature detection.
60            return unsafe { avx2::popcount_avx2(data) };
61        }
62    }
63
64    #[cfg(target_arch = "aarch64")]
65    {
66        #[cfg(target_feature = "sve")]
67        {
68            // SAFETY: SVE target feature is compile-time guaranteed.
69            return unsafe { sve::popcount_sve(data) };
70        }
71        #[cfg(not(target_feature = "sve"))]
72        {
73            // SAFETY: NEON is baseline on aarch64 targets.
74            return unsafe { neon::popcount_neon(data) };
75        }
76    }
77
78    #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
79    {
80        // SAFETY: RVV target feature is compile-time guaranteed.
81        return unsafe { rvv::popcount_rvv(data) };
82    }
83
84    crate::bitstream::popcount_words_portable(data)
85}
86
87/// Fused AND+popcount dispatch using the best available SIMD path.
88pub fn fused_and_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
89    let len = a.len().min(b.len());
90    let a = &a[..len];
91    let b = &b[..len];
92
93    #[cfg(target_arch = "x86_64")]
94    {
95        if is_x86_feature_detected!("avx512vpopcntdq") {
96            // SAFETY: Guarded by runtime feature detection.
97            return unsafe { avx512::fused_and_popcount_avx512(a, b) };
98        }
99        if is_x86_feature_detected!("avx2") {
100            // SAFETY: Guarded by runtime feature detection.
101            return unsafe { avx2::fused_and_popcount_avx2(a, b) };
102        }
103    }
104
105    #[cfg(target_arch = "aarch64")]
106    {
107        #[cfg(target_feature = "sve")]
108        {
109            return unsafe { sve::fused_and_popcount_sve(a, b) };
110        }
111    }
112
113    #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
114    {
115        return unsafe { rvv::fused_and_popcount_rvv(a, b) };
116    }
117
118    let mut total = 0_u64;
119    let mut chunks_a = a.chunks_exact(4);
120    let mut chunks_b = b.chunks_exact(4);
121    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
122        total += (ca[0] & cb[0]).count_ones() as u64;
123        total += (ca[1] & cb[1]).count_ones() as u64;
124        total += (ca[2] & cb[2]).count_ones() as u64;
125        total += (ca[3] & cb[3]).count_ones() as u64;
126    }
127    total += chunks_a
128        .remainder()
129        .iter()
130        .zip(chunks_b.remainder().iter())
131        .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
132        .sum::<u64>();
133    total
134}
135
136/// Fused XOR+popcount dispatch using the best available SIMD path.
137pub fn fused_xor_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
138    let len = a.len().min(b.len());
139    let a = &a[..len];
140    let b = &b[..len];
141
142    #[cfg(target_arch = "x86_64")]
143    {
144        if is_x86_feature_detected!("avx512vpopcntdq") {
145            // SAFETY: Guarded by runtime feature detection.
146            return unsafe { avx512::fused_xor_popcount_avx512(a, b) };
147        }
148        if is_x86_feature_detected!("avx2") {
149            // SAFETY: Guarded by runtime feature detection.
150            return unsafe { avx2::fused_xor_popcount_avx2(a, b) };
151        }
152        if is_x86_feature_detected!("avx") {
153            let mut total = 0_u64;
154            let mut chunks_a = a.chunks_exact(16);
155            let mut chunks_b = b.chunks_exact(16);
156            for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
157                for i in 0..16 {
158                    total += (ca[i] ^ cb[i]).count_ones() as u64;
159                }
160            }
161            total += chunks_a
162                .remainder()
163                .iter()
164                .zip(chunks_b.remainder().iter())
165                .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
166                .sum::<u64>();
167            return total;
168        }
169    }
170
171    #[cfg(target_arch = "aarch64")]
172    {
173        #[cfg(target_feature = "sve")]
174        {
175            return unsafe { sve::fused_xor_popcount_sve(a, b) };
176        }
177    }
178
179    #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
180    {
181        return unsafe { rvv::fused_xor_popcount_rvv(a, b) };
182    }
183
184    let mut total = 0_u64;
185    let mut chunks_a = a.chunks_exact(4);
186    let mut chunks_b = b.chunks_exact(4);
187    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
188        total += (ca[0] ^ cb[0]).count_ones() as u64;
189        total += (ca[1] ^ cb[1]).count_ones() as u64;
190        total += (ca[2] ^ cb[2]).count_ones() as u64;
191        total += (ca[3] ^ cb[3]).count_ones() as u64;
192    }
193    total += chunks_a
194        .remainder()
195        .iter()
196        .zip(chunks_b.remainder().iter())
197        .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
198        .sum::<u64>();
199    total
200}
201
202// --- f64 dispatch functions ---
203
204/// Dot product of two f64 slices using the best available SIMD path.
205pub fn dot_f64_dispatch(a: &[f64], b: &[f64]) -> f64 {
206    #[cfg(target_arch = "x86_64")]
207    {
208        if is_x86_feature_detected!("avx512f") {
209            return unsafe { avx512::dot_f64_avx512(a, b) };
210        }
211        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
212            return unsafe { avx2::dot_f64_avx2(a, b) };
213        }
214        if is_x86_feature_detected!("avx") {
215            return unsafe { avx2::dot_f64_avx(a, b) };
216        }
217        if is_x86_feature_detected!("sse2") {
218            let len = a.len().min(b.len());
219            let mut sum = 0.0_f64;
220            let mut chunks_a = a[..len].chunks_exact(4);
221            let mut chunks_b = b[..len].chunks_exact(4);
222            for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
223                sum += ca[0] * cb[0] + ca[1] * cb[1] + ca[2] * cb[2] + ca[3] * cb[3];
224            }
225            sum += chunks_a
226                .remainder()
227                .iter()
228                .zip(chunks_b.remainder())
229                .map(|(x, y)| x * y)
230                .sum::<f64>();
231            return sum;
232        }
233    }
234
235    #[cfg(target_arch = "aarch64")]
236    {
237        return unsafe { neon::dot_f64_neon(a, b) };
238    }
239
240    let len = a.len().min(b.len());
241    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
242}
243
244/// Maximum of f64 slice using the best available SIMD path.
245pub fn max_f64_dispatch(a: &[f64]) -> f64 {
246    #[cfg(target_arch = "x86_64")]
247    {
248        if is_x86_feature_detected!("avx512f") {
249            return unsafe { avx512::max_f64_avx512(a) };
250        }
251        if is_x86_feature_detected!("avx2") {
252            return unsafe { avx2::max_f64_avx2(a) };
253        }
254        if is_x86_feature_detected!("avx") {
255            return unsafe { avx2::max_f64_avx(a) };
256        }
257        if is_x86_feature_detected!("sse2") {
258            let mut m = f64::NEG_INFINITY;
259            let mut chunks = a.chunks_exact(4);
260            for c in chunks.by_ref() {
261                m = m.max(c[0].max(c[1]).max(c[2].max(c[3])));
262            }
263            for &v in chunks.remainder() {
264                m = m.max(v);
265            }
266            return m;
267        }
268    }
269
270    #[cfg(target_arch = "aarch64")]
271    {
272        return unsafe { neon::max_f64_neon(a) };
273    }
274
275    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
276}
277
278/// Sum of f64 slice using the best available SIMD path.
279pub fn sum_f64_dispatch(a: &[f64]) -> f64 {
280    #[cfg(target_arch = "x86_64")]
281    {
282        if is_x86_feature_detected!("avx512f") {
283            return unsafe { avx512::sum_f64_avx512(a) };
284        }
285        if is_x86_feature_detected!("avx2") {
286            return unsafe { avx2::sum_f64_avx2(a) };
287        }
288        if is_x86_feature_detected!("avx") {
289            return unsafe { avx2::sum_f64_avx(a) };
290        }
291        if is_x86_feature_detected!("sse2") {
292            let mut s = 0.0_f64;
293            let mut chunks = a.chunks_exact(4);
294            for c in chunks.by_ref() {
295                s += c[0] + c[1] + c[2] + c[3];
296            }
297            s += chunks.remainder().iter().sum::<f64>();
298            return s;
299        }
300    }
301
302    #[cfg(target_arch = "aarch64")]
303    {
304        return unsafe { neon::sum_f64_neon(a) };
305    }
306
307    a.iter().sum()
308}
309
310/// Scale f64 slice in-place: y[i] *= alpha, using the best available SIMD path.
311pub fn scale_f64_dispatch(alpha: f64, y: &mut [f64]) {
312    #[cfg(target_arch = "x86_64")]
313    {
314        if is_x86_feature_detected!("avx512f") {
315            unsafe { avx512::scale_f64_avx512(alpha, y) };
316            return;
317        }
318        if is_x86_feature_detected!("avx2") {
319            unsafe { avx2::scale_f64_avx2(alpha, y) };
320            return;
321        }
322        if is_x86_feature_detected!("avx") {
323            unsafe { avx2::scale_f64_avx(alpha, y) };
324            return;
325        }
326        if is_x86_feature_detected!("sse2") {
327            let mut chunks = y.chunks_exact_mut(4);
328            for c in chunks.by_ref() {
329                c[0] *= alpha;
330                c[1] *= alpha;
331                c[2] *= alpha;
332                c[3] *= alpha;
333            }
334            for v in chunks.into_remainder() {
335                *v *= alpha;
336            }
337            return;
338        }
339    }
340
341    #[cfg(target_arch = "aarch64")]
342    {
343        unsafe { neon::scale_f64_neon(alpha, y) };
344        return;
345    }
346
347    for x in y.iter_mut() {
348        *x *= alpha;
349    }
350}
351
352/// Hamming distance between two packed bitstream slices.
353pub fn hamming_distance_dispatch(a: &[u64], b: &[u64]) -> u64 {
354    fused_xor_popcount_dispatch(a, b)
355}
356
357/// In-place softmax over an f64 slice (numerically stable).
358///
359/// Computes: subtract max → exp → normalize by sum.
360/// Uses SIMD dispatch for max-find, scaling, and sum reduction.
361pub fn softmax_inplace_f64_dispatch(scores: &mut [f64]) {
362    if scores.is_empty() {
363        return;
364    }
365
366    #[cfg(target_arch = "x86_64")]
367    {
368        if is_x86_feature_detected!("avx2") {
369            unsafe { avx2::softmax_inplace_f64_avx2(scores) };
370            return;
371        }
372    }
373
374    let max_val = max_f64_dispatch(scores);
375    let mut chunks = scores.chunks_exact_mut(4);
376    for c in chunks.by_ref() {
377        c[0] = (c[0] - max_val).exp();
378        c[1] = (c[1] - max_val).exp();
379        c[2] = (c[2] - max_val).exp();
380        c[3] = (c[3] - max_val).exp();
381    }
382    for s in chunks.into_remainder() {
383        *s = (*s - max_val).exp();
384    }
385
386    let exp_sum = sum_f64_dispatch(scores);
387    if exp_sum > 0.0 {
388        scale_f64_dispatch(1.0 / exp_sum, scores);
389    }
390}
391
392/// Fused encode+AND+popcount dispatch.
393///
394/// Delegates to the scalar-control implementation in `bitstream`,
395/// which already performs SIMD Bernoulli compare where available.
396pub fn encode_and_popcount_dispatch<R: Rng + ?Sized>(
397    weight_words: &[u64],
398    prob: f64,
399    length: usize,
400    rng: &mut R,
401) -> u64 {
402    crate::bitstream::encode_and_popcount(weight_words, prob, length, rng)
403}
404
405/// Batch compare 1024 bytes against threshold using best SIMD.
406pub fn bernoulli_compare_batch_1024(buf: &[u8], threshold: u8, out: &mut [u64]) {
407    #[cfg(target_arch = "x86_64")]
408    {
409        if is_x86_feature_detected!("avx512bw") {
410            return unsafe { avx512::bernoulli_compare_batch_avx512(buf, threshold, out) };
411        }
412        if is_x86_feature_detected!("avx2") {
413            return unsafe { avx2::bernoulli_compare_batch_avx2(buf, threshold, out) };
414        }
415    }
416
417    #[cfg(target_arch = "x86_64")]
418    {
419        // SSE2 fallback (available on all x86_64)
420        use core::arch::x86_64::*;
421        unsafe {
422            let v_thresh = _mm_set1_epi8(threshold as i8);
423            let bias = _mm_set1_epi8(i8::MIN);
424            let v_thresh_biased = _mm_xor_si128(v_thresh, bias);
425
426            for i in 0..16 {
427                let chunk = &buf[i * 64..(i + 1) * 64];
428                let mut word = 0_u64;
429                for j in 0..4 {
430                    let v = _mm_loadu_si128(chunk.as_ptr().add(j * 16) as *const __m128i);
431                    let v_biased = _mm_xor_si128(v, bias);
432                    let m = _mm_cmpgt_epi8(v_thresh_biased, v_biased);
433                    let mask = _mm_movemask_epi8(m) as u32;
434                    word |= (mask as u64) << (j * 16);
435                }
436                out[i] = word;
437            }
438        }
439    }
440
441    // Generic fallback: 16 scalar calls (non-x86_64 architectures)
442    #[cfg(not(target_arch = "x86_64"))]
443    for i in 0..16 {
444        out[i] =
445            crate::bitstream::simd_bernoulli_compare_exposed(&buf[i * 64..(i + 1) * 64], threshold);
446    }
447}