1use super::basic;
12
13fn gp_kernel(n: usize, tau: f64, sigma: f64) -> Vec<f64> {
17 let mut k = vec![0.0f64; n * n];
18 let tau_sq = tau * tau + 1e-12;
19 let sigma_sq = sigma * sigma;
20 for i in 0..n {
21 for j in 0..n {
22 let diff = i as f64 - j as f64;
23 k[i * n + j] = sigma_sq * (-0.5 * diff * diff / tau_sq).exp();
24 }
25 }
26 k
27}
28
29fn mat_inv(a: &[f64], n: usize) -> Vec<f64> {
31 let mut aug = vec![0.0f64; n * 2 * n];
32 for i in 0..n {
33 for j in 0..n {
34 aug[i * 2 * n + j] = a[i * n + j];
35 }
36 aug[i * 2 * n + n + i] = 1.0;
37 }
38 for col in 0..n {
39 let mut max_row = col;
40 let mut max_val = aug[col * 2 * n + col].abs();
41 for row in col + 1..n {
42 let v = aug[row * 2 * n + col].abs();
43 if v > max_val {
44 max_val = v;
45 max_row = row;
46 }
47 }
48 if max_val < 1e-30 {
49 continue;
50 }
51 if max_row != col {
52 for k in 0..2 * n {
53 aug.swap(col * 2 * n + k, max_row * 2 * n + k);
54 }
55 }
56 let pivot = aug[col * 2 * n + col];
57 for k in 0..2 * n {
58 aug[col * 2 * n + k] /= pivot;
59 }
60 for row in 0..n {
61 if row == col {
62 continue;
63 }
64 let factor = aug[row * 2 * n + col];
65 for k in 0..2 * n {
66 aug[row * 2 * n + k] -= factor * aug[col * 2 * n + k];
67 }
68 }
69 }
70 let mut inv = vec![0.0f64; n * n];
71 for i in 0..n {
72 for j in 0..n {
73 inv[i * n + j] = aug[i * 2 * n + n + j];
74 }
75 }
76 inv
77}
78
79fn mat_solve(a: &[f64], b: &[f64], n: usize, m: usize) -> Vec<f64> {
81 let mut aug = vec![0.0f64; n * (n + m)];
82 let w = n + m;
83 for i in 0..n {
84 for j in 0..n {
85 aug[i * w + j] = a[i * n + j];
86 }
87 for j in 0..m {
88 aug[i * w + n + j] = b[i * m + j];
89 }
90 }
91 for col in 0..n {
92 let mut max_row = col;
93 let mut max_val = aug[col * w + col].abs();
94 for row in col + 1..n {
95 let v = aug[row * w + col].abs();
96 if v > max_val {
97 max_val = v;
98 max_row = row;
99 }
100 }
101 if max_val < 1e-30 {
102 continue;
103 }
104 if max_row != col {
105 for k in 0..w {
106 aug.swap(col * w + k, max_row * w + k);
107 }
108 }
109 let pivot = aug[col * w + col];
110 for k in 0..w {
111 aug[col * w + k] /= pivot;
112 }
113 for row in 0..n {
114 if row == col {
115 continue;
116 }
117 let factor = aug[row * w + col];
118 for k in 0..w {
119 aug[row * w + k] -= factor * aug[col * w + k];
120 }
121 }
122 }
123 let mut x = vec![0.0f64; n * m];
124 for i in 0..n {
125 for j in 0..m {
126 x[i * m + j] = aug[i * w + n + j];
127 }
128 }
129 x
130}
131
132fn gpfa_e_step(
134 y: &[f64], c: &[f64], d: &[f64], r_diag: &[f64], k_all: &[Vec<f64>], n_neurons: usize,
140 n_bins: usize,
141 n_latents: usize,
142) -> (Vec<f64>, Vec<f64>) {
143 let kt = n_latents * n_bins;
145
146 let r_inv: Vec<f64> = r_diag.iter().map(|&r| 1.0 / (r + 1e-10)).collect();
148
149 let mut ct_rinv_c = vec![0.0f64; n_latents * n_latents];
151 for i in 0..n_latents {
152 for j in 0..n_latents {
153 let mut s = 0.0;
154 for k in 0..n_neurons {
155 s += c[k * n_latents + i] * r_inv[k] * c[k * n_latents + j];
156 }
157 ct_rinv_c[i * n_latents + j] = s;
158 }
159 }
160
161 let mut ct_rinv = vec![0.0f64; n_latents * n_neurons];
163 for i in 0..n_latents {
164 for k in 0..n_neurons {
165 ct_rinv[i * n_neurons + k] = c[k * n_latents + i] * r_inv[k];
166 }
167 }
168
169 let mut prec = vec![0.0f64; kt * kt];
171 for j in 0..n_latents {
172 let slj = j * n_bins;
173 let mut k_reg = k_all[j].clone();
175 for i in 0..n_bins {
176 k_reg[i * n_bins + i] += 1e-6;
177 }
178 let k_eye = vec![0.0f64; n_bins * n_bins]
179 .iter()
180 .enumerate()
181 .map(|(idx, _)| {
182 if idx / n_bins == idx % n_bins {
183 1.0
184 } else {
185 0.0
186 }
187 })
188 .collect::<Vec<f64>>();
189 let k_inv = mat_solve(&k_reg, &k_eye, n_bins, n_bins);
190
191 for i in 0..n_bins {
192 for jj in 0..n_bins {
193 prec[(slj + i) * kt + (slj + jj)] = k_inv[i * n_bins + jj]
194 + ct_rinv_c[j * n_latents + j] * if i == jj { 1.0 } else { 0.0 };
195 }
196 }
197 for k in 0..n_latents {
198 if k != j {
199 let slk = k * n_bins;
200 for i in 0..n_bins {
201 prec[(slj + i) * kt + (slk + i)] = ct_rinv_c[j * n_latents + k];
202 }
203 }
204 }
205 }
206
207 let mut rhs = vec![0.0f64; kt];
209 for t in 0..n_bins {
211 for j in 0..n_latents {
213 let mut s = 0.0;
214 for k in 0..n_neurons {
215 s += ct_rinv[j * n_neurons + k] * (y[k * n_bins + t] - d[k]);
216 }
217 rhs[j * n_bins + t] = s;
218 }
219 }
220
221 for i in 0..kt {
223 prec[i * kt + i] += 1e-8;
224 }
225
226 let rhs_col: Vec<f64> = rhs.clone();
228 let x_vec = mat_solve(&prec, &rhs_col, kt, 1);
229
230 let eye_kt: Vec<f64> = (0..kt * kt)
232 .map(|idx| if idx / kt == idx % kt { 1.0 } else { 0.0 })
233 .collect();
234 let sigma_post = mat_solve(&prec, &eye_kt, kt, kt);
235
236 let mut xx_post = vec![0.0f64; n_latents * n_latents];
238 for t in 0..n_bins {
239 for j in 0..n_latents {
240 let xj = x_vec[j * n_bins + t];
241 for k in 0..n_latents {
242 let xk = x_vec[k * n_bins + t];
243 xx_post[j * n_latents + k] +=
244 xj * xk + sigma_post[(j * n_bins + t) * kt + (k * n_bins + t)];
245 }
246 }
247 }
248
249 (x_vec, xx_post)
250}
251
252fn gpfa_m_step(
254 y: &[f64],
255 x_post: &[f64],
256 xx_post: &[f64],
257 n_neurons: usize,
258 n_bins: usize,
259 n_latents: usize,
260) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
261 let mut d_new = vec![0.0f64; n_neurons];
263 for i in 0..n_neurons {
264 let s: f64 = (0..n_bins).map(|t| y[i * n_bins + t]).sum();
265 d_new[i] = s / n_bins as f64;
266 }
267
268 let mut yx = vec![0.0f64; n_neurons * n_latents];
271 for i in 0..n_neurons {
272 for j in 0..n_latents {
273 let mut s = 0.0;
274 for t in 0..n_bins {
275 s += (y[i * n_bins + t] - d_new[i]) * x_post[j * n_bins + t];
276 }
277 yx[i * n_latents + j] = s;
278 }
279 }
280
281 let mut xx_reg = xx_post.to_vec();
283 for i in 0..n_latents {
284 xx_reg[i * n_latents + i] += 1e-8;
285 }
286 let xx_inv = mat_inv(&xx_reg, n_latents);
287 let mut c_new = vec![0.0f64; n_neurons * n_latents];
288 for i in 0..n_neurons {
289 for j in 0..n_latents {
290 let mut s = 0.0;
291 for k in 0..n_latents {
292 s += yx[i * n_latents + k] * xx_inv[k * n_latents + j];
293 }
294 c_new[i * n_latents + j] = s;
295 }
296 }
297
298 let mut r_new = vec![0.0f64; n_neurons];
300 for i in 0..n_neurons {
301 let yyt: f64 = (0..n_bins)
302 .map(|t| {
303 let v = y[i * n_bins + t] - d_new[i];
304 v * v
305 })
306 .sum::<f64>()
307 / n_bins as f64;
308 let mut cxy = 0.0;
310 for j in 0..n_latents {
311 for t in 0..n_bins {
312 cxy += c_new[i * n_latents + j]
313 * x_post[j * n_bins + t]
314 * (y[i * n_bins + t] - d_new[i]);
315 }
316 }
317 cxy /= n_bins as f64;
318 r_new[i] = (yyt - cxy).max(1e-6);
319 }
320
321 (c_new, d_new, r_new)
322}
323
324pub struct GpfaResult {
328 pub trajectories: Vec<f64>,
330 pub c: Vec<f64>,
332 pub d: Vec<f64>,
334 pub r: Vec<f64>,
336 pub tau: Vec<f64>,
338 pub log_likelihoods: Vec<f64>,
340 pub n_latents: usize,
341 pub n_bins: usize,
342 pub n_neurons: usize,
343}
344
345pub fn gpfa(
347 trains: &[&[i32]],
348 n_latents: usize,
349 bin_ms: f64,
350 dt: f64,
351 max_iter: usize,
352 tol: f64,
353 seed: u64,
354) -> GpfaResult {
355 let n_neurons = trains.len();
356 if n_neurons == 0 {
357 return GpfaResult {
358 trajectories: vec![],
359 c: vec![],
360 d: vec![],
361 r: vec![],
362 tau: vec![],
363 log_likelihoods: vec![],
364 n_latents: 0,
365 n_bins: 0,
366 n_neurons: 0,
367 };
368 }
369 let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
370 let binned: Vec<Vec<f64>> = trains
371 .iter()
372 .map(|t| {
373 basic::bin_spike_train(t, bin_steps)
374 .into_iter()
375 .map(|c| c as f64)
376 .collect()
377 })
378 .collect();
379 let n_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
380 if n_bins == 0 {
381 return GpfaResult {
382 trajectories: vec![],
383 c: vec![],
384 d: vec![],
385 r: vec![],
386 tau: vec![],
387 log_likelihoods: vec![],
388 n_latents: 0,
389 n_bins: 0,
390 n_neurons,
391 };
392 }
393 let mut y = vec![0.0f64; n_neurons * n_bins];
395 for i in 0..n_neurons {
396 for j in 0..n_bins {
397 y[i * n_bins + j] = binned[i][j];
398 }
399 }
400 let nl = n_latents.min(n_neurons).min(n_bins);
401
402 let mut rng = seed;
404 let mut c = vec![0.0f64; n_neurons * nl];
405 for v in &mut c {
406 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
407 *v = ((rng >> 33) as f64 / (1u64 << 31) as f64 - 0.5) * 0.2;
408 }
409 let mut d_vec = vec![0.0f64; n_neurons];
410 for i in 0..n_neurons {
411 d_vec[i] = y[i * n_bins..i * n_bins + n_bins].iter().sum::<f64>() / n_bins as f64;
412 }
413 let mut r_diag = vec![0.0f64; n_neurons];
414 for i in 0..n_neurons {
415 let mean = d_vec[i];
416 let var: f64 = (0..n_bins)
417 .map(|t| (y[i * n_bins + t] - mean).powi(2))
418 .sum::<f64>()
419 / n_bins as f64;
420 r_diag[i] = var + 1e-4;
421 }
422 let tau = vec![bin_ms * 2.0; nl];
423
424 let mut log_liks = Vec::new();
425 let mut x_post = vec![0.0f64; nl * n_bins];
426
427 for _ in 0..max_iter {
428 let k_all: Vec<Vec<f64>> = (0..nl).map(|j| gp_kernel(n_bins, tau[j], 1.0)).collect();
429
430 let (xp, xx_post) = gpfa_e_step(&y, &c, &d_vec, &r_diag, &k_all, n_neurons, n_bins, nl);
431 x_post = xp;
432
433 let (c_new, d_new, r_new) = gpfa_m_step(&y, &x_post, &xx_post, n_neurons, n_bins, nl);
434 c = c_new;
435 d_vec = d_new;
436 r_diag = r_new;
437
438 let mut ll = 0.0f64;
440 for i in 0..n_neurons {
441 for t in 0..n_bins {
442 let mut pred = d_vec[i];
443 for j in 0..nl {
444 pred += c[i * nl + j] * x_post[j * n_bins + t];
445 }
446 let resid = y[i * n_bins + t] - pred;
447 ll -= 0.5 * resid * resid / (r_diag[i] + 1e-10);
448 }
449 }
450 ll -= 0.5 * n_bins as f64 * r_diag.iter().map(|&r| (r + 1e-10).ln()).sum::<f64>();
451 log_liks.push(ll);
452
453 if log_liks.len() > 1 {
454 let prev = log_liks[log_liks.len() - 2];
455 if (ll - prev).abs() < tol {
456 break;
457 }
458 }
459 }
460
461 GpfaResult {
462 trajectories: x_post,
463 c,
464 d: d_vec,
465 r: r_diag,
466 tau,
467 log_likelihoods: log_liks,
468 n_latents: nl,
469 n_bins,
470 n_neurons,
471 }
472}
473
474pub fn gpfa_transform(
476 new_trains: &[&[i32]],
477 c: &[f64],
478 d: &[f64],
479 r_diag: &[f64],
480 tau: &[f64],
481 n_latents: usize,
482 bin_ms: f64,
483 dt: f64,
484) -> Vec<f64> {
485 let n_neurons = new_trains.len();
486 if n_neurons == 0 || c.is_empty() {
487 return vec![];
488 }
489 let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
490 let binned: Vec<Vec<f64>> = new_trains
491 .iter()
492 .map(|t| {
493 basic::bin_spike_train(t, bin_steps)
494 .into_iter()
495 .map(|v| v as f64)
496 .collect()
497 })
498 .collect();
499 let n_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
500 if n_bins == 0 {
501 return vec![];
502 }
503 let mut y = vec![0.0f64; n_neurons * n_bins];
504 for i in 0..n_neurons {
505 for j in 0..n_bins {
506 y[i * n_bins + j] = binned[i][j];
507 }
508 }
509 let k_all: Vec<Vec<f64>> = (0..n_latents)
510 .map(|j| gp_kernel(n_bins, tau[j], 1.0))
511 .collect();
512 let (x_post, _) = gpfa_e_step(&y, c, d, r_diag, &k_all, n_neurons, n_bins, n_latents);
513 x_post
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 fn make_trains() -> Vec<Vec<i32>> {
521 let mut trains = Vec::new();
522 for n in 0..4 {
523 let mut t = vec![0i32; 100];
524 let step = 3 + n * 2;
525 for i in (0..100).step_by(step) {
526 t[i] = 1;
527 }
528 trains.push(t);
529 }
530 trains
531 }
532
533 #[test]
534 fn test_gpfa_basic() {
535 let trains = make_trains();
536 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
537 let result = gpfa(&refs, 2, 10.0, 0.001, 5, 1e-4, 42);
538 assert_eq!(result.n_neurons, 4);
539 assert_eq!(result.n_latents, 2);
540 assert!(!result.trajectories.is_empty());
541 assert!(!result.log_likelihoods.is_empty());
542 }
543
544 #[test]
545 fn test_gpfa_empty() {
546 let result = gpfa(&[], 2, 10.0, 0.001, 5, 1e-4, 42);
547 assert_eq!(result.n_neurons, 0);
548 assert!(result.trajectories.is_empty());
549 }
550
551 #[test]
552 fn test_gpfa_single_neuron() {
553 let train = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
554 let refs = vec![train.as_slice()];
555 let result = gpfa(&refs, 1, 5.0, 0.001, 3, 1e-4, 42);
556 assert_eq!(result.n_neurons, 1);
557 assert_eq!(result.n_latents, 1);
558 }
559
560 #[test]
561 fn test_gpfa_convergence() {
562 let trains = make_trains();
563 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
564 let result = gpfa(&refs, 2, 10.0, 0.001, 20, 1e-4, 42);
565 if result.log_likelihoods.len() > 2 {
567 let last = result.log_likelihoods[result.log_likelihoods.len() - 1];
568 let second = result.log_likelihoods[1];
569 assert!(
570 last >= second - 1.0,
571 "LL should generally increase: {second} -> {last}"
572 );
573 }
574 }
575
576 #[test]
577 fn test_gpfa_transform() {
578 let trains = make_trains();
579 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
580 let result = gpfa(&refs, 2, 10.0, 0.001, 5, 1e-4, 42);
581
582 let new_trains = make_trains();
583 let new_refs: Vec<&[i32]> = new_trains.iter().map(|t| t.as_slice()).collect();
584 let projected = gpfa_transform(
585 &new_refs,
586 &result.c,
587 &result.d,
588 &result.r,
589 &result.tau,
590 result.n_latents,
591 10.0,
592 0.001,
593 );
594 assert!(!projected.is_empty());
595 assert_eq!(projected.len(), result.n_latents * result.n_bins);
596 }
597
598 #[test]
599 fn test_gpfa_transform_empty() {
600 let proj = gpfa_transform(&[], &[], &[], &[], &[], 0, 10.0, 0.001);
601 assert!(proj.is_empty());
602 }
603
604 #[test]
605 fn test_gp_kernel_shape() {
606 let k = gp_kernel(10, 5.0, 1.0);
607 assert_eq!(k.len(), 100);
608 for i in 0..10 {
610 assert!((k[i * 10 + i] - 1.0).abs() < 1e-10);
611 }
612 for i in 0..10 {
614 for j in 0..10 {
615 assert!((k[i * 10 + j] - k[j * 10 + i]).abs() < 1e-12);
616 }
617 }
618 }
619
620 #[test]
621 fn test_gp_kernel_decay() {
622 let k = gp_kernel(20, 3.0, 1.0);
623 assert!(k[1] > k[10]);
625 }
626}