Skip to main content

sc_neurocore_engine/simd/
mod.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore — SIMD Popcount Dispatch
7
8//! # SIMD Popcount Dispatch
9//!
10//! Runtime CPU-feature dispatch for packed-bit popcount kernels.
11//! Supported backends: AVX-512, AVX2, ARM NEON, ARM SVE, RISC-V RVV.
12
13use rand::Rng;
14
15pub mod avx2;
16pub mod avx512;
17pub mod neon;
18pub mod rvv;
19pub mod sve;
20
21/// Pack u8 bits into u64 words using the best available SIMD path.
22pub fn pack_dispatch(bits: &[u8]) -> crate::bitstream::BitStreamTensor {
23    let length = bits.len();
24
25    #[cfg(target_arch = "x86_64")]
26    {
27        if is_x86_feature_detected!("avx512bw") {
28            // SAFETY: Guarded by runtime feature detection.
29            let data = unsafe { avx512::pack_avx512(bits) };
30            return crate::bitstream::BitStreamTensor { data, length };
31        }
32        if is_x86_feature_detected!("avx2") {
33            // SAFETY: Guarded by runtime feature detection.
34            let data = unsafe { avx2::pack_avx2(bits) };
35            return crate::bitstream::BitStreamTensor { data, length };
36        }
37    }
38
39    #[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
40    {
41        // SAFETY: SVE target feature is compile-time guaranteed.
42        let data = unsafe { sve::pack_sve(bits) };
43        return crate::bitstream::BitStreamTensor { data, length };
44    }
45
46    crate::bitstream::pack_fast(bits)
47}
48
49/// Count set bits in packed `u64` words using the best available SIMD path.
50pub fn popcount_dispatch(data: &[u64]) -> u64 {
51    #[cfg(target_arch = "x86_64")]
52    {
53        if is_x86_feature_detected!("avx512vpopcntdq") {
54            // SAFETY: Guarded by runtime feature detection.
55            return unsafe { avx512::popcount_avx512(data) };
56        }
57        if is_x86_feature_detected!("avx2") {
58            // SAFETY: Guarded by runtime feature detection.
59            return unsafe { avx2::popcount_avx2(data) };
60        }
61    }
62
63    #[cfg(target_arch = "aarch64")]
64    {
65        #[cfg(target_feature = "sve")]
66        {
67            // SAFETY: SVE target feature is compile-time guaranteed.
68            return unsafe { sve::popcount_sve(data) };
69        }
70        #[cfg(not(target_feature = "sve"))]
71        {
72            // SAFETY: NEON is baseline on aarch64 targets.
73            return unsafe { neon::popcount_neon(data) };
74        }
75    }
76
77    #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
78    {
79        // SAFETY: RVV target feature is compile-time guaranteed.
80        return unsafe { rvv::popcount_rvv(data) };
81    }
82
83    crate::bitstream::popcount_words_portable(data)
84}
85
86/// Fused AND+popcount dispatch using the best available SIMD path.
87pub fn fused_and_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
88    let len = a.len().min(b.len());
89    let a = &a[..len];
90    let b = &b[..len];
91
92    #[cfg(target_arch = "x86_64")]
93    {
94        if is_x86_feature_detected!("avx512vpopcntdq") {
95            // SAFETY: Guarded by runtime feature detection.
96            return unsafe { avx512::fused_and_popcount_avx512(a, b) };
97        }
98        if is_x86_feature_detected!("avx2") {
99            // SAFETY: Guarded by runtime feature detection.
100            return unsafe { avx2::fused_and_popcount_avx2(a, b) };
101        }
102    }
103
104    #[cfg(target_arch = "aarch64")]
105    {
106        #[cfg(target_feature = "sve")]
107        {
108            return unsafe { sve::fused_and_popcount_sve(a, b) };
109        }
110    }
111
112    #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
113    {
114        return unsafe { rvv::fused_and_popcount_rvv(a, b) };
115    }
116
117    a.iter()
118        .zip(b.iter())
119        .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
120        .sum()
121}
122
123/// Fused XOR+popcount dispatch using the best available SIMD path.
124pub fn fused_xor_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
125    let len = a.len().min(b.len());
126    let a = &a[..len];
127    let b = &b[..len];
128
129    #[cfg(target_arch = "x86_64")]
130    {
131        if is_x86_feature_detected!("avx512vpopcntdq") {
132            // SAFETY: Guarded by runtime feature detection.
133            return unsafe { avx512::fused_xor_popcount_avx512(a, b) };
134        }
135        if is_x86_feature_detected!("avx2") {
136            // SAFETY: Guarded by runtime feature detection.
137            return unsafe { avx2::fused_xor_popcount_avx2(a, b) };
138        }
139    }
140
141    #[cfg(target_arch = "aarch64")]
142    {
143        #[cfg(target_feature = "sve")]
144        {
145            return unsafe { sve::fused_xor_popcount_sve(a, b) };
146        }
147    }
148
149    #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
150    {
151        return unsafe { rvv::fused_xor_popcount_rvv(a, b) };
152    }
153
154    a.iter()
155        .zip(b.iter())
156        .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
157        .sum()
158}
159
160// --- f64 dispatch functions ---
161
162/// Dot product of two f64 slices using the best available SIMD path.
163pub fn dot_f64_dispatch(a: &[f64], b: &[f64]) -> f64 {
164    #[cfg(target_arch = "x86_64")]
165    {
166        if is_x86_feature_detected!("avx512f") {
167            return unsafe { avx512::dot_f64_avx512(a, b) };
168        }
169        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
170            return unsafe { avx2::dot_f64_avx2(a, b) };
171        }
172    }
173
174    #[cfg(target_arch = "aarch64")]
175    {
176        return unsafe { neon::dot_f64_neon(a, b) };
177    }
178
179    #[allow(unreachable_code)]
180    {
181        let len = a.len().min(b.len());
182        a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
183    }
184}
185
186/// Maximum of f64 slice using the best available SIMD path.
187pub fn max_f64_dispatch(a: &[f64]) -> f64 {
188    #[cfg(target_arch = "x86_64")]
189    {
190        if is_x86_feature_detected!("avx512f") {
191            return unsafe { avx512::max_f64_avx512(a) };
192        }
193        if is_x86_feature_detected!("avx2") {
194            return unsafe { avx2::max_f64_avx2(a) };
195        }
196    }
197
198    #[cfg(target_arch = "aarch64")]
199    {
200        return unsafe { neon::max_f64_neon(a) };
201    }
202
203    #[allow(unreachable_code)]
204    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
205}
206
207/// Sum of f64 slice using the best available SIMD path.
208pub fn sum_f64_dispatch(a: &[f64]) -> f64 {
209    #[cfg(target_arch = "x86_64")]
210    {
211        if is_x86_feature_detected!("avx512f") {
212            return unsafe { avx512::sum_f64_avx512(a) };
213        }
214        if is_x86_feature_detected!("avx2") {
215            return unsafe { avx2::sum_f64_avx2(a) };
216        }
217    }
218
219    #[cfg(target_arch = "aarch64")]
220    {
221        return unsafe { neon::sum_f64_neon(a) };
222    }
223
224    #[allow(unreachable_code)]
225    a.iter().sum()
226}
227
228/// Scale f64 slice in-place: y[i] *= alpha, using the best available SIMD path.
229pub fn scale_f64_dispatch(alpha: f64, y: &mut [f64]) {
230    #[cfg(target_arch = "x86_64")]
231    {
232        if is_x86_feature_detected!("avx512f") {
233            unsafe { avx512::scale_f64_avx512(alpha, y) };
234            return;
235        }
236        if is_x86_feature_detected!("avx2") {
237            unsafe { avx2::scale_f64_avx2(alpha, y) };
238            return;
239        }
240    }
241
242    #[cfg(target_arch = "aarch64")]
243    {
244        unsafe { neon::scale_f64_neon(alpha, y) };
245        return;
246    }
247
248    #[allow(unreachable_code)]
249    for v in y.iter_mut() {
250        *v *= alpha;
251    }
252}
253
254/// Hamming distance between two packed bitstream slices.
255pub fn hamming_distance_dispatch(a: &[u64], b: &[u64]) -> u64 {
256    fused_xor_popcount_dispatch(a, b)
257}
258
259/// In-place softmax over an f64 slice (numerically stable).
260///
261/// Computes: subtract max → exp → normalize by sum.
262/// Uses SIMD dispatch for max-find, scaling, and sum reduction.
263pub fn softmax_inplace_f64_dispatch(scores: &mut [f64]) {
264    if scores.is_empty() {
265        return;
266    }
267    let max_val = max_f64_dispatch(scores);
268    for s in scores.iter_mut() {
269        *s = (*s - max_val).exp();
270    }
271    let exp_sum = sum_f64_dispatch(scores);
272    if exp_sum > 0.0 {
273        scale_f64_dispatch(1.0 / exp_sum, scores);
274    }
275}
276
277/// Fused encode+AND+popcount dispatch.
278///
279/// Delegates to the scalar-control implementation in `bitstream`,
280/// which already performs SIMD Bernoulli compare where available.
281pub fn encode_and_popcount_dispatch<R: Rng + ?Sized>(
282    weight_words: &[u64],
283    prob: f64,
284    length: usize,
285    rng: &mut R,
286) -> u64 {
287    crate::bitstream::encode_and_popcount(weight_words, prob, length, rng)
288}