Skip to main content

sc_neurocore_engine/analysis/
neural_decoders.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Foundation-model neural population decoder primitives
8
9//! Core compute kernels for POYO+, POSSM, NDT3, and CEBRA decoders.
10//!
11//! These are the hot-path operations that benefit from Rust acceleration:
12//! spike tokenisation, sinusoidal position encoding, scaled dot-product
13//! attention, diagonal SSM step, and InfoNCE contrastive loss.
14
15use rayon::prelude::*;
16
17/// Token: (unit_id, timestamp_ms).
18pub type SpikeToken = (usize, f64);
19
20/// Convert binary spike trains to sorted (unit_id, timestamp) tokens.
21///
22/// Azabou et al. (2023), NeurIPS; Ryoo et al. (2025), ICLR.
23/// Each spike in each train produces one token. Tokens are sorted by time.
24pub fn tokenise_spikes(trains: &[&[i32]], dt: f64) -> Vec<SpikeToken> {
25    let mut tokens: Vec<SpikeToken> = trains
26        .par_iter()
27        .enumerate()
28        .flat_map_iter(|(uid, train)| {
29            train
30                .iter()
31                .enumerate()
32                .filter(|(_, &v)| v != 0)
33                .map(move |(idx, _)| (uid, idx as f64 * dt))
34        })
35        .collect();
36    tokens.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
37    tokens
38}
39
40/// Sinusoidal position encoding. Vaswani et al. (2017).
41///
42/// PE(t, 2i)   = sin(t / 10000^{2i/d})
43/// PE(t, 2i+1) = cos(t / 10000^{2i/d})
44///
45/// Output: flat row-major [n_timestamps × d_model].
46pub fn sinusoidal_position_encode(timestamps: &[f64], d_model: usize) -> Vec<f64> {
47    let n = timestamps.len();
48    let mut pe = vec![0.0_f64; n * d_model];
49    let half_d = d_model / 2 + d_model % 2;
50    let divisors: Vec<f64> = (0..half_d)
51        .map(|i| 10000.0_f64.powf(2.0 * i as f64 / d_model as f64))
52        .collect();
53
54    pe.par_chunks_mut(d_model)
55        .enumerate()
56        .for_each(|(row, pe_row)| {
57            let t = timestamps[row];
58            for (k, div) in divisors.iter().enumerate() {
59                let col_sin = 2 * k;
60                let col_cos = 2 * k + 1;
61                let angle = t / div;
62                pe_row[col_sin] = angle.sin();
63                if col_cos < d_model {
64                    pe_row[col_cos] = angle.cos();
65                }
66            }
67        });
68    pe
69}
70
71/// Scaled dot-product attention.
72///
73/// Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
74///
75/// All matrices are row-major flat: Q [nq × d], K [nk × d], V [nk × d].
76/// Output: [nq × d].
77pub fn scaled_dot_product_attention(
78    queries: &[f64],
79    keys: &[f64],
80    values: &[f64],
81    nq: usize,
82    nk: usize,
83    d: usize,
84) -> Vec<f64> {
85    let inv_sqrt_d = 1.0 / (d as f64).sqrt();
86    let mut output = vec![0.0_f64; nq * d];
87
88    output
89        .par_chunks_mut(d)
90        .enumerate()
91        .for_each(|(i, out_row)| {
92            let q_row = &queries[i * d..(i + 1) * d];
93            // Compute scores
94            let mut scores = vec![0.0_f64; nk];
95            let mut max_score = f64::NEG_INFINITY;
96            for j in 0..nk {
97                let k_row = &keys[j * d..(j + 1) * d];
98                let mut dot = 0.0;
99                for f in 0..d {
100                    dot += q_row[f] * k_row[f];
101                }
102                scores[j] = dot * inv_sqrt_d;
103                if scores[j] > max_score {
104                    max_score = scores[j];
105                }
106            }
107            // Stable softmax
108            let mut sum_exp = 0.0;
109            for s in &mut scores {
110                *s = (*s - max_score).exp();
111                sum_exp += *s;
112            }
113            let inv_sum = 1.0 / (sum_exp + 1e-30);
114            for s in &mut scores {
115                *s *= inv_sum;
116            }
117            // Weighted sum of values
118            for j in 0..nk {
119                let w = scores[j];
120                let v_row = &values[j * d..(j + 1) * d];
121                for f in 0..d {
122                    out_row[f] += w * v_row[f];
123                }
124            }
125        });
126    output
127}
128
129/// Gaussian attention. Li et al. (2025), scKGBERT.
130///
131/// α_ij = exp(-||q_i - k_j||² / (2σ²)) / Σ_m exp(-||q_i - k_m||² / (2σ²))
132///
133/// Q [nq × d], K [nk × d], V [nk × d]. Output: [nq × d].
134pub fn gaussian_attention(
135    queries: &[f64],
136    keys: &[f64],
137    values: &[f64],
138    nq: usize,
139    nk: usize,
140    d: usize,
141    sigma: f64,
142) -> Vec<f64> {
143    let inv_2sigma2 = 1.0 / (2.0 * sigma * sigma);
144    let mut output = vec![0.0_f64; nq * d];
145
146    output
147        .par_chunks_mut(d)
148        .enumerate()
149        .for_each(|(i, out_row)| {
150            let q_row = &queries[i * d..(i + 1) * d];
151            let mut log_weights = vec![0.0_f64; nk];
152            let mut max_lw = f64::NEG_INFINITY;
153            for j in 0..nk {
154                let k_row = &keys[j * d..(j + 1) * d];
155                let mut dist_sq = 0.0;
156                for f in 0..d {
157                    let diff = q_row[f] - k_row[f];
158                    dist_sq += diff * diff;
159                }
160                log_weights[j] = -dist_sq * inv_2sigma2;
161                if log_weights[j] > max_lw {
162                    max_lw = log_weights[j];
163                }
164            }
165            let mut sum_exp = 0.0;
166            for lw in &mut log_weights {
167                *lw = (*lw - max_lw).exp();
168                sum_exp += *lw;
169            }
170            let inv_sum = 1.0 / (sum_exp + 1e-30);
171            for j in 0..nk {
172                let w = log_weights[j] * inv_sum;
173                let v_row = &values[j * d..(j + 1) * d];
174                for f in 0..d {
175                    out_row[f] += w * v_row[f];
176                }
177            }
178        });
179    output
180}
181
182/// Diagonal SSM step. Gu et al. (2022), S4D; Ryoo et al. (2025), POSSM.
183///
184/// h_t = A_bar ⊙ h_{t-1} + B_bar x_t
185/// y_t = Re(C h_t) + D x_t
186///
187/// A_bar is complex diagonal [d_state] (re, im interleaved → 2 * d_state floats).
188/// B_bar is [d_state × d_model] complex flat (re, im interleaved).
189/// C is [d_model × d_state] complex flat.
190/// D is [d_model × d_model] real flat.
191/// h is [d_state] complex (re, im interleaved).
192/// x is [d_model] real input.
193///
194/// Returns y [d_model] and updates h in place.
195pub fn ssm_step_diagonal(
196    a_bar_re: &[f64],
197    a_bar_im: &[f64],
198    b_bar_re: &[f64],
199    b_bar_im: &[f64],
200    c_re: &[f64],
201    c_im: &[f64],
202    d_mat: &[f64],
203    h_re: &mut [f64],
204    h_im: &mut [f64],
205    x: &[f64],
206    d_state: usize,
207    d_model: usize,
208) -> Vec<f64> {
209    // h_t = A_bar ⊙ h_{t-1} + B_bar x_t
210    for s in 0..d_state {
211        // Complex multiply: (a_re + i*a_im) * (h_re + i*h_im)
212        let new_re = a_bar_re[s] * h_re[s] - a_bar_im[s] * h_im[s];
213        let new_im = a_bar_re[s] * h_im[s] + a_bar_im[s] * h_re[s];
214        // B_bar @ x: B_bar[s, :] . x
215        let mut bx_re = 0.0;
216        let mut bx_im = 0.0;
217        for m in 0..d_model {
218            bx_re += b_bar_re[s * d_model + m] * x[m];
219            bx_im += b_bar_im[s * d_model + m] * x[m];
220        }
221        h_re[s] = new_re + bx_re;
222        h_im[s] = new_im + bx_im;
223    }
224
225    // y_t = Re(C h_t) + D x_t
226    let mut y = vec![0.0_f64; d_model];
227    for m in 0..d_model {
228        let mut ch_re = 0.0;
229        for s in 0..d_state {
230            // Re(C[m,s] * h[s]) = C_re*h_re - C_im*h_im
231            ch_re += c_re[m * d_state + s] * h_re[s] - c_im[m * d_state + s] * h_im[s];
232        }
233        let mut dx = 0.0;
234        for m2 in 0..d_model {
235            dx += d_mat[m * d_model + m2] * x[m2];
236        }
237        y[m] = ch_re + dx;
238    }
239    y
240}
241
242/// InfoNCE contrastive loss. van den Oord et al. (2018); CEBRA.
243///
244/// L = -(1/N) Σ_i log( exp(sim(z_i, z_i^+)/τ) / Σ_j exp(sim(z_i, z_j)/τ) )
245/// where sim(a, b) = cosine similarity.
246///
247/// anchors, positives: [N × d] row-major flat.
248pub fn infonce_loss(
249    anchors: &[f64],
250    positives: &[f64],
251    n: usize,
252    d: usize,
253    temperature: f64,
254) -> f64 {
255    if n == 0 || d == 0 {
256        return 0.0;
257    }
258    let inv_tau = 1.0 / temperature;
259
260    // Normalise
261    let norm = |v: &[f64]| -> Vec<f64> {
262        let mut out = v.to_vec();
263        for i in 0..n {
264            let row = &mut out[i * d..(i + 1) * d];
265            let nrm: f64 = row.iter().map(|x| x * x).sum::<f64>().sqrt() + 1e-30;
266            for x in row.iter_mut() {
267                *x /= nrm;
268            }
269        }
270        out
271    };
272
273    let a_norm = norm(anchors);
274    let p_norm = norm(positives);
275
276    let total_loss: f64 = (0..n)
277        .into_par_iter()
278        .map(|i| {
279            let a_row = &a_norm[i * d..(i + 1) * d];
280            // Positive similarity (diagonal)
281            let p_row = &p_norm[i * d..(i + 1) * d];
282            let pos_sim: f64 = a_row.iter().zip(p_row).map(|(a, p)| a * p).sum();
283
284            // All similarities
285            let mut max_sim = f64::NEG_INFINITY;
286            let mut sims = vec![0.0_f64; n];
287            for j in 0..n {
288                let pj = &p_norm[j * d..(j + 1) * d];
289                let sim: f64 = a_row.iter().zip(pj).map(|(a, p)| a * p).sum();
290                sims[j] = sim * inv_tau;
291                if sims[j] > max_sim {
292                    max_sim = sims[j];
293                }
294            }
295            let sum_exp: f64 = sims.iter().map(|s| (s - max_sim).exp()).sum();
296            let log_softmax = pos_sim * inv_tau - max_sim - sum_exp.ln();
297            -log_softmax
298        })
299        .sum();
300
301    total_loss / n as f64
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_tokenise_empty() {
310        let tokens = tokenise_spikes(&[], 1.0);
311        assert!(tokens.is_empty());
312    }
313
314    #[test]
315    fn test_tokenise_single() {
316        let train = vec![0, 0, 1, 0, 0];
317        let tokens = tokenise_spikes(&[&train], 0.5);
318        assert_eq!(tokens.len(), 1);
319        assert_eq!(tokens[0].0, 0);
320        assert!((tokens[0].1 - 1.0).abs() < 1e-10);
321    }
322
323    #[test]
324    fn test_tokenise_sorted() {
325        let t0 = vec![0, 0, 0, 0, 1]; // spike at t=4
326        let t1 = vec![0, 1, 0, 0, 0]; // spike at t=1
327        let tokens = tokenise_spikes(&[&t0, &t1], 1.0);
328        assert_eq!(tokens.len(), 2);
329        assert!(tokens[0].1 <= tokens[1].1);
330    }
331
332    #[test]
333    fn test_sinusoidal_pe_shape() {
334        let ts = vec![0.0, 1.0, 2.0];
335        let pe = sinusoidal_position_encode(&ts, 8);
336        assert_eq!(pe.len(), 3 * 8);
337    }
338
339    #[test]
340    fn test_sinusoidal_pe_zero() {
341        let pe = sinusoidal_position_encode(&[0.0], 4);
342        assert!((pe[0] - 0.0).abs() < 1e-10); // sin(0)
343        assert!((pe[1] - 1.0).abs() < 1e-10); // cos(0)
344    }
345
346    #[test]
347    fn test_attention_shape() {
348        let q = vec![1.0, 0.0, 0.0, 1.0]; // 2×2
349        let k = vec![1.0, 0.0, 0.0, 1.0, 0.5, 0.5]; // 3×2
350        let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3×2
351        let out = scaled_dot_product_attention(&q, &k, &v, 2, 3, 2);
352        assert_eq!(out.len(), 4);
353    }
354
355    #[test]
356    fn test_gaussian_attention_concentrates() {
357        // Query at origin, one key at origin, one far away
358        let q = vec![0.0, 0.0];
359        let k = vec![0.0, 0.0, 100.0, 100.0];
360        let v = vec![1.0, 0.0, 0.0, 1.0];
361        let out = gaussian_attention(&q, &k, &v, 1, 2, 2, 0.01);
362        // Should concentrate on first key (distance=0)
363        assert!((out[0] - 1.0).abs() < 1e-3);
364        assert!((out[1] - 0.0).abs() < 1e-3);
365    }
366
367    #[test]
368    fn test_ssm_step_output_size() {
369        let d_state = 2;
370        let d_model = 3;
371        let a_re = vec![0.9, 0.8];
372        let a_im = vec![0.1, 0.2];
373        let b_re = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]; // 2×3
374        let b_im = vec![0.0; 6];
375        let c_re = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]; // 3×2
376        let c_im = vec![0.0; 6];
377        let d_mat = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; // 3×3 identity
378        let mut h_re = vec![0.0; 2];
379        let mut h_im = vec![0.0; 2];
380        let x = vec![1.0, 0.0, 0.0];
381        let y = ssm_step_diagonal(
382            &a_re, &a_im, &b_re, &b_im, &c_re, &c_im, &d_mat, &mut h_re, &mut h_im, &x, d_state,
383            d_model,
384        );
385        assert_eq!(y.len(), 3);
386    }
387
388    #[test]
389    fn test_ssm_state_update() {
390        let d_state = 1;
391        let d_model = 1;
392        let mut h_re = vec![0.0];
393        let mut h_im = vec![0.0];
394        ssm_step_diagonal(
395            &[0.9],
396            &[0.0],
397            &[1.0],
398            &[0.0],
399            &[1.0],
400            &[0.0],
401            &[0.0],
402            &mut h_re,
403            &mut h_im,
404            &[1.0],
405            d_state,
406            d_model,
407        );
408        // h_re should be 0.9 * 0 + 1.0 * 1.0 = 1.0
409        assert!((h_re[0] - 1.0).abs() < 1e-10);
410    }
411
412    #[test]
413    fn test_infonce_identical_pairs() {
414        let d = 4;
415        let n = 3;
416        let data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
417        let loss = infonce_loss(&data, &data, n, d, 1.0);
418        // Identical pairs: each positive is the anchor itself
419        // cosine similarity = 1 for diagonal, < 1 for off-diag
420        assert!(loss >= 0.0);
421    }
422
423    #[test]
424    fn test_infonce_temperature() {
425        let d = 2;
426        let n = 4;
427        let a = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0];
428        let p = a.clone();
429        let loss_cold = infonce_loss(&a, &p, n, d, 0.1);
430        let loss_hot = infonce_loss(&a, &p, n, d, 10.0);
431        assert!(loss_cold < loss_hot);
432    }
433}