Skip to main content

sc_neurocore_engine/simd/
neon.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore — Neon
7
8#[cfg(target_arch = "aarch64")]
9use core::arch::aarch64::*;
10
11#[cfg(target_arch = "aarch64")]
12#[target_feature(enable = "neon")]
13/// Count set bits in 64-bit words using ARM NEON instructions.
14///
15/// # Safety
16/// Caller must ensure the current CPU supports `neon`.
17pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
18    let mut total = 0_u64;
19    let mut chunks = data.chunks_exact(2);
20
21    for chunk in &mut chunks {
22        let v = vld1q_u8(chunk.as_ptr() as *const u8);
23        let byte_counts = vcntq_u8(v);
24        let sum16 = vpaddlq_u8(byte_counts);
25        let sum32 = vpaddlq_u16(sum16);
26        let sum64 = vpaddlq_u32(sum32);
27        total += vgetq_lane_u64(sum64, 0) + vgetq_lane_u64(sum64, 1);
28    }
29
30    total + crate::bitstream::popcount_words_portable(chunks.remainder())
31}
32
33#[cfg(not(target_arch = "aarch64"))]
34/// Fallback popcount when NEON is unavailable on this architecture.
35///
36/// # Safety
37/// This function is marked unsafe for API parity with the NEON variant.
38pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
39    crate::bitstream::popcount_words_portable(data)
40}
41
42// --- f64 SIMD operations (NEON: 2-wide f64, AArch64 only) ---
43
44#[cfg(target_arch = "aarch64")]
45#[target_feature(enable = "neon")]
46/// Dot product of two f64 slices using NEON.
47///
48/// # Safety
49/// Caller must ensure the current CPU supports `neon`.
50pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
51    let len = a.len().min(b.len());
52    let mut acc = vdupq_n_f64(0.0);
53    let mut chunks_a = a[..len].chunks_exact(2);
54    let mut chunks_b = b[..len].chunks_exact(2);
55
56    for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
57        let va = vld1q_f64(ca.as_ptr());
58        let vb = vld1q_f64(cb.as_ptr());
59        acc = vfmaq_f64(acc, va, vb);
60    }
61
62    let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
63    for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
64        sum += ra * rb;
65    }
66    sum
67}
68
69#[cfg(target_arch = "aarch64")]
70#[target_feature(enable = "neon")]
71/// Maximum of f64 slice using NEON.
72///
73/// # Safety
74/// Caller must ensure the current CPU supports `neon`.
75pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
76    if a.is_empty() {
77        return f64::NEG_INFINITY;
78    }
79    let mut vmax = vdupq_n_f64(f64::NEG_INFINITY);
80    let mut chunks = a.chunks_exact(2);
81
82    for chunk in chunks.by_ref() {
83        let va = vld1q_f64(chunk.as_ptr());
84        vmax = vmaxq_f64(vmax, va);
85    }
86
87    let mut m = f64::max(vgetq_lane_f64(vmax, 0), vgetq_lane_f64(vmax, 1));
88    for &v in chunks.remainder() {
89        m = m.max(v);
90    }
91    m
92}
93
94#[cfg(target_arch = "aarch64")]
95#[target_feature(enable = "neon")]
96/// Sum of f64 slice using NEON.
97///
98/// # Safety
99/// Caller must ensure the current CPU supports `neon`.
100pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
101    let mut acc = vdupq_n_f64(0.0);
102    let mut chunks = a.chunks_exact(2);
103
104    for chunk in chunks.by_ref() {
105        let va = vld1q_f64(chunk.as_ptr());
106        acc = vaddq_f64(acc, va);
107    }
108
109    let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
110    for &v in chunks.remainder() {
111        sum += v;
112    }
113    sum
114}
115
116#[cfg(target_arch = "aarch64")]
117#[target_feature(enable = "neon")]
118/// Scale f64 slice in-place: y[i] *= alpha, using NEON.
119///
120/// # Safety
121/// Caller must ensure the current CPU supports `neon`.
122pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
123    let valpha = vdupq_n_f64(alpha);
124    let mut chunks = y.chunks_exact_mut(2);
125
126    for chunk in chunks.by_ref() {
127        let vy = vld1q_f64(chunk.as_ptr());
128        let scaled = vmulq_f64(vy, valpha);
129        vst1q_f64(chunk.as_mut_ptr(), scaled);
130    }
131
132    for v in chunks.into_remainder() {
133        *v *= alpha;
134    }
135}
136
137#[cfg(not(target_arch = "aarch64"))]
138/// # Safety
139/// Fallback for non-AArch64; unsafe for API parity.
140pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
141    let len = a.len().min(b.len());
142    a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
143}
144
145#[cfg(not(target_arch = "aarch64"))]
146/// # Safety
147/// Fallback for non-AArch64; unsafe for API parity.
148pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
149    a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
150}
151
152#[cfg(not(target_arch = "aarch64"))]
153/// # Safety
154/// Fallback for non-AArch64; unsafe for API parity.
155pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
156    a.iter().sum()
157}
158
159#[cfg(not(target_arch = "aarch64"))]
160/// # Safety
161/// Fallback for non-AArch64; unsafe for API parity.
162pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
163    for v in y.iter_mut() {
164        *v *= alpha;
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_popcount_empty() {
174        assert_eq!(unsafe { popcount_neon(&[]) }, 0);
175    }
176
177    #[test]
178    fn test_popcount_known_values() {
179        // 0xFFFF_FFFF_FFFF_FFFF has 64 set bits
180        assert_eq!(unsafe { popcount_neon(&[u64::MAX]) }, 64);
181        assert_eq!(unsafe { popcount_neon(&[0]) }, 0);
182        assert_eq!(unsafe { popcount_neon(&[1]) }, 1);
183        assert_eq!(unsafe { popcount_neon(&[0b1010_1010]) }, 4);
184    }
185
186    #[test]
187    fn test_popcount_multiple_words() {
188        let data = [u64::MAX, u64::MAX, 1];
189        assert_eq!(unsafe { popcount_neon(&data) }, 129); // 64+64+1
190    }
191
192    #[test]
193    fn test_dot_f64_simple() {
194        let a = [1.0, 2.0, 3.0];
195        let b = [4.0, 5.0, 6.0];
196        let result = unsafe { dot_f64_neon(&a, &b) };
197        assert!((result - 32.0).abs() < 1e-10); // 1*4 + 2*5 + 3*6 = 32
198    }
199
200    #[test]
201    fn test_dot_f64_empty() {
202        let result = unsafe { dot_f64_neon(&[], &[]) };
203        assert!((result - 0.0).abs() < 1e-10);
204    }
205
206    #[test]
207    fn test_dot_f64_mismatched_length() {
208        let a = [1.0, 2.0, 3.0, 4.0];
209        let b = [1.0, 1.0];
210        let result = unsafe { dot_f64_neon(&a, &b) };
211        assert!((result - 3.0).abs() < 1e-10); // 1*1 + 2*1
212    }
213
214    #[test]
215    fn test_max_f64() {
216        let a = [1.0, 5.0, 3.0, 2.0, 4.0];
217        assert!((unsafe { max_f64_neon(&a) } - 5.0).abs() < 1e-10);
218    }
219
220    #[test]
221    fn test_max_f64_empty() {
222        assert!(unsafe { max_f64_neon(&[]) } == f64::NEG_INFINITY);
223    }
224
225    #[test]
226    fn test_max_f64_negative() {
227        let a = [-5.0, -1.0, -3.0];
228        assert!((unsafe { max_f64_neon(&a) } - (-1.0)).abs() < 1e-10);
229    }
230
231    #[test]
232    fn test_sum_f64() {
233        let a = [1.0, 2.0, 3.0, 4.0, 5.0];
234        assert!((unsafe { sum_f64_neon(&a) } - 15.0).abs() < 1e-10);
235    }
236
237    #[test]
238    fn test_sum_f64_empty() {
239        assert!((unsafe { sum_f64_neon(&[]) } - 0.0).abs() < 1e-10);
240    }
241
242    #[test]
243    fn test_scale_f64() {
244        let mut y = [1.0, 2.0, 3.0, 4.0, 5.0];
245        unsafe { scale_f64_neon(2.0, &mut y) };
246        assert!((y[0] - 2.0).abs() < 1e-10);
247        assert!((y[4] - 10.0).abs() < 1e-10);
248    }
249
250    #[test]
251    fn test_scale_f64_zero() {
252        let mut y = [1.0, 2.0, 3.0];
253        unsafe { scale_f64_neon(0.0, &mut y) };
254        assert!(y.iter().all(|&v| v == 0.0));
255    }
256}