Skip to main content

sc_neurocore_engine/analysis/
decoding.rs

1use rayon::prelude::*;
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial license available
4// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
5// © Code 2020–2026 Miroslav Šotek. All rights reserved.
6// ORCID: 0009-0009-3560-0851
7// Contact: www.anulum.li | protoscience@anulum.li
8// SC-NeuroCore — Neural population decoding algorithms
9
10use std::f64::consts::PI;
11
12/// Georgopoulos population vector decoding.
13/// trains: slice of binary spike trains (i32). preferred_directions: angle per neuron (radians).
14/// Returns decoded angle per time bin.
15pub fn population_vector_decode(
16    trains: &[&[i32]],
17    preferred_directions: &[f64],
18    window: usize,
19) -> Vec<f64> {
20    if trains.is_empty() || window == 0 {
21        return vec![];
22    }
23    let min_len = trains.iter().map(|t| t.len()).min().unwrap_or(0);
24    let n_bins = min_len / window;
25    if n_bins == 0 {
26        return vec![];
27    }
28    // Pre-calculate cos/sin for preferred directions
29    let dirs_cos: Vec<f64> = preferred_directions.iter().map(|&d| d.cos()).collect();
30    let dirs_sin: Vec<f64> = preferred_directions.iter().map(|&d| d.sin()).collect();
31
32    let decoded: Vec<f64> = (0..n_bins)
33        .into_par_iter()
34        .map(|b| {
35            let mut sx = 0.0_f64;
36            let mut sy = 0.0_f64;
37            let start = b * window;
38            let end = (b + 1) * window;
39            for (i, t) in trains.iter().enumerate() {
40                let count: i64 = t[start..end].iter().map(|&v| v as i64).sum();
41                let c = dirs_cos.get(i).copied().unwrap_or(1.0);
42                let s = dirs_sin.get(i).copied().unwrap_or(0.0);
43                sx += count as f64 * c;
44                sy += count as f64 * s;
45            }
46            sy.atan2(sx)
47        })
48        .collect();
49    decoded
50}
51
52/// Bayesian MAP decoder (Dayan & Abbott 2001).
53/// spike_counts: (n_neurons,). tuning_rates: (n_stimuli × n_neurons, row-major flat).
54/// prior: (n_stimuli,) or empty for uniform. Returns MAP stimulus index.
55pub fn bayesian_decode(
56    spike_counts: &[f64],
57    tuning_rates: &[f64],
58    n_stimuli: usize,
59    n_neurons: usize,
60    prior: &[f64],
61) -> usize {
62    if n_stimuli == 0 || n_neurons == 0 {
63        return 0;
64    }
65    let use_uniform = prior.is_empty();
66    let log_prior_uniform = -(n_stimuli as f64).ln();
67
68    let (best_s, _best_lp) = (0..n_stimuli)
69        .into_par_iter()
70        .map(|s| {
71            let mut lp = if use_uniform {
72                log_prior_uniform
73            } else {
74                (prior.get(s).copied().unwrap_or(1e-30) + 1e-30).ln()
75            };
76            let row_rates = &tuning_rates[s * n_neurons..(s + 1) * n_neurons];
77            let mut j = 0;
78            while j + 3 < n_neurons {
79                let lam0 = row_rates[j].max(1e-10);
80                let lam1 = row_rates[j + 1].max(1e-10);
81                let lam2 = row_rates[j + 2].max(1e-10);
82                let lam3 = row_rates[j + 3].max(1e-10);
83
84                lp += spike_counts[j] * lam0.ln() - lam0;
85                lp += spike_counts[j + 1] * lam1.ln() - lam1;
86                lp += spike_counts[j + 2] * lam2.ln() - lam2;
87                lp += spike_counts[j + 3] * lam3.ln() - lam3;
88                j += 4;
89            }
90            while j < n_neurons {
91                let lam = row_rates[j].max(1e-10);
92                lp += spike_counts[j] * lam.ln() - lam;
93                j += 1;
94            }
95            (s, lp)
96        })
97        .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
98        .unwrap_or((0, f64::NEG_INFINITY));
99    best_s
100}
101
102/// Maximum likelihood stimulus decoder (Dayan & Abbott 2001). Uniform prior.
103pub fn maximum_likelihood_decode(
104    spike_counts: &[f64],
105    tuning_rates: &[f64],
106    n_stimuli: usize,
107    n_neurons: usize,
108) -> usize {
109    bayesian_decode(spike_counts, tuning_rates, n_stimuli, n_neurons, &[])
110}
111
112/// Fisher linear discriminant decoder (Fisher 1936).
113/// train_data: (n_samples × n_features, row-major flat). labels: (n_samples,).
114/// test_point: (n_features,). Returns predicted class label.
115pub fn linear_discriminant_decode(
116    train_data: &[f64],
117    n_samples: usize,
118    n_features: usize,
119    labels: &[i64],
120    test_point: &[f64],
121) -> i64 {
122    if n_samples == 0 || n_features == 0 {
123        return 0;
124    }
125
126    // Unique classes
127    let mut classes: Vec<i64> = labels[..n_samples].to_vec();
128    classes.sort();
129    classes.dedup();
130    if classes.len() < 2 {
131        return classes.first().copied().unwrap_or(0);
132    }
133
134    // Class means (parallelised)
135    let (class_means, class_indices): (Vec<Vec<f64>>, Vec<Vec<usize>>) = classes
136        .par_iter()
137        .map(|&c| {
138            let indices: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == c).collect();
139            let mut mean = vec![0.0_f64; n_features];
140            for &idx in &indices {
141                let row = &train_data[idx * n_features..(idx + 1) * n_features];
142                for f in 0..n_features {
143                    mean[f] += row[f];
144                }
145            }
146            let n_c = indices.len() as f64;
147            for v in &mut mean {
148                *v /= n_c;
149            }
150            (mean, indices)
151        })
152        .unzip();
153
154    // Within-class scatter S_w (n_features × n_features)
155    let nf = n_features;
156    let mut s_w = vec![0.0_f64; nf * nf];
157    for (ci, indices) in class_indices.iter().enumerate() {
158        let mean = &class_means[ci];
159        for &idx in indices {
160            for i in 0..nf {
161                let di = train_data[idx * nf + i] - mean[i];
162                for j in 0..nf {
163                    let dj = train_data[idx * nf + j] - mean[j];
164                    s_w[i * nf + j] += di * dj;
165                }
166            }
167        }
168    }
169    // Regularise
170    for i in 0..nf {
171        s_w[i * nf + i] += 1e-8;
172    }
173
174    // Overall mean
175    let mut overall_mean = vec![0.0_f64; nf];
176    for i in 0..n_samples {
177        for f in 0..nf {
178            overall_mean[f] += train_data[i * nf + f];
179        }
180    }
181    for v in &mut overall_mean {
182        *v /= n_samples as f64;
183    }
184
185    // For each class: w = S_w^{-1} (mean_c - overall_mean), score = w . test_point
186    let mut best_class = classes[0];
187    let mut best_score = f64::NEG_INFINITY;
188
189    for (ci, &c) in classes.iter().enumerate() {
190        let diff: Vec<f64> = (0..nf)
191            .map(|f| class_means[ci][f] - overall_mean[f])
192            .collect();
193        let w = solve_linear(&s_w, &diff, nf);
194        let score: f64 = (0..nf).map(|f| w[f] * test_point[f]).sum();
195        if score > best_score {
196            best_score = score;
197            best_class = c;
198        }
199    }
200    best_class
201}
202
203/// Gaussian naive Bayes decoder (Mitchell 1997).
204/// train_data: (n_samples × n_features, row-major flat). labels: (n_samples,).
205/// test_point: (n_features,). Returns predicted class label.
206pub fn naive_bayes_decode(
207    train_data: &[f64],
208    n_samples: usize,
209    n_features: usize,
210    labels: &[i64],
211    test_point: &[f64],
212) -> i64 {
213    if n_samples == 0 || n_features == 0 {
214        return 0;
215    }
216
217    let mut classes: Vec<i64> = labels[..n_samples].to_vec();
218    classes.sort();
219    classes.dedup();
220
221    let mut best_class = classes.first().copied().unwrap_or(0);
222    let mut best_log_p = f64::NEG_INFINITY;
223
224    for &c in &classes {
225        let indices: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == c).collect();
226        let n_c = indices.len() as f64;
227        let log_prior = (n_c / n_samples as f64).ln();
228
229        // Per-feature mean and variance
230        let mut log_likelihood = 0.0_f64;
231        for f in 0..n_features {
232            let vals: Vec<f64> = indices
233                .iter()
234                .map(|&i| train_data[i * n_features + f])
235                .collect();
236            let mu: f64 = vals.iter().sum::<f64>() / n_c;
237            let var: f64 = vals.iter().map(|&v| (v - mu).powi(2)).sum::<f64>() / n_c + 1e-10;
238            let x = test_point[f];
239            log_likelihood += -0.5 * ((2.0 * PI * var).ln() + (x - mu).powi(2) / var);
240        }
241
242        let log_p = log_prior + log_likelihood;
243        if log_p > best_log_p {
244            best_log_p = log_p;
245            best_class = c;
246        }
247    }
248    best_class
249}
250
251/// Solve A x = b via Gaussian elimination with partial pivoting.
252fn solve_linear(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
253    let mut aug = vec![0.0_f64; n * (n + 1)];
254    for i in 0..n {
255        for j in 0..n {
256            aug[i * (n + 1) + j] = a[i * n + j];
257        }
258        aug[i * (n + 1) + n] = b[i];
259    }
260    let stride = n + 1;
261    for col in 0..n {
262        let mut max_row = col;
263        let mut max_val = aug[col * stride + col].abs();
264        for row in (col + 1)..n {
265            let v = aug[row * stride + col].abs();
266            if v > max_val {
267                max_val = v;
268                max_row = row;
269            }
270        }
271        if max_row != col {
272            for j in 0..stride {
273                aug.swap(col * stride + j, max_row * stride + j);
274            }
275        }
276        let pivot = aug[col * stride + col];
277        if pivot.abs() < 1e-30 {
278            continue;
279        }
280        for row in (col + 1)..n {
281            let factor = aug[row * stride + col] / pivot;
282            for j in col..stride {
283                aug[row * stride + j] -= factor * aug[col * stride + j];
284            }
285        }
286    }
287    let mut x = vec![0.0_f64; n];
288    for i in (0..n).rev() {
289        let mut sum = aug[i * stride + n];
290        for j in (i + 1)..n {
291            sum -= aug[i * stride + j] * x[j];
292        }
293        let diag = aug[i * stride + i];
294        x[i] = if diag.abs() > 1e-30 { sum / diag } else { 0.0 };
295    }
296    x
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    // ── population_vector_decode ────────────────────────────────────
304
305    #[test]
306    fn test_pv_single_neuron_right() {
307        // Single neuron with preferred direction 0 (right)
308        let train = vec![1i32; 100];
309        let trains: Vec<&[i32]> = vec![&train];
310        let dirs = vec![0.0_f64]; // 0 radians = right
311        let decoded = population_vector_decode(&trains, &dirs, 50);
312        assert_eq!(decoded.len(), 2);
313        assert!((decoded[0] - 0.0).abs() < 1e-10, "should decode to 0 rad");
314    }
315
316    #[test]
317    fn test_pv_two_neurons_45deg() {
318        // Two neurons: one at 0, one at π/2, equal firing → 45°
319        let train = vec![1i32; 100];
320        let trains: Vec<&[i32]> = vec![&train, &train];
321        let dirs = vec![0.0, PI / 2.0];
322        let decoded = population_vector_decode(&trains, &dirs, 100);
323        assert_eq!(decoded.len(), 1);
324        assert!(
325            (decoded[0] - PI / 4.0).abs() < 1e-10,
326            "equal firing at 0 and π/2 → π/4, got {}",
327            decoded[0]
328        );
329    }
330
331    #[test]
332    fn test_pv_empty() {
333        let decoded = population_vector_decode(&[], &[], 50);
334        assert!(decoded.is_empty());
335    }
336
337    #[test]
338    fn test_pv_no_bins() {
339        let train = vec![1i32; 10];
340        let trains: Vec<&[i32]> = vec![&train];
341        let decoded = population_vector_decode(&trains, &[0.0], 100);
342        assert!(decoded.is_empty(), "train shorter than window → empty");
343    }
344
345    // ── bayesian_decode ─────────────────────────────────────────────
346
347    #[test]
348    fn test_bayesian_obvious() {
349        // 2 stimuli, 2 neurons. Stimulus 0: high rate neuron 0, low neuron 1.
350        let tuning = vec![10.0, 0.1, 0.1, 10.0]; // 2×2
351        let counts = vec![8.0, 0.0]; // neuron 0 fires a lot → stimulus 0
352        let s = bayesian_decode(&counts, &tuning, 2, 2, &[]);
353        assert_eq!(s, 0, "high neuron 0 firing → stimulus 0");
354    }
355
356    #[test]
357    fn test_bayesian_with_prior() {
358        let tuning = vec![5.0, 5.0, 5.0, 5.0]; // equal tuning
359        let counts = vec![5.0, 5.0];
360        let prior = vec![0.1, 0.9]; // strong prior for stimulus 1
361        let s = bayesian_decode(&counts, &tuning, 2, 2, &prior);
362        assert_eq!(s, 1, "equal evidence + strong prior → stimulus 1");
363    }
364
365    #[test]
366    fn test_bayesian_empty() {
367        assert_eq!(bayesian_decode(&[], &[], 0, 0, &[]), 0);
368    }
369
370    // ── maximum_likelihood_decode ───────────────────────────────────
371
372    #[test]
373    fn test_ml_matches_bayesian_uniform() {
374        let tuning = vec![10.0, 0.1, 0.1, 10.0];
375        let counts = vec![0.0, 8.0]; // neuron 1 fires → stimulus 1
376        let s_ml = maximum_likelihood_decode(&counts, &tuning, 2, 2);
377        let s_bay = bayesian_decode(&counts, &tuning, 2, 2, &[]);
378        assert_eq!(s_ml, s_bay);
379        assert_eq!(s_ml, 1);
380    }
381
382    // ── linear_discriminant_decode ──────────────────────────────────
383
384    #[test]
385    fn test_lda_separable() {
386        // Class 0: features around (0, 0). Class 1: around (10, 10).
387        // Fisher score = w . test where w = S_w_inv * (mean_c - overall_mean).
388        // Test points at class centroids for unambiguous projection.
389        #[rustfmt::skip]
390        let data = vec![
391            0.0, 0.0,
392            0.1, 0.1,
393            -0.1, 0.1,
394            10.0, 10.0,
395            10.1, 9.9,
396            9.9, 10.1,
397        ];
398        let labels = vec![0_i64, 0, 0, 1, 1, 1];
399        // Test at class centroids: class 1 centroid should decode to 1
400        let test_1 = vec![10.0, 10.0];
401        assert_eq!(linear_discriminant_decode(&data, 6, 2, &labels, &test_1), 1);
402        // Two different tests should give different classes
403        let r0 = linear_discriminant_decode(&data, 6, 2, &labels, &[-5.0, -5.0]);
404        let r1 = linear_discriminant_decode(&data, 6, 2, &labels, &[15.0, 15.0]);
405        assert_ne!(r0, r1, "distant points should decode to different classes");
406    }
407
408    #[test]
409    fn test_lda_single_class() {
410        let data = vec![1.0, 2.0, 3.0, 4.0];
411        let labels = vec![5_i64, 5];
412        let test = vec![2.0, 3.0];
413        assert_eq!(linear_discriminant_decode(&data, 2, 2, &labels, &test), 5);
414    }
415
416    #[test]
417    fn test_lda_empty() {
418        assert_eq!(linear_discriminant_decode(&[], 0, 0, &[], &[]), 0);
419    }
420
421    // ── naive_bayes_decode ──────────────────────────────────────────
422
423    #[test]
424    fn test_nb_separable() {
425        #[rustfmt::skip]
426        let data = vec![
427            0.0, 0.0,
428            0.1, 0.1,
429            -0.1, -0.1,
430            10.0, 10.0,
431            10.1, 10.1,
432            9.9, 9.9,
433        ];
434        let labels = vec![0_i64, 0, 0, 1, 1, 1];
435        let test_0 = vec![0.2, 0.2];
436        let test_1 = vec![9.8, 9.8];
437        assert_eq!(naive_bayes_decode(&data, 6, 2, &labels, &test_0), 0);
438        assert_eq!(naive_bayes_decode(&data, 6, 2, &labels, &test_1), 1);
439    }
440
441    #[test]
442    fn test_nb_single_class() {
443        let data = vec![1.0, 2.0];
444        let labels = vec![7_i64];
445        assert_eq!(naive_bayes_decode(&data, 1, 2, &labels, &[1.0, 2.0]), 7);
446    }
447
448    #[test]
449    fn test_nb_agrees_with_lda_simple() {
450        // For well-separated Gaussian data, NB and LDA should agree
451        #[rustfmt::skip]
452        let data = vec![
453            -5.0, -5.0,
454            -4.9, -5.1,
455            5.0, 5.0,
456            5.1, 4.9,
457        ];
458        let labels = vec![0_i64, 0, 1, 1];
459        let test = vec![4.0, 4.0];
460        let lda = linear_discriminant_decode(&data, 4, 2, &labels, &test);
461        let nb = naive_bayes_decode(&data, 4, 2, &labels, &test);
462        assert_eq!(lda, nb, "well-separated → both predict same class");
463    }
464
465    // ── solve_linear ────────────────────────────────────────────────
466
467    #[test]
468    fn test_solve_2x2() {
469        let a = vec![2.0, 1.0, 1.0, 3.0];
470        let b = vec![5.0, 10.0];
471        let x = solve_linear(&a, &b, 2);
472        assert!((x[0] - 1.0).abs() < 1e-10);
473        assert!((x[1] - 3.0).abs() < 1e-10);
474    }
475}