Skip to main content

sc_neurocore_engine/
attention.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Stochastic Attention
8
9//! # Stochastic Attention
10//!
11//! Rate-mode and SC-mode attention primitives used by the Python bridge.
12
13use rand::SeedableRng;
14use rand_chacha::ChaCha8Rng;
15use rayon::prelude::*;
16
17pub struct StochasticAttention {
18    pub dim_k: usize,
19    /// Softmax temperature. Default: sqrt(dim_k) (Vaswani et al. 2017).
20    pub temperature: f64,
21}
22
23impl StochasticAttention {
24    pub fn new(dim_k: usize) -> Self {
25        Self {
26            dim_k,
27            temperature: (dim_k as f64).sqrt(),
28        }
29    }
30
31    pub fn with_temperature(dim_k: usize, temperature: f64) -> Self {
32        Self { dim_k, temperature }
33    }
34
35    /// Softmax attention: Q·K^T / temperature → softmax → · V.
36    ///
37    /// Numerically stable: subtract row max before exp().
38    #[allow(clippy::too_many_arguments)]
39    pub fn forward_softmax(
40        &self,
41        q: &[f64],
42        q_rows: usize,
43        q_cols: usize,
44        k: &[f64],
45        k_rows: usize,
46        k_cols: usize,
47        v: &[f64],
48        v_rows: usize,
49        v_cols: usize,
50    ) -> Result<Vec<f64>, String> {
51        validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
52        let inv_temp = if self.temperature > 0.0 {
53            1.0 / self.temperature
54        } else {
55            1.0
56        };
57
58        let out_rows: Vec<Vec<f64>> = (0..q_rows)
59            .into_par_iter()
60            .map(|i| {
61                let q_row = &q[i * q_cols..(i + 1) * q_cols];
62
63                let mut scores = vec![0.0_f64; k_rows];
64                for j in 0..k_rows {
65                    let k_row = &k[j * k_cols..(j + 1) * k_cols];
66                    scores[j] = crate::simd::dot_f64_dispatch(q_row, k_row) * inv_temp;
67                }
68
69                crate::simd::softmax_inplace_f64_dispatch(&mut scores);
70
71                // Weighted sum over V
72                let mut out = vec![0.0_f64; v_cols];
73                for d in 0..v_cols {
74                    let mut acc = 0.0_f64;
75                    for j in 0..k_rows {
76                        acc += scores[j] * v[j * v_cols + d];
77                    }
78                    out[d] = acc;
79                }
80                out
81            })
82            .collect();
83
84        Ok(flatten_rows(out_rows, q_rows, v_cols))
85    }
86
87    /// Multi-head softmax attention.
88    #[allow(clippy::too_many_arguments)]
89    pub fn forward_multihead_softmax(
90        &self,
91        q: &[f64],
92        q_rows: usize,
93        q_total_cols: usize,
94        k: &[f64],
95        k_rows: usize,
96        k_total_cols: usize,
97        v: &[f64],
98        v_rows: usize,
99        v_total_cols: usize,
100        n_heads: usize,
101    ) -> Result<Vec<f64>, String> {
102        validate_multihead_shapes(
103            q,
104            q_rows,
105            q_total_cols,
106            k,
107            k_rows,
108            k_total_cols,
109            v,
110            v_rows,
111            v_total_cols,
112            n_heads,
113        )?;
114
115        let dk = q_total_cols / n_heads;
116        let dv = v_total_cols / n_heads;
117        let head_attn = Self::with_temperature(dk, self.temperature);
118
119        let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
120            .into_par_iter()
121            .map(|h| {
122                let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
123                let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
124                let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
125                head_attn.forward_softmax(
126                    &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
127                )
128            })
129            .collect();
130        let head_outputs = head_outputs?;
131
132        let out_cols = dv * n_heads;
133        let mut out = Vec::with_capacity(q_rows * out_cols);
134        for i in 0..q_rows {
135            for head in head_outputs.iter().take(n_heads) {
136                out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
137            }
138        }
139        Ok(out)
140    }
141
142    /// Rate-mode attention forward pass.
143    ///
144    /// Shapes:
145    /// - `q`: `(q_rows, q_cols)`
146    /// - `k`: `(k_rows, k_cols)`
147    /// - `v`: `(v_rows, v_cols)`
148    #[allow(clippy::too_many_arguments)]
149    pub fn forward(
150        &self,
151        q: &[f64],
152        q_rows: usize,
153        q_cols: usize,
154        k: &[f64],
155        k_rows: usize,
156        k_cols: usize,
157        v: &[f64],
158        v_rows: usize,
159        v_cols: usize,
160    ) -> Result<Vec<f64>, String> {
161        validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
162
163        let score_rows: Vec<Vec<f64>> = (0..q_rows)
164            .into_par_iter()
165            .map(|i| {
166                let q_row = &q[i * q_cols..(i + 1) * q_cols];
167                let mut row = vec![0.0_f64; k_rows];
168                for j in 0..k_rows {
169                    let k_row = &k[j * k_cols..(j + 1) * k_cols];
170                    row[j] = crate::simd::dot_f64_dispatch(q_row, k_row);
171                }
172                row
173            })
174            .collect();
175
176        let out_rows: Vec<Vec<f64>> = (0..q_rows)
177            .into_par_iter()
178            .map(|i| {
179                let scores = &score_rows[i];
180                let mut row_sum = scores.iter().sum::<f64>();
181                if row_sum == 0.0 {
182                    row_sum = 1.0;
183                }
184
185                let mut out = vec![0.0_f64; v_cols];
186                for d in 0..v_cols {
187                    let mut acc = 0.0_f64;
188                    for j in 0..k_rows {
189                        let weight = scores[j] / row_sum;
190                        acc += weight * v[j * v_cols + d];
191                    }
192                    out[d] = acc;
193                }
194                out
195            })
196            .collect();
197
198        Ok(flatten_rows(out_rows, q_rows, v_cols))
199    }
200
201    /// SC-mode attention forward pass with Bernoulli bitstream encoding.
202    #[allow(clippy::too_many_arguments)]
203    pub fn forward_sc(
204        &self,
205        q: &[f64],
206        q_rows: usize,
207        q_cols: usize,
208        k: &[f64],
209        k_rows: usize,
210        k_cols: usize,
211        v: &[f64],
212        v_rows: usize,
213        v_cols: usize,
214        length: usize,
215        seed: u64,
216    ) -> Result<Vec<f64>, String> {
217        validate_shapes(q, q_rows, q_cols, k, k_rows, k_cols, v, v_rows, v_cols)?;
218        if length == 0 {
219            return Err("length must be > 0 for SC mode.".to_string());
220        }
221
222        let mut rng = ChaCha8Rng::seed_from_u64(seed);
223        let words = length.div_ceil(64);
224
225        let q_packed = crate::bitstream::encode_matrix_prob_to_packed(
226            q, q_rows, q_cols, length, words, &mut rng,
227        );
228        let k_packed = crate::bitstream::encode_matrix_prob_to_packed(
229            k, k_rows, k_cols, length, words, &mut rng,
230        );
231        let v_packed = crate::bitstream::encode_matrix_prob_to_packed(
232            v, v_rows, v_cols, length, words, &mut rng,
233        );
234
235        let mut score_rows = vec![vec![0.0_f64; k_rows]; q_rows];
236        for (i, score_row) in score_rows.iter_mut().enumerate().take(q_rows) {
237            for (j, score_value) in score_row.iter_mut().enumerate().take(k_rows) {
238                let mut pop_total = 0_u64;
239                for d in 0..q_cols {
240                    let q_idx = i * q_cols + d;
241                    let k_idx = j * k_cols + d;
242                    let qa = &q_packed[q_idx];
243                    let kb = &k_packed[k_idx];
244                    for w in 0..words {
245                        pop_total += crate::bitstream::swar_popcount_word(qa[w] & kb[w]);
246                    }
247                }
248                *score_value = pop_total as f64 / length as f64;
249            }
250        }
251
252        let attn_weights: Vec<Vec<f64>> = score_rows
253            .iter()
254            .map(|row| {
255                let mut row_sum = row.iter().sum::<f64>();
256                if row_sum == 0.0 {
257                    row_sum = 1.0;
258                }
259                row.iter().map(|x| x / row_sum).collect()
260            })
261            .collect();
262
263        let attn_flat: Vec<f64> = attn_weights.into_iter().flatten().collect();
264        let attn_packed = crate::bitstream::encode_matrix_prob_to_packed(
265            &attn_flat, q_rows, k_rows, length, words, &mut rng,
266        );
267
268        let out_rows: Vec<Vec<f64>> = (0..q_rows)
269            .into_par_iter()
270            .map(|i| {
271                let mut out = vec![0.0_f64; v_cols];
272                for d in 0..v_cols {
273                    let mut pop_total = 0_u64;
274                    for j in 0..k_rows {
275                        let a = &attn_packed[i * k_rows + j];
276                        let b = &v_packed[j * v_cols + d];
277                        for w in 0..words {
278                            pop_total += crate::bitstream::swar_popcount_word(a[w] & b[w]);
279                        }
280                    }
281                    out[d] = pop_total as f64 / length as f64;
282                }
283                out
284            })
285            .collect();
286
287        Ok(flatten_rows(out_rows, q_rows, v_cols))
288    }
289
290    /// Multi-head linear attention (backwards compat).
291    #[allow(clippy::too_many_arguments)]
292    pub fn forward_multihead(
293        &self,
294        q: &[f64],
295        q_rows: usize,
296        q_total_cols: usize,
297        k: &[f64],
298        k_rows: usize,
299        k_total_cols: usize,
300        v: &[f64],
301        v_rows: usize,
302        v_total_cols: usize,
303        n_heads: usize,
304    ) -> Result<Vec<f64>, String> {
305        validate_multihead_shapes(
306            q,
307            q_rows,
308            q_total_cols,
309            k,
310            k_rows,
311            k_total_cols,
312            v,
313            v_rows,
314            v_total_cols,
315            n_heads,
316        )?;
317
318        let dk = q_total_cols / n_heads;
319        let dv = v_total_cols / n_heads;
320
321        let head_outputs: Result<Vec<Vec<f64>>, String> = (0..n_heads)
322            .into_par_iter()
323            .map(|h| {
324                let q_head = extract_head_columns(q, q_rows, q_total_cols, h, dk);
325                let k_head = extract_head_columns(k, k_rows, k_total_cols, h, dk);
326                let v_head = extract_head_columns(v, v_rows, v_total_cols, h, dv);
327                self.forward(
328                    &q_head, q_rows, dk, &k_head, k_rows, dk, &v_head, v_rows, dv,
329                )
330            })
331            .collect();
332        let head_outputs = head_outputs?;
333
334        let out_cols = dv * n_heads;
335        let mut out = Vec::with_capacity(q_rows * out_cols);
336        for i in 0..q_rows {
337            for head in head_outputs.iter().take(n_heads) {
338                out.extend_from_slice(&head[i * dv..(i + 1) * dv]);
339            }
340        }
341        Ok(out)
342    }
343}
344
345#[allow(clippy::too_many_arguments)]
346fn validate_multihead_shapes(
347    q: &[f64],
348    q_rows: usize,
349    q_total_cols: usize,
350    k: &[f64],
351    k_rows: usize,
352    k_total_cols: usize,
353    v: &[f64],
354    v_rows: usize,
355    v_total_cols: usize,
356    n_heads: usize,
357) -> Result<(), String> {
358    if n_heads == 0 {
359        return Err("n_heads must be > 0.".to_string());
360    }
361    if !q_total_cols.is_multiple_of(n_heads)
362        || !k_total_cols.is_multiple_of(n_heads)
363        || !v_total_cols.is_multiple_of(n_heads)
364    {
365        return Err(format!(
366            "Total columns must be divisible by n_heads={}. Got Q={}, K={}, V={}.",
367            n_heads, q_total_cols, k_total_cols, v_total_cols
368        ));
369    }
370    if q.len() != q_rows * q_total_cols {
371        return Err(format!(
372            "Q data length mismatch: got {}, expected {}.",
373            q.len(),
374            q_rows * q_total_cols
375        ));
376    }
377    if k.len() != k_rows * k_total_cols {
378        return Err(format!(
379            "K data length mismatch: got {}, expected {}.",
380            k.len(),
381            k_rows * k_total_cols
382        ));
383    }
384    if v.len() != v_rows * v_total_cols {
385        return Err(format!(
386            "V data length mismatch: got {}, expected {}.",
387            v.len(),
388            v_rows * v_total_cols
389        ));
390    }
391    let dk = q_total_cols / n_heads;
392    let dk_k = k_total_cols / n_heads;
393    if dk != dk_k {
394        return Err(format!(
395            "Q/K head dimensions must match: Q_head={}, K_head={}.",
396            dk, dk_k
397        ));
398    }
399    Ok(())
400}
401
402#[allow(clippy::too_many_arguments)]
403fn validate_shapes(
404    q: &[f64],
405    q_rows: usize,
406    q_cols: usize,
407    k: &[f64],
408    k_rows: usize,
409    k_cols: usize,
410    v: &[f64],
411    v_rows: usize,
412    v_cols: usize,
413) -> Result<(), String> {
414    if q_cols != k_cols {
415        return Err(format!(
416            "Q/K dimension mismatch: q_cols={}, k_cols={}.",
417            q_cols, k_cols
418        ));
419    }
420    if k_rows != v_rows {
421        return Err(format!(
422            "K/V row mismatch: k_rows={}, v_rows={}.",
423            k_rows, v_rows
424        ));
425    }
426    if q.len() != q_rows * q_cols {
427        return Err(format!(
428            "Q data length mismatch: got {}, expected {}.",
429            q.len(),
430            q_rows * q_cols
431        ));
432    }
433    if k.len() != k_rows * k_cols {
434        return Err(format!(
435            "K data length mismatch: got {}, expected {}.",
436            k.len(),
437            k_rows * k_cols
438        ));
439    }
440    if v.len() != v_rows * v_cols {
441        return Err(format!(
442            "V data length mismatch: got {}, expected {}.",
443            v.len(),
444            v_rows * v_cols
445        ));
446    }
447    Ok(())
448}
449
450fn flatten_rows(rows: Vec<Vec<f64>>, n_rows: usize, n_cols: usize) -> Vec<f64> {
451    let mut flat = Vec::with_capacity(n_rows * n_cols);
452    for row in rows {
453        flat.extend(row);
454    }
455    flat
456}
457
458/// Extract one head slice from a row-major matrix.
459fn extract_head_columns(
460    matrix: &[f64],
461    rows: usize,
462    total_cols: usize,
463    head_idx: usize,
464    head_cols: usize,
465) -> Vec<f64> {
466    let offset = head_idx * head_cols;
467    let mut out = Vec::with_capacity(rows * head_cols);
468    for i in 0..rows {
469        let row_start = i * total_cols + offset;
470        out.extend_from_slice(&matrix[row_start..row_start + head_cols]);
471    }
472    out
473}