sc_neurocore_engine/simd/
neon.rs1#[cfg(target_arch = "aarch64")]
9use core::arch::aarch64::*;
10
11#[cfg(target_arch = "aarch64")]
12#[target_feature(enable = "neon")]
13pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
18 let mut total = 0_u64;
19 let mut chunks = data.chunks_exact(2);
20
21 for chunk in &mut chunks {
22 let v = vld1q_u8(chunk.as_ptr() as *const u8);
23 let byte_counts = vcntq_u8(v);
24 let sum16 = vpaddlq_u8(byte_counts);
25 let sum32 = vpaddlq_u16(sum16);
26 let sum64 = vpaddlq_u32(sum32);
27 total += vgetq_lane_u64(sum64, 0) + vgetq_lane_u64(sum64, 1);
28 }
29
30 total + crate::bitstream::popcount_words_portable(chunks.remainder())
31}
32
33#[cfg(not(target_arch = "aarch64"))]
34pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
39 crate::bitstream::popcount_words_portable(data)
40}
41
42#[cfg(target_arch = "aarch64")]
45#[target_feature(enable = "neon")]
46pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
51 let len = a.len().min(b.len());
52 let mut acc = vdupq_n_f64(0.0);
53 let mut chunks_a = a[..len].chunks_exact(2);
54 let mut chunks_b = b[..len].chunks_exact(2);
55
56 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
57 let va = vld1q_f64(ca.as_ptr());
58 let vb = vld1q_f64(cb.as_ptr());
59 acc = vfmaq_f64(acc, va, vb);
60 }
61
62 let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
63 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
64 sum += ra * rb;
65 }
66 sum
67}
68
69#[cfg(target_arch = "aarch64")]
70#[target_feature(enable = "neon")]
71pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
76 if a.is_empty() {
77 return f64::NEG_INFINITY;
78 }
79 let mut vmax = vdupq_n_f64(f64::NEG_INFINITY);
80 let mut chunks = a.chunks_exact(2);
81
82 for chunk in chunks.by_ref() {
83 let va = vld1q_f64(chunk.as_ptr());
84 vmax = vmaxq_f64(vmax, va);
85 }
86
87 let mut m = f64::max(vgetq_lane_f64(vmax, 0), vgetq_lane_f64(vmax, 1));
88 for &v in chunks.remainder() {
89 m = m.max(v);
90 }
91 m
92}
93
94#[cfg(target_arch = "aarch64")]
95#[target_feature(enable = "neon")]
96pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
101 let mut acc = vdupq_n_f64(0.0);
102 let mut chunks = a.chunks_exact(2);
103
104 for chunk in chunks.by_ref() {
105 let va = vld1q_f64(chunk.as_ptr());
106 acc = vaddq_f64(acc, va);
107 }
108
109 let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
110 for &v in chunks.remainder() {
111 sum += v;
112 }
113 sum
114}
115
116#[cfg(target_arch = "aarch64")]
117#[target_feature(enable = "neon")]
118pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
123 let valpha = vdupq_n_f64(alpha);
124 let mut chunks = y.chunks_exact_mut(2);
125
126 for chunk in chunks.by_ref() {
127 let vy = vld1q_f64(chunk.as_ptr());
128 let scaled = vmulq_f64(vy, valpha);
129 vst1q_f64(chunk.as_mut_ptr(), scaled);
130 }
131
132 for v in chunks.into_remainder() {
133 *v *= alpha;
134 }
135}
136
137#[cfg(not(target_arch = "aarch64"))]
138pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
141 let len = a.len().min(b.len());
142 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
143}
144
145#[cfg(not(target_arch = "aarch64"))]
146pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
149 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
150}
151
152#[cfg(not(target_arch = "aarch64"))]
153pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
156 a.iter().sum()
157}
158
159#[cfg(not(target_arch = "aarch64"))]
160pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
163 for v in y.iter_mut() {
164 *v *= alpha;
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_popcount_empty() {
174 assert_eq!(unsafe { popcount_neon(&[]) }, 0);
175 }
176
177 #[test]
178 fn test_popcount_known_values() {
179 assert_eq!(unsafe { popcount_neon(&[u64::MAX]) }, 64);
181 assert_eq!(unsafe { popcount_neon(&[0]) }, 0);
182 assert_eq!(unsafe { popcount_neon(&[1]) }, 1);
183 assert_eq!(unsafe { popcount_neon(&[0b1010_1010]) }, 4);
184 }
185
186 #[test]
187 fn test_popcount_multiple_words() {
188 let data = [u64::MAX, u64::MAX, 1];
189 assert_eq!(unsafe { popcount_neon(&data) }, 129); }
191
192 #[test]
193 fn test_dot_f64_simple() {
194 let a = [1.0, 2.0, 3.0];
195 let b = [4.0, 5.0, 6.0];
196 let result = unsafe { dot_f64_neon(&a, &b) };
197 assert!((result - 32.0).abs() < 1e-10); }
199
200 #[test]
201 fn test_dot_f64_empty() {
202 let result = unsafe { dot_f64_neon(&[], &[]) };
203 assert!((result - 0.0).abs() < 1e-10);
204 }
205
206 #[test]
207 fn test_dot_f64_mismatched_length() {
208 let a = [1.0, 2.0, 3.0, 4.0];
209 let b = [1.0, 1.0];
210 let result = unsafe { dot_f64_neon(&a, &b) };
211 assert!((result - 3.0).abs() < 1e-10); }
213
214 #[test]
215 fn test_max_f64() {
216 let a = [1.0, 5.0, 3.0, 2.0, 4.0];
217 assert!((unsafe { max_f64_neon(&a) } - 5.0).abs() < 1e-10);
218 }
219
220 #[test]
221 fn test_max_f64_empty() {
222 assert!(unsafe { max_f64_neon(&[]) } == f64::NEG_INFINITY);
223 }
224
225 #[test]
226 fn test_max_f64_negative() {
227 let a = [-5.0, -1.0, -3.0];
228 assert!((unsafe { max_f64_neon(&a) } - (-1.0)).abs() < 1e-10);
229 }
230
231 #[test]
232 fn test_sum_f64() {
233 let a = [1.0, 2.0, 3.0, 4.0, 5.0];
234 assert!((unsafe { sum_f64_neon(&a) } - 15.0).abs() < 1e-10);
235 }
236
237 #[test]
238 fn test_sum_f64_empty() {
239 assert!((unsafe { sum_f64_neon(&[]) } - 0.0).abs() < 1e-10);
240 }
241
242 #[test]
243 fn test_scale_f64() {
244 let mut y = [1.0, 2.0, 3.0, 4.0, 5.0];
245 unsafe { scale_f64_neon(2.0, &mut y) };
246 assert!((y[0] - 2.0).abs() < 1e-10);
247 assert!((y[4] - 10.0).abs() < 1e-10);
248 }
249
250 #[test]
251 fn test_scale_f64_zero() {
252 let mut y = [1.0, 2.0, 3.0];
253 unsafe { scale_f64_neon(0.0, &mut y) };
254 assert!(y.iter().all(|&v| v == 0.0));
255 }
256}