sc_neurocore_engine/simd/
mod.rs1use rand::Rng;
14
15pub mod avx2;
16pub mod avx512;
17pub mod neon;
18pub mod rvv;
19pub mod sve;
20
21pub fn pack_dispatch(bits: &[u8]) -> crate::bitstream::BitStreamTensor {
23 let length = bits.len();
24
25 #[cfg(target_arch = "x86_64")]
26 {
27 if is_x86_feature_detected!("avx512bw") {
28 let data = unsafe { avx512::pack_avx512(bits) };
30 return crate::bitstream::BitStreamTensor { data, length };
31 }
32 if is_x86_feature_detected!("avx2") {
33 let data = unsafe { avx2::pack_avx2(bits) };
35 return crate::bitstream::BitStreamTensor { data, length };
36 }
37 }
38
39 #[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
40 {
41 let data = unsafe { sve::pack_sve(bits) };
43 return crate::bitstream::BitStreamTensor { data, length };
44 }
45
46 crate::bitstream::pack_fast(bits)
47}
48
49pub fn popcount_dispatch(data: &[u64]) -> u64 {
51 #[cfg(target_arch = "x86_64")]
52 {
53 if is_x86_feature_detected!("avx512vpopcntdq") {
54 return unsafe { avx512::popcount_avx512(data) };
56 }
57 if is_x86_feature_detected!("avx2") {
58 return unsafe { avx2::popcount_avx2(data) };
60 }
61 }
62
63 #[cfg(target_arch = "aarch64")]
64 {
65 #[cfg(target_feature = "sve")]
66 {
67 return unsafe { sve::popcount_sve(data) };
69 }
70 #[cfg(not(target_feature = "sve"))]
71 {
72 return unsafe { neon::popcount_neon(data) };
74 }
75 }
76
77 #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
78 {
79 return unsafe { rvv::popcount_rvv(data) };
81 }
82
83 crate::bitstream::popcount_words_portable(data)
84}
85
86pub fn fused_and_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
88 let len = a.len().min(b.len());
89 let a = &a[..len];
90 let b = &b[..len];
91
92 #[cfg(target_arch = "x86_64")]
93 {
94 if is_x86_feature_detected!("avx512vpopcntdq") {
95 return unsafe { avx512::fused_and_popcount_avx512(a, b) };
97 }
98 if is_x86_feature_detected!("avx2") {
99 return unsafe { avx2::fused_and_popcount_avx2(a, b) };
101 }
102 }
103
104 #[cfg(target_arch = "aarch64")]
105 {
106 #[cfg(target_feature = "sve")]
107 {
108 return unsafe { sve::fused_and_popcount_sve(a, b) };
109 }
110 }
111
112 #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
113 {
114 return unsafe { rvv::fused_and_popcount_rvv(a, b) };
115 }
116
117 a.iter()
118 .zip(b.iter())
119 .map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
120 .sum()
121}
122
123pub fn fused_xor_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
125 let len = a.len().min(b.len());
126 let a = &a[..len];
127 let b = &b[..len];
128
129 #[cfg(target_arch = "x86_64")]
130 {
131 if is_x86_feature_detected!("avx512vpopcntdq") {
132 return unsafe { avx512::fused_xor_popcount_avx512(a, b) };
134 }
135 if is_x86_feature_detected!("avx2") {
136 return unsafe { avx2::fused_xor_popcount_avx2(a, b) };
138 }
139 }
140
141 #[cfg(target_arch = "aarch64")]
142 {
143 #[cfg(target_feature = "sve")]
144 {
145 return unsafe { sve::fused_xor_popcount_sve(a, b) };
146 }
147 }
148
149 #[cfg(all(target_arch = "riscv64", target_feature = "v"))]
150 {
151 return unsafe { rvv::fused_xor_popcount_rvv(a, b) };
152 }
153
154 a.iter()
155 .zip(b.iter())
156 .map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
157 .sum()
158}
159
160pub fn dot_f64_dispatch(a: &[f64], b: &[f64]) -> f64 {
164 #[cfg(target_arch = "x86_64")]
165 {
166 if is_x86_feature_detected!("avx512f") {
167 return unsafe { avx512::dot_f64_avx512(a, b) };
168 }
169 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
170 return unsafe { avx2::dot_f64_avx2(a, b) };
171 }
172 }
173
174 #[cfg(target_arch = "aarch64")]
175 {
176 return unsafe { neon::dot_f64_neon(a, b) };
177 }
178
179 #[allow(unreachable_code)]
180 {
181 let len = a.len().min(b.len());
182 a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
183 }
184}
185
186pub fn max_f64_dispatch(a: &[f64]) -> f64 {
188 #[cfg(target_arch = "x86_64")]
189 {
190 if is_x86_feature_detected!("avx512f") {
191 return unsafe { avx512::max_f64_avx512(a) };
192 }
193 if is_x86_feature_detected!("avx2") {
194 return unsafe { avx2::max_f64_avx2(a) };
195 }
196 }
197
198 #[cfg(target_arch = "aarch64")]
199 {
200 return unsafe { neon::max_f64_neon(a) };
201 }
202
203 #[allow(unreachable_code)]
204 a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
205}
206
207pub fn sum_f64_dispatch(a: &[f64]) -> f64 {
209 #[cfg(target_arch = "x86_64")]
210 {
211 if is_x86_feature_detected!("avx512f") {
212 return unsafe { avx512::sum_f64_avx512(a) };
213 }
214 if is_x86_feature_detected!("avx2") {
215 return unsafe { avx2::sum_f64_avx2(a) };
216 }
217 }
218
219 #[cfg(target_arch = "aarch64")]
220 {
221 return unsafe { neon::sum_f64_neon(a) };
222 }
223
224 #[allow(unreachable_code)]
225 a.iter().sum()
226}
227
228pub fn scale_f64_dispatch(alpha: f64, y: &mut [f64]) {
230 #[cfg(target_arch = "x86_64")]
231 {
232 if is_x86_feature_detected!("avx512f") {
233 unsafe { avx512::scale_f64_avx512(alpha, y) };
234 return;
235 }
236 if is_x86_feature_detected!("avx2") {
237 unsafe { avx2::scale_f64_avx2(alpha, y) };
238 return;
239 }
240 }
241
242 #[cfg(target_arch = "aarch64")]
243 {
244 unsafe { neon::scale_f64_neon(alpha, y) };
245 return;
246 }
247
248 #[allow(unreachable_code)]
249 for v in y.iter_mut() {
250 *v *= alpha;
251 }
252}
253
254pub fn hamming_distance_dispatch(a: &[u64], b: &[u64]) -> u64 {
256 fused_xor_popcount_dispatch(a, b)
257}
258
259pub fn softmax_inplace_f64_dispatch(scores: &mut [f64]) {
264 if scores.is_empty() {
265 return;
266 }
267 let max_val = max_f64_dispatch(scores);
268 for s in scores.iter_mut() {
269 *s = (*s - max_val).exp();
270 }
271 let exp_sum = sum_f64_dispatch(scores);
272 if exp_sum > 0.0 {
273 scale_f64_dispatch(1.0 / exp_sum, scores);
274 }
275}
276
277pub fn encode_and_popcount_dispatch<R: Rng + ?Sized>(
282 weight_words: &[u64],
283 prob: f64,
284 length: usize,
285 rng: &mut R,
286) -> u64 {
287 crate::bitstream::encode_and_popcount(weight_words, prob, length, rng)
288}