sc_neurocore_engine/simd/
avx512.rs1#[cfg(target_arch = "x86_64")]
9use core::arch::x86_64::*;
10
11#[cfg(target_arch = "x86_64")]
12#[target_feature(enable = "avx512f,avx512vpopcntdq")]
13pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
18 let mut total = 0_u64;
19 let mut chunks = data.chunks_exact(8);
20
21 for chunk in &mut chunks {
22 let v = _mm512_loadu_si512(chunk.as_ptr() as *const __m512i);
23 let counts = _mm512_popcnt_epi64(v);
24 let mut lanes = [0_u64; 8];
25 _mm512_storeu_si512(lanes.as_mut_ptr() as *mut __m512i, counts);
26 total += lanes.iter().sum::<u64>();
27 }
28
29 total + crate::bitstream::popcount_words_portable(chunks.remainder())
30}
31
32#[cfg(target_arch = "x86_64")]
33#[target_feature(enable = "avx512f,avx512bw")]
34pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
42 let length = bits.len();
43 let words = length.div_ceil(64);
44 let mut data = vec![0_u64; words];
45 let full_words = length / 64;
46 let zero = _mm512_setzero_si512();
47
48 for (word_idx, word) in data.iter_mut().take(full_words).enumerate() {
49 let base = word_idx * 64;
50 let v = _mm512_loadu_si512(bits.as_ptr().add(base) as *const __m512i);
51 let mask = _mm512_cmpneq_epi8_mask(v, zero);
52 *word = mask;
53 }
54
55 if full_words < words {
56 let tail_start = full_words * 64;
57 let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
58 data[full_words] = tail.data.first().copied().unwrap_or(0);
59 }
60
61 data
62}
63
64#[cfg(target_arch = "x86_64")]
65#[target_feature(enable = "avx512f,avx512vpopcntdq")]
66pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
71 let len = a.len().min(b.len());
72 let mut total = _mm512_setzero_si512();
73 let mut chunks_a = a[..len].chunks_exact(8);
74 let mut chunks_b = b[..len].chunks_exact(8);
75
76 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
77 let va = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
78 let vb = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
79 let anded = _mm512_and_epi64(va, vb);
80 let counts = _mm512_popcnt_epi64(anded);
81 total = _mm512_add_epi64(total, counts);
82 }
83
84 let mut lanes = [0_u64; 8];
85 _mm512_storeu_si512(lanes.as_mut_ptr() as *mut __m512i, total);
86 let mut sum: u64 = lanes.iter().sum();
87
88 for (&wa, &wb) in chunks_a.remainder().iter().zip(chunks_b.remainder().iter()) {
89 sum += (wa & wb).count_ones() as u64;
90 }
91 sum
92}
93
94#[cfg(target_arch = "x86_64")]
95#[target_feature(enable = "avx512f,avx512vpopcntdq")]
96pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
101 let len = a.len().min(b.len());
102 let mut total = _mm512_setzero_si512();
103 let mut chunks_a = a[..len].chunks_exact(8);
104 let mut chunks_b = b[..len].chunks_exact(8);
105
106 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
107 let va = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
108 let vb = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
109 let xored = _mm512_xor_epi64(va, vb);
110 let counts = _mm512_popcnt_epi64(xored);
111 total = _mm512_add_epi64(total, counts);
112 }
113
114 let mut lanes = [0_u64; 8];
115 _mm512_storeu_si512(lanes.as_mut_ptr() as *mut __m512i, total);
116 let mut sum: u64 = lanes.iter().sum();
117
118 for (&wa, &wb) in chunks_a.remainder().iter().zip(chunks_b.remainder().iter()) {
119 sum += (wa ^ wb).count_ones() as u64;
120 }
121 sum
122}
123
124#[cfg(not(target_arch = "x86_64"))]
125pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
130 a.iter()
131 .zip(b.iter())
132 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
133 .sum()
134}
135
136#[cfg(target_arch = "x86_64")]
137#[target_feature(enable = "avx512f,avx512bw")]
138pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
146 assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
147 let data = _mm512_loadu_si512(buf.as_ptr() as *const __m512i);
148 let thresh = _mm512_set1_epi8(threshold as i8);
149 _mm512_cmplt_epu8_mask(data, thresh)
150}
151
152#[cfg(not(target_arch = "x86_64"))]
153pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
158 crate::bitstream::popcount_words_portable(data)
159}
160
161#[cfg(not(target_arch = "x86_64"))]
162pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
167 crate::bitstream::pack_fast(bits).data
168}
169
170#[cfg(not(target_arch = "x86_64"))]
171pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
176 a.iter()
177 .zip(b.iter())
178 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
179 .sum()
180}
181
182#[cfg(not(target_arch = "x86_64"))]
183pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
188 let mut mask = 0_u64;
189 for (bit, &rb) in buf.iter().take(64).enumerate() {
190 if rb < threshold {
191 mask |= 1_u64 << bit;
192 }
193 }
194 mask
195}
196
197#[cfg(target_arch = "x86_64")]
200#[target_feature(enable = "avx512f")]
201pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
206 let len = a.len().min(b.len());
207 let mut acc = _mm512_setzero_pd();
208 let mut chunks_a = a[..len].chunks_exact(8);
209 let mut chunks_b = b[..len].chunks_exact(8);
210
211 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
212 let va = _mm512_loadu_pd(ca.as_ptr());
213 let vb = _mm512_loadu_pd(cb.as_ptr());
214 acc = _mm512_fmadd_pd(va, vb, acc);
215 }
216
217 let mut sum = _mm512_reduce_add_pd(acc);
218 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
219 sum += ra * rb;
220 }
221 sum
222}
223
224#[cfg(target_arch = "x86_64")]
225#[target_feature(enable = "avx512f")]
226pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
231 if a.is_empty() {
232 return f64::NEG_INFINITY;
233 }
234 let mut vmax = _mm512_set1_pd(f64::NEG_INFINITY);
235 let mut chunks = a.chunks_exact(8);
236
237 for chunk in chunks.by_ref() {
238 let va = _mm512_loadu_pd(chunk.as_ptr());
239 vmax = _mm512_max_pd(vmax, va);
240 }
241
242 let mut m = _mm512_reduce_max_pd(vmax);
243 for &v in chunks.remainder() {
244 m = m.max(v);
245 }
246 m
247}
248
249#[cfg(target_arch = "x86_64")]
250#[target_feature(enable = "avx512f")]
251pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
256 let mut acc = _mm512_setzero_pd();
257 let mut chunks = a.chunks_exact(8);
258
259 for chunk in chunks.by_ref() {
260 let va = _mm512_loadu_pd(chunk.as_ptr());
261 acc = _mm512_add_pd(acc, va);
262 }
263
264 let mut sum = _mm512_reduce_add_pd(acc);
265 for &v in chunks.remainder() {
266 sum += v;
267 }
268 sum
269}
270
271#[cfg(target_arch = "x86_64")]
272#[target_feature(enable = "avx512f")]
273pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
278 let valpha = _mm512_set1_pd(alpha);
279 let mut chunks = y.chunks_exact_mut(8);
280
281 for chunk in chunks.by_ref() {
282 let vy = _mm512_loadu_pd(chunk.as_ptr());
283 let scaled = _mm512_mul_pd(vy, valpha);
284 _mm512_storeu_pd(chunk.as_mut_ptr(), scaled);
285 }
286
287 for v in chunks.into_remainder() {
288 *v *= alpha;
289 }
290}
291
292#[cfg(not(target_arch = "x86_64"))]
293pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
294 let len = a.len().min(b.len());
295 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
296}
297
298#[cfg(not(target_arch = "x86_64"))]
299pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
300 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
301}
302
303#[cfg(not(target_arch = "x86_64"))]
304pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
305 a.iter().sum()
306}
307
308#[cfg(not(target_arch = "x86_64"))]
309pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
310 for v in y.iter_mut() {
311 *v *= alpha;
312 }
313}
314
315#[cfg(all(test, target_arch = "x86_64"))]
316mod tests {
317 use crate::bitstream::pack;
318
319 #[test]
320 fn pack_avx512_matches_pack() {
321 if !is_x86_feature_detected!("avx512bw") {
322 return;
323 }
324
325 let lengths = [
326 1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
327 ];
328 for length in lengths {
329 let bits: Vec<u8> = (0..length)
330 .map(|i| if (i * 19 + 11) % 4 == 0 { 1 } else { 0 })
331 .collect();
332 let got = unsafe { super::pack_avx512(&bits) };
334 let expected = pack(&bits).data;
335 assert_eq!(got, expected, "Mismatch at length={length}");
336 }
337 }
338
339 #[test]
340 fn fused_and_popcount_avx512_matches_scalar() {
341 if !is_x86_feature_detected!("avx512vpopcntdq") {
342 return;
343 }
344
345 let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
346 for len in lengths {
347 let a: Vec<u64> = (0..len)
348 .map(|i| (i as u64).wrapping_mul(0xD6E8_FD9D_5A2B_1C47) ^ 0x1357_9BDF_2468_ACE0)
349 .collect();
350 let b: Vec<u64> = (0..len)
351 .map(|i| (i as u64).wrapping_mul(0x94D0_49BB_1331_11EB) ^ 0xF0F0_0F0F_AAAA_5555)
352 .collect();
353
354 let expected: u64 = a
355 .iter()
356 .zip(b.iter())
357 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
358 .sum();
359
360 let got = unsafe { super::fused_and_popcount_avx512(&a, &b) };
362 assert_eq!(got, expected, "Mismatch at len={len}");
363 }
364 }
365
366 #[test]
367 fn bernoulli_compare_avx512_matches_scalar() {
368 if !is_x86_feature_detected!("avx512bw") {
369 return;
370 }
371
372 let buf: Vec<u8> = (0..64).map(|i| (i * 41 + 23) as u8).collect();
373 let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
374
375 for threshold in thresholds {
376 let expected = buf.iter().enumerate().fold(0_u64, |acc, (bit, &rb)| {
377 acc | (u64::from(rb < threshold) << bit)
378 });
379
380 let got = unsafe { super::bernoulli_compare_avx512(&buf, threshold) };
382 assert_eq!(
383 got, expected,
384 "Mismatch for threshold={threshold} buf={buf:?}"
385 );
386 }
387 }
388}