Skip to main content

sc_neurocore_engine/simd/
neon.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Neon
8
9#[cfg(target_arch = "aarch64")]
10use core::arch::aarch64::*;
11
12#[cfg(target_arch = "aarch64")]
13#[target_feature(enable = "neon")]
14/// Count set bits in 64-bit words using ARM NEON instructions.
15///
16/// # Safety
17/// Caller must ensure the current CPU supports `neon`.
18pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
19    let mut total = 0_u64;
20    let mut chunks = data.chunks_exact(8);
21
22    for chunk in chunks.by_ref() {
23        let v0 = vld1q_u8(chunk.as_ptr() as *const u8);
24        let v1 = vld1q_u8(chunk.as_ptr().add(2) as *const u8);
25        let v2 = vld1q_u8(chunk.as_ptr().add(4) as *const u8);
26        let v3 = vld1q_u8(chunk.as_ptr().add(6) as *const u8);
27
28        let c0 = vcntq_u8(v0);
29        let c1 = vcntq_u8(v1);
30        let c2 = vcntq_u8(v2);
31        let c3 = vcntq_u8(v3);
32
33        let s0 = vpaddlq_u8(c0);
34        let s1 = vpaddlq_u8(c1);
35        let s2 = vpaddlq_u8(c2);
36        let s3 = vpaddlq_u8(c3);
37
38        let s32_0 = vpaddlq_u16(s0);
39        let s32_1 = vpaddlq_u16(s1);
40        let s32_2 = vpaddlq_u16(s2);
41        let s32_3 = vpaddlq_u16(s3);
42
43        let s64_0 = vpaddlq_u32(s32_0);
44        let s64_1 = vpaddlq_u32(s32_1);
45        let s64_2 = vpaddlq_u32(s32_2);
46        let s64_3 = vpaddlq_u32(s32_3);
47
48        total += vgetq_lane_u64(s64_0, 0) + vgetq_lane_u64(s64_0, 1);
49        total += vgetq_lane_u64(s64_1, 0) + vgetq_lane_u64(s64_1, 1);
50        total += vgetq_lane_u64(s64_2, 0) + vgetq_lane_u64(s64_2, 1);
51        total += vgetq_lane_u64(s64_3, 0) + vgetq_lane_u64(s64_3, 1);
52    }
53
54    total + crate::bitstream::popcount_words_portable(chunks.remainder())
55}
56
57#[cfg(not(target_arch = "aarch64"))]
58/// Fallback popcount when NEON is unavailable on this architecture.
59///
60/// # Safety
61/// This function is marked unsafe for API parity with the NEON variant.
62pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
63    crate::bitstream::popcount_words_portable(data)
64}
65
66// --- f64 SIMD operations (NEON: 2-wide f64, AArch64 only) ---
67
68#[cfg(target_arch = "aarch64")]
69#[target_feature(enable = "neon")]
70/// Dot product of two f64 slices using NEON.
71///
72/// # Safety
73/// Caller must ensure the current CPU supports `neon`.
74pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
75    let len = a.len().min(b.len());
76    let mut acc0 = vdupq_n_f64(0.0);
77    let mut acc1 = vdupq_n_f64(0.0);
78    let mut acc2 = vdupq_n_f64(0.0);
79    let mut acc3 = vdupq_n_f64(0.0);
80
81    let mut chunks_a = a[..len].chunks_exact(8);
82    let mut chunks_b = b[..len].chunks_exact(8);
83
84    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
85        acc0 = vfmaq_f64(acc0, vld1q_f64(ca.as_ptr()), vld1q_f64(cb.as_ptr()));
86        acc1 = vfmaq_f64(
87            acc1,
88            vld1q_f64(ca.as_ptr().add(2)),
89            vld1q_f64(cb.as_ptr().add(2)),
90        );
91        acc2 = vfmaq_f64(
92            acc2,
93            vld1q_f64(ca.as_ptr().add(4)),
94            vld1q_f64(cb.as_ptr().add(4)),
95        );
96        acc3 = vfmaq_f64(
97            acc3,
98            vld1q_f64(ca.as_ptr().add(6)),
99            vld1q_f64(cb.as_ptr().add(6)),
100        );
101    }
102
103    acc0 = vaddq_f64(acc0, acc1);
104    acc2 = vaddq_f64(acc2, acc3);
105    acc0 = vaddq_f64(acc0, acc2);
106
107    let mut sum = vgetq_lane_f64(acc0, 0) + vgetq_lane_f64(acc0, 1);
108    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
109        sum += ra * rb;
110    }
111    sum
112}
113
114#[cfg(target_arch = "aarch64")]
115#[target_feature(enable = "neon")]
116/// Maximum of f64 slice using NEON.
117///
118/// # Safety
119/// Caller must ensure the current CPU supports `neon`.
120pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
121    if a.is_empty() {
122        return f64::NEG_INFINITY;
123    }
124    let mut vmax0 = vdupq_n_f64(f64::NEG_INFINITY);
125    let mut vmax1 = vdupq_n_f64(f64::NEG_INFINITY);
126    let mut vmax2 = vdupq_n_f64(f64::NEG_INFINITY);
127    let mut vmax3 = vdupq_n_f64(f64::NEG_INFINITY);
128
129    let mut chunks = a.chunks_exact(8);
130    for chunk in chunks.by_ref() {
131        vmax0 = vmaxq_f64(vmax0, vld1q_f64(chunk.as_ptr()));
132        vmax1 = vmaxq_f64(vmax1, vld1q_f64(chunk.as_ptr().add(2)));
133        vmax2 = vmaxq_f64(vmax2, vld1q_f64(chunk.as_ptr().add(4)));
134        vmax3 = vmaxq_f64(vmax3, vld1q_f64(chunk.as_ptr().add(6)));
135    }
136
137    vmax0 = vmaxq_f64(vmax0, vmax1);
138    vmax2 = vmaxq_f64(vmax2, vmax3);
139    vmax0 = vmaxq_f64(vmax0, vmax2);
140
141    let mut m = f64::max(vgetq_lane_f64(vmax0, 0), vgetq_lane_f64(vmax0, 1));
142    for &v in chunks.remainder() {
143        m = m.max(v);
144    }
145    m
146}
147
148#[cfg(target_arch = "aarch64")]
149#[target_feature(enable = "neon")]
150/// Sum of f64 slice using NEON.
151///
152/// # Safety
153/// Caller must ensure the current CPU supports `neon`.
154pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
155    let mut acc0 = vdupq_n_f64(0.0);
156    let mut acc1 = vdupq_n_f64(0.0);
157    let mut acc2 = vdupq_n_f64(0.0);
158    let mut acc3 = vdupq_n_f64(0.0);
159
160    let mut chunks = a.chunks_exact(8);
161    for chunk in chunks.by_ref() {
162        acc0 = vaddq_f64(acc0, vld1q_f64(chunk.as_ptr()));
163        acc1 = vaddq_f64(acc1, vld1q_f64(chunk.as_ptr().add(2)));
164        acc2 = vaddq_f64(acc2, vld1q_f64(chunk.as_ptr().add(4)));
165        acc3 = vaddq_f64(acc3, vld1q_f64(chunk.as_ptr().add(6)));
166    }
167
168    acc0 = vaddq_f64(acc0, acc1);
169    acc2 = vaddq_f64(acc2, acc3);
170    acc0 = vaddq_f64(acc0, acc2);
171
172    let mut sum = vgetq_lane_f64(acc0, 0) + vgetq_lane_f64(acc0, 1);
173    for &v in chunks.remainder() {
174        sum += v;
175    }
176    sum
177}
178
179#[cfg(target_arch = "aarch64")]
180#[target_feature(enable = "neon")]
181/// Scale f64 slice in-place: y[i] *= alpha, using NEON.
182///
183/// # Safety
184/// Caller must ensure the current CPU supports `neon`.
185pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
186    let valpha = vdupq_n_f64(alpha);
187    let mut chunks = y.chunks_exact_mut(8);
188
189    for chunk in chunks.by_ref() {
190        vst1q_f64(
191            chunk.as_mut_ptr(),
192            vmulq_f64(vld1q_f64(chunk.as_ptr()), valpha),
193        );
194        vst1q_f64(
195            chunk.as_mut_ptr().add(2),
196            vmulq_f64(vld1q_f64(chunk.as_ptr().add(2)), valpha),
197        );
198        vst1q_f64(
199            chunk.as_mut_ptr().add(4),
200            vmulq_f64(vld1q_f64(chunk.as_ptr().add(4)), valpha),
201        );
202        vst1q_f64(
203            chunk.as_mut_ptr().add(6),
204            vmulq_f64(vld1q_f64(chunk.as_ptr().add(6)), valpha),
205        );
206    }
207
208    for v in chunks.into_remainder() {
209        *v *= alpha;
210    }
211}
212
213#[cfg(not(target_arch = "aarch64"))]
214/// # Safety
215/// Fallback for non-AArch64; unsafe for API parity.
216pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
217    let len = a.len().min(b.len());
218    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
219}
220
221#[cfg(not(target_arch = "aarch64"))]
222/// # Safety
223/// Fallback for non-AArch64; unsafe for API parity.
224pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
225    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
226}
227
228#[cfg(not(target_arch = "aarch64"))]
229/// # Safety
230/// Fallback for non-AArch64; unsafe for API parity.
231pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
232    a.iter().sum()
233}
234
235#[cfg(not(target_arch = "aarch64"))]
236/// # Safety
237/// Fallback for non-AArch64; unsafe for API parity.
238pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
239    for v in y.iter_mut() {
240        *v *= alpha;
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_popcount_empty() {
250        assert_eq!(unsafe { popcount_neon(&[]) }, 0);
251    }
252
253    #[test]
254    fn test_popcount_known_values() {
255        // 0xFFFF_FFFF_FFFF_FFFF has 64 set bits
256        assert_eq!(unsafe { popcount_neon(&[u64::MAX]) }, 64);
257        assert_eq!(unsafe { popcount_neon(&[0]) }, 0);
258        assert_eq!(unsafe { popcount_neon(&[1]) }, 1);
259        assert_eq!(unsafe { popcount_neon(&[0b1010_1010]) }, 4);
260    }
261
262    #[test]
263    fn test_popcount_multiple_words() {
264        let data = [u64::MAX, u64::MAX, 1];
265        assert_eq!(unsafe { popcount_neon(&data) }, 129); // 64+64+1
266    }
267
268    #[test]
269    fn test_dot_f64_simple() {
270        let a = [1.0, 2.0, 3.0];
271        let b = [4.0, 5.0, 6.0];
272        let result = unsafe { dot_f64_neon(&a, &b) };
273        assert!((result - 32.0).abs() < 1e-10); // 1*4 + 2*5 + 3*6 = 32
274    }
275
276    #[test]
277    fn test_dot_f64_empty() {
278        let result = unsafe { dot_f64_neon(&[], &[]) };
279        assert!((result - 0.0).abs() < 1e-10);
280    }
281
282    #[test]
283    fn test_dot_f64_mismatched_length() {
284        let a = [1.0, 2.0, 3.0, 4.0];
285        let b = [1.0, 1.0];
286        let result = unsafe { dot_f64_neon(&a, &b) };
287        assert!((result - 3.0).abs() < 1e-10); // 1*1 + 2*1
288    }
289
290    #[test]
291    fn test_max_f64() {
292        let a = [1.0, 5.0, 3.0, 2.0, 4.0];
293        assert!((unsafe { max_f64_neon(&a) } - 5.0).abs() < 1e-10);
294    }
295
296    #[test]
297    fn test_max_f64_empty() {
298        assert!(unsafe { max_f64_neon(&[]) } == f64::NEG_INFINITY);
299    }
300
301    #[test]
302    fn test_max_f64_negative() {
303        let a = [-5.0, -1.0, -3.0];
304        assert!((unsafe { max_f64_neon(&a) } - (-1.0)).abs() < 1e-10);
305    }
306
307    #[test]
308    fn test_sum_f64() {
309        let a = [1.0, 2.0, 3.0, 4.0, 5.0];
310        assert!((unsafe { sum_f64_neon(&a) } - 15.0).abs() < 1e-10);
311    }
312
313    #[test]
314    fn test_sum_f64_empty() {
315        assert!((unsafe { sum_f64_neon(&[]) } - 0.0).abs() < 1e-10);
316    }
317
318    #[test]
319    fn test_scale_f64() {
320        let mut y = [1.0, 2.0, 3.0, 4.0, 5.0];
321        unsafe { scale_f64_neon(2.0, &mut y) };
322        assert!((y[0] - 2.0).abs() < 1e-10);
323        assert!((y[4] - 10.0).abs() < 1e-10);
324    }
325
326    #[test]
327    fn test_scale_f64_zero() {
328        let mut y = [1.0, 2.0, 3.0];
329        unsafe { scale_f64_neon(0.0, &mut y) };
330        assert!(y.iter().all(|&v| v == 0.0));
331    }
332}