1use super::basic;
10use nalgebra::{DMatrix, SymmetricEigen};
11
12fn symmetric_eigen(a: &[f64], n: usize) -> (Vec<f64>, Vec<f64>) {
21 let se = SymmetricEigen::new(DMatrix::<f64>::from_row_slice(n, n, a));
22 let mut idx: Vec<usize> = (0..n).collect();
23 idx.sort_by(|&i, &j| se.eigenvalues[j].partial_cmp(&se.eigenvalues[i]).unwrap());
24
25 let vals: Vec<f64> = idx.iter().map(|&i| se.eigenvalues[i]).collect();
26 let mut vecs = vec![0.0f64; n * n];
27 for (new_col, &old_col) in idx.iter().enumerate() {
28 let mut pivot = 0usize;
30 let mut max_abs = 0.0f64;
31 for r in 0..n {
32 let v = se.eigenvectors[(r, old_col)].abs();
33 if v > max_abs {
34 max_abs = v;
35 pivot = r;
36 }
37 }
38 let sign = if se.eigenvectors[(pivot, old_col)] < 0.0 {
39 -1.0
40 } else {
41 1.0
42 };
43 for r in 0..n {
44 vecs[r * n + new_col] = sign * se.eigenvectors[(r, old_col)];
45 }
46 }
47 (vals, vecs)
48}
49
50fn spd_solve(a: &[f64], n: usize, b: &[f64], k: usize) -> Vec<f64> {
53 let chol = DMatrix::<f64>::from_row_slice(n, n, a)
54 .cholesky()
55 .expect("factor-analysis system must be symmetric positive-definite");
56 let solved = chol.solve(&DMatrix::<f64>::from_row_slice(n, k, b));
57 let mut out = vec![0.0f64; n * k];
58 for i in 0..n {
59 for j in 0..k {
60 out[i * k + j] = solved[(i, j)];
61 }
62 }
63 out
64}
65
66fn spd_inverse(a: &[f64], n: usize) -> Vec<f64> {
68 let inv = DMatrix::<f64>::from_row_slice(n, n, a)
69 .cholesky()
70 .expect("factor-analysis system must be symmetric positive-definite")
71 .inverse();
72 let mut out = vec![0.0f64; n * n];
73 for i in 0..n {
74 for j in 0..n {
75 out[i * n + j] = inv[(i, j)];
76 }
77 }
78 out
79}
80
81pub fn pca_from_centered(
86 mat: &[f64],
87 d: usize,
88 t: usize,
89 n_components: usize,
90) -> (Vec<f64>, Vec<f64>) {
91 let denom = (t - 1).max(1) as f64;
92 let mut cov = vec![0.0f64; d * d];
93 for i in 0..d {
94 for j in i..d {
95 let mut s = 0.0;
96 for k in 0..t {
97 s += mat[i * t + k] * mat[j * t + k];
98 }
99 s /= denom;
100 cov[i * d + j] = s;
101 cov[j * d + i] = s;
102 }
103 }
104 let (eigvals, eigvecs) = symmetric_eigen(&cov, d);
105 let nc = n_components.min(d);
106 let total: f64 = eigvals.iter().sum();
107 let explained: Vec<f64> = eigvals[..nc]
108 .iter()
109 .map(|&v| if total > 0.0 { v / total } else { v })
110 .collect();
111 let mut projected = vec![0.0f64; nc * t];
112 for c in 0..nc {
113 for tt in 0..t {
114 let mut s = 0.0;
115 for i in 0..d {
116 s += eigvecs[i * d + c] * mat[i * t + tt];
117 }
118 projected[c * t + tt] = s;
119 }
120 }
121 (projected, explained)
122}
123
124pub fn demixed_from_centered(
127 mean_mat: &[f64],
128 n_cond: usize,
129 t: usize,
130 n_components: usize,
131) -> (Vec<f64>, Vec<f64>) {
132 let denom = n_cond as f64;
133 let mut cov = vec![0.0f64; t * t];
134 for i in 0..t {
135 for j in i..t {
136 let mut s = 0.0;
137 for c in 0..n_cond {
138 s += mean_mat[c * t + i] * mean_mat[c * t + j];
139 }
140 s /= denom;
141 cov[i * t + j] = s;
142 cov[j * t + i] = s;
143 }
144 }
145 let (eigvals, eigvecs) = symmetric_eigen(&cov, t);
146 let nc = n_components.min(t);
147 let total: f64 = eigvals.iter().sum();
148 let explained: Vec<f64> = eigvals[..nc]
149 .iter()
150 .map(|&v| if total > 0.0 { v / total } else { v })
151 .collect();
152 let mut projected = vec![0.0f64; n_cond * nc];
153 for c in 0..n_cond {
154 for k in 0..nc {
155 let mut s = 0.0;
156 for j in 0..t {
157 s += mean_mat[c * t + j] * eigvecs[j * t + k];
158 }
159 projected[c * nc + k] = s;
160 }
161 }
162 (projected, explained)
163}
164
165pub fn fa_from_centered(
172 mat: &[f64],
173 d: usize,
174 t: usize,
175 n_factors: usize,
176 n_iter: usize,
177) -> (Vec<f64>, Vec<f64>) {
178 let tf = t as f64;
179 let mut cov = vec![0.0f64; d * d];
180 for i in 0..d {
181 for j in i..d {
182 let mut s = 0.0;
183 for k in 0..t {
184 s += mat[i * t + k] * mat[j * t + k];
185 }
186 s /= tf;
187 cov[i * d + j] = s;
188 cov[j * d + i] = s;
189 }
190 }
191 let nf = n_factors.min(d);
192 let (eigvals, eigvecs) = symmetric_eigen(&cov, d);
193 let mut loadings = vec![0.0f64; d * nf];
195 for c in 0..nf {
196 let scale = eigvals[c].max(0.0).sqrt();
197 for i in 0..d {
198 loadings[i * nf + c] = eigvecs[i * d + c] * scale;
199 }
200 }
201 let mut psi: Vec<f64> = (0..d).map(|i| cov[i * d + i]).collect();
202
203 for _ in 0..n_iter {
204 let psi_inv: Vec<f64> = psi.iter().map(|&p| 1.0 / (p + 1e-10)).collect();
205
206 let mut m = vec![0.0f64; nf * nf];
208 for a in 0..nf {
209 for b in 0..nf {
210 let mut s = 0.0;
211 for i in 0..d {
212 s += loadings[i * nf + a] * psi_inv[i] * loadings[i * nf + b];
213 }
214 m[a * nf + b] = s + if a == b { 1.0 } else { 0.0 };
215 }
216 }
217 let m_inv = spd_inverse(&m, nf);
218
219 let mut beta = vec![0.0f64; nf * d];
221 for a in 0..nf {
222 for i in 0..d {
223 let mut s = 0.0;
224 for kk in 0..nf {
225 s += m_inv[a * nf + kk] * loadings[i * nf + kk] * psi_inv[i];
226 }
227 beta[a * d + i] = s;
228 }
229 }
230
231 let mut ez = vec![0.0f64; nf * t];
233 for a in 0..nf {
234 for tt in 0..t {
235 let mut s = 0.0;
236 for i in 0..d {
237 s += beta[a * d + i] * mat[i * t + tt];
238 }
239 ez[a * t + tt] = s;
240 }
241 }
242
243 let mut ezzt = vec![0.0f64; nf * nf];
245 for a in 0..nf {
246 for b in 0..nf {
247 let mut s = 0.0;
248 for tt in 0..t {
249 s += ez[a * t + tt] * ez[b * t + tt];
250 }
251 ezzt[a * nf + b] = nf as f64 * m_inv[a * nf + b] + s / tf;
252 }
253 }
254
255 let mut mat_ez_t = vec![0.0f64; d * nf];
257 for i in 0..d {
258 for a in 0..nf {
259 let mut s = 0.0;
260 for tt in 0..t {
261 s += mat[i * t + tt] * ez[a * t + tt];
262 }
263 mat_ez_t[i * nf + a] = s / tf;
264 }
265 }
266
267 let mut rhs = vec![0.0f64; nf * d];
269 for a in 0..nf {
270 for i in 0..d {
271 rhs[a * d + i] = mat_ez_t[i * nf + a];
272 }
273 }
274 let solved = spd_solve(&ezzt, nf, &rhs, d); for i in 0..d {
276 for a in 0..nf {
277 loadings[i * nf + a] = solved[a * d + i];
278 }
279 }
280
281 for i in 0..d {
283 let mut s = 0.0;
284 for tt in 0..t {
285 let mut l_ez = 0.0;
286 for a in 0..nf {
287 l_ez += loadings[i * nf + a] * ez[a * t + tt];
288 }
289 s += l_ez * mat[i * t + tt];
290 }
291 psi[i] = (cov[i * d + i] - s / tf).max(1e-6);
292 }
293 }
294
295 (loadings, psi)
296}
297
298fn binned_centred(trains: &[&[i32]], bin_size: usize) -> (Vec<f64>, usize, usize) {
301 let binned: Vec<Vec<f64>> = trains
302 .iter()
303 .map(|t| {
304 basic::bin_spike_train(t, bin_size)
305 .into_iter()
306 .map(|c| c as f64)
307 .collect()
308 })
309 .collect();
310 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
311 let d = trains.len();
312 if min_bins == 0 {
313 return (vec![], d, 0);
314 }
315 let mut mat = vec![0.0f64; d * min_bins];
316 for i in 0..d {
317 let mean: f64 = binned[i][..min_bins].iter().sum::<f64>() / min_bins as f64;
318 for j in 0..min_bins {
319 mat[i * min_bins + j] = binned[i][j] - mean;
320 }
321 }
322 (mat, d, min_bins)
323}
324
325pub fn spike_train_pca(
327 trains: &[&[i32]],
328 n_components: usize,
329 bin_size: usize,
330) -> (Vec<f64>, Vec<f64>) {
331 if trains.is_empty() {
332 return (vec![], vec![]);
333 }
334 let (mat, d, min_bins) = binned_centred(trains, bin_size);
335 if min_bins == 0 {
336 return (vec![], vec![]);
337 }
338 if d < 2 {
339 return (mat[..min_bins].to_vec(), vec![1.0]);
340 }
341 pca_from_centered(&mat, d, min_bins, n_components)
342}
343
344pub fn demixed_pca(
346 conditions: &[Vec<&[i32]>],
347 n_components: usize,
348 bin_size: usize,
349) -> (Vec<f64>, Vec<f64>) {
350 if conditions.len() < 2 {
351 return (vec![], vec![]);
352 }
353 let mut all_means: Vec<Vec<f64>> = Vec::new();
354 for trains in conditions {
355 let binned: Vec<Vec<f64>> = trains
356 .iter()
357 .map(|t| {
358 basic::bin_spike_train(t, bin_size)
359 .into_iter()
360 .map(|c| c as f64)
361 .collect()
362 })
363 .collect();
364 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
365 if min_bins == 0 {
366 continue;
367 }
368 let n = binned.len();
369 let mut mean = vec![0.0f64; min_bins];
370 for b in &binned {
371 for (j, m) in mean.iter_mut().enumerate() {
372 *m += b[j];
373 }
374 }
375 for v in &mut mean {
376 *v /= n as f64;
377 }
378 all_means.push(mean);
379 }
380 if all_means.len() < 2 {
381 return (vec![], vec![]);
382 }
383 let min_bins = all_means.iter().map(|m| m.len()).min().unwrap();
384 let n_cond = all_means.len();
385 let mut grand = vec![0.0f64; min_bins];
386 for m in &all_means {
387 for (j, g) in grand.iter_mut().enumerate() {
388 *g += m[j];
389 }
390 }
391 for v in &mut grand {
392 *v /= n_cond as f64;
393 }
394 let mut mean_mat = vec![0.0f64; n_cond * min_bins];
395 for (i, m) in all_means.iter().enumerate() {
396 for j in 0..min_bins {
397 mean_mat[i * min_bins + j] = m[j] - grand[j];
398 }
399 }
400 demixed_from_centered(&mean_mat, n_cond, min_bins, n_components)
401}
402
403pub fn factor_analysis(
405 trains: &[&[i32]],
406 n_factors: usize,
407 bin_size: usize,
408 n_iter: usize,
409) -> (Vec<f64>, Vec<f64>) {
410 let d = trains.len();
411 if d == 0 {
412 return (vec![], vec![]);
413 }
414 let (mat, _d, t) = binned_centred(trains, bin_size);
415 if t == 0 {
416 return (vec![0.0; d * n_factors.min(d)], vec![1.0; d]);
417 }
418 fa_from_centered(&mat, d, t, n_factors, n_iter)
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 fn make_trains() -> Vec<Vec<i32>> {
426 let mut trains = Vec::new();
427 for n in 0..5 {
428 let mut t = vec![0i32; 200];
429 let step = 5 + n * 3;
430 for i in (0..200).step_by(step) {
431 t[i] = 1;
432 }
433 trains.push(t);
434 }
435 trains
436 }
437
438 #[test]
439 fn test_spike_train_pca_basic() {
440 let trains = make_trains();
441 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
442 let (proj, explained) = spike_train_pca(&refs, 3, 10);
443 assert_eq!(explained.len(), 3);
444 let total: f64 = explained.iter().sum();
445 assert!(total <= 1.0 + 1e-6, "Total explained {total} > 1");
446 assert!(explained[0] >= explained[1]);
447 assert!(!proj.is_empty());
448 }
449
450 #[test]
451 fn test_spike_train_pca_empty() {
452 let (proj, expl) = spike_train_pca(&[], 3, 10);
453 assert!(proj.is_empty());
454 assert!(expl.is_empty());
455 }
456
457 #[test]
458 fn test_spike_train_pca_single_neuron() {
459 let train = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
460 let refs = vec![train.as_slice()];
461 let (proj, expl) = spike_train_pca(&refs, 1, 2);
462 assert_eq!(expl.len(), 1);
463 assert!(!proj.is_empty());
464 }
465
466 #[test]
467 fn test_demixed_pca_basic() {
468 let trains_a = make_trains();
469 let trains_b: Vec<Vec<i32>> = (0..5)
470 .map(|n| {
471 let mut t = vec![0i32; 200];
472 let step = 3 + n * 2;
473 for i in (0..200).step_by(step) {
474 t[i] = 1;
475 }
476 t
477 })
478 .collect();
479 let cond_a: Vec<&[i32]> = trains_a.iter().map(|t| t.as_slice()).collect();
480 let cond_b: Vec<&[i32]> = trains_b.iter().map(|t| t.as_slice()).collect();
481 let conditions = vec![cond_a, cond_b];
482 let (proj, expl) = demixed_pca(&conditions, 2, 10);
483 assert!(!expl.is_empty());
484 assert!(!proj.is_empty());
485 }
486
487 #[test]
488 fn test_demixed_pca_single_condition() {
489 let t = [vec![1, 0, 1, 0]];
490 let refs: Vec<&[i32]> = t.iter().map(|v| v.as_slice()).collect();
491 let (proj, expl) = demixed_pca(&[refs], 2, 2);
492 assert!(proj.is_empty());
493 assert!(expl.is_empty());
494 }
495
496 #[test]
497 fn test_factor_analysis_basic() {
498 let trains = make_trains();
499 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
500 let (loadings, psi) = factor_analysis(&refs, 2, 10, 20);
501 assert_eq!(loadings.len(), 5 * 2);
502 assert_eq!(psi.len(), 5);
503 assert!(psi.iter().all(|&p| p > 0.0));
504 }
505
506 #[test]
507 fn test_factor_analysis_empty() {
508 let (l, p) = factor_analysis(&[], 2, 10, 20);
509 assert!(l.is_empty());
510 assert!(p.is_empty());
511 }
512
513 #[test]
514 fn test_symmetric_eigen_identity() {
515 let eye = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
516 let (vals, _) = symmetric_eigen(&eye, 3);
517 for v in &vals {
518 assert!((v - 1.0).abs() < 1e-10);
519 }
520 }
521
522 #[test]
523 fn test_symmetric_eigen_known() {
524 let a = vec![2.0, 1.0, 1.0, 2.0];
526 let (vals, _) = symmetric_eigen(&a, 2);
527 assert!((vals[0] - 3.0).abs() < 1e-10);
528 assert!((vals[1] - 1.0).abs() < 1e-10);
529 }
530
531 #[test]
532 fn test_symmetric_eigen_sign_canonical() {
533 let a = vec![2.0, 1.0, 1.0, 2.0];
535 let (_, vecs) = symmetric_eigen(&a, 2);
536 for c in 0..2 {
537 let mut pivot = 0usize;
538 let mut max_abs = 0.0f64;
539 for r in 0..2 {
540 if vecs[r * 2 + c].abs() > max_abs {
541 max_abs = vecs[r * 2 + c].abs();
542 pivot = r;
543 }
544 }
545 assert!(vecs[pivot * 2 + c] > 0.0, "column {c} not sign-canonical");
546 }
547 }
548
549 #[test]
550 fn test_spd_solve_matches_inverse() {
551 let a = vec![4.0, 1.0, 1.0, 3.0];
553 let b = vec![1.0, 2.0];
554 let x = spd_solve(&a, 2, &b, 1);
555 let det = 4.0 * 3.0 - 1.0 * 1.0;
556 let ref0 = (3.0 * 1.0 - 1.0 * 2.0) / det;
557 let ref1 = (-1.0 * 1.0 + 4.0 * 2.0) / det;
558 assert!((x[0] - ref0).abs() < 1e-12 && (x[1] - ref1).abs() < 1e-12);
559 }
560
561 #[test]
562 fn test_pca_explains_variance() {
563 let trains = make_trains();
564 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
565 let (_, explained) = spike_train_pca(&refs, 5, 10);
566 let total: f64 = explained.iter().sum();
567 assert!(
568 (total - 1.0).abs() < 0.05,
569 "Total explained {total} should be ~1.0"
570 );
571 }
572}