sc_neurocore_engine/simd/
mod.rs1use rand::Rng;
15
16pub mod avx2;
17pub mod avx512;
18pub mod neon;
19pub mod rvv;
20pub mod sve;
21
22pub fn pack_dispatch(bits: &[u8]) -> crate::bitstream::BitStreamTensor {
24 let length = bits.len();
25
26 #[cfg(target_arch = "x86_64")]
27 {
28 if is_x86_feature_detected!("avx512bw") {
29 let data = unsafe { avx512::pack_avx512(bits) };
31 return crate::bitstream::BitStreamTensor { data, length };
32 }
33 if is_x86_feature_detected!("avx2") {
34 let data = unsafe { avx2::pack_avx2(bits) };
36 return crate::bitstream::BitStreamTensor { data, length };
37 }
38 }
39
40 #[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
41 {
42 let data = unsafe { sve::pack_sve(bits) };
44 return crate::bitstream::BitStreamTensor { data, length };
45 }
46
47 crate::bitstream::pack_fast(bits)
48}
49
50pub fn popcount_dispatch(data: &[u64]) -> u64 {
52 #[cfg(target_arch = "x86_64")]
53 {
54 if is_x86_feature_detected!("avx512vpopcntdq") {
55 return unsafe { avx512::popcount_avx512(data) };
57 }
58 if is_x86_feature_detected!("avx2") {
59 return unsafe { avx2::popcount_avx2(data) };
61 }
62 }
63
64 #[cfg(target_arch = "aarch64")]
65 {
66 #[cfg(target_feature = "sve")]
67 {
68 return unsafe { sve::popcount_sve(data) };
70 }
71 #[cfg(not(target_feature = "sve"))]
72 {
73 return unsafe { neon::popcount_neon(data) };
75 }
76 }
77
78 #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
79 {
80 return unsafe { rvv::popcount_rvv(data) };
82 }
83
84 crate::bitstream::popcount_words_portable(data)
85}
86
87pub fn fused_and_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
89 let len = a.len().min(b.len());
90 let a = &a[..len];
91 let b = &b[..len];
92
93 #[cfg(target_arch = "x86_64")]
94 {
95 if is_x86_feature_detected!("avx512vpopcntdq") {
96 return unsafe { avx512::fused_and_popcount_avx512(a, b) };
98 }
99 if is_x86_feature_detected!("avx2") {
100 return unsafe { avx2::fused_and_popcount_avx2(a, b) };
102 }
103 }
104
105 #[cfg(target_arch = "aarch64")]
106 {
107 #[cfg(target_feature = "sve")]
108 {
109 return unsafe { sve::fused_and_popcount_sve(a, b) };
110 }
111 }
112
113 #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
114 {
115 return unsafe { rvv::fused_and_popcount_rvv(a, b) };
116 }
117
118 let mut total = 0_u64;
119 let mut chunks_a = a.chunks_exact(4);
120 let mut chunks_b = b.chunks_exact(4);
121 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
122 total += (ca[0] & cb[0]).count_ones() as u64;
123 total += (ca[1] & cb[1]).count_ones() as u64;
124 total += (ca[2] & cb[2]).count_ones() as u64;
125 total += (ca[3] & cb[3]).count_ones() as u64;
126 }
127 total += chunks_a
128 .remainder()
129 .iter()
130 .zip(chunks_b.remainder().iter())
131 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
132 .sum::<u64>();
133 total
134}
135
136pub fn fused_xor_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
138 let len = a.len().min(b.len());
139 let a = &a[..len];
140 let b = &b[..len];
141
142 #[cfg(target_arch = "x86_64")]
143 {
144 if is_x86_feature_detected!("avx512vpopcntdq") {
145 return unsafe { avx512::fused_xor_popcount_avx512(a, b) };
147 }
148 if is_x86_feature_detected!("avx2") {
149 return unsafe { avx2::fused_xor_popcount_avx2(a, b) };
151 }
152 if is_x86_feature_detected!("avx") {
153 let mut total = 0_u64;
154 let mut chunks_a = a.chunks_exact(16);
155 let mut chunks_b = b.chunks_exact(16);
156 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
157 for i in 0..16 {
158 total += (ca[i] ^ cb[i]).count_ones() as u64;
159 }
160 }
161 total += chunks_a
162 .remainder()
163 .iter()
164 .zip(chunks_b.remainder().iter())
165 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
166 .sum::<u64>();
167 return total;
168 }
169 }
170
171 #[cfg(target_arch = "aarch64")]
172 {
173 #[cfg(target_feature = "sve")]
174 {
175 return unsafe { sve::fused_xor_popcount_sve(a, b) };
176 }
177 }
178
179 #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
180 {
181 return unsafe { rvv::fused_xor_popcount_rvv(a, b) };
182 }
183
184 let mut total = 0_u64;
185 let mut chunks_a = a.chunks_exact(4);
186 let mut chunks_b = b.chunks_exact(4);
187 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
188 total += (ca[0] ^ cb[0]).count_ones() as u64;
189 total += (ca[1] ^ cb[1]).count_ones() as u64;
190 total += (ca[2] ^ cb[2]).count_ones() as u64;
191 total += (ca[3] ^ cb[3]).count_ones() as u64;
192 }
193 total += chunks_a
194 .remainder()
195 .iter()
196 .zip(chunks_b.remainder().iter())
197 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
198 .sum::<u64>();
199 total
200}
201
202pub fn dot_f64_dispatch(a: &[f64], b: &[f64]) -> f64 {
206 #[cfg(target_arch = "x86_64")]
207 {
208 if is_x86_feature_detected!("avx512f") {
209 return unsafe { avx512::dot_f64_avx512(a, b) };
210 }
211 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
212 return unsafe { avx2::dot_f64_avx2(a, b) };
213 }
214 if is_x86_feature_detected!("avx") {
215 return unsafe { avx2::dot_f64_avx(a, b) };
216 }
217 if is_x86_feature_detected!("sse2") {
218 let len = a.len().min(b.len());
219 let mut sum = 0.0_f64;
220 let mut chunks_a = a[..len].chunks_exact(4);
221 let mut chunks_b = b[..len].chunks_exact(4);
222 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
223 sum += ca[0] * cb[0] + ca[1] * cb[1] + ca[2] * cb[2] + ca[3] * cb[3];
224 }
225 sum += chunks_a
226 .remainder()
227 .iter()
228 .zip(chunks_b.remainder())
229 .map(|(x, y)| x * y)
230 .sum::<f64>();
231 return sum;
232 }
233 }
234
235 #[cfg(target_arch = "aarch64")]
236 {
237 return unsafe { neon::dot_f64_neon(a, b) };
238 }
239
240 let len = a.len().min(b.len());
241 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
242}
243
244pub fn max_f64_dispatch(a: &[f64]) -> f64 {
246 #[cfg(target_arch = "x86_64")]
247 {
248 if is_x86_feature_detected!("avx512f") {
249 return unsafe { avx512::max_f64_avx512(a) };
250 }
251 if is_x86_feature_detected!("avx2") {
252 return unsafe { avx2::max_f64_avx2(a) };
253 }
254 if is_x86_feature_detected!("avx") {
255 return unsafe { avx2::max_f64_avx(a) };
256 }
257 if is_x86_feature_detected!("sse2") {
258 let mut m = f64::NEG_INFINITY;
259 let mut chunks = a.chunks_exact(4);
260 for c in chunks.by_ref() {
261 m = m.max(c[0].max(c[1]).max(c[2].max(c[3])));
262 }
263 for &v in chunks.remainder() {
264 m = m.max(v);
265 }
266 return m;
267 }
268 }
269
270 #[cfg(target_arch = "aarch64")]
271 {
272 return unsafe { neon::max_f64_neon(a) };
273 }
274
275 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
276}
277
278pub fn sum_f64_dispatch(a: &[f64]) -> f64 {
280 #[cfg(target_arch = "x86_64")]
281 {
282 if is_x86_feature_detected!("avx512f") {
283 return unsafe { avx512::sum_f64_avx512(a) };
284 }
285 if is_x86_feature_detected!("avx2") {
286 return unsafe { avx2::sum_f64_avx2(a) };
287 }
288 if is_x86_feature_detected!("avx") {
289 return unsafe { avx2::sum_f64_avx(a) };
290 }
291 if is_x86_feature_detected!("sse2") {
292 let mut s = 0.0_f64;
293 let mut chunks = a.chunks_exact(4);
294 for c in chunks.by_ref() {
295 s += c[0] + c[1] + c[2] + c[3];
296 }
297 s += chunks.remainder().iter().sum::<f64>();
298 return s;
299 }
300 }
301
302 #[cfg(target_arch = "aarch64")]
303 {
304 return unsafe { neon::sum_f64_neon(a) };
305 }
306
307 a.iter().sum()
308}
309
310pub fn scale_f64_dispatch(alpha: f64, y: &mut [f64]) {
312 #[cfg(target_arch = "x86_64")]
313 {
314 if is_x86_feature_detected!("avx512f") {
315 unsafe { avx512::scale_f64_avx512(alpha, y) };
316 return;
317 }
318 if is_x86_feature_detected!("avx2") {
319 unsafe { avx2::scale_f64_avx2(alpha, y) };
320 return;
321 }
322 if is_x86_feature_detected!("avx") {
323 unsafe { avx2::scale_f64_avx(alpha, y) };
324 return;
325 }
326 if is_x86_feature_detected!("sse2") {
327 let mut chunks = y.chunks_exact_mut(4);
328 for c in chunks.by_ref() {
329 c[0] *= alpha;
330 c[1] *= alpha;
331 c[2] *= alpha;
332 c[3] *= alpha;
333 }
334 for v in chunks.into_remainder() {
335 *v *= alpha;
336 }
337 return;
338 }
339 }
340
341 #[cfg(target_arch = "aarch64")]
342 {
343 unsafe { neon::scale_f64_neon(alpha, y) };
344 return;
345 }
346
347 for x in y.iter_mut() {
348 *x *= alpha;
349 }
350}
351
352pub fn hamming_distance_dispatch(a: &[u64], b: &[u64]) -> u64 {
354 fused_xor_popcount_dispatch(a, b)
355}
356
357pub fn softmax_inplace_f64_dispatch(scores: &mut [f64]) {
362 if scores.is_empty() {
363 return;
364 }
365
366 #[cfg(target_arch = "x86_64")]
367 {
368 if is_x86_feature_detected!("avx2") {
369 unsafe { avx2::softmax_inplace_f64_avx2(scores) };
370 return;
371 }
372 }
373
374 let max_val = max_f64_dispatch(scores);
375 let mut chunks = scores.chunks_exact_mut(4);
376 for c in chunks.by_ref() {
377 c[0] = (c[0] - max_val).exp();
378 c[1] = (c[1] - max_val).exp();
379 c[2] = (c[2] - max_val).exp();
380 c[3] = (c[3] - max_val).exp();
381 }
382 for s in chunks.into_remainder() {
383 *s = (*s - max_val).exp();
384 }
385
386 let exp_sum = sum_f64_dispatch(scores);
387 if exp_sum > 0.0 {
388 scale_f64_dispatch(1.0 / exp_sum, scores);
389 }
390}
391
392pub fn encode_and_popcount_dispatch<R: Rng + ?Sized>(
397 weight_words: &[u64],
398 prob: f64,
399 length: usize,
400 rng: &mut R,
401) -> u64 {
402 crate::bitstream::encode_and_popcount(weight_words, prob, length, rng)
403}
404
405pub fn bernoulli_compare_batch_1024(buf: &[u8], threshold: u8, out: &mut [u64]) {
407 #[cfg(target_arch = "x86_64")]
408 {
409 if is_x86_feature_detected!("avx512bw") {
410 return unsafe { avx512::bernoulli_compare_batch_avx512(buf, threshold, out) };
411 }
412 if is_x86_feature_detected!("avx2") {
413 return unsafe { avx2::bernoulli_compare_batch_avx2(buf, threshold, out) };
414 }
415 }
416
417 #[cfg(target_arch = "x86_64")]
418 {
419 use core::arch::x86_64::*;
421 unsafe {
422 let v_thresh = _mm_set1_epi8(threshold as i8);
423 let bias = _mm_set1_epi8(i8::MIN);
424 let v_thresh_biased = _mm_xor_si128(v_thresh, bias);
425
426 for i in 0..16 {
427 let chunk = &buf[i * 64..(i + 1) * 64];
428 let mut word = 0_u64;
429 for j in 0..4 {
430 let v = _mm_loadu_si128(chunk.as_ptr().add(j * 16) as *const __m128i);
431 let v_biased = _mm_xor_si128(v, bias);
432 let m = _mm_cmpgt_epi8(v_thresh_biased, v_biased);
433 let mask = _mm_movemask_epi8(m) as u32;
434 word |= (mask as u64) << (j * 16);
435 }
436 out[i] = word;
437 }
438 }
439 }
440
441 #[cfg(not(target_arch = "x86_64"))]
443 for i in 0..16 {
444 out[i] =
445 crate::bitstream::simd_bernoulli_compare_exposed(&buf[i * 64..(i + 1) * 64], threshold);
446 }
447}