1use nalgebra::{Cholesky, Complex, DMatrix};
10use rayon::prelude::*;
11use std::f64::consts::PI;
12
13use super::basic::bin_spike_train;
14
15fn solve_spd(s: &[f64], b: &[f64], n: usize, m: usize) -> Vec<f64> {
28 let s_mat = DMatrix::<f64>::from_row_slice(n, n, s);
29 let b_mat = DMatrix::<f64>::from_row_slice(n, m, b);
30 match Cholesky::new(s_mat) {
31 Some(chol) => {
32 let x = chol.solve(&b_mat);
33 let mut out = vec![0.0_f64; n * m];
34 for i in 0..n {
35 for j in 0..m {
36 out[i * m + j] = x[(i, j)];
37 }
38 }
39 out
40 }
41 None => vec![0.0_f64; n * m],
42 }
43}
44
45fn spectral_matrix(beta: &[f64], d: usize, order: usize, f: f64) -> DMatrix<Complex<f64>> {
52 let mut a_f = DMatrix::<Complex<f64>>::identity(d, d);
53 for k in 0..order {
54 let angle = -2.0 * PI * f * (k + 1) as f64;
55 let exp_val = Complex::new(angle.cos(), angle.sin());
56 for i in 0..d {
57 for j in 0..d {
58 let coeff = beta[(k * d + j) * d + i];
59 a_f[(i, j)] -= Complex::new(coeff, 0.0) * exp_val;
60 }
61 }
62 }
63 a_f
64}
65
66fn spectral_transfer_inverse(a_f: DMatrix<Complex<f64>>) -> Option<DMatrix<Complex<f64>>> {
75 let lu = a_f.lu();
76 if lu.determinant().norm() < 1e-30 {
77 return None;
78 }
79 lu.try_inverse()
80}
81
82fn var_coefficients(trains_binned: &[Vec<f64>], order: usize) -> (Vec<f64>, Vec<f64>) {
86 let d = trains_binned.len();
87 let t = if d > 0 { trains_binned[0].len() } else { 0 };
88 if t <= order + 1 || d == 0 {
89 return (vec![0.0; order * d * d], identity_flat(d));
90 }
91 let n_pts = t - order;
92 let x_cols = order * d;
93
94 let mut y_cols = vec![vec![0.0_f64; n_pts]; d];
96 for ch in 0..d {
97 for i in 0..n_pts {
98 y_cols[ch][i] = trains_binned[ch][order + i];
99 }
100 }
101
102 let mut x_cols_data = vec![vec![0.0_f64; n_pts]; x_cols];
104 for i in 0..n_pts {
105 for k in 0..order {
106 for ch in 0..d {
107 x_cols_data[k * d + ch][i] = trains_binned[ch][order - k - 1 + i];
108 }
109 }
110 }
111
112 let mut xtx = vec![0.0_f64; x_cols * x_cols];
114 xtx.par_chunks_exact_mut(x_cols)
115 .enumerate()
116 .for_each(|(i, row)| {
117 for j in 0..=i {
118 let dot = crate::simd::dot_f64_dispatch(&x_cols_data[i], &x_cols_data[j]);
119 row[j] = dot + if i == j { 1e-8 } else { 0.0 };
120 }
121 });
122 for i in 0..x_cols {
124 for j in (i + 1)..x_cols {
125 xtx[i * x_cols + j] = xtx[j * x_cols + i];
126 }
127 }
128
129 let mut xty = vec![0.0_f64; x_cols * d];
131 xty.par_chunks_exact_mut(d)
132 .enumerate()
133 .for_each(|(i, row)| {
134 for j in 0..d {
135 row[j] = crate::simd::dot_f64_dispatch(&x_cols_data[i], &y_cols[j]);
136 }
137 });
138
139 let beta = solve_spd(&xtx, &xty, x_cols, d);
141
142 let mut sigma = vec![0.0_f64; d * d];
144 let n_norm = n_pts.max(1) as f64;
145
146 let res_cols: Vec<Vec<f64>> = (0..d)
148 .into_par_iter()
149 .map(|j| {
150 let mut res = vec![0.0_f64; n_pts];
151 for p in 0..n_pts {
152 let mut r = y_cols[j][p];
153 for c in 0..x_cols {
154 r -= x_cols_data[c][p] * beta[c * d + j];
155 }
156 res[p] = r;
157 }
158 res
159 })
160 .collect();
161
162 for i in 0..d {
163 for j in 0..=i {
164 let dot = crate::simd::dot_f64_dispatch(&res_cols[i], &res_cols[j]);
165 let val = dot / n_norm;
166 sigma[i * d + j] = val;
167 sigma[j * d + i] = val;
168 }
169 }
170
171 (beta, sigma)
172}
173
174fn identity_flat(d: usize) -> Vec<f64> {
175 let mut m = vec![0.0_f64; d * d];
176 for i in 0..d {
177 m[i * d + i] = 1.0;
178 }
179 m
180}
181
182fn sse_ols(x: &[f64], y: &[f64], n_pts: usize, x_cols: usize) -> f64 {
184 let mut xtx = vec![0.0_f64; x_cols * x_cols];
186 for i in 0..x_cols {
187 for j in 0..x_cols {
188 let mut s = 0.0;
189 for p in 0..n_pts {
190 s += x[p * x_cols + i] * x[p * x_cols + j];
191 }
192 xtx[i * x_cols + j] = s + if i == j { 1e-8 } else { 0.0 };
193 }
194 }
195 let mut xty = vec![0.0_f64; x_cols];
197 for i in 0..x_cols {
198 let mut s = 0.0;
199 for p in 0..n_pts {
200 s += x[p * x_cols + i] * y[p];
201 }
202 xty[i] = s;
203 }
204 let beta = solve_spd(&xtx, &xty, x_cols, 1);
206 let mut sse = 0.0_f64;
207 for p in 0..n_pts {
208 let mut pred = 0.0;
209 for c in 0..x_cols {
210 pred += x[p * x_cols + c] * beta[c];
211 }
212 let r = y[p] - pred;
213 sse += r * r;
214 }
215 sse
216}
217
218pub fn pairwise_granger_causality(
223 source: &[i32],
224 target: &[i32],
225 bin_size: usize,
226 order: usize,
227) -> f64 {
228 let cs: Vec<f64> = bin_spike_train(source, bin_size)
229 .iter()
230 .map(|&v| v as f64)
231 .collect();
232 let ct: Vec<f64> = bin_spike_train(target, bin_size)
233 .iter()
234 .map(|&v| v as f64)
235 .collect();
236 let n = cs.len().min(ct.len());
237 if n <= 2 * order {
238 return 0.0;
239 }
240
241 let n_pts = n - order;
242 let y: Vec<f64> = ct[order..n].to_vec();
243
244 let r_cols = order;
246 let mut x_r = vec![0.0_f64; n_pts * r_cols];
247 for p in 0..n_pts {
248 for k in 0..order {
249 x_r[p * r_cols + k] = ct[order - k - 1 + p];
250 }
251 }
252 let sse_r = sse_ols(&x_r, &y, n_pts, r_cols);
253
254 let f_cols = 2 * order;
256 let mut x_f = vec![0.0_f64; n_pts * f_cols];
257 for p in 0..n_pts {
258 for k in 0..order {
259 x_f[p * f_cols + k] = ct[order - k - 1 + p];
260 x_f[p * f_cols + order + k] = cs[order - k - 1 + p];
261 }
262 }
263 let sse_f = sse_ols(&x_f, &y, n_pts, f_cols);
264
265 if sse_f <= 0.0 {
266 return 0.0;
267 }
268 (sse_r.max(1e-30) / sse_f.max(1e-30)).ln()
269}
270
271pub fn conditional_granger_causality(
274 source: &[i32],
275 target: &[i32],
276 condition: &[i32],
277 bin_size: usize,
278 order: usize,
279) -> f64 {
280 let cs: Vec<f64> = bin_spike_train(source, bin_size)
281 .iter()
282 .map(|&v| v as f64)
283 .collect();
284 let ct: Vec<f64> = bin_spike_train(target, bin_size)
285 .iter()
286 .map(|&v| v as f64)
287 .collect();
288 let cc: Vec<f64> = bin_spike_train(condition, bin_size)
289 .iter()
290 .map(|&v| v as f64)
291 .collect();
292 let n = cs.len().min(ct.len()).min(cc.len());
293 if n <= 2 * order {
294 return 0.0;
295 }
296
297 let n_pts = n - order;
298 let y: Vec<f64> = ct[order..n].to_vec();
299
300 let c_cols = 2 * order;
302 let mut x_c = vec![0.0_f64; n_pts * c_cols];
303 for p in 0..n_pts {
304 for k in 0..order {
305 x_c[p * c_cols + k] = ct[order - k - 1 + p];
306 x_c[p * c_cols + order + k] = cc[order - k - 1 + p];
307 }
308 }
309 let sse_c = sse_ols(&x_c, &y, n_pts, c_cols);
310
311 let f_cols = 3 * order;
313 let mut x_f = vec![0.0_f64; n_pts * f_cols];
314 for p in 0..n_pts {
315 for k in 0..order {
316 x_f[p * f_cols + k] = ct[order - k - 1 + p];
317 x_f[p * f_cols + order + k] = cc[order - k - 1 + p];
318 x_f[p * f_cols + 2 * order + k] = cs[order - k - 1 + p];
319 }
320 }
321 let sse_f = sse_ols(&x_f, &y, n_pts, f_cols);
322
323 if sse_f <= 0.0 {
324 return 0.0;
325 }
326 (sse_c.max(1e-30) / sse_f.max(1e-30)).ln()
327}
328
329pub fn spectral_granger_causality(
332 trains: &[&[i32]],
333 bin_size: usize,
334 order: usize,
335 n_freqs: usize,
336) -> (Vec<f64>, usize) {
337 let binned: Vec<Vec<f64>> = trains
338 .iter()
339 .map(|t| {
340 bin_spike_train(t, bin_size)
341 .iter()
342 .map(|&v| v as f64)
343 .collect()
344 })
345 .collect();
346 let d = binned.len();
347 let (beta, sigma) = var_coefficients(&binned, order);
348
349 let mut gc = vec![0.0_f64; d * d * n_freqs];
350
351 for fi in 0..n_freqs {
352 let f = fi as f64 / (2 * n_freqs) as f64; let a_f = spectral_matrix(&beta, d, order, f);
356 let h = match spectral_transfer_inverse(a_f) {
357 Some(inv) => inv,
358 None => continue,
359 };
360
361 let sigma_c =
363 DMatrix::<Complex<f64>>::from_fn(d, d, |i, j| Complex::new(sigma[i * d + j], 0.0));
364 let s = &h * &sigma_c * h.adjoint();
365
366 for i in 0..d {
367 for j in 0..d {
368 if i == j {
369 continue;
370 }
371 let s_ii = s[(i, i)].norm();
372 if s_ii > 1e-30 {
373 let h_ij_sq = h[(i, j)].norm_sqr();
374 let reduced = s_ii - sigma[j * d + j] * h_ij_sq;
375 if reduced > 0.0 && reduced < s_ii {
376 gc[(i * d + j) * n_freqs + fi] = (s_ii / reduced).ln().max(0.0);
377 }
378 }
379 }
380 }
381 }
382 (gc, d)
383}
384
385pub fn partial_directed_coherence(
388 trains: &[&[i32]],
389 bin_size: usize,
390 order: usize,
391 n_freqs: usize,
392) -> (Vec<f64>, usize) {
393 let binned: Vec<Vec<f64>> = trains
394 .iter()
395 .map(|t| {
396 bin_spike_train(t, bin_size)
397 .iter()
398 .map(|&v| v as f64)
399 .collect()
400 })
401 .collect();
402 let d = binned.len();
403 let (beta, _) = var_coefficients(&binned, order);
404
405 let mut pdc = vec![0.0_f64; d * d * n_freqs];
406
407 for fi in 0..n_freqs {
408 let f = fi as f64 / (2 * n_freqs) as f64;
409
410 let a_f = spectral_matrix(&beta, d, order, f);
411
412 for j in 0..d {
413 let norm: f64 = (0..d).map(|i| a_f[(i, j)].norm_sqr()).sum::<f64>().sqrt();
414 if norm > 0.0 {
415 for i in 0..d {
416 pdc[(i * d + j) * n_freqs + fi] = a_f[(i, j)].norm() / norm;
417 }
418 }
419 }
420 }
421 (pdc, d)
422}
423
424pub fn directed_transfer_function(
427 trains: &[&[i32]],
428 bin_size: usize,
429 order: usize,
430 n_freqs: usize,
431) -> (Vec<f64>, usize) {
432 let binned: Vec<Vec<f64>> = trains
433 .iter()
434 .map(|t| {
435 bin_spike_train(t, bin_size)
436 .iter()
437 .map(|&v| v as f64)
438 .collect()
439 })
440 .collect();
441 let d = binned.len();
442 let (beta, _sigma) = var_coefficients(&binned, order);
443
444 let mut dtf = vec![0.0_f64; d * d * n_freqs];
445
446 for fi in 0..n_freqs {
447 let f = fi as f64 / (2 * n_freqs) as f64;
448
449 let a_f = spectral_matrix(&beta, d, order, f);
450 let h = match spectral_transfer_inverse(a_f) {
451 Some(inv) => inv,
452 None => continue,
453 };
454
455 for i in 0..d {
456 let norm: f64 = (0..d).map(|j| h[(i, j)].norm_sqr()).sum::<f64>().sqrt();
457 if norm > 0.0 {
458 for j in 0..d {
459 dtf[(i * d + j) * n_freqs + fi] = h[(i, j)].norm() / norm;
460 }
461 }
462 }
463 }
464 (dtf, d)
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
472 let mut t = vec![0i32; len];
473 for &s in spikes {
474 t[s] = 1;
475 }
476 t
477 }
478
479 #[test]
482 fn test_solve_spd_identity() {
483 let a = vec![1.0, 0.0, 0.0, 1.0];
485 let b = vec![3.0, 7.0];
486 let x = solve_spd(&a, &b, 2, 1);
487 assert!((x[0] - 3.0).abs() < 1e-10);
488 assert!((x[1] - 7.0).abs() < 1e-10);
489 }
490
491 #[test]
492 fn test_solve_spd_2x2() {
493 let a = vec![2.0, 1.0, 1.0, 3.0];
495 let b = vec![5.0, 10.0];
496 let x = solve_spd(&a, &b, 2, 1);
497 assert!((x[0] - 1.0).abs() < 1e-10);
498 assert!((x[1] - 3.0).abs() < 1e-10);
499 }
500
501 #[test]
502 fn test_solve_spd_multi_rhs() {
503 let a = vec![2.0, 0.0, 0.0, 4.0];
505 let b = vec![2.0, 4.0, 4.0, 8.0];
506 let x = solve_spd(&a, &b, 2, 2);
507 assert!((x[0] - 1.0).abs() < 1e-10); assert!((x[1] - 2.0).abs() < 1e-10); assert!((x[2] - 1.0).abs() < 1e-10); assert!((x[3] - 2.0).abs() < 1e-10); }
512
513 #[test]
514 fn test_solve_spd_non_pd_falls_back_to_zero() {
515 let a = vec![0.0, 1.0, 1.0, 0.0];
517 let b = vec![1.0, 1.0];
518 let x = solve_spd(&a, &b, 2, 1);
519 assert_eq!(x, vec![0.0, 0.0]);
520 }
521
522 #[test]
523 fn test_spectral_matrix_dc_zero_beta() {
524 let beta = vec![0.0_f64; 2 * 2 * 2]; let a = spectral_matrix(&beta, 2, 2, 0.0);
527 assert!((a[(0, 0)].re - 1.0).abs() < 1e-12);
528 assert!((a[(1, 1)].re - 1.0).abs() < 1e-12);
529 assert!(a[(0, 1)].norm() < 1e-12);
530 assert!(a[(1, 0)].norm() < 1e-12);
531 }
532
533 #[test]
534 fn test_spectral_transfer_inverse_identity() {
535 let a = DMatrix::<Complex<f64>>::identity(2, 2);
536 let inv = spectral_transfer_inverse(a).unwrap();
537 assert!((inv[(0, 0)].re - 1.0).abs() < 1e-10);
538 assert!((inv[(1, 1)].re - 1.0).abs() < 1e-10);
539 assert!(inv[(0, 1)].norm() < 1e-10);
540 assert!(inv[(1, 0)].norm() < 1e-10);
541 }
542
543 #[test]
544 fn test_spectral_transfer_inverse_roundtrip() {
545 let a = DMatrix::from_row_slice(
546 2,
547 2,
548 &[
549 Complex::new(2.0, 1.0),
550 Complex::new(1.0, 0.0),
551 Complex::new(0.0, 1.0),
552 Complex::new(3.0, 0.0),
553 ],
554 );
555 let inv = spectral_transfer_inverse(a.clone()).unwrap();
556 let prod = &a * &inv; assert!((prod[(0, 0)].re - 1.0).abs() < 1e-8);
558 assert!((prod[(1, 1)].re - 1.0).abs() < 1e-8);
559 assert!(prod[(0, 1)].norm() < 1e-8);
560 assert!(prod[(1, 0)].norm() < 1e-8);
561 }
562
563 #[test]
564 fn test_spectral_transfer_inverse_singular() {
565 let a = DMatrix::<Complex<f64>>::zeros(2, 2);
567 assert!(spectral_transfer_inverse(a).is_none());
568 }
569
570 #[test]
573 fn test_gc_self_finite() {
574 let train = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
575 let gc = pairwise_granger_causality(&train, &train, 5, 3);
576 assert!(gc.is_finite(), "self GC should be finite, got {gc}");
578 assert!(gc >= 0.0, "GC should be non-negative, got {gc}");
579 }
580
581 #[test]
582 fn test_gc_non_negative_typical() {
583 let source = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
584 let target = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
585 let gc = pairwise_granger_causality(&source, &target, 5, 3);
586 assert!(gc.is_finite(), "GC should be finite, got {gc}");
588 }
589
590 #[test]
591 fn test_gc_too_short() {
592 let a = make_train(&[1], 10);
593 let b = make_train(&[2], 10);
594 let gc = pairwise_granger_causality(&a, &b, 5, 5);
595 assert_eq!(gc, 0.0, "too short → 0");
596 }
597
598 #[test]
601 fn test_cond_gc_finite() {
602 let source = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
603 let target = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
604 let cond = make_train(&[3, 13, 23, 33, 43, 53, 63, 73, 83, 93], 100);
605 let gc = conditional_granger_causality(&source, &target, &cond, 5, 3);
606 assert!(gc.is_finite(), "conditional GC should be finite");
607 }
608
609 #[test]
610 fn test_cond_gc_too_short() {
611 let a = make_train(&[1], 10);
612 let b = make_train(&[2], 10);
613 let c = make_train(&[3], 10);
614 assert_eq!(conditional_granger_causality(&a, &b, &c, 5, 5), 0.0);
615 }
616
617 #[test]
620 fn test_spectral_gc_shape() {
621 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
622 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
623 let trains: Vec<&[i32]> = vec![&t1, &t2];
624 let (gc, d) = spectral_granger_causality(&trains, 5, 3, 16);
625 assert_eq!(d, 2);
626 assert_eq!(gc.len(), 2 * 2 * 16);
627 }
628
629 #[test]
630 fn test_spectral_gc_diagonal_zero() {
631 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
632 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
633 let trains: Vec<&[i32]> = vec![&t1, &t2];
634 let (gc, _) = spectral_granger_causality(&trains, 5, 3, 16);
635 for fi in 0..16 {
637 assert_eq!(gc[fi], 0.0, "GC[0,0] should be 0");
638 assert_eq!(gc[3 * 16 + fi], 0.0, "GC[1,1] should be 0");
639 }
640 }
641
642 #[test]
643 fn test_spectral_gc_non_negative() {
644 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
645 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
646 let trains: Vec<&[i32]> = vec![&t1, &t2];
647 let (gc, _) = spectral_granger_causality(&trains, 5, 3, 16);
648 for &v in &gc {
649 assert!(v >= 0.0, "spectral GC must be non-negative, got {v}");
650 }
651 }
652
653 #[test]
656 fn test_pdc_shape() {
657 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
658 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
659 let trains: Vec<&[i32]> = vec![&t1, &t2];
660 let (pdc, d) = partial_directed_coherence(&trains, 5, 3, 16);
661 assert_eq!(d, 2);
662 assert_eq!(pdc.len(), 2 * 2 * 16);
663 }
664
665 #[test]
666 fn test_pdc_range() {
667 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
668 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
669 let trains: Vec<&[i32]> = vec![&t1, &t2];
670 let (pdc, _) = partial_directed_coherence(&trains, 5, 3, 16);
671 for &v in &pdc {
672 assert!(
673 (0.0..=1.0 + 1e-10).contains(&v),
674 "PDC should be in [0,1], got {v}"
675 );
676 }
677 }
678
679 #[test]
682 fn test_dtf_shape() {
683 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
684 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
685 let trains: Vec<&[i32]> = vec![&t1, &t2];
686 let (dtf, d) = directed_transfer_function(&trains, 5, 3, 16);
687 assert_eq!(d, 2);
688 assert_eq!(dtf.len(), 2 * 2 * 16);
689 }
690
691 #[test]
692 fn test_dtf_range() {
693 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
694 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
695 let trains: Vec<&[i32]> = vec![&t1, &t2];
696 let (dtf, _) = directed_transfer_function(&trains, 5, 3, 16);
697 for &v in &dtf {
698 assert!(
699 (0.0..=1.0 + 1e-10).contains(&v),
700 "DTF should be in [0,1], got {v}"
701 );
702 }
703 }
704
705 #[test]
708 fn test_var_too_short() {
709 let trains = vec![vec![1.0, 2.0]];
710 let (beta, sigma) = var_coefficients(&trains, 5);
711 assert!(beta.iter().all(|&v| v == 0.0), "too short → zero beta");
712 assert!((sigma[0] - 1.0).abs() < 1e-10, "identity sigma");
713 }
714}