Skip to main content

sc_neurocore_engine/analysis/
gpfa.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — GPFA: Gaussian Process Factor Analysis
8//
9// Yu, Cunningham et al. (2009) J. Neurophysiol. 102:614-635.
10
11use super::basic;
12
13// ── helpers ─────────────────────────────────────────────────────────
14
15/// Squared-exponential GP kernel for `n` time points.
16fn gp_kernel(n: usize, tau: f64, sigma: f64) -> Vec<f64> {
17    let mut k = vec![0.0f64; n * n];
18    let tau_sq = tau * tau + 1e-12;
19    let sigma_sq = sigma * sigma;
20    for i in 0..n {
21        for j in 0..n {
22            let diff = i as f64 - j as f64;
23            k[i * n + j] = sigma_sq * (-0.5 * diff * diff / tau_sq).exp();
24        }
25    }
26    k
27}
28
29/// Gauss-Jordan inverse for n x n matrix.
30fn mat_inv(a: &[f64], n: usize) -> Vec<f64> {
31    let mut aug = vec![0.0f64; n * 2 * n];
32    for i in 0..n {
33        for j in 0..n {
34            aug[i * 2 * n + j] = a[i * n + j];
35        }
36        aug[i * 2 * n + n + i] = 1.0;
37    }
38    for col in 0..n {
39        let mut max_row = col;
40        let mut max_val = aug[col * 2 * n + col].abs();
41        for row in col + 1..n {
42            let v = aug[row * 2 * n + col].abs();
43            if v > max_val {
44                max_val = v;
45                max_row = row;
46            }
47        }
48        if max_val < 1e-30 {
49            continue;
50        }
51        if max_row != col {
52            for k in 0..2 * n {
53                aug.swap(col * 2 * n + k, max_row * 2 * n + k);
54            }
55        }
56        let pivot = aug[col * 2 * n + col];
57        for k in 0..2 * n {
58            aug[col * 2 * n + k] /= pivot;
59        }
60        for row in 0..n {
61            if row == col {
62                continue;
63            }
64            let factor = aug[row * 2 * n + col];
65            for k in 0..2 * n {
66                aug[row * 2 * n + k] -= factor * aug[col * 2 * n + k];
67            }
68        }
69    }
70    let mut inv = vec![0.0f64; n * n];
71    for i in 0..n {
72        for j in 0..n {
73            inv[i * n + j] = aug[i * 2 * n + n + j];
74        }
75    }
76    inv
77}
78
79/// Solve A x = b via Gauss-Jordan (A is n x n, b is n x m, returns x as n x m).
80fn mat_solve(a: &[f64], b: &[f64], n: usize, m: usize) -> Vec<f64> {
81    let mut aug = vec![0.0f64; n * (n + m)];
82    let w = n + m;
83    for i in 0..n {
84        for j in 0..n {
85            aug[i * w + j] = a[i * n + j];
86        }
87        for j in 0..m {
88            aug[i * w + n + j] = b[i * m + j];
89        }
90    }
91    for col in 0..n {
92        let mut max_row = col;
93        let mut max_val = aug[col * w + col].abs();
94        for row in col + 1..n {
95            let v = aug[row * w + col].abs();
96            if v > max_val {
97                max_val = v;
98                max_row = row;
99            }
100        }
101        if max_val < 1e-30 {
102            continue;
103        }
104        if max_row != col {
105            for k in 0..w {
106                aug.swap(col * w + k, max_row * w + k);
107            }
108        }
109        let pivot = aug[col * w + col];
110        for k in 0..w {
111            aug[col * w + k] /= pivot;
112        }
113        for row in 0..n {
114            if row == col {
115                continue;
116            }
117            let factor = aug[row * w + col];
118            for k in 0..w {
119                aug[row * w + k] -= factor * aug[col * w + k];
120            }
121        }
122    }
123    let mut x = vec![0.0f64; n * m];
124    for i in 0..n {
125        for j in 0..m {
126            x[i * m + j] = aug[i * w + n + j];
127        }
128    }
129    x
130}
131
132/// E-step: compute posterior p(x|y).
133fn gpfa_e_step(
134    y: &[f64],          // n_neurons x n_bins (row-major)
135    c: &[f64],          // n_neurons x n_latents
136    d: &[f64],          // n_neurons
137    r_diag: &[f64],     // n_neurons
138    k_all: &[Vec<f64>], // n_latents kernels, each n_bins x n_bins
139    n_neurons: usize,
140    n_bins: usize,
141    n_latents: usize,
142) -> (Vec<f64>, Vec<f64>) {
143    // x_post (n_latents x n_bins), xx_post (n_latents x n_latents)
144    let kt = n_latents * n_bins;
145
146    // R^{-1}
147    let r_inv: Vec<f64> = r_diag.iter().map(|&r| 1.0 / (r + 1e-10)).collect();
148
149    // C^T R^{-1} C (n_latents x n_latents)
150    let mut ct_rinv_c = vec![0.0f64; n_latents * n_latents];
151    for i in 0..n_latents {
152        for j in 0..n_latents {
153            let mut s = 0.0;
154            for k in 0..n_neurons {
155                s += c[k * n_latents + i] * r_inv[k] * c[k * n_latents + j];
156            }
157            ct_rinv_c[i * n_latents + j] = s;
158        }
159    }
160
161    // C^T R^{-1} (n_latents x n_neurons)
162    let mut ct_rinv = vec![0.0f64; n_latents * n_neurons];
163    for i in 0..n_latents {
164        for k in 0..n_neurons {
165            ct_rinv[i * n_neurons + k] = c[k * n_latents + i] * r_inv[k];
166        }
167    }
168
169    // Build precision (kt x kt)
170    let mut prec = vec![0.0f64; kt * kt];
171    for j in 0..n_latents {
172        let slj = j * n_bins;
173        // K_j^{-1}
174        let mut k_reg = k_all[j].clone();
175        for i in 0..n_bins {
176            k_reg[i * n_bins + i] += 1e-6;
177        }
178        let k_eye = vec![0.0f64; n_bins * n_bins]
179            .iter()
180            .enumerate()
181            .map(|(idx, _)| {
182                if idx / n_bins == idx % n_bins {
183                    1.0
184                } else {
185                    0.0
186                }
187            })
188            .collect::<Vec<f64>>();
189        let k_inv = mat_solve(&k_reg, &k_eye, n_bins, n_bins);
190
191        for i in 0..n_bins {
192            for jj in 0..n_bins {
193                prec[(slj + i) * kt + (slj + jj)] = k_inv[i * n_bins + jj]
194                    + ct_rinv_c[j * n_latents + j] * if i == jj { 1.0 } else { 0.0 };
195            }
196        }
197        for k in 0..n_latents {
198            if k != j {
199                let slk = k * n_bins;
200                for i in 0..n_bins {
201                    prec[(slj + i) * kt + (slk + i)] = ct_rinv_c[j * n_latents + k];
202                }
203            }
204        }
205    }
206
207    // RHS
208    let mut rhs = vec![0.0f64; kt];
209    // Y_centered = Y - d[:, None]
210    for t in 0..n_bins {
211        // v = C^T R^{-1} (y_t - d)
212        for j in 0..n_latents {
213            let mut s = 0.0;
214            for k in 0..n_neurons {
215                s += ct_rinv[j * n_neurons + k] * (y[k * n_bins + t] - d[k]);
216            }
217            rhs[j * n_bins + t] = s;
218        }
219    }
220
221    // Regularise precision
222    for i in 0..kt {
223        prec[i * kt + i] += 1e-8;
224    }
225
226    // Solve prec * x_vec = rhs
227    let rhs_col: Vec<f64> = rhs.clone();
228    let x_vec = mat_solve(&prec, &rhs_col, kt, 1);
229
230    // Posterior covariance (for E[xx^T])
231    let eye_kt: Vec<f64> = (0..kt * kt)
232        .map(|idx| if idx / kt == idx % kt { 1.0 } else { 0.0 })
233        .collect();
234    let sigma_post = mat_solve(&prec, &eye_kt, kt, kt);
235
236    // E[xx^T] per timepoint
237    let mut xx_post = vec![0.0f64; n_latents * n_latents];
238    for t in 0..n_bins {
239        for j in 0..n_latents {
240            let xj = x_vec[j * n_bins + t];
241            for k in 0..n_latents {
242                let xk = x_vec[k * n_bins + t];
243                xx_post[j * n_latents + k] +=
244                    xj * xk + sigma_post[(j * n_bins + t) * kt + (k * n_bins + t)];
245            }
246        }
247    }
248
249    (x_vec, xx_post)
250}
251
252/// M-step: update C, d, R.
253fn gpfa_m_step(
254    y: &[f64],
255    x_post: &[f64],
256    xx_post: &[f64],
257    n_neurons: usize,
258    n_bins: usize,
259    n_latents: usize,
260) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
261    // d_new = Y.mean(axis=1)
262    let mut d_new = vec![0.0f64; n_neurons];
263    for i in 0..n_neurons {
264        let s: f64 = (0..n_bins).map(|t| y[i * n_bins + t]).sum();
265        d_new[i] = s / n_bins as f64;
266    }
267
268    // Y_centered
269    // Yx = Y_centered @ x_post^T (n_neurons x n_latents)
270    let mut yx = vec![0.0f64; n_neurons * n_latents];
271    for i in 0..n_neurons {
272        for j in 0..n_latents {
273            let mut s = 0.0;
274            for t in 0..n_bins {
275                s += (y[i * n_bins + t] - d_new[i]) * x_post[j * n_bins + t];
276            }
277            yx[i * n_latents + j] = s;
278        }
279    }
280
281    // C_new = Yx @ inv(xx_post + eps*I)
282    let mut xx_reg = xx_post.to_vec();
283    for i in 0..n_latents {
284        xx_reg[i * n_latents + i] += 1e-8;
285    }
286    let xx_inv = mat_inv(&xx_reg, n_latents);
287    let mut c_new = vec![0.0f64; n_neurons * n_latents];
288    for i in 0..n_neurons {
289        for j in 0..n_latents {
290            let mut s = 0.0;
291            for k in 0..n_latents {
292                s += yx[i * n_latents + k] * xx_inv[k * n_latents + j];
293            }
294            c_new[i * n_latents + j] = s;
295        }
296    }
297
298    // R_new = diag(YY^T/T - C E[x]Y^T/T)
299    let mut r_new = vec![0.0f64; n_neurons];
300    for i in 0..n_neurons {
301        let yyt: f64 = (0..n_bins)
302            .map(|t| {
303                let v = y[i * n_bins + t] - d_new[i];
304                v * v
305            })
306            .sum::<f64>()
307            / n_bins as f64;
308        // C[i,:] @ x_post @ Y_centered[i,:]^T / T
309        let mut cxy = 0.0;
310        for j in 0..n_latents {
311            for t in 0..n_bins {
312                cxy += c_new[i * n_latents + j]
313                    * x_post[j * n_bins + t]
314                    * (y[i * n_bins + t] - d_new[i]);
315            }
316        }
317        cxy /= n_bins as f64;
318        r_new[i] = (yyt - cxy).max(1e-6);
319    }
320
321    (c_new, d_new, r_new)
322}
323
324// ── public API ──────────────────────────────────────────────────────
325
326/// GPFA result.
327pub struct GpfaResult {
328    /// Latent trajectories, row-major `(n_latents, n_bins)`.
329    pub trajectories: Vec<f64>,
330    /// Loading matrix, row-major `(n_neurons, n_latents)`.
331    pub c: Vec<f64>,
332    /// Mean vector `(n_neurons)`.
333    pub d: Vec<f64>,
334    /// Noise diagonal `(n_neurons)`.
335    pub r: Vec<f64>,
336    /// GP timescales `(n_latents)`.
337    pub tau: Vec<f64>,
338    /// Log-likelihoods per iteration.
339    pub log_likelihoods: Vec<f64>,
340    pub n_latents: usize,
341    pub n_bins: usize,
342    pub n_neurons: usize,
343}
344
345/// Extract smooth latent trajectories from parallel spike trains via EM.
346pub fn gpfa(
347    trains: &[&[i32]],
348    n_latents: usize,
349    bin_ms: f64,
350    dt: f64,
351    max_iter: usize,
352    tol: f64,
353    seed: u64,
354) -> GpfaResult {
355    let n_neurons = trains.len();
356    if n_neurons == 0 {
357        return GpfaResult {
358            trajectories: vec![],
359            c: vec![],
360            d: vec![],
361            r: vec![],
362            tau: vec![],
363            log_likelihoods: vec![],
364            n_latents: 0,
365            n_bins: 0,
366            n_neurons: 0,
367        };
368    }
369    let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
370    let binned: Vec<Vec<f64>> = trains
371        .iter()
372        .map(|t| {
373            basic::bin_spike_train(t, bin_steps)
374                .into_iter()
375                .map(|c| c as f64)
376                .collect()
377        })
378        .collect();
379    let n_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
380    if n_bins == 0 {
381        return GpfaResult {
382            trajectories: vec![],
383            c: vec![],
384            d: vec![],
385            r: vec![],
386            tau: vec![],
387            log_likelihoods: vec![],
388            n_latents: 0,
389            n_bins: 0,
390            n_neurons,
391        };
392    }
393    // Y: n_neurons x n_bins
394    let mut y = vec![0.0f64; n_neurons * n_bins];
395    for i in 0..n_neurons {
396        for j in 0..n_bins {
397            y[i * n_bins + j] = binned[i][j];
398        }
399    }
400    let nl = n_latents.min(n_neurons).min(n_bins);
401
402    // Initialise
403    let mut rng = seed;
404    let mut c = vec![0.0f64; n_neurons * nl];
405    for v in &mut c {
406        rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
407        *v = ((rng >> 33) as f64 / (1u64 << 31) as f64 - 0.5) * 0.2;
408    }
409    let mut d_vec = vec![0.0f64; n_neurons];
410    for i in 0..n_neurons {
411        d_vec[i] = y[i * n_bins..i * n_bins + n_bins].iter().sum::<f64>() / n_bins as f64;
412    }
413    let mut r_diag = vec![0.0f64; n_neurons];
414    for i in 0..n_neurons {
415        let mean = d_vec[i];
416        let var: f64 = (0..n_bins)
417            .map(|t| (y[i * n_bins + t] - mean).powi(2))
418            .sum::<f64>()
419            / n_bins as f64;
420        r_diag[i] = var + 1e-4;
421    }
422    let tau = vec![bin_ms * 2.0; nl];
423
424    let mut log_liks = Vec::new();
425    let mut x_post = vec![0.0f64; nl * n_bins];
426
427    for _ in 0..max_iter {
428        let k_all: Vec<Vec<f64>> = (0..nl).map(|j| gp_kernel(n_bins, tau[j], 1.0)).collect();
429
430        let (xp, xx_post) = gpfa_e_step(&y, &c, &d_vec, &r_diag, &k_all, n_neurons, n_bins, nl);
431        x_post = xp;
432
433        let (c_new, d_new, r_new) = gpfa_m_step(&y, &x_post, &xx_post, n_neurons, n_bins, nl);
434        c = c_new;
435        d_vec = d_new;
436        r_diag = r_new;
437
438        // Approximate log-likelihood
439        let mut ll = 0.0f64;
440        for i in 0..n_neurons {
441            for t in 0..n_bins {
442                let mut pred = d_vec[i];
443                for j in 0..nl {
444                    pred += c[i * nl + j] * x_post[j * n_bins + t];
445                }
446                let resid = y[i * n_bins + t] - pred;
447                ll -= 0.5 * resid * resid / (r_diag[i] + 1e-10);
448            }
449        }
450        ll -= 0.5 * n_bins as f64 * r_diag.iter().map(|&r| (r + 1e-10).ln()).sum::<f64>();
451        log_liks.push(ll);
452
453        if log_liks.len() > 1 {
454            let prev = log_liks[log_liks.len() - 2];
455            if (ll - prev).abs() < tol {
456                break;
457            }
458        }
459    }
460
461    GpfaResult {
462        trajectories: x_post,
463        c,
464        d: d_vec,
465        r: r_diag,
466        tau,
467        log_likelihoods: log_liks,
468        n_latents: nl,
469        n_bins,
470        n_neurons,
471    }
472}
473
474/// Project new spike trains using learned GPFA parameters.
475pub fn gpfa_transform(
476    new_trains: &[&[i32]],
477    c: &[f64],
478    d: &[f64],
479    r_diag: &[f64],
480    tau: &[f64],
481    n_latents: usize,
482    bin_ms: f64,
483    dt: f64,
484) -> Vec<f64> {
485    let n_neurons = new_trains.len();
486    if n_neurons == 0 || c.is_empty() {
487        return vec![];
488    }
489    let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
490    let binned: Vec<Vec<f64>> = new_trains
491        .iter()
492        .map(|t| {
493            basic::bin_spike_train(t, bin_steps)
494                .into_iter()
495                .map(|v| v as f64)
496                .collect()
497        })
498        .collect();
499    let n_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
500    if n_bins == 0 {
501        return vec![];
502    }
503    let mut y = vec![0.0f64; n_neurons * n_bins];
504    for i in 0..n_neurons {
505        for j in 0..n_bins {
506            y[i * n_bins + j] = binned[i][j];
507        }
508    }
509    let k_all: Vec<Vec<f64>> = (0..n_latents)
510        .map(|j| gp_kernel(n_bins, tau[j], 1.0))
511        .collect();
512    let (x_post, _) = gpfa_e_step(&y, c, d, r_diag, &k_all, n_neurons, n_bins, n_latents);
513    x_post
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    fn make_trains() -> Vec<Vec<i32>> {
521        let mut trains = Vec::new();
522        for n in 0..4 {
523            let mut t = vec![0i32; 100];
524            let step = 3 + n * 2;
525            for i in (0..100).step_by(step) {
526                t[i] = 1;
527            }
528            trains.push(t);
529        }
530        trains
531    }
532
533    #[test]
534    fn test_gpfa_basic() {
535        let trains = make_trains();
536        let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
537        let result = gpfa(&refs, 2, 10.0, 0.001, 5, 1e-4, 42);
538        assert_eq!(result.n_neurons, 4);
539        assert_eq!(result.n_latents, 2);
540        assert!(!result.trajectories.is_empty());
541        assert!(!result.log_likelihoods.is_empty());
542    }
543
544    #[test]
545    fn test_gpfa_empty() {
546        let result = gpfa(&[], 2, 10.0, 0.001, 5, 1e-4, 42);
547        assert_eq!(result.n_neurons, 0);
548        assert!(result.trajectories.is_empty());
549    }
550
551    #[test]
552    fn test_gpfa_single_neuron() {
553        let train = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
554        let refs = vec![train.as_slice()];
555        let result = gpfa(&refs, 1, 5.0, 0.001, 3, 1e-4, 42);
556        assert_eq!(result.n_neurons, 1);
557        assert_eq!(result.n_latents, 1);
558    }
559
560    #[test]
561    fn test_gpfa_convergence() {
562        let trains = make_trains();
563        let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
564        let result = gpfa(&refs, 2, 10.0, 0.001, 20, 1e-4, 42);
565        // Log-likelihood should generally increase
566        if result.log_likelihoods.len() > 2 {
567            let last = result.log_likelihoods[result.log_likelihoods.len() - 1];
568            let second = result.log_likelihoods[1];
569            assert!(
570                last >= second - 1.0,
571                "LL should generally increase: {second} -> {last}"
572            );
573        }
574    }
575
576    #[test]
577    fn test_gpfa_transform() {
578        let trains = make_trains();
579        let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
580        let result = gpfa(&refs, 2, 10.0, 0.001, 5, 1e-4, 42);
581
582        let new_trains = make_trains();
583        let new_refs: Vec<&[i32]> = new_trains.iter().map(|t| t.as_slice()).collect();
584        let projected = gpfa_transform(
585            &new_refs,
586            &result.c,
587            &result.d,
588            &result.r,
589            &result.tau,
590            result.n_latents,
591            10.0,
592            0.001,
593        );
594        assert!(!projected.is_empty());
595        assert_eq!(projected.len(), result.n_latents * result.n_bins);
596    }
597
598    #[test]
599    fn test_gpfa_transform_empty() {
600        let proj = gpfa_transform(&[], &[], &[], &[], &[], 0, 10.0, 0.001);
601        assert!(proj.is_empty());
602    }
603
604    #[test]
605    fn test_gp_kernel_shape() {
606        let k = gp_kernel(10, 5.0, 1.0);
607        assert_eq!(k.len(), 100);
608        // Diagonal should be sigma^2
609        for i in 0..10 {
610            assert!((k[i * 10 + i] - 1.0).abs() < 1e-10);
611        }
612        // Should be symmetric
613        for i in 0..10 {
614            for j in 0..10 {
615                assert!((k[i * 10 + j] - k[j * 10 + i]).abs() < 1e-12);
616            }
617        }
618    }
619
620    #[test]
621    fn test_gp_kernel_decay() {
622        let k = gp_kernel(20, 3.0, 1.0);
623        // Off-diagonal should decay with distance
624        assert!(k[1] > k[10]);
625    }
626}