1use rand::SeedableRng;
13use rand_chacha::ChaCha8Rng;
14use rayon::prelude::*;
15
16pub struct StochasticAttention {
17 pub dim_k: usize,
18 pub temperature: f64,
20}
21
22impl StochasticAttention {
23 pub fn new(dim_k: usize) -> Self {
24 Self {
25 dim_k,
26 temperature: (dim_k as f64).sqrt(),
27 }
28 }
29
30 pub fn with_temperature(dim_k: usize, temperature: f64) -> Self {
31 Self { dim_k, temperature }
32 }
33
34 #[allow(clippy::too_many_arguments)]
38 pub fn forward_softmax(
39 &self,
40 q: &[f64],
41 q_rows: usize,
42 q_cols: usize,
43 k: &[f64],
44 k_rows: usize,
45 k_cols: usize,
46 v: &[f64],
47 v_rows: usize,
48 v_cols: usize,
49 ) -> Result<Vec<f64>, String> {
50 validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
51 let inv_temp = if self.temperature > 0.0 {
52 1.0 / self.temperature
53 } else {
54 1.0
55 };
56
57 let out_rows: Vec<Vec<f64>> = (0..q_rows)
58 .into_par_iter()
59 .map(|i| {
60 let q_row = &q[i * q_cols..(i + 1) * q_cols];
61
62 let mut scores = vec![0.0_f64; k_rows];
63 for j in 0..k_rows {
64 let k_row = &k[j * k_cols..(j + 1) * k_cols];
65 scores[j] = crate::simd::dot_f64_dispatch(q_row, k_row) * inv_temp;
66 }
67
68 crate::simd::softmax_inplace_f64_dispatch(&mut scores);
69
70 let mut out = vec![0.0_f64; v_cols];
72 for d in 0..v_cols {
73 let mut acc = 0.0_f64;
74 for j in 0..k_rows {
75 acc += scores[j] * v[j * v_cols + d];
76 }
77 out[d] = acc;
78 }
79 out
80 })
81 .collect();
82
83 Ok(flatten_rows(out_rows, q_rows, v_cols))
84 }
85
86 #[allow(clippy::too_many_arguments)]
88 pub fn forward_multihead_softmax(
89 &self,
90 q: &[f64],
91 q_rows: usize,
92 q_total_cols: usize,
93 k: &[f64],
94 k_rows: usize,
95 k_total_cols: usize,
96 v: &[f64],
97 v_rows: usize,
98 v_total_cols: usize,
99 n_heads: usize,
100 ) -> Result<Vec<f64>, String> {
101 validate_multihead_shapes(
102 q,
103 q_rows,
104 q_total_cols,
105 k,
106 k_rows,
107 k_total_cols,
108 v,
109 v_rows,
110 v_total_cols,
111 n_heads,
112 )?;
113
114 let dk = q_total_cols / n_heads;
115 let dv = v_total_cols / n_heads;
116 let head_attn = Self::with_temperature(dk, self.temperature);
117
118 let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
119 .into_par_iter()
120 .map(|h| {
121 let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
122 let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
123 let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
124 head_attn.forward_softmax(
125 &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
126 )
127 })
128 .collect();
129 let head_outputs = head_outputs?;
130
131 let out_cols = dv * n_heads;
132 let mut out = Vec::with_capacity(q_rows * out_cols);
133 for i in 0..q_rows {
134 for head in head_outputs.iter().take(n_heads) {
135 out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
136 }
137 }
138 Ok(out)
139 }
140
141 #[allow(clippy::too_many_arguments)]
148 pub fn forward(
149 &self,
150 q: &[f64],
151 q_rows: usize,
152 q_cols: usize,
153 k: &[f64],
154 k_rows: usize,
155 k_cols: usize,
156 v: &[f64],
157 v_rows: usize,
158 v_cols: usize,
159 ) -> Result<Vec<f64>, String> {
160 validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
161
162 let score_rows: Vec<Vec<f64>> = (0..q_rows)
163 .into_par_iter()
164 .map(|i| {
165 let q_row = &q[i * q_cols..(i + 1) * q_cols];
166 let mut row = vec![0.0_f64; k_rows];
167 for j in 0..k_rows {
168 let k_row = &k[j * k_cols..(j + 1) * k_cols];
169 row[j] = crate::simd::dot_f64_dispatch(q_row, k_row);
170 }
171 row
172 })
173 .collect();
174
175 let out_rows: Vec<Vec<f64>> = (0..q_rows)
176 .into_par_iter()
177 .map(|i| {
178 let scores = &score_rows[i];
179 let mut row_sum = scores.iter().sum::<f64>();
180 if row_sum == 0.0 {
181 row_sum = 1.0;
182 }
183
184 let mut out = vec![0.0_f64; v_cols];
185 for d in 0..v_cols {
186 let mut acc = 0.0_f64;
187 for j in 0..k_rows {
188 let weight = scores[j] / row_sum;
189 acc += weight * v[j * v_cols + d];
190 }
191 out[d] = acc;
192 }
193 out
194 })
195 .collect();
196
197 Ok(flatten_rows(out_rows, q_rows, v_cols))
198 }
199
200 #[allow(clippy::too_many_arguments)]
202 pub fn forward_sc(
203 &self,
204 q: &[f64],
205 q_rows: usize,
206 q_cols: usize,
207 k: &[f64],
208 k_rows: usize,
209 k_cols: usize,
210 v: &[f64],
211 v_rows: usize,
212 v_cols: usize,
213 length: usize,
214 seed: u64,
215 ) -> Result<Vec<f64>, String> {
216 validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
217 if length == 0 {
218 return Err("length must be > 0 for SC mode.".to_string());
219 }
220
221 let mut rng = ChaCha8Rng::seed_from_u64(seed);
222 let words = length.div_ceil(64);
223
224 let q_packed = crate::bitstream::encode_matrix_prob_to_packed(
225 q, q_rows, q_cols, length, words, &mut rng,
226 );
227 let k_packed = crate::bitstream::encode_matrix_prob_to_packed(
228 k, k_rows, k_cols, length, words, &mut rng,
229 );
230 let v_packed = crate::bitstream::encode_matrix_prob_to_packed(
231 v, v_rows, v_cols, length, words, &mut rng,
232 );
233
234 let mut score_rows = vec![vec![0.0_f64; k_rows]; q_rows];
235 for (i, score_row) in score_rows.iter_mut().enumerate().take(q_rows) {
236 for (j, score_value) in score_row.iter_mut().enumerate().take(k_rows) {
237 let mut pop_total = 0_u64;
238 for d in 0..q_cols {
239 let q_idx = i * q_cols + d;
240 let k_idx = j * k_cols + d;
241 let qa = &q_packed[q_idx];
242 let kb = &k_packed[k_idx];
243 for w in 0..words {
244 pop_total += crate::bitstream::swar_popcount_word(qa[w] & kb[w]);
245 }
246 }
247 *score_value = pop_total as f64 / length as f64;
248 }
249 }
250
251 let attn_weights: Vec<Vec<f64>> = score_rows
252 .iter()
253 .map(|row| {
254 let mut row_sum = row.iter().sum::<f64>();
255 if row_sum == 0.0 {
256 row_sum = 1.0;
257 }
258 row.iter().map(|x| x / row_sum).collect()
259 })
260 .collect();
261
262 let attn_flat: Vec<f64> = attn_weights.into_iter().flatten().collect();
263 let attn_packed = crate::bitstream::encode_matrix_prob_to_packed(
264 &attn_flat, q_rows, k_rows, length, words, &mut rng,
265 );
266
267 let out_rows: Vec<Vec<f64>> = (0..q_rows)
268 .into_par_iter()
269 .map(|i| {
270 let mut out = vec![0.0_f64; v_cols];
271 for d in 0..v_cols {
272 let mut pop_total = 0_u64;
273 for j in 0..k_rows {
274 let a = &attn_packed[i * k_rows + j];
275 let b = &v_packed[j * v_cols + d];
276 for w in 0..words {
277 pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
278 }
279 }
280 out[d] = pop_total as f64 / length as f64;
281 }
282 out
283 })
284 .collect();
285
286 Ok(flatten_rows(out_rows, q_rows, v_cols))
287 }
288
289 #[allow(clippy::too_many_arguments)]
291 pub fn forward_multihead(
292 &self,
293 q: &[f64],
294 q_rows: usize,
295 q_total_cols: usize,
296 k: &[f64],
297 k_rows: usize,
298 k_total_cols: usize,
299 v: &[f64],
300 v_rows: usize,
301 v_total_cols: usize,
302 n_heads: usize,
303 ) -> Result<Vec<f64>, String> {
304 validate_multihead_shapes(
305 q,
306 q_rows,
307 q_total_cols,
308 k,
309 k_rows,
310 k_total_cols,
311 v,
312 v_rows,
313 v_total_cols,
314 n_heads,
315 )?;
316
317 let dk = q_total_cols / n_heads;
318 let dv = v_total_cols / n_heads;
319
320 let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
321 .into_par_iter()
322 .map(|h| {
323 let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
324 let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
325 let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
326 self.forward(
327 &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
328 )
329 })
330 .collect();
331 let head_outputs = head_outputs?;
332
333 let out_cols = dv * n_heads;
334 let mut out = Vec::with_capacity(q_rows * out_cols);
335 for i in 0..q_rows {
336 for head in head_outputs.iter().take(n_heads) {
337 out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
338 }
339 }
340 Ok(out)
341 }
342}
343
344#[allow(clippy::too_many_arguments)]
345fn validate_multihead_shapes(
346 q: &[f64],
347 q_rows: usize,
348 q_total_cols: usize,
349 k: &[f64],
350 k_rows: usize,
351 k_total_cols: usize,
352 v: &[f64],
353 v_rows: usize,
354 v_total_cols: usize,
355 n_heads: usize,
356) -> Result<(), String> {
357 if n_heads == 0 {
358 return Err("n_heads must be > 0.".to_string());
359 }
360 if !q_total_cols.is_multiple_of(n_heads)
361 || !k_total_cols.is_multiple_of(n_heads)
362 || !v_total_cols.is_multiple_of(n_heads)
363 {
364 return Err(format!(
365 "Total columns must be divisible by n_heads={}. Got Q={}, K={}, V={}.",
366 n_heads, q_total_cols, k_total_cols, v_total_cols
367 ));
368 }
369 if q.len() != q_rows * q_total_cols {
370 return Err(format!(
371 "Q data length mismatch: got {}, expected {}.",
372 q.len(),
373 q_rows * q_total_cols
374 ));
375 }
376 if k.len() != k_rows * k_total_cols {
377 return Err(format!(
378 "K data length mismatch: got {}, expected {}.",
379 k.len(),
380 k_rows * k_total_cols
381 ));
382 }
383 if v.len() != v_rows * v_total_cols {
384 return Err(format!(
385 "V data length mismatch: got {}, expected {}.",
386 v.len(),
387 v_rows * v_total_cols
388 ));
389 }
390 let dk = q_total_cols / n_heads;
391 let dk_k = k_total_cols / n_heads;
392 if dk != dk_k {
393 return Err(format!(
394 "Q/K head dimensions must match: Q_head={}, K_head={}.",
395 dk, dk_k
396 ));
397 }
398 Ok(())
399}
400
401#[allow(clippy::too_many_arguments)]
402fn validate_shapes(
403 q: &[f64],
404 q_rows: usize,
405 q_cols: usize,
406 k: &[f64],
407 k_rows: usize,
408 k_cols: usize,
409 v: &[f64],
410 v_rows: usize,
411 v_cols: usize,
412) -> Result<(), String> {
413 if q_cols != k_cols {
414 return Err(format!(
415 "Q/K dimension mismatch: q_cols={}, k_cols={}.",
416 q_cols, k_cols
417 ));
418 }
419 if k_rows != v_rows {
420 return Err(format!(
421 "K/V row mismatch: k_rows={}, v_rows={}.",
422 k_rows, v_rows
423 ));
424 }
425 if q.len() != q_rows * q_cols {
426 return Err(format!(
427 "Q data length mismatch: got {}, expected {}.",
428 q.len(),
429 q_rows * q_cols
430 ));
431 }
432 if k.len() != k_rows * k_cols {
433 return Err(format!(
434 "K data length mismatch: got {}, expected {}.",
435 k.len(),
436 k_rows * k_cols
437 ));
438 }
439 if v.len() != v_rows * v_cols {
440 return Err(format!(
441 "V data length mismatch: got {}, expected {}.",
442 v.len(),
443 v_rows * v_cols
444 ));
445 }
446 Ok(())
447}
448
449fn flatten_rows(rows: Vec<Vec<f64>>, n_rows: usize, n_cols: usize) -> Vec<f64> {
450 let mut flat = Vec::with_capacity(n_rows * n_cols);
451 for row in rows {
452 flat.extend(row);
453 }
454 flat
455}
456
457fn extract_head_columns(
459 matrix: &[f64],
460 rows: usize,
461 total_cols: usize,
462 head_idx: usize,
463 head_cols: usize,
464) -> Vec<f64> {
465 let offset = head_idx * head_cols;
466 let mut out = Vec::with_capacity(rows * head_cols);
467 for i in 0..rows {
468 let row_start = i * total_cols + offset;
469 out.extend_from_slice(&matrix[row_start..row_start + head_cols]);
470 }
471 out
472}