1use super::basic;
12use nalgebra::{Cholesky, DMatrix, Dyn};
13
14pub type GpfaEmOutput = (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>);
16
17fn gp_kernel(n: usize, tau: f64, sigma: f64) -> Vec<f64> {
21 let mut k = vec![0.0f64; n * n];
22 let tau_sq = tau * tau + 1e-12;
23 let sigma_sq = sigma * sigma;
24 for i in 0..n {
25 for j in 0..n {
26 let diff = i as f64 - j as f64;
27 k[i * n + j] = sigma_sq * (-0.5 * diff * diff / tau_sq).exp();
28 }
29 }
30 k
31}
32
33fn spd_cholesky(a: &[f64], n: usize) -> Cholesky<f64, Dyn> {
38 DMatrix::<f64>::from_row_slice(n, n, a)
39 .cholesky()
40 .expect("GPFA matrix must be symmetric positive-definite")
41}
42
43fn chol_logdet(chol: &Cholesky<f64, Dyn>, n: usize) -> f64 {
45 let l = chol.l();
46 2.0 * (0..n).map(|i| l[(i, i)].ln()).sum::<f64>()
47}
48
49fn spd_inverse(a: &[f64], n: usize) -> Vec<f64> {
51 let inv = spd_cholesky(a, n).inverse();
52 let mut out = vec![0.0f64; n * n];
53 for i in 0..n {
54 for j in 0..n {
55 out[i * n + j] = inv[(i, j)];
56 }
57 }
58 out
59}
60
61fn gpfa_precision(
71 c: &[f64],
72 r_diag: &[f64],
73 k_all: &[Vec<f64>],
74 n_neurons: usize,
75 n_bins: usize,
76 n_latents: usize,
77) -> (Vec<f64>, f64) {
78 let n_state = n_latents * n_bins;
79 let r_inv: Vec<f64> = r_diag.iter().map(|&r| 1.0 / r).collect();
80 let mut ctr_inv_c = vec![0.0f64; n_latents * n_latents];
81 for i in 0..n_latents {
82 for j in 0..n_latents {
83 let mut s = 0.0;
84 for k in 0..n_neurons {
85 s += c[k * n_latents + i] * r_inv[k] * c[k * n_latents + j];
86 }
87 ctr_inv_c[i * n_latents + j] = s;
88 }
89 }
90 let mut m = vec![0.0f64; n_state * n_state];
91 let mut logdet_k = 0.0f64;
92 for j in 0..n_latents {
93 let mut k_reg = k_all[j].clone();
94 for i in 0..n_bins {
95 k_reg[i * n_bins + i] += 1e-6;
96 }
97 let chol = spd_cholesky(&k_reg, n_bins);
98 logdet_k += chol_logdet(&chol, n_bins);
99 let k_inv = chol.inverse();
100 let slj = j * n_bins;
101 for i in 0..n_bins {
102 for jj in 0..n_bins {
103 m[(slj + i) * n_state + (slj + jj)] = k_inv[(i, jj)];
104 }
105 }
106 }
107 for j in 0..n_latents {
108 for k in 0..n_latents {
109 let v = ctr_inv_c[j * n_latents + k];
110 for t in 0..n_bins {
111 m[(j * n_bins + t) * n_state + (k * n_bins + t)] += v;
112 }
113 }
114 }
115 (m, logdet_k)
116}
117
118fn gpfa_e_step(
125 y: &[f64], c: &[f64], d: &[f64], r_diag: &[f64], k_all: &[Vec<f64>], n_neurons: usize,
131 n_bins: usize,
132 n_latents: usize,
133) -> (Vec<f64>, Vec<f64>) {
134 let n_state = n_latents * n_bins;
135 let r_inv: Vec<f64> = r_diag.iter().map(|&r| 1.0 / r).collect();
136 let (m, _) = gpfa_precision(c, r_diag, k_all, n_neurons, n_bins, n_latents);
137
138 let mut rhs = vec![0.0f64; n_state];
140 for t in 0..n_bins {
141 for j in 0..n_latents {
142 let mut s = 0.0;
143 for k in 0..n_neurons {
144 s += c[k * n_latents + j] * r_inv[k] * (y[k * n_bins + t] - d[k]);
145 }
146 rhs[j * n_bins + t] = s;
147 }
148 }
149
150 let chol = spd_cholesky(&m, n_state);
151 let x_solved = chol.solve(&DMatrix::from_row_slice(n_state, 1, &rhs));
152 let x_vec: Vec<f64> = (0..n_state).map(|i| x_solved[(i, 0)]).collect();
153 let sigma = chol.inverse();
154
155 let mut xx_post = vec![0.0f64; n_latents * n_latents];
158 for t in 0..n_bins {
159 for j in 0..n_latents {
160 let xj = x_vec[j * n_bins + t];
161 for k in 0..n_latents {
162 let xk = x_vec[k * n_bins + t];
163 xx_post[j * n_latents + k] += xj * xk + sigma[(j * n_bins + t, k * n_bins + t)];
164 }
165 }
166 }
167
168 (x_vec, xx_post)
169}
170
171fn gpfa_m_step(
173 y: &[f64],
174 x_post: &[f64],
175 xx_post: &[f64],
176 n_neurons: usize,
177 n_bins: usize,
178 n_latents: usize,
179) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
180 let mut d_new = vec![0.0f64; n_neurons];
182 for i in 0..n_neurons {
183 let s: f64 = (0..n_bins).map(|t| y[i * n_bins + t]).sum();
184 d_new[i] = s / n_bins as f64;
185 }
186
187 let mut yx = vec![0.0f64; n_neurons * n_latents];
190 for i in 0..n_neurons {
191 for j in 0..n_latents {
192 let mut s = 0.0;
193 for t in 0..n_bins {
194 s += (y[i * n_bins + t] - d_new[i]) * x_post[j * n_bins + t];
195 }
196 yx[i * n_latents + j] = s;
197 }
198 }
199
200 let mut xx_reg = xx_post.to_vec();
202 for i in 0..n_latents {
203 xx_reg[i * n_latents + i] += 1e-8;
204 }
205 let xx_inv = spd_inverse(&xx_reg, n_latents);
206 let mut c_new = vec![0.0f64; n_neurons * n_latents];
207 for i in 0..n_neurons {
208 for j in 0..n_latents {
209 let mut s = 0.0;
210 for k in 0..n_latents {
211 s += yx[i * n_latents + k] * xx_inv[k * n_latents + j];
212 }
213 c_new[i * n_latents + j] = s;
214 }
215 }
216
217 let mut r_new = vec![0.0f64; n_neurons];
219 for i in 0..n_neurons {
220 let yyt: f64 = (0..n_bins)
221 .map(|t| {
222 let v = y[i * n_bins + t] - d_new[i];
223 v * v
224 })
225 .sum::<f64>()
226 / n_bins as f64;
227 let mut cxy = 0.0;
229 for j in 0..n_latents {
230 for t in 0..n_bins {
231 cxy += c_new[i * n_latents + j]
232 * x_post[j * n_bins + t]
233 * (y[i * n_bins + t] - d_new[i]);
234 }
235 }
236 cxy /= n_bins as f64;
237 r_new[i] = (yyt - cxy).max(1e-6);
238 }
239
240 (c_new, d_new, r_new)
241}
242
243fn gpfa_log_likelihood(
259 y: &[f64],
260 c: &[f64],
261 d: &[f64],
262 r_diag: &[f64],
263 k_all: &[Vec<f64>],
264 n_neurons: usize,
265 n_bins: usize,
266 n_latents: usize,
267) -> f64 {
268 let n_obs = n_neurons * n_bins;
269 let n_state = n_latents * n_bins;
270 let r_inv: Vec<f64> = r_diag.iter().map(|&r| 1.0 / r).collect();
271 let (m, logdet_k) = gpfa_precision(c, r_diag, k_all, n_neurons, n_bins, n_latents);
272
273 let mut rhs = vec![0.0f64; n_state];
274 for t in 0..n_bins {
275 for j in 0..n_latents {
276 let mut s = 0.0;
277 for k in 0..n_neurons {
278 s += c[k * n_latents + j] * r_inv[k] * (y[k * n_bins + t] - d[k]);
279 }
280 rhs[j * n_bins + t] = s;
281 }
282 }
283
284 let chol = spd_cholesky(&m, n_state);
285 let logdet_m = chol_logdet(&chol, n_state);
286 let x_mean = chol.solve(&DMatrix::from_row_slice(n_state, 1, &rhs));
287 let rhs_x_mean: f64 = (0..n_state).map(|i| rhs[i] * x_mean[(i, 0)]).sum();
288
289 let mut y_rinv_y = 0.0f64;
290 for k in 0..n_neurons {
291 for t in 0..n_bins {
292 let v = y[k * n_bins + t] - d[k];
293 y_rinv_y += r_inv[k] * v * v;
294 }
295 }
296 let quad = y_rinv_y - rhs_x_mean;
297 let logdet_r_big = n_bins as f64 * r_diag.iter().map(|&r| r.ln()).sum::<f64>();
298 let logdet_sigma = logdet_m + logdet_k + logdet_r_big;
299 -0.5 * (quad + logdet_sigma + n_obs as f64 * (2.0 * std::f64::consts::PI).ln())
300}
301
302#[allow(clippy::too_many_arguments)]
307pub fn gpfa_em_from_init(
308 y: &[f64],
309 c0: &[f64],
310 d0: &[f64],
311 r0_diag: &[f64],
312 tau: &[f64],
313 n_neurons: usize,
314 n_bins: usize,
315 n_latents: usize,
316 max_iter: usize,
317 tol: f64,
318) -> GpfaEmOutput {
319 let k_all: Vec<Vec<f64>> = (0..n_latents)
320 .map(|j| gp_kernel(n_bins, tau[j], 1.0))
321 .collect();
322 let mut c = c0.to_vec();
323 let mut d = d0.to_vec();
324 let mut r = r0_diag.to_vec();
325 let mut log_liks: Vec<f64> = Vec::new();
326 let mut x_post = vec![0.0f64; n_latents * n_bins];
327
328 for _ in 0..max_iter {
329 let (xp, xx_post) = gpfa_e_step(y, &c, &d, &r, &k_all, n_neurons, n_bins, n_latents);
330 x_post = xp;
331 let (c_new, d_new, r_new) = gpfa_m_step(y, &x_post, &xx_post, n_neurons, n_bins, n_latents);
332 c = c_new;
333 d = d_new;
334 r = r_new;
335 let ll = gpfa_log_likelihood(y, &c, &d, &r, &k_all, n_neurons, n_bins, n_latents);
336 log_liks.push(ll);
337 if log_liks.len() > 1 {
338 let prev = log_liks[log_liks.len() - 2];
339 if (ll - prev).abs() < tol {
340 break;
341 }
342 }
343 }
344
345 (x_post, c, d, r, log_liks)
346}
347
348pub struct GpfaResult {
352 pub trajectories: Vec<f64>,
354 pub c: Vec<f64>,
356 pub d: Vec<f64>,
358 pub r: Vec<f64>,
360 pub tau: Vec<f64>,
362 pub log_likelihoods: Vec<f64>,
364 pub n_latents: usize,
365 pub n_bins: usize,
366 pub n_neurons: usize,
367}
368
369pub fn gpfa(
371 trains: &[&[i32]],
372 n_latents: usize,
373 bin_ms: f64,
374 dt: f64,
375 max_iter: usize,
376 tol: f64,
377 seed: u64,
378) -> GpfaResult {
379 let n_neurons = trains.len();
380 if n_neurons == 0 {
381 return GpfaResult {
382 trajectories: vec![],
383 c: vec![],
384 d: vec![],
385 r: vec![],
386 tau: vec![],
387 log_likelihoods: vec![],
388 n_latents: 0,
389 n_bins: 0,
390 n_neurons: 0,
391 };
392 }
393 let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
394 let binned: Vec<Vec<f64>> = trains
395 .iter()
396 .map(|t| {
397 basic::bin_spike_train(t, bin_steps)
398 .into_iter()
399 .map(|c| c as f64)
400 .collect()
401 })
402 .collect();
403 let n_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
404 if n_bins == 0 {
405 return GpfaResult {
406 trajectories: vec![],
407 c: vec![],
408 d: vec![],
409 r: vec![],
410 tau: vec![],
411 log_likelihoods: vec![],
412 n_latents: 0,
413 n_bins: 0,
414 n_neurons,
415 };
416 }
417 let mut y = vec![0.0f64; n_neurons * n_bins];
419 for i in 0..n_neurons {
420 for j in 0..n_bins {
421 y[i * n_bins + j] = binned[i][j];
422 }
423 }
424 let nl = n_latents.min(n_neurons).min(n_bins);
425
426 let mut rng = seed;
428 let mut c = vec![0.0f64; n_neurons * nl];
429 for v in &mut c {
430 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
431 *v = ((rng >> 33) as f64 / (1u64 << 31) as f64 - 0.5) * 0.2;
432 }
433 let mut d_vec = vec![0.0f64; n_neurons];
434 for i in 0..n_neurons {
435 d_vec[i] = y[i * n_bins..i * n_bins + n_bins].iter().sum::<f64>() / n_bins as f64;
436 }
437 let mut r_diag = vec![0.0f64; n_neurons];
438 for i in 0..n_neurons {
439 let mean = d_vec[i];
440 let var: f64 = (0..n_bins)
441 .map(|t| (y[i * n_bins + t] - mean).powi(2))
442 .sum::<f64>()
443 / n_bins as f64;
444 r_diag[i] = var + 1e-4;
445 }
446 let tau = vec![bin_ms * 2.0; nl];
447
448 let (x_post, c, d_vec, r_diag, log_liks) = gpfa_em_from_init(
449 &y, &c, &d_vec, &r_diag, &tau, n_neurons, n_bins, nl, max_iter, tol,
450 );
451
452 GpfaResult {
453 trajectories: x_post,
454 c,
455 d: d_vec,
456 r: r_diag,
457 tau,
458 log_likelihoods: log_liks,
459 n_latents: nl,
460 n_bins,
461 n_neurons,
462 }
463}
464
465pub fn gpfa_transform(
467 new_trains: &[&[i32]],
468 c: &[f64],
469 d: &[f64],
470 r_diag: &[f64],
471 tau: &[f64],
472 n_latents: usize,
473 bin_ms: f64,
474 dt: f64,
475) -> Vec<f64> {
476 let n_neurons = new_trains.len();
477 if n_neurons == 0 || c.is_empty() {
478 return vec![];
479 }
480 let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
481 let binned: Vec<Vec<f64>> = new_trains
482 .iter()
483 .map(|t| {
484 basic::bin_spike_train(t, bin_steps)
485 .into_iter()
486 .map(|v| v as f64)
487 .collect()
488 })
489 .collect();
490 let n_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
491 if n_bins == 0 {
492 return vec![];
493 }
494 let mut y = vec![0.0f64; n_neurons * n_bins];
495 for i in 0..n_neurons {
496 for j in 0..n_bins {
497 y[i * n_bins + j] = binned[i][j];
498 }
499 }
500 let k_all: Vec<Vec<f64>> = (0..n_latents)
501 .map(|j| gp_kernel(n_bins, tau[j], 1.0))
502 .collect();
503 let (x_post, _) = gpfa_e_step(&y, c, d, r_diag, &k_all, n_neurons, n_bins, n_latents);
504 x_post
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 fn make_trains() -> Vec<Vec<i32>> {
512 let mut trains = Vec::new();
513 for n in 0..4 {
514 let mut t = vec![0i32; 100];
515 let step = 3 + n * 2;
516 for i in (0..100).step_by(step) {
517 t[i] = 1;
518 }
519 trains.push(t);
520 }
521 trains
522 }
523
524 #[test]
525 fn test_gpfa_basic() {
526 let trains = make_trains();
527 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
528 let result = gpfa(&refs, 2, 10.0, 0.001, 5, 1e-4, 42);
529 assert_eq!(result.n_neurons, 4);
530 assert_eq!(result.n_latents, 2);
531 assert!(!result.trajectories.is_empty());
532 assert!(!result.log_likelihoods.is_empty());
533 }
534
535 #[test]
536 fn test_gpfa_empty() {
537 let result = gpfa(&[], 2, 10.0, 0.001, 5, 1e-4, 42);
538 assert_eq!(result.n_neurons, 0);
539 assert!(result.trajectories.is_empty());
540 }
541
542 #[test]
543 fn test_gpfa_single_neuron() {
544 let train = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
545 let refs = vec![train.as_slice()];
546 let result = gpfa(&refs, 1, 5.0, 0.001, 3, 1e-4, 42);
547 assert_eq!(result.n_neurons, 1);
548 assert_eq!(result.n_latents, 1);
549 }
550
551 #[test]
552 fn test_gpfa_convergence() {
553 let trains = make_trains();
554 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
555 let result = gpfa(&refs, 2, 10.0, 0.001, 20, 1e-4, 42);
556 if result.log_likelihoods.len() > 2 {
558 let last = result.log_likelihoods[result.log_likelihoods.len() - 1];
559 let second = result.log_likelihoods[1];
560 assert!(
561 last >= second - 1.0,
562 "LL should generally increase: {second} -> {last}"
563 );
564 }
565 }
566
567 #[test]
568 fn test_gpfa_transform() {
569 let trains = make_trains();
570 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
571 let result = gpfa(&refs, 2, 10.0, 0.001, 5, 1e-4, 42);
572
573 let new_trains = make_trains();
574 let new_refs: Vec<&[i32]> = new_trains.iter().map(|t| t.as_slice()).collect();
575 let projected = gpfa_transform(
576 &new_refs,
577 &result.c,
578 &result.d,
579 &result.r,
580 &result.tau,
581 result.n_latents,
582 10.0,
583 0.001,
584 );
585 assert!(!projected.is_empty());
586 assert_eq!(projected.len(), result.n_latents * result.n_bins);
587 }
588
589 #[test]
590 fn test_gpfa_transform_empty() {
591 let proj = gpfa_transform(&[], &[], &[], &[], &[], 0, 10.0, 0.001);
592 assert!(proj.is_empty());
593 }
594
595 #[test]
596 fn test_gp_kernel_shape() {
597 let k = gp_kernel(10, 5.0, 1.0);
598 assert_eq!(k.len(), 100);
599 for i in 0..10 {
601 assert!((k[i * 10 + i] - 1.0).abs() < 1e-10);
602 }
603 for i in 0..10 {
605 for j in 0..10 {
606 assert!((k[i * 10 + j] - k[j * 10 + i]).abs() < 1e-12);
607 }
608 }
609 }
610
611 #[test]
612 fn test_gp_kernel_decay() {
613 let k = gp_kernel(20, 3.0, 1.0);
614 assert!(k[1] > k[10]);
616 }
617}