1#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64::*;
11
12#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx2")]
14pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
19 let mut total = 0_u64;
20 let mut chunks = data.chunks_exact(16);
21
22 for chunk in chunks.by_ref() {
23 let v0 = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
24 let v1 = _mm256_loadu_si256(chunk.as_ptr().add(4) as *const __m256i);
25 let v2 = _mm256_loadu_si256(chunk.as_ptr().add(8) as *const __m256i);
26 let v3 = _mm256_loadu_si256(chunk.as_ptr().add(12) as *const __m256i);
27
28 let mut lanes = [0_u64; 16];
29 _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, v0);
30 _mm256_storeu_si256(lanes.as_mut_ptr().add(4) as *mut __m256i, v1);
31 _mm256_storeu_si256(lanes.as_mut_ptr().add(8) as *mut __m256i, v2);
32 _mm256_storeu_si256(lanes.as_mut_ptr().add(12) as *mut __m256i, v3);
33
34 for &w in lanes.iter() {
35 total += w.count_ones() as u64;
36 }
37 }
38
39 total
40 + chunks
41 .remainder()
42 .iter()
43 .map(|&w| w.count_ones() as u64)
44 .sum::<u64>()
45}
46
47#[cfg(target_arch = "x86_64")]
48#[target_feature(enable = "avx2")]
49pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
56 let length = bits.len();
57 let words = length.div_ceil(64);
58 let mut data = vec![0_u64; words];
59 let full_words = length / 64;
60 let zero = _mm256_setzero_si256();
61
62 let mut chunks = data[..full_words].chunks_exact_mut(4);
63 let mut word_idx = 0;
64 for chunk in chunks.by_ref() {
65 let base = word_idx * 64;
66 for i in 0..4 {
67 let b = base + i * 64;
68 let lo = _mm256_loadu_si256(bits.as_ptr().add(b) as *const __m256i);
69 let hi = _mm256_loadu_si256(bits.as_ptr().add(b + 32) as *const __m256i);
70 let lo_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(lo, zero)) as u32);
71 let hi_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(hi, zero)) as u32);
72 chunk[i] = ((hi_mask as u64) << 32) | (lo_mask as u64);
73 }
74 word_idx += 4;
75 }
76
77 for i in word_idx..full_words {
78 let base = i * 64;
79 let lo = _mm256_loadu_si256(bits.as_ptr().add(base) as *const __m256i);
80 let hi = _mm256_loadu_si256(bits.as_ptr().add(base + 32) as *const __m256i);
81 let lo_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(lo, zero)) as u32);
82 let hi_mask = !(_mm256_movemask_epi8(_mm256_cmpeq_epi8(hi, zero)) as u32);
83 data[i] = ((hi_mask as u64) << 32) | (lo_mask as u64);
84 }
85
86 if full_words < words {
87 let tail_start = full_words * 64;
88 let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
89 data[full_words] = tail.data.first().copied().unwrap_or(0);
90 }
91
92 data
93}
94
95#[cfg(target_arch = "x86_64")]
96#[target_feature(enable = "avx2")]
97pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
102 let len = a.len().min(b.len());
103 let mut total = 0_u64;
104 let mut chunks_a = a[..len].chunks_exact(16);
105 let mut chunks_b = b[..len].chunks_exact(16);
106
107 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
108 let va0 = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
109 let vb0 = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
110 let va1 = _mm256_loadu_si256(ca.as_ptr().add(4) as *const __m256i);
111 let vb1 = _mm256_loadu_si256(cb.as_ptr().add(4) as *const __m256i);
112 let va2 = _mm256_loadu_si256(ca.as_ptr().add(8) as *const __m256i);
113 let vb2 = _mm256_loadu_si256(cb.as_ptr().add(8) as *const __m256i);
114 let va3 = _mm256_loadu_si256(ca.as_ptr().add(12) as *const __m256i);
115 let vb3 = _mm256_loadu_si256(cb.as_ptr().add(12) as *const __m256i);
116
117 let and0 = _mm256_and_si256(va0, vb0);
118 let and1 = _mm256_and_si256(va1, vb1);
119 let and2 = _mm256_and_si256(va2, vb2);
120 let and3 = _mm256_and_si256(va3, vb3);
121
122 let mut lanes = [0_u64; 16];
123 _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, and0);
124 _mm256_storeu_si256(lanes.as_mut_ptr().add(4) as *mut __m256i, and1);
125 _mm256_storeu_si256(lanes.as_mut_ptr().add(8) as *mut __m256i, and2);
126 _mm256_storeu_si256(lanes.as_mut_ptr().add(12) as *mut __m256i, and3);
127
128 for &w in lanes.iter() {
129 total += w.count_ones() as u64;
130 }
131 }
132
133 total
134 + chunks_a
135 .remainder()
136 .iter()
137 .zip(chunks_b.remainder().iter())
138 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
139 .sum::<u64>()
140}
141
142#[cfg(target_arch = "x86_64")]
143#[target_feature(enable = "avx2")]
144pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
149 let len = a.len().min(b.len());
150 let mut total = 0_u64;
151 let mut chunks_a = a[..len].chunks_exact(16);
152 let mut chunks_b = b[..len].chunks_exact(16);
153
154 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
155 let va0 = _mm256_loadu_si256(ca.as_ptr() as *const __m256i);
156 let vb0 = _mm256_loadu_si256(cb.as_ptr() as *const __m256i);
157 let va1 = _mm256_loadu_si256(ca.as_ptr().add(4) as *const __m256i);
158 let vb1 = _mm256_loadu_si256(cb.as_ptr().add(4) as *const __m256i);
159 let va2 = _mm256_loadu_si256(ca.as_ptr().add(8) as *const __m256i);
160 let vb2 = _mm256_loadu_si256(cb.as_ptr().add(8) as *const __m256i);
161 let va3 = _mm256_loadu_si256(ca.as_ptr().add(12) as *const __m256i);
162 let vb3 = _mm256_loadu_si256(cb.as_ptr().add(12) as *const __m256i);
163
164 let xor0 = _mm256_xor_si256(va0, vb0);
165 let xor1 = _mm256_xor_si256(va1, vb1);
166 let xor2 = _mm256_xor_si256(va2, vb2);
167 let xor3 = _mm256_xor_si256(va3, vb3);
168
169 let mut lanes = [0_u64; 16];
170 _mm256_storeu_si256(lanes.as_mut_ptr() as *mut __m256i, xor0);
171 _mm256_storeu_si256(lanes.as_mut_ptr().add(4) as *mut __m256i, xor1);
172 _mm256_storeu_si256(lanes.as_mut_ptr().add(8) as *mut __m256i, xor2);
173 _mm256_storeu_si256(lanes.as_mut_ptr().add(12) as *mut __m256i, xor3);
174
175 for &w in lanes.iter() {
176 total += w.count_ones() as u64;
177 }
178 }
179
180 total
181 + chunks_a
182 .remainder()
183 .iter()
184 .zip(chunks_b.remainder().iter())
185 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
186 .sum::<u64>()
187}
188
189#[cfg(not(target_arch = "x86_64"))]
190pub unsafe fn fused_xor_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
195 a.iter()
196 .zip(b.iter())
197 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
198 .sum()
199}
200
201#[cfg(target_arch = "x86_64")]
202#[target_feature(enable = "avx2")]
203pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
211 assert!(buf.len() >= 32, "buffer must contain at least 32 bytes");
212
213 let data = _mm256_loadu_si256(buf.as_ptr() as *const __m256i);
214 let bias = _mm256_set1_epi8(i8::MIN);
215 let data_biased = _mm256_xor_si256(data, bias);
216 let thresh_biased = _mm256_set1_epi8((threshold ^ 0x80) as i8);
217 let lt = _mm256_cmpgt_epi8(thresh_biased, data_biased);
218 _mm256_movemask_epi8(lt) as u32
219}
220
221#[cfg(not(target_arch = "x86_64"))]
222pub unsafe fn popcount_avx2(data: &[u64]) -> u64 {
227 crate::bitstream::popcount_words_portable(data)
228}
229
230#[cfg(not(target_arch = "x86_64"))]
231pub unsafe fn pack_avx2(bits: &[u8]) -> Vec<u64> {
236 crate::bitstream::pack_fast(bits).data
237}
238
239#[cfg(not(target_arch = "x86_64"))]
240pub unsafe fn fused_and_popcount_avx2(a: &[u64], b: &[u64]) -> u64 {
245 a.iter()
246 .zip(b.iter())
247 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
248 .sum()
249}
250
251#[cfg(not(target_arch = "x86_64"))]
252pub unsafe fn bernoulli_compare_avx2(buf: &[u8], threshold: u8) -> u32 {
257 let mut mask = 0_u32;
258 for (bit, &rb) in buf.iter().take(32).enumerate() {
259 if rb < threshold {
260 mask |= 1_u32 << bit;
261 }
262 }
263 mask
264}
265
266#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "avx2,fma")]
270pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
275 let len = a.len().min(b.len());
276 let mut acc = _mm256_setzero_pd();
277 let mut chunks_a = a[..len].chunks_exact(4);
278 let mut chunks_b = b[..len].chunks_exact(4);
279
280 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
281 let va = _mm256_loadu_pd(ca.as_ptr());
282 let vb = _mm256_loadu_pd(cb.as_ptr());
283 acc = _mm256_fmadd_pd(va, vb, acc);
284 }
285
286 let mut lanes = [0.0_f64; 4];
287 _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
288 let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
289
290 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
291 sum += ra * rb;
292 }
293 sum
294}
295
296#[cfg(target_arch = "x86_64")]
297#[target_feature(enable = "avx2")]
298pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
303 if a.is_empty() {
304 return f64::NEG_INFINITY;
305 }
306 let mut vmax = _mm256_set1_pd(f64::NEG_INFINITY);
307 let mut chunks = a.chunks_exact(4);
308
309 for chunk in chunks.by_ref() {
310 let va = _mm256_loadu_pd(chunk.as_ptr());
311 vmax = _mm256_max_pd(vmax, va);
312 }
313
314 let mut lanes = [0.0_f64; 4];
315 _mm256_storeu_pd(lanes.as_mut_ptr(), vmax);
316 let mut m = lanes[0].max(lanes[1]).max(lanes[2].max(lanes[3]));
317 for &v in chunks.remainder() {
318 m = m.max(v);
319 }
320 m
321}
322
323#[cfg(target_arch = "x86_64")]
324#[target_feature(enable = "avx2")]
325pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
330 let mut acc = _mm256_setzero_pd();
331 let mut chunks = a.chunks_exact(4);
332
333 for chunk in chunks.by_ref() {
334 let va = _mm256_loadu_pd(chunk.as_ptr());
335 acc = _mm256_add_pd(acc, va);
336 }
337
338 let mut lanes = [0.0_f64; 4];
339 _mm256_storeu_pd(lanes.as_mut_ptr(), acc);
340 let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
341 for &v in chunks.remainder() {
342 sum += v;
343 }
344 sum
345}
346
347#[cfg(target_arch = "x86_64")]
348#[target_feature(enable = "avx2")]
349pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
354 let valpha = _mm256_set1_pd(alpha);
355 let mut chunks = y.chunks_exact_mut(16);
356
357 for chunk in chunks.by_ref() {
358 let v0 = _mm256_loadu_pd(chunk.as_ptr());
359 let v1 = _mm256_loadu_pd(chunk.as_ptr().add(4));
360 let v2 = _mm256_loadu_pd(chunk.as_ptr().add(8));
361 let v3 = _mm256_loadu_pd(chunk.as_ptr().add(12));
362
363 _mm256_storeu_pd(chunk.as_mut_ptr(), _mm256_mul_pd(v0, valpha));
364 _mm256_storeu_pd(chunk.as_mut_ptr().add(4), _mm256_mul_pd(v1, valpha));
365 _mm256_storeu_pd(chunk.as_mut_ptr().add(8), _mm256_mul_pd(v2, valpha));
366 _mm256_storeu_pd(chunk.as_mut_ptr().add(12), _mm256_mul_pd(v3, valpha));
367 }
368
369 for v in chunks.into_remainder() {
370 *v *= alpha;
371 }
372}
373
374#[cfg(not(target_arch = "x86_64"))]
375pub unsafe fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
376 let len = a.len().min(b.len());
377 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
378}
379
380#[cfg(not(target_arch = "x86_64"))]
381pub unsafe fn max_f64_avx2(a: &[f64]) -> f64 {
382 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
383}
384
385#[cfg(not(target_arch = "x86_64"))]
386pub unsafe fn sum_f64_avx2(a: &[f64]) -> f64 {
387 a.iter().sum()
388}
389
390#[cfg(not(target_arch = "x86_64"))]
391pub unsafe fn scale_f64_avx2(alpha: f64, y: &mut [f64]) {
392 for v in y.iter_mut() {
393 *v *= alpha;
394 }
395}
396
397pub unsafe fn hamming_distance_avx2(a: &[u64], b: &[u64]) -> u64 {
402 fused_xor_popcount_avx2(a, b)
403}
404
405#[cfg(target_arch = "x86_64")]
406#[target_feature(enable = "avx2")]
407pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
412 if scores.is_empty() {
413 return;
414 }
415 let max_val = max_f64_avx2(scores);
416 let mut chunks = scores.chunks_exact_mut(16);
419 for chunk in chunks.by_ref() {
420 for i in 0..16 {
423 chunk[i] = (chunk[i] - max_val).exp();
424 }
425 }
426 for s in chunks.into_remainder() {
427 *s = (*s - max_val).exp();
428 }
429
430 let exp_sum = sum_f64_avx2(scores);
431 if exp_sum > 0.0 {
432 scale_f64_avx2(1.0 / exp_sum, scores);
433 }
434}
435
436#[cfg(not(target_arch = "x86_64"))]
437pub unsafe fn softmax_inplace_f64_avx2(scores: &mut [f64]) {
438 if scores.is_empty() {
439 return;
440 }
441 let max_val = max_f64_avx2(scores);
442 let mut chunks = scores.chunks_exact_mut(16);
445 for chunk in chunks.by_ref() {
446 for i in 0..16 {
449 chunk[i] = (chunk[i] - max_val).exp();
450 }
451 }
452 for s in chunks.into_remainder() {
453 *s = (*s - max_val).exp();
454 }
455
456 let exp_sum = sum_f64_avx2(scores);
457 if exp_sum > 0.0 {
458 scale_f64_avx2(1.0 / exp_sum, scores);
459 }
460}
461
462#[cfg(target_arch = "x86_64")]
463#[target_feature(enable = "avx")]
464pub unsafe fn dot_f64_avx(a: &[f64], b: &[f64]) -> f64 {
469 let len = a.len().min(b.len());
470 let mut acc0 = _mm256_setzero_pd();
471 let mut acc1 = _mm256_setzero_pd();
472 let mut acc2 = _mm256_setzero_pd();
473 let mut acc3 = _mm256_setzero_pd();
474
475 let mut chunks_a = a[..len].chunks_exact(16);
476 let mut chunks_b = b[..len].chunks_exact(16);
477
478 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
479 let va0 = _mm256_loadu_pd(ca.as_ptr());
480 let vb0 = _mm256_loadu_pd(cb.as_ptr());
481 acc0 = _mm256_add_pd(acc0, _mm256_mul_pd(va0, vb0));
482
483 let va1 = _mm256_loadu_pd(ca.as_ptr().add(4));
484 let vb1 = _mm256_loadu_pd(cb.as_ptr().add(4));
485 acc1 = _mm256_add_pd(acc1, _mm256_mul_pd(va1, vb1));
486
487 let va2 = _mm256_loadu_pd(ca.as_ptr().add(8));
488 let vb2 = _mm256_loadu_pd(cb.as_ptr().add(8));
489 acc2 = _mm256_add_pd(acc2, _mm256_mul_pd(va2, vb2));
490
491 let va3 = _mm256_loadu_pd(ca.as_ptr().add(12));
492 let vb3 = _mm256_loadu_pd(cb.as_ptr().add(12));
493 acc3 = _mm256_add_pd(acc3, _mm256_mul_pd(va3, vb3));
494 }
495
496 acc0 = _mm256_add_pd(acc0, acc1);
497 acc2 = _mm256_add_pd(acc2, acc3);
498 acc0 = _mm256_add_pd(acc0, acc2);
499
500 let mut lanes = [0.0_f64; 4];
501 _mm256_storeu_pd(lanes.as_mut_ptr(), acc0);
502 let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
503
504 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
505 sum += ra * rb;
506 }
507 sum
508}
509
510#[cfg(target_arch = "x86_64")]
511#[target_feature(enable = "avx")]
512pub unsafe fn sum_f64_avx(a: &[f64]) -> f64 {
517 let mut acc0 = _mm256_setzero_pd();
518 let mut acc1 = _mm256_setzero_pd();
519 let mut acc2 = _mm256_setzero_pd();
520 let mut acc3 = _mm256_setzero_pd();
521
522 let mut chunks = a.chunks_exact(16);
523 for chunk in chunks.by_ref() {
524 acc0 = _mm256_add_pd(acc0, _mm256_loadu_pd(chunk.as_ptr()));
525 acc1 = _mm256_add_pd(acc1, _mm256_loadu_pd(chunk.as_ptr().add(4)));
526 acc2 = _mm256_add_pd(acc2, _mm256_loadu_pd(chunk.as_ptr().add(8)));
527 acc3 = _mm256_add_pd(acc3, _mm256_loadu_pd(chunk.as_ptr().add(12)));
528 }
529
530 acc0 = _mm256_add_pd(acc0, acc1);
531 acc2 = _mm256_add_pd(acc2, acc3);
532 acc0 = _mm256_add_pd(acc0, acc2);
533
534 let mut lanes = [0.0_f64; 4];
535 _mm256_storeu_pd(lanes.as_mut_ptr(), acc0);
536 let mut sum = lanes[0] + lanes[1] + lanes[2] + lanes[3];
537 for &v in chunks.remainder() {
538 sum += v;
539 }
540 sum
541}
542
543#[cfg(target_arch = "x86_64")]
544#[target_feature(enable = "avx2")]
545pub unsafe fn bernoulli_compare_batch_avx2(buf: &[u8], threshold: u8, out: &mut [u64]) {
550 let v_thresh = _mm256_set1_epi8(threshold as i8);
551 let bias = _mm256_set1_epi8(i8::MIN);
553 let v_thresh_biased = _mm256_xor_si256(v_thresh, bias);
554
555 for i in 0..16 {
556 let chunk = &buf[i * 64..(i + 1) * 64];
558 let v0 = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
559 let v1 = _mm256_loadu_si256(chunk.as_ptr().add(32) as *const __m256i);
560
561 let v0_biased = _mm256_xor_si256(v0, bias);
562 let v1_biased = _mm256_xor_si256(v1, bias);
563
564 let m0 = _mm256_cmpgt_epi8(v_thresh_biased, v0_biased);
565 let m1 = _mm256_cmpgt_epi8(v_thresh_biased, v1_biased);
566
567 let mask0 = _mm256_movemask_epi8(m0) as u32;
568 let mask1 = _mm256_movemask_epi8(m1) as u32;
569 out[i] = (mask0 as u64) | ((mask1 as u64) << 32);
570 }
571}
572
573#[cfg(target_arch = "x86_64")]
574#[target_feature(enable = "avx")]
575pub unsafe fn max_f64_avx(a: &[f64]) -> f64 {
580 if a.is_empty() {
581 return f64::NEG_INFINITY;
582 }
583 let mut max_vec0 = _mm256_set1_pd(f64::NEG_INFINITY);
584 let mut max_vec1 = _mm256_set1_pd(f64::NEG_INFINITY);
585 let mut max_vec2 = _mm256_set1_pd(f64::NEG_INFINITY);
586 let mut max_vec3 = _mm256_set1_pd(f64::NEG_INFINITY);
587
588 let mut chunks = a.chunks_exact(16);
589 for chunk in chunks.by_ref() {
590 max_vec0 = _mm256_max_pd(max_vec0, _mm256_loadu_pd(chunk.as_ptr()));
591 max_vec1 = _mm256_max_pd(max_vec1, _mm256_loadu_pd(chunk.as_ptr().add(4)));
592 max_vec2 = _mm256_max_pd(max_vec2, _mm256_loadu_pd(chunk.as_ptr().add(8)));
593 max_vec3 = _mm256_max_pd(max_vec3, _mm256_loadu_pd(chunk.as_ptr().add(12)));
594 }
595
596 max_vec0 = _mm256_max_pd(max_vec0, max_vec1);
597 max_vec2 = _mm256_max_pd(max_vec2, max_vec3);
598 max_vec0 = _mm256_max_pd(max_vec0, max_vec2);
599
600 let mut lanes = [0.0_f64; 4];
601 _mm256_storeu_pd(lanes.as_mut_ptr(), max_vec0);
602 let mut m = lanes[0].max(lanes[1]).max(lanes[2].max(lanes[3]));
603 for &v in chunks.remainder() {
604 m = m.max(v);
605 }
606 m
607}
608
609#[cfg(target_arch = "x86_64")]
610#[target_feature(enable = "avx")]
611pub unsafe fn scale_f64_avx(alpha: f64, y: &mut [f64]) {
616 let valpha = _mm256_set1_pd(alpha);
617 let mut chunks = y.chunks_exact_mut(16);
618
619 for chunk in chunks.by_ref() {
620 let v0 = _mm256_loadu_pd(chunk.as_ptr());
621 let v1 = _mm256_loadu_pd(chunk.as_ptr().add(4));
622 let v2 = _mm256_loadu_pd(chunk.as_ptr().add(8));
623 let v3 = _mm256_loadu_pd(chunk.as_ptr().add(12));
624
625 _mm256_storeu_pd(chunk.as_mut_ptr(), _mm256_mul_pd(v0, valpha));
626 _mm256_storeu_pd(chunk.as_mut_ptr().add(4), _mm256_mul_pd(v1, valpha));
627 _mm256_storeu_pd(chunk.as_mut_ptr().add(8), _mm256_mul_pd(v2, valpha));
628 _mm256_storeu_pd(chunk.as_mut_ptr().add(12), _mm256_mul_pd(v3, valpha));
629 }
630
631 for v in chunks.into_remainder() {
632 *v *= alpha;
633 }
634}
635
636#[cfg(all(test, target_arch = "x86_64"))]
637mod tests {
638 use crate::bitstream::pack;
639
640 #[test]
641 fn pack_avx2_matches_pack() {
642 if !is_x86_feature_detected!("avx2") {
643 return;
644 }
645
646 let lengths = [
647 1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
648 ];
649 for length in lengths {
650 let bits: Vec<u8> = (0..length)
651 .map(|i| if (i * 17 + 5) % 3 == 0 { 1 } else { 0 })
652 .collect();
653 let got = unsafe { super::pack_avx2(&bits) };
655 let expected = pack(&bits).data;
656 assert_eq!(got, expected, "Mismatch at length={length}");
657 }
658 }
659
660 #[test]
661 fn fused_and_popcount_avx2_matches_scalar() {
662 if !is_x86_feature_detected!("avx2") {
663 return;
664 }
665
666 let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
667 for len in lengths {
668 let a: Vec<u64> = (0..len)
669 .map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
670 .collect();
671 let b: Vec<u64> = (0..len)
672 .map(|i| (i as u64).wrapping_mul(0xC2B2_AE3D_27D4_EB4F) ^ 0x0F0F_F0F0_33CC_CC33)
673 .collect();
674
675 let expected: u64 = a
676 .iter()
677 .zip(b.iter())
678 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
679 .sum();
680
681 let got = unsafe { super::fused_and_popcount_avx2(&a, &b) };
683 assert_eq!(got, expected, "Mismatch at len={len}");
684 }
685 }
686
687 #[test]
688 fn dot_f64_avx2_matches_scalar() {
689 if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
690 return;
691 }
692 let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.1).collect();
693 let b: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 5.0).collect();
694 let expected: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
695 let got = unsafe { super::dot_f64_avx2(&a, &b) };
696 assert!(
697 (got - expected).abs() < 1e-9,
698 "dot: got {got}, expected {expected}"
699 );
700 }
701
702 #[test]
703 fn max_f64_avx2_matches_scalar() {
704 if !is_x86_feature_detected!("avx2") {
705 return;
706 }
707 let a: Vec<f64> = (0..67).map(|i| (i as f64 * 7.3).sin()).collect();
708 let expected = a.iter().copied().fold(f64::NEG_INFINITY, f64::max);
709 let got = unsafe { super::max_f64_avx2(&a) };
710 assert!(
711 (got - expected).abs() < 1e-12,
712 "max: got {got}, expected {expected}"
713 );
714 }
715
716 #[test]
717 fn sum_f64_avx2_matches_scalar() {
718 if !is_x86_feature_detected!("avx2") {
719 return;
720 }
721 let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.01).collect();
722 let expected: f64 = a.iter().sum();
723 let got = unsafe { super::sum_f64_avx2(&a) };
724 assert!(
725 (got - expected).abs() < 1e-9,
726 "sum: got {got}, expected {expected}"
727 );
728 }
729
730 #[test]
731 fn softmax_avx2_sums_to_one() {
732 if !is_x86_feature_detected!("avx2") {
733 return;
734 }
735 let mut scores: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 10.0).collect();
736 unsafe { super::softmax_inplace_f64_avx2(&mut scores) };
737 let sum: f64 = scores.iter().sum();
738 assert!(
739 (sum - 1.0).abs() < 1e-10,
740 "softmax must sum to 1.0, got {sum}"
741 );
742 assert!(scores.iter().all(|&s| s >= 0.0), "all values must be >= 0");
743 }
744
745 #[test]
746 fn bernoulli_compare_avx2_matches_scalar() {
747 if !is_x86_feature_detected!("avx2") {
748 return;
749 }
750
751 let buf: Vec<u8> = (0..32).map(|i| (i * 73 + 17) as u8).collect();
752 let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
753
754 for threshold in thresholds {
755 let expected = buf.iter().enumerate().fold(0_u32, |acc, (bit, &rb)| {
756 acc | (u32::from(rb < threshold) << bit)
757 });
758
759 let got = unsafe { super::bernoulli_compare_avx2(&buf, threshold) };
761 assert_eq!(
762 got, expected,
763 "Mismatch for threshold={threshold} buf={buf:?}"
764 );
765 }
766 }
767
768 #[test]
769 fn dot_f64_avx_matches_scalar() {
770 if !is_x86_feature_detected!("avx") {
771 return;
772 }
773 let a: Vec<f64> = (0..67).map(|i| i as f64 * 0.1).collect();
774 let b: Vec<f64> = (0..67).map(|i| (i as f64 * 0.3) - 5.0).collect();
775 let expected: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
776 let got = unsafe { super::dot_f64_avx(&a, &b) };
777 assert!(
778 (got - expected).abs() < 1e-9,
779 "dot_avx: got {got}, expected {expected}"
780 );
781 }
782}