sc_neurocore_engine/simd/
neon.rs1#[cfg(target_arch = "aarch64")]
10use core::arch::aarch64::*;
11
12#[cfg(target_arch = "aarch64")]
13#[target_feature(enable = "neon")]
14pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
19 let mut total = 0_u64;
20 let mut chunks = data.chunks_exact(8);
21
22 for chunk in chunks.by_ref() {
23 let v0 = vld1q_u8(chunk.as_ptr() as *const u8);
24 let v1 = vld1q_u8(chunk.as_ptr().add(2) as *const u8);
25 let v2 = vld1q_u8(chunk.as_ptr().add(4) as *const u8);
26 let v3 = vld1q_u8(chunk.as_ptr().add(6) as *const u8);
27
28 let c0 = vcntq_u8(v0);
29 let c1 = vcntq_u8(v1);
30 let c2 = vcntq_u8(v2);
31 let c3 = vcntq_u8(v3);
32
33 let s0 = vpaddlq_u8(c0);
34 let s1 = vpaddlq_u8(c1);
35 let s2 = vpaddlq_u8(c2);
36 let s3 = vpaddlq_u8(c3);
37
38 let s32_0 = vpaddlq_u16(s0);
39 let s32_1 = vpaddlq_u16(s1);
40 let s32_2 = vpaddlq_u16(s2);
41 let s32_3 = vpaddlq_u16(s3);
42
43 let s64_0 = vpaddlq_u32(s32_0);
44 let s64_1 = vpaddlq_u32(s32_1);
45 let s64_2 = vpaddlq_u32(s32_2);
46 let s64_3 = vpaddlq_u32(s32_3);
47
48 total += vgetq_lane_u64(s64_0, 0) + vgetq_lane_u64(s64_0, 1);
49 total += vgetq_lane_u64(s64_1, 0) + vgetq_lane_u64(s64_1, 1);
50 total += vgetq_lane_u64(s64_2, 0) + vgetq_lane_u64(s64_2, 1);
51 total += vgetq_lane_u64(s64_3, 0) + vgetq_lane_u64(s64_3, 1);
52 }
53
54 total + crate::bitstream::popcount_words_portable(chunks.remainder())
55}
56
57#[cfg(not(target_arch = "aarch64"))]
58pub unsafe fn popcount_neon(data: &[u64]) -> u64 {
63 crate::bitstream::popcount_words_portable(data)
64}
65
66#[cfg(target_arch = "aarch64")]
69#[target_feature(enable = "neon")]
70pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
75 let len = a.len().min(b.len());
76 let mut acc0 = vdupq_n_f64(0.0);
77 let mut acc1 = vdupq_n_f64(0.0);
78 let mut acc2 = vdupq_n_f64(0.0);
79 let mut acc3 = vdupq_n_f64(0.0);
80
81 let mut chunks_a = a[..len].chunks_exact(8);
82 let mut chunks_b = b[..len].chunks_exact(8);
83
84 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
85 acc0 = vfmaq_f64(acc0, vld1q_f64(ca.as_ptr()), vld1q_f64(cb.as_ptr()));
86 acc1 = vfmaq_f64(
87 acc1,
88 vld1q_f64(ca.as_ptr().add(2)),
89 vld1q_f64(cb.as_ptr().add(2)),
90 );
91 acc2 = vfmaq_f64(
92 acc2,
93 vld1q_f64(ca.as_ptr().add(4)),
94 vld1q_f64(cb.as_ptr().add(4)),
95 );
96 acc3 = vfmaq_f64(
97 acc3,
98 vld1q_f64(ca.as_ptr().add(6)),
99 vld1q_f64(cb.as_ptr().add(6)),
100 );
101 }
102
103 acc0 = vaddq_f64(acc0, acc1);
104 acc2 = vaddq_f64(acc2, acc3);
105 acc0 = vaddq_f64(acc0, acc2);
106
107 let mut sum = vgetq_lane_f64(acc0, 0) + vgetq_lane_f64(acc0, 1);
108 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
109 sum += ra * rb;
110 }
111 sum
112}
113
114#[cfg(target_arch = "aarch64")]
115#[target_feature(enable = "neon")]
116pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
121 if a.is_empty() {
122 return f64::NEG_INFINITY;
123 }
124 let mut vmax0 = vdupq_n_f64(f64::NEG_INFINITY);
125 let mut vmax1 = vdupq_n_f64(f64::NEG_INFINITY);
126 let mut vmax2 = vdupq_n_f64(f64::NEG_INFINITY);
127 let mut vmax3 = vdupq_n_f64(f64::NEG_INFINITY);
128
129 let mut chunks = a.chunks_exact(8);
130 for chunk in chunks.by_ref() {
131 vmax0 = vmaxq_f64(vmax0, vld1q_f64(chunk.as_ptr()));
132 vmax1 = vmaxq_f64(vmax1, vld1q_f64(chunk.as_ptr().add(2)));
133 vmax2 = vmaxq_f64(vmax2, vld1q_f64(chunk.as_ptr().add(4)));
134 vmax3 = vmaxq_f64(vmax3, vld1q_f64(chunk.as_ptr().add(6)));
135 }
136
137 vmax0 = vmaxq_f64(vmax0, vmax1);
138 vmax2 = vmaxq_f64(vmax2, vmax3);
139 vmax0 = vmaxq_f64(vmax0, vmax2);
140
141 let mut m = f64::max(vgetq_lane_f64(vmax0, 0), vgetq_lane_f64(vmax0, 1));
142 for &v in chunks.remainder() {
143 m = m.max(v);
144 }
145 m
146}
147
148#[cfg(target_arch = "aarch64")]
149#[target_feature(enable = "neon")]
150pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
155 let mut acc0 = vdupq_n_f64(0.0);
156 let mut acc1 = vdupq_n_f64(0.0);
157 let mut acc2 = vdupq_n_f64(0.0);
158 let mut acc3 = vdupq_n_f64(0.0);
159
160 let mut chunks = a.chunks_exact(8);
161 for chunk in chunks.by_ref() {
162 acc0 = vaddq_f64(acc0, vld1q_f64(chunk.as_ptr()));
163 acc1 = vaddq_f64(acc1, vld1q_f64(chunk.as_ptr().add(2)));
164 acc2 = vaddq_f64(acc2, vld1q_f64(chunk.as_ptr().add(4)));
165 acc3 = vaddq_f64(acc3, vld1q_f64(chunk.as_ptr().add(6)));
166 }
167
168 acc0 = vaddq_f64(acc0, acc1);
169 acc2 = vaddq_f64(acc2, acc3);
170 acc0 = vaddq_f64(acc0, acc2);
171
172 let mut sum = vgetq_lane_f64(acc0, 0) + vgetq_lane_f64(acc0, 1);
173 for &v in chunks.remainder() {
174 sum += v;
175 }
176 sum
177}
178
179#[cfg(target_arch = "aarch64")]
180#[target_feature(enable = "neon")]
181pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
186 let valpha = vdupq_n_f64(alpha);
187 let mut chunks = y.chunks_exact_mut(8);
188
189 for chunk in chunks.by_ref() {
190 vst1q_f64(
191 chunk.as_mut_ptr(),
192 vmulq_f64(vld1q_f64(chunk.as_ptr()), valpha),
193 );
194 vst1q_f64(
195 chunk.as_mut_ptr().add(2),
196 vmulq_f64(vld1q_f64(chunk.as_ptr().add(2)), valpha),
197 );
198 vst1q_f64(
199 chunk.as_mut_ptr().add(4),
200 vmulq_f64(vld1q_f64(chunk.as_ptr().add(4)), valpha),
201 );
202 vst1q_f64(
203 chunk.as_mut_ptr().add(6),
204 vmulq_f64(vld1q_f64(chunk.as_ptr().add(6)), valpha),
205 );
206 }
207
208 for v in chunks.into_remainder() {
209 *v *= alpha;
210 }
211}
212
213#[cfg(not(target_arch = "aarch64"))]
214pub unsafe fn dot_f64_neon(a: &[f64], b: &[f64]) -> f64 {
217 let len = a.len().min(b.len());
218 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
219}
220
221#[cfg(not(target_arch = "aarch64"))]
222pub unsafe fn max_f64_neon(a: &[f64]) -> f64 {
225 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
226}
227
228#[cfg(not(target_arch = "aarch64"))]
229pub unsafe fn sum_f64_neon(a: &[f64]) -> f64 {
232 a.iter().sum()
233}
234
235#[cfg(not(target_arch = "aarch64"))]
236pub unsafe fn scale_f64_neon(alpha: f64, y: &mut [f64]) {
239 for v in y.iter_mut() {
240 *v *= alpha;
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_popcount_empty() {
250 assert_eq!(unsafe { popcount_neon(&[]) }, 0);
251 }
252
253 #[test]
254 fn test_popcount_known_values() {
255 assert_eq!(unsafe { popcount_neon(&[u64::MAX]) }, 64);
257 assert_eq!(unsafe { popcount_neon(&[0]) }, 0);
258 assert_eq!(unsafe { popcount_neon(&[1]) }, 1);
259 assert_eq!(unsafe { popcount_neon(&[0b1010_1010]) }, 4);
260 }
261
262 #[test]
263 fn test_popcount_multiple_words() {
264 let data = [u64::MAX, u64::MAX, 1];
265 assert_eq!(unsafe { popcount_neon(&data) }, 129); }
267
268 #[test]
269 fn test_dot_f64_simple() {
270 let a = [1.0, 2.0, 3.0];
271 let b = [4.0, 5.0, 6.0];
272 let result = unsafe { dot_f64_neon(&a, &b) };
273 assert!((result - 32.0).abs() < 1e-10); }
275
276 #[test]
277 fn test_dot_f64_empty() {
278 let result = unsafe { dot_f64_neon(&[], &[]) };
279 assert!((result - 0.0).abs() < 1e-10);
280 }
281
282 #[test]
283 fn test_dot_f64_mismatched_length() {
284 let a = [1.0, 2.0, 3.0, 4.0];
285 let b = [1.0, 1.0];
286 let result = unsafe { dot_f64_neon(&a, &b) };
287 assert!((result - 3.0).abs() < 1e-10); }
289
290 #[test]
291 fn test_max_f64() {
292 let a = [1.0, 5.0, 3.0, 2.0, 4.0];
293 assert!((unsafe { max_f64_neon(&a) } - 5.0).abs() < 1e-10);
294 }
295
296 #[test]
297 fn test_max_f64_empty() {
298 assert!(unsafe { max_f64_neon(&[]) } == f64::NEG_INFINITY);
299 }
300
301 #[test]
302 fn test_max_f64_negative() {
303 let a = [-5.0, -1.0, -3.0];
304 assert!((unsafe { max_f64_neon(&a) } - (-1.0)).abs() < 1e-10);
305 }
306
307 #[test]
308 fn test_sum_f64() {
309 let a = [1.0, 2.0, 3.0, 4.0, 5.0];
310 assert!((unsafe { sum_f64_neon(&a) } - 15.0).abs() < 1e-10);
311 }
312
313 #[test]
314 fn test_sum_f64_empty() {
315 assert!((unsafe { sum_f64_neon(&[]) } - 0.0).abs() < 1e-10);
316 }
317
318 #[test]
319 fn test_scale_f64() {
320 let mut y = [1.0, 2.0, 3.0, 4.0, 5.0];
321 unsafe { scale_f64_neon(2.0, &mut y) };
322 assert!((y[0] - 2.0).abs() < 1e-10);
323 assert!((y[4] - 10.0).abs() < 1e-10);
324 }
325
326 #[test]
327 fn test_scale_f64_zero() {
328 let mut y = [1.0, 2.0, 3.0];
329 unsafe { scale_f64_neon(0.0, &mut y) };
330 assert!(y.iter().all(|&v| v == 0.0));
331 }
332}