Skip to main content

sc_neurocore_engine/
attention.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore — Stochastic Attention
7
8//! # Stochastic Attention
9//!
10//! Rate-mode and SC-mode attention primitives used by the Python bridge.
11
12use rand::SeedableRng;
13use rand_chacha::ChaCha8Rng;
14use rayon::prelude::*;
15
16pub struct StochasticAttention {
17    pub dim_k: usize,
18    /// Softmax temperature. Default: sqrt(dim_k) (Vaswani et al. 2017).
19    pub temperature: f64,
20}
21
22impl StochasticAttention {
23    pub fn new(dim_k: usize) -> Self {
24        Self {
25            dim_k,
26            temperature: (dim_k as f64).sqrt(),
27        }
28    }
29
30    pub fn with_temperature(dim_k: usize, temperature: f64) -> Self {
31        Self { dim_k, temperature }
32    }
33
34    /// Softmax attention: Q·K^T / temperature → softmax → · V.
35    ///
36    /// Numerically stable: subtract row max before exp().
37    #[allow(clippy::too_many_arguments)]
38    pub fn forward_softmax(
39        &self,
40        q: &[f64],
41        q_rows: usize,
42        q_cols: usize,
43        k: &[f64],
44        k_rows: usize,
45        k_cols: usize,
46        v: &[f64],
47        v_rows: usize,
48        v_cols: usize,
49    ) -> Result<Vec<f64>, String> {
50        validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
51        let inv_temp = if self.temperature > 0.0 {
52            1.0 / self.temperature
53        } else {
54            1.0
55        };
56
57        let out_rows: Vec<Vec<f64>> = (0..q_rows)
58            .into_par_iter()
59            .map(|i| {
60                let q_row = &q[i * q_cols..(i + 1) * q_cols];
61
62                let mut scores = vec![0.0_f64; k_rows];
63                for j in 0..k_rows {
64                    let k_row = &k[j * k_cols..(j + 1) * k_cols];
65                    scores[j] = crate::simd::dot_f64_dispatch(q_row, k_row) * inv_temp;
66                }
67
68                crate::simd::softmax_inplace_f64_dispatch(&mut scores);
69
70                // Weighted sum over V
71                let mut out = vec![0.0_f64; v_cols];
72                for d in 0..v_cols {
73                    let mut acc = 0.0_f64;
74                    for j in 0..k_rows {
75                        acc += scores[j] * v[j * v_cols + d];
76                    }
77                    out[d] = acc;
78                }
79                out
80            })
81            .collect();
82
83        Ok(flatten_rows(out_rows, q_rows, v_cols))
84    }
85
86    /// Multi-head softmax attention.
87    #[allow(clippy::too_many_arguments)]
88    pub fn forward_multihead_softmax(
89        &self,
90        q: &[f64],
91        q_rows: usize,
92        q_total_cols: usize,
93        k: &[f64],
94        k_rows: usize,
95        k_total_cols: usize,
96        v: &[f64],
97        v_rows: usize,
98        v_total_cols: usize,
99        n_heads: usize,
100    ) -> Result<Vec<f64>, String> {
101        validate_multihead_shapes(
102            q,
103            q_rows,
104            q_total_cols,
105            k,
106            k_rows,
107            k_total_cols,
108            v,
109            v_rows,
110            v_total_cols,
111            n_heads,
112        )?;
113
114        let dk = q_total_cols / n_heads;
115        let dv = v_total_cols / n_heads;
116        let head_attn = Self::with_temperature(dk, self.temperature);
117
118        let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
119            .into_par_iter()
120            .map(|h| {
121                let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
122                let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
123                let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
124                head_attn.forward_softmax(
125                    &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
126                )
127            })
128            .collect();
129        let head_outputs = head_outputs?;
130
131        let out_cols = dv * n_heads;
132        let mut out = Vec::with_capacity(q_rows * out_cols);
133        for i in 0..q_rows {
134            for head in head_outputs.iter().take(n_heads) {
135                out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
136            }
137        }
138        Ok(out)
139    }
140
141    /// Rate-mode attention forward pass.
142    ///
143    /// Shapes:
144    /// - `q`: `(q_rows, q_cols)`
145    /// - `k`: `(k_rows, k_cols)`
146    /// - `v`: `(v_rows, v_cols)`
147    #[allow(clippy::too_many_arguments)]
148    pub fn forward(
149        &self,
150        q: &[f64],
151        q_rows: usize,
152        q_cols: usize,
153        k: &[f64],
154        k_rows: usize,
155        k_cols: usize,
156        v: &[f64],
157        v_rows: usize,
158        v_cols: usize,
159    ) -> Result<Vec<f64>, String> {
160        validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
161
162        let score_rows: Vec<Vec<f64>> = (0..q_rows)
163            .into_par_iter()
164            .map(|i| {
165                let q_row = &q[i * q_cols..(i + 1) * q_cols];
166                let mut row = vec![0.0_f64; k_rows];
167                for j in 0..k_rows {
168                    let k_row = &k[j * k_cols..(j + 1) * k_cols];
169                    row[j] = crate::simd::dot_f64_dispatch(q_row, k_row);
170                }
171                row
172            })
173            .collect();
174
175        let out_rows: Vec<Vec<f64>> = (0..q_rows)
176            .into_par_iter()
177            .map(|i| {
178                let scores = &score_rows[i];
179                let mut row_sum = scores.iter().sum::<f64>();
180                if row_sum == 0.0 {
181                    row_sum = 1.0;
182                }
183
184                let mut out = vec![0.0_f64; v_cols];
185                for d in 0..v_cols {
186                    let mut acc = 0.0_f64;
187                    for j in 0..k_rows {
188                        let weight = scores[j] / row_sum;
189                        acc += weight * v[j * v_cols + d];
190                    }
191                    out[d] = acc;
192                }
193                out
194            })
195            .collect();
196
197        Ok(flatten_rows(out_rows, q_rows, v_cols))
198    }
199
200    /// SC-mode attention forward pass with Bernoulli bitstream encoding.
201    #[allow(clippy::too_many_arguments)]
202    pub fn forward_sc(
203        &self,
204        q: &[f64],
205        q_rows: usize,
206        q_cols: usize,
207        k: &[f64],
208        k_rows: usize,
209        k_cols: usize,
210        v: &[f64],
211        v_rows: usize,
212        v_cols: usize,
213        length: usize,
214        seed: u64,
215    ) -> Result<Vec<f64>, String> {
216        validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
217        if length == 0 {
218            return Err("length must be > 0 for SC mode.".to_string());
219        }
220
221        let mut rng = ChaCha8Rng::seed_from_u64(seed);
222        let words = length.div_ceil(64);
223
224        let q_packed = crate::bitstream::encode_matrix_prob_to_packed(
225            q, q_rows, q_cols, length, words, &mut rng,
226        );
227        let k_packed = crate::bitstream::encode_matrix_prob_to_packed(
228            k, k_rows, k_cols, length, words, &mut rng,
229        );
230        let v_packed = crate::bitstream::encode_matrix_prob_to_packed(
231            v, v_rows, v_cols, length, words, &mut rng,
232        );
233
234        let mut score_rows = vec![vec![0.0_f64; k_rows]; q_rows];
235        for (i, score_row) in score_rows.iter_mut().enumerate().take(q_rows) {
236            for (j, score_value) in score_row.iter_mut().enumerate().take(k_rows) {
237                let mut pop_total = 0_u64;
238                for d in 0..q_cols {
239                    let q_idx = i * q_cols + d;
240                    let k_idx = j * k_cols + d;
241                    let qa = &q_packed[q_idx];
242                    let kb = &k_packed[k_idx];
243                    for w in 0..words {
244                        pop_total += crate::bitstream::swar_popcount_word(qa[w] & kb[w]);
245                    }
246                }
247                *score_value = pop_total as f64 / length as f64;
248            }
249        }
250
251        let attn_weights: Vec<Vec<f64>> = score_rows
252            .iter()
253            .map(|row| {
254                let mut row_sum = row.iter().sum::<f64>();
255                if row_sum == 0.0 {
256                    row_sum = 1.0;
257                }
258                row.iter().map(|x| x / row_sum).collect()
259            })
260            .collect();
261
262        let attn_flat: Vec<f64> = attn_weights.into_iter().flatten().collect();
263        let attn_packed = crate::bitstream::encode_matrix_prob_to_packed(
264            &attn_flat, q_rows, k_rows, length, words, &mut rng,
265        );
266
267        let out_rows: Vec<Vec<f64>> = (0..q_rows)
268            .into_par_iter()
269            .map(|i| {
270                let mut out = vec![0.0_f64; v_cols];
271                for d in 0..v_cols {
272                    let mut pop_total = 0_u64;
273                    for j in 0..k_rows {
274                        let a = &attn_packed[i * k_rows + j];
275                        let b = &v_packed[j * v_cols + d];
276                        for w in 0..words {
277                            pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
278                        }
279                    }
280                    out[d] = pop_total as f64 / length as f64;
281                }
282                out
283            })
284            .collect();
285
286        Ok(flatten_rows(out_rows, q_rows, v_cols))
287    }
288
289    /// Multi-head linear attention (backwards compat).
290    #[allow(clippy::too_many_arguments)]
291    pub fn forward_multihead(
292        &self,
293        q: &[f64],
294        q_rows: usize,
295        q_total_cols: usize,
296        k: &[f64],
297        k_rows: usize,
298        k_total_cols: usize,
299        v: &[f64],
300        v_rows: usize,
301        v_total_cols: usize,
302        n_heads: usize,
303    ) -> Result<Vec<f64>, String> {
304        validate_multihead_shapes(
305            q,
306            q_rows,
307            q_total_cols,
308            k,
309            k_rows,
310            k_total_cols,
311            v,
312            v_rows,
313            v_total_cols,
314            n_heads,
315        )?;
316
317        let dk = q_total_cols / n_heads;
318        let dv = v_total_cols / n_heads;
319
320        let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
321            .into_par_iter()
322            .map(|h| {
323                let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
324                let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
325                let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
326                self.forward(
327                    &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
328                )
329            })
330            .collect();
331        let head_outputs = head_outputs?;
332
333        let out_cols = dv * n_heads;
334        let mut out = Vec::with_capacity(q_rows * out_cols);
335        for i in 0..q_rows {
336            for head in head_outputs.iter().take(n_heads) {
337                out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
338            }
339        }
340        Ok(out)
341    }
342}
343
344#[allow(clippy::too_many_arguments)]
345fn validate_multihead_shapes(
346    q: &[f64],
347    q_rows: usize,
348    q_total_cols: usize,
349    k: &[f64],
350    k_rows: usize,
351    k_total_cols: usize,
352    v: &[f64],
353    v_rows: usize,
354    v_total_cols: usize,
355    n_heads: usize,
356) -> Result<(), String> {
357    if n_heads == 0 {
358        return Err("n_heads must be > 0.".to_string());
359    }
360    if !q_total_cols.is_multiple_of(n_heads)
361        || !k_total_cols.is_multiple_of(n_heads)
362        || !v_total_cols.is_multiple_of(n_heads)
363    {
364        return Err(format!(
365            "Total columns must be divisible by n_heads={}. Got Q={}, K={}, V={}.",
366            n_heads, q_total_cols, k_total_cols, v_total_cols
367        ));
368    }
369    if q.len() != q_rows * q_total_cols {
370        return Err(format!(
371            "Q data length mismatch: got {}, expected {}.",
372            q.len(),
373            q_rows * q_total_cols
374        ));
375    }
376    if k.len() != k_rows * k_total_cols {
377        return Err(format!(
378            "K data length mismatch: got {}, expected {}.",
379            k.len(),
380            k_rows * k_total_cols
381        ));
382    }
383    if v.len() != v_rows * v_total_cols {
384        return Err(format!(
385            "V data length mismatch: got {}, expected {}.",
386            v.len(),
387            v_rows * v_total_cols
388        ));
389    }
390    let dk = q_total_cols / n_heads;
391    let dk_k = k_total_cols / n_heads;
392    if dk != dk_k {
393        return Err(format!(
394            "Q/K head dimensions must match: Q_head={}, K_head={}.",
395            dk, dk_k
396        ));
397    }
398    Ok(())
399}
400
401#[allow(clippy::too_many_arguments)]
402fn validate_shapes(
403    q: &[f64],
404    q_rows: usize,
405    q_cols: usize,
406    k: &[f64],
407    k_rows: usize,
408    k_cols: usize,
409    v: &[f64],
410    v_rows: usize,
411    v_cols: usize,
412) -> Result<(), String> {
413    if q_cols != k_cols {
414        return Err(format!(
415            "Q/K dimension mismatch: q_cols={}, k_cols={}.",
416            q_cols, k_cols
417        ));
418    }
419    if k_rows != v_rows {
420        return Err(format!(
421            "K/V row mismatch: k_rows={}, v_rows={}.",
422            k_rows, v_rows
423        ));
424    }
425    if q.len() != q_rows * q_cols {
426        return Err(format!(
427            "Q data length mismatch: got {}, expected {}.",
428            q.len(),
429            q_rows * q_cols
430        ));
431    }
432    if k.len() != k_rows * k_cols {
433        return Err(format!(
434            "K data length mismatch: got {}, expected {}.",
435            k.len(),
436            k_rows * k_cols
437        ));
438    }
439    if v.len() != v_rows * v_cols {
440        return Err(format!(
441            "V data length mismatch: got {}, expected {}.",
442            v.len(),
443            v_rows * v_cols
444        ));
445    }
446    Ok(())
447}
448
449fn flatten_rows(rows: Vec<Vec<f64>>, n_rows: usize, n_cols: usize) -> Vec<f64> {
450    let mut flat = Vec::with_capacity(n_rows * n_cols);
451    for row in rows {
452        flat.extend(row);
453    }
454    flat
455}
456
457/// Extract one head slice from a row-major matrix.
458fn extract_head_columns(
459    matrix: &[f64],
460    rows: usize,
461    total_cols: usize,
462    head_idx: usize,
463    head_cols: usize,
464) -> Vec<f64> {
465    let offset = head_idx * head_cols;
466    let mut out = Vec::with_capacity(rows * head_cols);
467    for i in 0..rows {
468        let row_start = i * total_cols + offset;
469        out.extend_from_slice(&matrix[row_start..row_start + head_cols]);
470    }
471    out
472}