sc_neurocore_engine/simd/
avx2.rs1#[cfg(target_arch = "x86_64")]
9use core::arch::x86_64::*;
10
11#[cfg(target_arch = "x86_64")]
12#[target_feature(enable = "avx2")]
13pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
18 let mut total = 0_u64;
19 let mut chunks = data.chunks_exact(4);
20
21 let m1 = _mm256_set1_epi64x(0x5555_5555_5555_5555_u64 as i64);
22 let m2 = _mm256_set1_epi64x(0x3333_3333_3333_3333_u64 as i64);
23 let m4 = _mm256_set1_epi64x(0x0f0f_0f0f_0f0f_0f0f_u64 as i64);
24
25 for chunk in &mut chunks {
26 let mut x = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
27 x = _mm256_sub_epi64(x, _mm256_and_si256(_mm256_srli_epi64::<1>(x), m1));
28 x = _mm256_add_epi64(
29 _mm256_and_si256(x, m2),
30 _mm256_and_si256(_mm256_srli_epi64::<2>(x), m2),
31 );
32 x = _mm256_and_si256(_mm256_add_epi64(x, _mm256_srli_epi64::<4>(x)), m4);
33
34 let mut lanes = [0_u64; 4];
35 _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, x);
36 total += lanes
37 .iter()
38 .copied()
39 .map(|lane| lane.wrapping_mul(0x0101_0101_0101_0101) >> 56)
40 .sum::<u64>();
41 }
42
43 total + crate::bitstream::popcount_words_portable(chunks.remainder())
44}
45
46#[cfg(target_arch = "x86_64")]
47#[target_feature(enable = "avx2")]
48pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
55 let length = bits.len();
56 let words = length.div_ceil(64);
57 let mut data = vec![0_u64; words];
58 let full_words = length / 64;
59 let zero = _mm256_setzero_si256();
60
61 for (word_idx, word) in data.iter_mut().take(full_words).enumerate() {
62 let base = word_idx * 64;
63 let lo = _mm256_loadu_si256(bits.as_ptr().add(base) as *const __m256i);
64 let hi = _mm256_loadu_si256(bits.as_ptr().add(base + 32) as *const __m256i);
65
66 let lo_eq_zero = _mm256_cmpeq_epi8(lo, zero);
67 let hi_eq_zero = _mm256_cmpeq_epi8(hi, zero);
68 let lo_mask = !(_mm256_movemask_epi8(lo_eq_zero) as u32);
69 let hi_mask = !(_mm256_movemask_epi8(hi_eq_zero) as u32);
70
71 *word = ((hi_mask as u64) << 32) | (lo_mask as u64);
72 }
73
74 if full_words < words {
75 let tail_start = full_words * 64;
76 let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
77 data[full_words] = tail.data.first().copied().unwrap_or(0);
78 }
79
80 data
81}
82
83#[cfg(target_arch = "x86_64")]
84#[target_feature(enable = "avx2")]
85pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
90 let len = a.len().min(b.len());
91 let mut total = 0_u64;
92 let mut chunks_a = a[..len].chunks_exact(4);
93 let mut chunks_b = b[..len].chunks_exact(4);
94
95 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
96 let va = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
97 let vb = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
98 let anded = _mm256_and_si256(va, vb);
99
100 let mut lanes = [0_u64; 4];
101 _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, anded);
102 total += lanes.iter().map(|w| w.count_ones() as u64).sum::<u64>();
103 }
104
105 total
106 + chunks_a
107 .remainder()
108 .iter()
109 .zip(chunks_b.remainder().iter())
110 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
111 .sum::<u64>()
112}
113
114#[cfg(target_arch = "x86_64")]
115#[target_feature(enable = "avx2")]
116pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
121 let len = a.len().min(b.len());
122 let mut total = 0_u64;
123 let mut chunks_a = a[..len].chunks_exact(4);
124 let mut chunks_b = b[..len].chunks_exact(4);
125
126 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
127 let va = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
128 let vb = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
129 let xored = _mm256_xor_si256(va, vb);
130
131 let mut lanes = [0_u64; 4];
132 _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, xored);
133 total += lanes.iter().map(|w| w.count_ones() as u64).sum::<u64>();
134 }
135
136 total
137 + chunks_a
138 .remainder()
139 .iter()
140 .zip(chunks_b.remainder().iter())
141 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
142 .sum::<u64>()
143}
144
145#[cfg(not(target_arch = "x86_64"))]
146pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
151 a.iter()
152 .zip(b.iter())
153 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
154 .sum()
155}
156
157#[cfg(target_arch = "x86_64")]
158#[target_feature(enable = "avx2")]
159pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
167 assert!(buf.len() >= 32, "buffer must contain at least 32 bytes");
168
169 let data = _mm256_loadu_si256(buf.as_ptr() as *const __m256i);
170 let bias = _mm256_set1_epi8(i8::MIN);
171 let data_biased = _mm256_xor_si256(data, bias);
172 let thresh_biased = _mm256_set1_epi8((threshold ^ 0x80) as i8);
173 let lt = _mm256_cmpgt_epi8(thresh_biased, data_biased);
174 _mm256_movemask_epi8(lt) as u32
175}
176
177#[cfg(not(target_arch = "x86_64"))]
178pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
183 crate::bitstream::popcount_words_portable(data)
184}
185
186#[cfg(not(target_arch = "x86_64"))]
187pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
192 crate::bitstream::pack_fast(bits).data
193}
194
195#[cfg(not(target_arch = "x86_64"))]
196pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
201 a.iter()
202 .zip(b.iter())
203 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
204 .sum()
205}
206
207#[cfg(not(target_arch = "x86_64"))]
208pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
213 let mut mask = 0_u32;
214 for (bit, &rb) in buf.iter().take(32).enumerate() {
215 if rb < threshold {
216 mask |= 1_u32 << bit;
217 }
218 }
219 mask
220}
221
222#[cfg(target_arch = "x86_64")]
225#[target_feature(enable = "avx2,fma")]
226pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
231 let len = a.len().min(b.len());
232 let mut acc = _mm256_setzero_pd();
233 let mut chunks_a = a[..len].chunks_exact(4);
234 let mut chunks_b = b[..len].chunks_exact(4);
235
236 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
237 let va = _mm256_loadu_pd(ca.as_ptr());
238 let vb = _mm256_loadu_pd(cb.as_ptr());
239 acc = _mm256_fmadd_pd(va, vb, acc);
240 }
241
242 let mut lanes = [0.0_f64; 4];
243 _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
244 let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
245
246 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
247 sum += ra * rb;
248 }
249 sum
250}
251
252#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx2")]
254pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
259 if a.is_empty() {
260 return f64::NEG_INFINITY;
261 }
262 let mut vmax = _mm256_set1_pd(f64::NEG_INFINITY);
263 let mut chunks = a.chunks_exact(4);
264
265 for chunk in chunks.by_ref() {
266 let va = _mm256_loadu_pd(chunk.as_ptr());
267 vmax = _mm256_max_pd(vmax, va);
268 }
269
270 let mut lanes = [0.0_f64; 4];
271 _mm256_storeu_pd(lanes.as_mut_ptr(), vmax);
272 let mut m = lanes[0].max(lanes[1]).max(lanes[2].max(lanes[3]));
273 for &v in chunks.remainder() {
274 m = m.max(v);
275 }
276 m
277}
278
279#[cfg(target_arch = "x86_64")]
280#[target_feature(enable = "avx2")]
281pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
286 let mut acc = _mm256_setzero_pd();
287 let mut chunks = a.chunks_exact(4);
288
289 for chunk in chunks.by_ref() {
290 let va = _mm256_loadu_pd(chunk.as_ptr());
291 acc = _mm256_add_pd(acc, va);
292 }
293
294 let mut lanes = [0.0_f64; 4];
295 _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
296 let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
297 for &v in chunks.remainder() {
298 sum += v;
299 }
300 sum
301}
302
303#[cfg(target_arch = "x86_64")]
304#[target_feature(enable = "avx2")]
305pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
310 let valpha = _mm256_set1_pd(alpha);
311 let mut chunks = y.chunks_exact_mut(4);
312
313 for chunk in chunks.by_ref() {
314 let vy = _mm256_loadu_pd(chunk.as_ptr());
315 let scaled = _mm256_mul_pd(vy, valpha);
316 _mm256_storeu_pd(chunk.as_mut_ptr(), scaled);
317 }
318
319 for v in chunks.into_remainder() {
320 *v *= alpha;
321 }
322}
323
324#[cfg(not(target_arch = "x86_64"))]
325pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
326 let len = a.len().min(b.len());
327 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
328}
329
330#[cfg(not(target_arch = "x86_64"))]
331pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
332 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
333}
334
335#[cfg(not(target_arch = "x86_64"))]
336pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
337 a.iter().sum()
338}
339
340#[cfg(not(target_arch = "x86_64"))]
341pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
342 for v in y.iter_mut() {
343 *v *= alpha;
344 }
345}
346
347pub unsafe fn hamming_distance_avx2(a: &[u64], b: &[u64]) -> u64 {
352 fused_xor_popcount_avx2(a, b)
353}
354
355#[cfg(target_arch = "x86_64")]
356#[target_feature(enable = "avx2")]
357pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
362 if scores.is_empty() {
363 return;
364 }
365 let max_val = max_f64_avx2(scores);
366 for s in scores.iter_mut() {
367 *s = (*s - max_val).exp();
368 }
369 let exp_sum = sum_f64_avx2(scores);
370 if exp_sum > 0.0 {
371 scale_f64_avx2(1.0 / exp_sum, scores);
372 }
373}
374
375#[cfg(not(target_arch = "x86_64"))]
376pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
377 if scores.is_empty() {
378 return;
379 }
380 let max_val = max_f64_avx2(scores);
381 for s in scores.iter_mut() {
382 *s = (*s - max_val).exp();
383 }
384 let exp_sum = sum_f64_avx2(scores);
385 if exp_sum > 0.0 {
386 scale_f64_avx2(1.0 / exp_sum, scores);
387 }
388}
389
390#[cfg(all(test, target_arch = "x86_64"))]
391mod tests {
392 use crate::bitstream::pack;
393
394 #[test]
395 fn pack_avx2_matches_pack() {
396 if !is_x86_feature_detected!("avx2") {
397 return;
398 }
399
400 let lengths = [
401 1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
402 ];
403 for length in lengths {
404 let bits: Vec<u8> = (0..length)
405 .map(|i| if (i * 17 + 5) % 3 == 0 { 1 } else { 0 })
406 .collect();
407 let got = unsafe { super::pack_avx2(&bits) };
409 let expected = pack(&bits).data;
410 assert_eq!(got, expected, "Mismatch at length={length}");
411 }
412 }
413
414 #[test]
415 fn fused_and_popcount_avx2_matches_scalar() {
416 if !is_x86_feature_detected!("avx2") {
417 return;
418 }
419
420 let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
421 for len in lengths {
422 let a: Vec<u64> = (0..len)
423 .map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
424 .collect();
425 let b: Vec<u64> = (0..len)
426 .map(|i| (i as u64).wrapping_mul(0xC2B2_AE3D_27D4_EB4F) ^ 0x0F0F_F0F0_33CC_CC33)
427 .collect();
428
429 let expected: u64 = a
430 .iter()
431 .zip(b.iter())
432 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
433 .sum();
434
435 let got = unsafe { super::fused_and_popcount_avx2(&a, &b) };
437 assert_eq!(got, expected, "Mismatch at len={len}");
438 }
439 }
440
441 #[test]
442 fn dot_f64_avx2_matches_scalar() {
443 if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
444 return;
445 }
446 let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.1).collect();
447 let b: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 5.0).collect();
448 let expected: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
449 let got = unsafe { super::dot_f64_avx2(&a, &b) };
450 assert!(
451 (got - expected).abs() < 1e-9,
452 "dot: got {got}, expected {expected}"
453 );
454 }
455
456 #[test]
457 fn max_f64_avx2_matches_scalar() {
458 if !is_x86_feature_detected!("avx2") {
459 return;
460 }
461 let a: Vec<f64> = (0..67).map(|i| (i as f64 * 7.3).sin()).collect();
462 let expected = a.iter().copied().fold(f64::NEG_INFINITY, f64::max);
463 let got = unsafe { super::max_f64_avx2(&a) };
464 assert!(
465 (got - expected).abs() < 1e-12,
466 "max: got {got}, expected {expected}"
467 );
468 }
469
470 #[test]
471 fn sum_f64_avx2_matches_scalar() {
472 if !is_x86_feature_detected!("avx2") {
473 return;
474 }
475 let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.01).collect();
476 let expected: f64 = a.iter().sum();
477 let got = unsafe { super::sum_f64_avx2(&a) };
478 assert!(
479 (got - expected).abs() < 1e-9,
480 "sum: got {got}, expected {expected}"
481 );
482 }
483
484 #[test]
485 fn softmax_avx2_sums_to_one() {
486 if !is_x86_feature_detected!("avx2") {
487 return;
488 }
489 let mut scores: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 10.0).collect();
490 unsafe { super::softmax_inplace_f64_avx2(&mut scores) };
491 let sum: f64 = scores.iter().sum();
492 assert!(
493 (sum - 1.0).abs() < 1e-10,
494 "softmax must sum to 1.0, got {sum}"
495 );
496 assert!(scores.iter().all(|&s| s >= 0.0), "all values must be >= 0");
497 }
498
499 #[test]
500 fn bernoulli_compare_avx2_matches_scalar() {
501 if !is_x86_feature_detected!("avx2") {
502 return;
503 }
504
505 let buf: Vec<u8> = (0..32).map(|i| (i * 73 + 17) as u8).collect();
506 let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
507
508 for threshold in thresholds {
509 let expected = buf.iter().enumerate().fold(0_u32, |acc, (bit, &rb)| {
510 acc | (u32::from(rb < threshold) << bit)
511 });
512
513 let got = unsafe { super::bernoulli_compare_avx2(&buf, threshold) };
515 assert_eq!(
516 got, expected,
517 "Mismatch for threshold={threshold} buf={buf:?}"
518 );
519 }
520 }
521}