sc_neurocore_engine/simd/
avx512.rs1#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64::*;
11
12#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx512f,avx512vpopcntdq")]
14pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
19 let mut total = 0_u64;
20 let mut chunks = data.chunks_exact(16);
21
22 for chunk in chunks.by_ref() {
23 let v0 = _mm512_loadu_si512(chunk.as_ptr() as *const __m512i);
24 let v1 = _mm512_loadu_si512(chunk.as_ptr().add(8) as *const __m512i);
25
26 total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(v0)) as u64;
27 total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(v1)) as u64;
28 }
29
30 total + crate::bitstream::popcount_words_portable(chunks.remainder())
31}
32
33#[cfg(target_arch = "x86_64")]
34#[target_feature(enable = "avx512f,avx512bw")]
35pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
43 let length = bits.len();
44 let words = length.div_ceil(64);
45 let mut data = vec![0_u64; words];
46 let full_words = length / 64;
47 let zero = _mm512_setzero_si512();
48
49 let mut chunks = data[..full_words].chunks_exact_mut(4);
50 let mut word_idx = 0;
51 for chunk in chunks.by_ref() {
52 let base = word_idx * 64;
53 for i in 0..4 {
54 let v = _mm512_loadu_si512(bits.as_ptr().add(base + i * 64) as *const __m512i);
55 chunk[i] = _mm512_cmpneq_epi8_mask(v, zero);
56 }
57 word_idx += 4;
58 }
59
60 for i in word_idx..full_words {
61 let v = _mm512_loadu_si512(bits.as_ptr().add(i * 64) as *const __m512i);
62 data[i] = _mm512_cmpneq_epi8_mask(v, zero);
63 }
64
65 if full_words < words {
66 let tail_start = full_words * 64;
67 let tail = crate::bitstream::pack_fast(&bits[tail_start..]);
68 data[full_words] = tail.data.first().copied().unwrap_or(0);
69 }
70
71 data
72}
73
74#[cfg(target_arch = "x86_64")]
75#[target_feature(enable = "avx512f,avx512vpopcntdq")]
76pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
81 let len = a.len().min(b.len());
82 let mut total = 0_u64;
83 let mut chunks_a = a[..len].chunks_exact(16);
84 let mut chunks_b = b[..len].chunks_exact(16);
85
86 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
87 let va0 = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
88 let vb0 = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
89 let va1 = _mm512_loadu_si512(ca.as_ptr().add(8) as *const __m512i);
90 let vb1 = _mm512_loadu_si512(cb.as_ptr().add(8) as *const __m512i);
91
92 let and0 = _mm512_and_si512(va0, vb0);
93 let and1 = _mm512_and_si512(va1, vb1);
94
95 total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(and0)) as u64;
96 total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(and1)) as u64;
97 }
98
99 total
100 + chunks_a
101 .remainder()
102 .iter()
103 .zip(chunks_b.remainder().iter())
104 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
105 .sum::<u64>()
106}
107
108#[cfg(target_arch = "x86_64")]
109#[target_feature(enable = "avx512f,avx512vpopcntdq")]
110pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
115 let len = a.len().min(b.len());
116 let mut total = 0_u64;
117 let mut chunks_a = a[..len].chunks_exact(16);
118 let mut chunks_b = b[..len].chunks_exact(16);
119
120 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
121 let va0 = _mm512_loadu_si512(ca.as_ptr() as *const __m512i);
122 let vb0 = _mm512_loadu_si512(cb.as_ptr() as *const __m512i);
123 let va1 = _mm512_loadu_si512(ca.as_ptr().add(8) as *const __m512i);
124 let vb1 = _mm512_loadu_si512(cb.as_ptr().add(8) as *const __m512i);
125
126 let xor0 = _mm512_xor_si512(va0, vb0);
127 let xor1 = _mm512_xor_si512(va1, vb1);
128
129 total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(xor0)) as u64;
130 total += _mm512_reduce_add_epi64(_mm512_popcnt_epi64(xor1)) as u64;
131 }
132
133 total
134 + chunks_a
135 .remainder()
136 .iter()
137 .zip(chunks_b.remainder().iter())
138 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
139 .sum::<u64>()
140}
141
142#[cfg(not(target_arch = "x86_64"))]
143pub unsafe fn fused_xor_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
148 a.iter()
149 .zip(b.iter())
150 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
151 .sum()
152}
153
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx512f,avx512bw")]
156pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
164 assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
165 let data = _mm512_loadu_si512(buf.as_ptr() as *const __m512i);
166 let thresh = _mm512_set1_epi8(threshold as i8);
167 _mm512_cmplt_epu8_mask(data, thresh)
168}
169
170#[cfg(not(target_arch = "x86_64"))]
171pub unsafe fn popcount_avx512(data: &[u64]) -> u64 {
176 crate::bitstream::popcount_words_portable(data)
177}
178
179#[cfg(not(target_arch = "x86_64"))]
180pub unsafe fn pack_avx512(bits: &[u8]) -> Vec<u64> {
185 crate::bitstream::pack_fast(bits).data
186}
187
188#[cfg(not(target_arch = "x86_64"))]
189pub unsafe fn fused_and_popcount_avx512(a: &[u64], b: &[u64]) -> u64 {
194 a.iter()
195 .zip(b.iter())
196 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
197 .sum()
198}
199
200#[cfg(not(target_arch = "x86_64"))]
201pub unsafe fn bernoulli_compare_avx512(buf: &[u8], threshold: u8) -> u64 {
206 let mut mask = 0_u64;
207 for (bit, &rb) in buf.iter().take(64).enumerate() {
208 if rb < threshold {
209 mask |= 1_u64 << bit;
210 }
211 }
212 mask
213}
214
215#[cfg(target_arch = "x86_64")]
218#[target_feature(enable = "avx512f")]
219pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
224 let len = a.len().min(b.len());
225 let mut acc = _mm512_setzero_pd();
226 let mut chunks_a = a[..len].chunks_exact(8);
227 let mut chunks_b = b[..len].chunks_exact(8);
228
229 for (ca, cb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
230 let va = _mm512_loadu_pd(ca.as_ptr());
231 let vb = _mm512_loadu_pd(cb.as_ptr());
232 acc = _mm512_fmadd_pd(va, vb, acc);
233 }
234
235 let mut sum = _mm512_reduce_add_pd(acc);
236 for (&ra, &rb) in chunks_a.remainder().iter().zip(chunks_b.remainder()) {
237 sum += ra * rb;
238 }
239 sum
240}
241
242#[cfg(target_arch = "x86_64")]
243#[target_feature(enable = "avx512f")]
244pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
249 if a.is_empty() {
250 return f64::NEG_INFINITY;
251 }
252 let mut vmax0 = _mm512_set1_pd(f64::NEG_INFINITY);
253 let mut vmax1 = _mm512_set1_pd(f64::NEG_INFINITY);
254 let mut chunks = a.chunks_exact(16);
255
256 for chunk in chunks.by_ref() {
257 vmax0 = _mm512_max_pd(vmax0, _mm512_loadu_pd(chunk.as_ptr()));
258 vmax1 = _mm512_max_pd(vmax1, _mm512_loadu_pd(chunk.as_ptr().add(8)));
259 }
260
261 let mut m = _mm512_reduce_max_pd(_mm512_max_pd(vmax0, vmax1));
262 for &v in chunks.remainder() {
263 m = m.max(v);
264 }
265 m
266}
267
268#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "avx512f")]
270pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
275 let mut acc0 = _mm512_setzero_pd();
276 let mut acc1 = _mm512_setzero_pd();
277 let mut chunks = a.chunks_exact(16);
278
279 for chunk in chunks.by_ref() {
280 acc0 = _mm512_add_pd(acc0, _mm512_loadu_pd(chunk.as_ptr()));
281 acc1 = _mm512_add_pd(acc1, _mm512_loadu_pd(chunk.as_ptr().add(8)));
282 }
283
284 let mut sum = _mm512_reduce_add_pd(_mm512_add_pd(acc0, acc1));
285 for &v in chunks.remainder() {
286 sum += v;
287 }
288 sum
289}
290
291#[cfg(target_arch = "x86_64")]
292#[target_feature(enable = "avx512f")]
293pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
298 let valpha = _mm512_set1_pd(alpha);
299 let mut chunks = y.chunks_exact_mut(16);
300
301 for chunk in chunks.by_ref() {
302 let v0 = _mm512_loadu_pd(chunk.as_ptr());
303 let v1 = _mm512_loadu_pd(chunk.as_ptr().add(8));
304 _mm512_storeu_pd(chunk.as_mut_ptr(), _mm512_mul_pd(v0, valpha));
305 _mm512_storeu_pd(chunk.as_mut_ptr().add(8), _mm512_mul_pd(v1, valpha));
306 }
307
308 for v in chunks.into_remainder() {
309 *v *= alpha;
310 }
311}
312
313#[cfg(not(target_arch = "x86_64"))]
314pub unsafe fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
315 let len = a.len().min(b.len());
316 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
317}
318
319#[cfg(not(target_arch = "x86_64"))]
320pub unsafe fn max_f64_avx512(a: &[f64]) -> f64 {
321 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
322}
323
324#[cfg(not(target_arch = "x86_64"))]
325pub unsafe fn sum_f64_avx512(a: &[f64]) -> f64 {
326 a.iter().sum()
327}
328
329#[cfg(not(target_arch = "x86_64"))]
330pub unsafe fn scale_f64_avx512(alpha: f64, y: &mut [f64]) {
331 for v in y.iter_mut() {
332 *v *= alpha;
333 }
334}
335
336#[cfg(target_arch = "x86_64")]
337#[target_feature(enable = "avx512bw")]
338pub unsafe fn bernoulli_compare_batch_avx512(buf: &[u8], threshold: u8, out: &mut [u64]) {
343 let v_thresh = _mm512_set1_epi8(threshold as i8);
344 for i in 0..16 {
345 let chunk = &buf[i * 64..(i + 1) * 64];
346 let v = _mm512_loadu_si512(chunk.as_ptr() as *const _);
347 out[i] = _mm512_cmplt_epu8_mask(v, v_thresh);
349 }
350}
351
352#[cfg(all(test, target_arch = "x86_64"))]
353mod tests {
354 use crate::bitstream::pack;
355
356 #[test]
357 fn pack_avx512_matches_pack() {
358 if !is_x86_feature_detected!("avx512bw") {
359 return;
360 }
361
362 let lengths = [
363 1_usize, 7, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024, 1031,
364 ];
365 for length in lengths {
366 let bits: Vec<u8> = (0..length)
367 .map(|i| if (i * 19 + 11) % 4 == 0 { 1 } else { 0 })
368 .collect();
369 let got = unsafe { super::pack_avx512(&bits) };
371 let expected = pack(&bits).data;
372 assert_eq!(got, expected, "Mismatch at length={length}");
373 }
374 }
375
376 #[test]
377 fn fused_and_popcount_avx512_matches_scalar() {
378 if !is_x86_feature_detected!("avx512vpopcntdq") {
379 return;
380 }
381
382 let lengths = [1_usize, 7, 8, 15, 16, 17, 31, 32, 64, 128];
383 for len in lengths {
384 let a: Vec<u64> = (0..len)
385 .map(|i| (i as u64).wrapping_mul(0xD6E8_FD9D_5A2B_1C47) ^ 0x1357_9BDF_2468_ACE0)
386 .collect();
387 let b: Vec<u64> = (0..len)
388 .map(|i| (i as u64).wrapping_mul(0x94D0_49BB_1331_11EB) ^ 0xF0F0_0F0F_AAAA_5555)
389 .collect();
390
391 let expected: u64 = a
392 .iter()
393 .zip(b.iter())
394 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
395 .sum();
396
397 let got = unsafe { super::fused_and_popcount_avx512(&a, &b) };
399 assert_eq!(got, expected, "Mismatch at len={len}");
400 }
401 }
402
403 #[test]
404 fn bernoulli_compare_avx512_matches_scalar() {
405 if !is_x86_feature_detected!("avx512bw") {
406 return;
407 }
408
409 let buf: Vec<u8> = (0..64).map(|i| (i * 41 + 23) as u8).collect();
410 let thresholds = [0_u8, 1, 2, 17, 64, 127, 128, 200, 255];
411
412 for threshold in thresholds {
413 let expected = buf.iter().enumerate().fold(0_u64, |acc, (bit, &rb)| {
414 acc | (u64::from(rb < threshold) << bit)
415 });
416
417 let got = unsafe { super::bernoulli_compare_avx512(&buf, threshold) };
419 assert_eq!(
420 got, expected,
421 "Mismatch for threshold={threshold} buf={buf:?}"
422 );
423 }
424 }
425}