1use rand::SeedableRng;
14use rand_chacha::ChaCha8Rng;
15use rayon::prelude::*;
16
17pub struct StochasticAttention {
18 pub dim_k: usize,
19 pub temperature: f64,
21}
22
23impl StochasticAttention {
24 pub fn new(dim_k: usize) -> Self {
25 Self {
26 dim_k,
27 temperature: (dim_k as f64).sqrt(),
28 }
29 }
30
31 pub fn with_temperature(dim_k: usize, temperature: f64) -> Self {
32 Self { dim_k, temperature }
33 }
34
35 #[allow(clippy::too_many_arguments)]
39 pub fn forward_softmax(
40 &self,
41 q: &[f64],
42 q_rows: usize,
43 q_cols: usize,
44 k: &[f64],
45 k_rows: usize,
46 k_cols: usize,
47 v: &[f64],
48 v_rows: usize,
49 v_cols: usize,
50 ) -> Result<Vec<f64>, String> {
51 validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
52 let inv_temp = if self.temperature > 0.0 {
53 1.0 / self.temperature
54 } else {
55 1.0
56 };
57
58 let out_rows: Vec<Vec<f64>> = (0..q_rows)
59 .into_par_iter()
60 .map(|i| {
61 let q_row = &q[i * q_cols..(i + 1) * q_cols];
62
63 let mut scores = vec![0.0_f64; k_rows];
64 for j in 0..k_rows {
65 let k_row = &k[j * k_cols..(j + 1) * k_cols];
66 scores[j] = crate::simd::dot_f64_dispatch(q_row, k_row) * inv_temp;
67 }
68
69 crate::simd::softmax_inplace_f64_dispatch(&mut scores);
70
71 let mut out = vec![0.0_f64; v_cols];
73 for d in 0..v_cols {
74 let mut acc = 0.0_f64;
75 for j in 0..k_rows {
76 acc += scores[j] * v[j * v_cols + d];
77 }
78 out[d] = acc;
79 }
80 out
81 })
82 .collect();
83
84 Ok(flatten_rows(out_rows, q_rows, v_cols))
85 }
86
87 #[allow(clippy::too_many_arguments)]
89 pub fn forward_multihead_softmax(
90 &self,
91 q: &[f64],
92 q_rows: usize,
93 q_total_cols: usize,
94 k: &[f64],
95 k_rows: usize,
96 k_total_cols: usize,
97 v: &[f64],
98 v_rows: usize,
99 v_total_cols: usize,
100 n_heads: usize,
101 ) -> Result<Vec<f64>, String> {
102 validate_multihead_shapes(
103 q,
104 q_rows,
105 q_total_cols,
106 k,
107 k_rows,
108 k_total_cols,
109 v,
110 v_rows,
111 v_total_cols,
112 n_heads,
113 )?;
114
115 let dk = q_total_cols / n_heads;
116 let dv = v_total_cols / n_heads;
117 let head_attn = Self::with_temperature(dk, self.temperature);
118
119 let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
120 .into_par_iter()
121 .map(|h| {
122 let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
123 let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
124 let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
125 head_attn.forward_softmax(
126 &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
127 )
128 })
129 .collect();
130 let head_outputs = head_outputs?;
131
132 let out_cols = dv * n_heads;
133 let mut out = Vec::with_capacity(q_rows * out_cols);
134 for i in 0..q_rows {
135 for head in head_outputs.iter().take(n_heads) {
136 out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
137 }
138 }
139 Ok(out)
140 }
141
142 #[allow(clippy::too_many_arguments)]
149 pub fn forward(
150 &self,
151 q: &[f64],
152 q_rows: usize,
153 q_cols: usize,
154 k: &[f64],
155 k_rows: usize,
156 k_cols: usize,
157 v: &[f64],
158 v_rows: usize,
159 v_cols: usize,
160 ) -> Result<Vec<f64>, String> {
161 validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
162
163 let score_rows: Vec<Vec<f64>> = (0..q_rows)
164 .into_par_iter()
165 .map(|i| {
166 let q_row = &q[i * q_cols..(i + 1) * q_cols];
167 let mut row = vec![0.0_f64; k_rows];
168 for j in 0..k_rows {
169 let k_row = &k[j * k_cols..(j + 1) * k_cols];
170 row[j] = crate::simd::dot_f64_dispatch(q_row, k_row);
171 }
172 row
173 })
174 .collect();
175
176 let out_rows: Vec<Vec<f64>> = (0..q_rows)
177 .into_par_iter()
178 .map(|i| {
179 let scores = &score_rows[i];
180 let mut row_sum = scores.iter().sum::<f64>();
181 if row_sum == 0.0 {
182 row_sum = 1.0;
183 }
184
185 let mut out = vec![0.0_f64; v_cols];
186 for d in 0..v_cols {
187 let mut acc = 0.0_f64;
188 for j in 0..k_rows {
189 let weight = scores[j] / row_sum;
190 acc += weight * v[j * v_cols + d];
191 }
192 out[d] = acc;
193 }
194 out
195 })
196 .collect();
197
198 Ok(flatten_rows(out_rows, q_rows, v_cols))
199 }
200
201 #[allow(clippy::too_many_arguments)]
203 pub fn forward_sc(
204 &self,
205 q: &[f64],
206 q_rows: usize,
207 q_cols: usize,
208 k: &[f64],
209 k_rows: usize,
210 k_cols: usize,
211 v: &[f64],
212 v_rows: usize,
213 v_cols: usize,
214 length: usize,
215 seed: u64,
216 ) -> Result<Vec<f64>, String> {
217 validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
218 if length == 0 {
219 return Err("length must be > 0 for SC mode.".to_string());
220 }
221
222 let mut rng = ChaCha8Rng::seed_from_u64(seed);
223 let words = length.div_ceil(64);
224
225 let q_packed = crate::bitstream::encode_matrix_prob_to_packed(
226 q, q_rows, q_cols, length, words, &mut rng,
227 );
228 let k_packed = crate::bitstream::encode_matrix_prob_to_packed(
229 k, k_rows, k_cols, length, words, &mut rng,
230 );
231 let v_packed = crate::bitstream::encode_matrix_prob_to_packed(
232 v, v_rows, v_cols, length, words, &mut rng,
233 );
234
235 let mut score_rows = vec![vec![0.0_f64; k_rows]; q_rows];
236 for (i, score_row) in score_rows.iter_mut().enumerate().take(q_rows) {
237 for (j, score_value) in score_row.iter_mut().enumerate().take(k_rows) {
238 let mut pop_total = 0_u64;
239 for d in 0..q_cols {
240 let q_idx = i * q_cols + d;
241 let k_idx = j * k_cols + d;
242 let qa = &q_packed[q_idx];
243 let kb = &k_packed[k_idx];
244 for w in 0..words {
245 pop_total += crate::bitstream::swar_popcount_word(qa[w] & kb[w]);
246 }
247 }
248 *score_value = pop_total as f64 / length as f64;
249 }
250 }
251
252 let attn_weights: Vec<Vec<f64>> = score_rows
253 .iter()
254 .map(|row| {
255 let mut row_sum = row.iter().sum::<f64>();
256 if row_sum == 0.0 {
257 row_sum = 1.0;
258 }
259 row.iter().map(|x| x / row_sum).collect()
260 })
261 .collect();
262
263 let attn_flat: Vec<f64> = attn_weights.into_iter().flatten().collect();
264 let attn_packed = crate::bitstream::encode_matrix_prob_to_packed(
265 &attn_flat, q_rows, k_rows, length, words, &mut rng,
266 );
267
268 let out_rows: Vec<Vec<f64>> = (0..q_rows)
269 .into_par_iter()
270 .map(|i| {
271 let mut out = vec![0.0_f64; v_cols];
272 for d in 0..v_cols {
273 let mut pop_total = 0_u64;
274 for j in 0..k_rows {
275 let a = &attn_packed[i * k_rows + j];
276 let b = &v_packed[j * v_cols + d];
277 for w in 0..words {
278 pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
279 }
280 }
281 out[d] = pop_total as f64 / length as f64;
282 }
283 out
284 })
285 .collect();
286
287 Ok(flatten_rows(out_rows, q_rows, v_cols))
288 }
289
290 #[allow(clippy::too_many_arguments)]
292 pub fn forward_multihead(
293 &self,
294 q: &[f64],
295 q_rows: usize,
296 q_total_cols: usize,
297 k: &[f64],
298 k_rows: usize,
299 k_total_cols: usize,
300 v: &[f64],
301 v_rows: usize,
302 v_total_cols: usize,
303 n_heads: usize,
304 ) -> Result<Vec<f64>, String> {
305 validate_multihead_shapes(
306 q,
307 q_rows,
308 q_total_cols,
309 k,
310 k_rows,
311 k_total_cols,
312 v,
313 v_rows,
314 v_total_cols,
315 n_heads,
316 )?;
317
318 let dk = q_total_cols / n_heads;
319 let dv = v_total_cols / n_heads;
320
321 let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
322 .into_par_iter()
323 .map(|h| {
324 let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
325 let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
326 let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
327 self.forward(
328 &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
329 )
330 })
331 .collect();
332 let head_outputs = head_outputs?;
333
334 let out_cols = dv * n_heads;
335 let mut out = Vec::with_capacity(q_rows * out_cols);
336 for i in 0..q_rows {
337 for head in head_outputs.iter().take(n_heads) {
338 out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
339 }
340 }
341 Ok(out)
342 }
343}
344
345#[allow(clippy::too_many_arguments)]
346fn validate_multihead_shapes(
347 q: &[f64],
348 q_rows: usize,
349 q_total_cols: usize,
350 k: &[f64],
351 k_rows: usize,
352 k_total_cols: usize,
353 v: &[f64],
354 v_rows: usize,
355 v_total_cols: usize,
356 n_heads: usize,
357) -> Result<(), String> {
358 if n_heads == 0 {
359 return Err("n_heads must be > 0.".to_string());
360 }
361 if !q_total_cols.is_multiple_of(n_heads)
362 || !k_total_cols.is_multiple_of(n_heads)
363 || !v_total_cols.is_multiple_of(n_heads)
364 {
365 return Err(format!(
366 "Total columns must be divisible by n_heads={}. Got Q={}, K={}, V={}.",
367 n_heads, q_total_cols, k_total_cols, v_total_cols
368 ));
369 }
370 if q.len() != q_rows * q_total_cols {
371 return Err(format!(
372 "Q data length mismatch: got {}, expected {}.",
373 q.len(),
374 q_rows * q_total_cols
375 ));
376 }
377 if k.len() != k_rows * k_total_cols {
378 return Err(format!(
379 "K data length mismatch: got {}, expected {}.",
380 k.len(),
381 k_rows * k_total_cols
382 ));
383 }
384 if v.len() != v_rows * v_total_cols {
385 return Err(format!(
386 "V data length mismatch: got {}, expected {}.",
387 v.len(),
388 v_rows * v_total_cols
389 ));
390 }
391 let dk = q_total_cols / n_heads;
392 let dk_k = k_total_cols / n_heads;
393 if dk != dk_k {
394 return Err(format!(
395 "Q/K head dimensions must match: Q_head={}, K_head={}.",
396 dk, dk_k
397 ));
398 }
399 Ok(())
400}
401
402#[allow(clippy::too_many_arguments)]
403fn validate_shapes(
404 q: &[f64],
405 q_rows: usize,
406 q_cols: usize,
407 k: &[f64],
408 k_rows: usize,
409 k_cols: usize,
410 v: &[f64],
411 v_rows: usize,
412 v_cols: usize,
413) -> Result<(), String> {
414 if q_cols != k_cols {
415 return Err(format!(
416 "Q/K dimension mismatch: q_cols={}, k_cols={}.",
417 q_cols, k_cols
418 ));
419 }
420 if k_rows != v_rows {
421 return Err(format!(
422 "K/V row mismatch: k_rows={}, v_rows={}.",
423 k_rows, v_rows
424 ));
425 }
426 if q.len() != q_rows * q_cols {
427 return Err(format!(
428 "Q data length mismatch: got {}, expected {}.",
429 q.len(),
430 q_rows * q_cols
431 ));
432 }
433 if k.len() != k_rows * k_cols {
434 return Err(format!(
435 "K data length mismatch: got {}, expected {}.",
436 k.len(),
437 k_rows * k_cols
438 ));
439 }
440 if v.len() != v_rows * v_cols {
441 return Err(format!(
442 "V data length mismatch: got {}, expected {}.",
443 v.len(),
444 v_rows * v_cols
445 ));
446 }
447 Ok(())
448}
449
450fn flatten_rows(rows: Vec<Vec<f64>>, n_rows: usize, n_cols: usize) -> Vec<f64> {
451 let mut flat = Vec::with_capacity(n_rows * n_cols);
452 for row in rows {
453 flat.extend(row);
454 }
455 flat
456}
457
458fn extract_head_columns(
460 matrix: &[f64],
461 rows: usize,
462 total_cols: usize,
463 head_idx: usize,
464 head_cols: usize,
465) -> Vec<f64> {
466 let offset = head_idx * head_cols;
467 let mut out = Vec::with_capacity(rows * head_cols);
468 for i in 0..rows {
469 let row_start = i * total_cols + offset;
470 out.extend_from_slice(&matrix[row_start..row_start + head_cols]);
471 }
472 out
473}