1use rayon::prelude::*;
10use std::f64::consts::PI;
11
12use super::basic::bin_spike_train;
13
14fn solve_linear(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
19 let mut aug = vec![0.0_f64; n * (n + 1)];
20 for i in 0..n {
21 for j in 0..n {
22 aug[i * (n + 1) + j] = a[i * n + j];
23 }
24 aug[i * (n + 1) + n] = b[i];
25 }
26 let stride = n + 1;
27
28 for col in 0..n {
29 let mut max_row = col;
31 let mut max_val = aug[col * stride + col].abs();
32 for row in (col + 1)..n {
33 let v = aug[row * stride + col].abs();
34 if v > max_val {
35 max_val = v;
36 max_row = row;
37 }
38 }
39 if max_row != col {
40 for j in 0..stride {
41 aug.swap(col * stride + j, max_row * stride + j);
42 }
43 }
44 let pivot = aug[col * stride + col];
45 if pivot.abs() < 1e-30 {
46 continue;
47 }
48 for row in (col + 1)..n {
49 let factor = aug[row * stride + col] / pivot;
50 let mut j = col;
51 let r_off = row * stride;
52 let c_off = col * stride;
53 while j + 3 < stride {
54 aug[r_off + j] -= factor * aug[c_off + j];
55 aug[r_off + j + 1] -= factor * aug[c_off + j + 1];
56 aug[r_off + j + 2] -= factor * aug[c_off + j + 2];
57 aug[r_off + j + 3] -= factor * aug[c_off + j + 3];
58 j += 4;
59 }
60 while j < stride {
61 aug[r_off + j] -= factor * aug[c_off + j];
62 j += 1;
63 }
64 }
65 }
66
67 let mut x = vec![0.0_f64; n];
69 for i in (0..n).rev() {
70 let mut sum = aug[i * stride + n];
71 for j in (i + 1)..n {
72 sum -= aug[i * stride + j] * x[j];
73 }
74 let diag = aug[i * stride + i];
75 x[i] = if diag.abs() > 1e-30 { sum / diag } else { 0.0 };
76 }
77 x
78}
79
80fn solve_matrix(a: &[f64], b: &[f64], n: usize, m: usize) -> Vec<f64> {
82 let result = vec![0.0_f64; n * m];
83 (0..m).into_par_iter().for_each(|col| {
84 let rhs: Vec<f64> = (0..n).map(|i| b[i * m + col]).collect();
85 let x = solve_linear(a, &rhs, n);
86 unsafe {
88 let ptr = result.as_ptr() as *mut f64;
89 for i in 0..n {
90 *ptr.add(i * m + col) = x[i];
91 }
92 }
93 });
94 result
95}
96
97#[derive(Clone, Copy)]
100struct C64 {
101 re: f64,
102 im: f64,
103}
104
105impl C64 {
106 fn new(re: f64, im: f64) -> Self {
107 Self { re, im }
108 }
109 fn zero() -> Self {
110 Self { re: 0.0, im: 0.0 }
111 }
112 fn one() -> Self {
113 Self { re: 1.0, im: 0.0 }
114 }
115 fn norm_sq(self) -> f64 {
116 self.re * self.re + self.im * self.im
117 }
118 fn abs(self) -> f64 {
119 self.norm_sq().sqrt()
120 }
121 fn conj(self) -> Self {
122 Self {
123 re: self.re,
124 im: -self.im,
125 }
126 }
127}
128
129impl std::ops::Add for C64 {
130 type Output = Self;
131 fn add(self, rhs: Self) -> Self {
132 Self {
133 re: self.re + rhs.re,
134 im: self.im + rhs.im,
135 }
136 }
137}
138
139impl std::ops::Sub for C64 {
140 type Output = Self;
141 fn sub(self, rhs: Self) -> Self {
142 Self {
143 re: self.re - rhs.re,
144 im: self.im - rhs.im,
145 }
146 }
147}
148
149impl std::ops::Mul for C64 {
150 type Output = Self;
151 fn mul(self, rhs: Self) -> Self {
152 Self {
153 re: self.re * rhs.re - self.im * rhs.im,
154 im: self.re * rhs.im + self.im * rhs.re,
155 }
156 }
157}
158
159impl std::ops::Mul<f64> for C64 {
160 type Output = Self;
161 fn mul(self, rhs: f64) -> Self {
162 Self {
163 re: self.re * rhs,
164 im: self.im * rhs,
165 }
166 }
167}
168
169impl std::ops::AddAssign for C64 {
170 fn add_assign(&mut self, rhs: Self) {
171 self.re += rhs.re;
172 self.im += rhs.im;
173 }
174}
175
176impl std::ops::SubAssign for C64 {
177 fn sub_assign(&mut self, rhs: Self) {
178 self.re -= rhs.re;
179 self.im -= rhs.im;
180 }
181}
182
183fn cmat_mul(a: &[C64], b: &[C64], d: usize) -> Vec<C64> {
185 let mut c = vec![C64::zero(); d * d];
186 for i in 0..d {
187 for j in 0..d {
188 let mut s = C64::zero();
189 for k in 0..d {
190 s += a[i * d + k] * b[k * d + j];
191 }
192 c[i * d + j] = s;
193 }
194 }
195 c
196}
197
198fn cmat_inv(a: &[C64], d: usize) -> Option<Vec<C64>> {
200 let mut aug = vec![C64::zero(); d * 2 * d];
201 for i in 0..d {
202 for j in 0..d {
203 aug[i * 2 * d + j] = a[i * d + j];
204 }
205 aug[i * 2 * d + d + i] = C64::one();
206 }
207 let w = 2 * d;
208 for col in 0..d {
209 let mut max_row = col;
211 let mut max_val = aug[col * w + col].abs();
212 for row in (col + 1)..d {
213 let v = aug[row * w + col].abs();
214 if v > max_val {
215 max_val = v;
216 max_row = row;
217 }
218 }
219 if max_val < 1e-30 {
220 return None;
221 }
222 if max_row != col {
223 for j in 0..w {
224 aug.swap(col * w + j, max_row * w + j);
225 }
226 }
227 let pivot = aug[col * w + col];
228 let inv_pivot = pivot.conj() * (1.0 / pivot.norm_sq());
229 for j in 0..w {
230 aug[col * w + j] = aug[col * w + j] * inv_pivot;
231 }
232 for row in 0..d {
233 if row == col {
234 continue;
235 }
236 let factor = aug[row * w + col];
237 for j in 0..w {
238 let sub = factor * aug[col * w + j];
239 aug[row * w + j] -= sub;
240 }
241 }
242 }
243 let mut result = vec![C64::zero(); d * d];
244 for i in 0..d {
245 for j in 0..d {
246 result[i * d + j] = aug[i * w + d + j];
247 }
248 }
249 Some(result)
250}
251
252fn cmat_det(a: &[C64], d: usize) -> C64 {
254 if d == 1 {
255 return a[0];
256 }
257 if d == 2 {
258 return a[0] * a[3] - a[1] * a[2];
259 }
260 let mut m = a.to_vec();
262 let mut det = C64::one();
263 for col in 0..d {
264 let mut max_row = col;
265 let mut max_val = m[col * d + col].abs();
266 for row in (col + 1)..d {
267 let v = m[row * d + col].abs();
268 if v > max_val {
269 max_val = v;
270 max_row = row;
271 }
272 }
273 if max_val < 1e-30 {
274 return C64::zero();
275 }
276 if max_row != col {
277 for j in 0..d {
278 m.swap(col * d + j, max_row * d + j);
279 }
280 det = det * (-1.0);
281 }
282 det = det * m[col * d + col];
283 let pivot = m[col * d + col];
284 let inv_pivot = pivot.conj() * (1.0 / pivot.norm_sq());
285 for row in (col + 1)..d {
286 let factor = m[row * d + col] * inv_pivot;
287 for j in col..d {
288 let sub = factor * m[col * d + j];
289 m[row * d + j] -= sub;
290 }
291 }
292 }
293 det
294}
295
296fn cmat_conj_t(a: &[C64], d: usize) -> Vec<C64> {
298 let mut r = vec![C64::zero(); d * d];
299 for i in 0..d {
300 for j in 0..d {
301 r[j * d + i] = a[i * d + j].conj();
302 }
303 }
304 r
305}
306
307fn var_coefficients(trains_binned: &[Vec<f64>], order: usize) -> (Vec<f64>, Vec<f64>) {
311 let d = trains_binned.len();
312 let t = if d > 0 { trains_binned[0].len() } else { 0 };
313 if t <= order + 1 || d == 0 {
314 return (vec![0.0; order * d * d], identity_flat(d));
315 }
316 let n_pts = t - order;
317 let x_cols = order * d;
318
319 let mut y_cols = vec![vec![0.0_f64; n_pts]; d];
321 for ch in 0..d {
322 for i in 0..n_pts {
323 y_cols[ch][i] = trains_binned[ch][order + i];
324 }
325 }
326
327 let mut x_cols_data = vec![vec![0.0_f64; n_pts]; x_cols];
329 for i in 0..n_pts {
330 for k in 0..order {
331 for ch in 0..d {
332 x_cols_data[k * d + ch][i] = trains_binned[ch][order - k - 1 + i];
333 }
334 }
335 }
336
337 let mut xtx = vec![0.0_f64; x_cols * x_cols];
339 xtx.par_chunks_exact_mut(x_cols)
340 .enumerate()
341 .for_each(|(i, row)| {
342 for j in 0..=i {
343 let dot = crate::simd::dot_f64_dispatch(&x_cols_data[i], &x_cols_data[j]);
344 row[j] = dot + if i == j { 1e-8 } else { 0.0 };
345 }
346 });
347 for i in 0..x_cols {
349 for j in (i + 1)..x_cols {
350 xtx[i * x_cols + j] = xtx[j * x_cols + i];
351 }
352 }
353
354 let mut xty = vec![0.0_f64; x_cols * d];
356 xty.par_chunks_exact_mut(d)
357 .enumerate()
358 .for_each(|(i, row)| {
359 for j in 0..d {
360 row[j] = crate::simd::dot_f64_dispatch(&x_cols_data[i], &y_cols[j]);
361 }
362 });
363
364 let beta = solve_matrix(&xtx, &xty, x_cols, d);
366
367 let mut sigma = vec![0.0_f64; d * d];
369 let n_norm = n_pts.max(1) as f64;
370
371 let res_cols: Vec<Vec<f64>> = (0..d)
373 .into_par_iter()
374 .map(|j| {
375 let mut res = vec![0.0_f64; n_pts];
376 for p in 0..n_pts {
377 let mut r = y_cols[j][p];
378 for c in 0..x_cols {
379 r -= x_cols_data[c][p] * beta[c * d + j];
380 }
381 res[p] = r;
382 }
383 res
384 })
385 .collect();
386
387 for i in 0..d {
388 for j in 0..=i {
389 let dot = crate::simd::dot_f64_dispatch(&res_cols[i], &res_cols[j]);
390 let val = dot / n_norm;
391 sigma[i * d + j] = val;
392 sigma[j * d + i] = val;
393 }
394 }
395
396 (beta, sigma)
397}
398
399fn identity_flat(d: usize) -> Vec<f64> {
400 let mut m = vec![0.0_f64; d * d];
401 for i in 0..d {
402 m[i * d + i] = 1.0;
403 }
404 m
405}
406
407fn sse_ols(x: &[f64], y: &[f64], n_pts: usize, x_cols: usize) -> f64 {
409 let mut xtx = vec![0.0_f64; x_cols * x_cols];
411 for i in 0..x_cols {
412 for j in 0..x_cols {
413 let mut s = 0.0;
414 for p in 0..n_pts {
415 s += x[p * x_cols + i] * x[p * x_cols + j];
416 }
417 xtx[i * x_cols + j] = s + if i == j { 1e-8 } else { 0.0 };
418 }
419 }
420 let mut xty = vec![0.0_f64; x_cols];
422 for i in 0..x_cols {
423 let mut s = 0.0;
424 for p in 0..n_pts {
425 s += x[p * x_cols + i] * y[p];
426 }
427 xty[i] = s;
428 }
429 let beta = solve_linear(&xtx, &xty, x_cols);
430 let mut sse = 0.0_f64;
431 for p in 0..n_pts {
432 let mut pred = 0.0;
433 for c in 0..x_cols {
434 pred += x[p * x_cols + c] * beta[c];
435 }
436 let r = y[p] - pred;
437 sse += r * r;
438 }
439 sse
440}
441
442pub fn pairwise_granger_causality(
447 source: &[i32],
448 target: &[i32],
449 bin_size: usize,
450 order: usize,
451) -> f64 {
452 let cs: Vec<f64> = bin_spike_train(source, bin_size)
453 .iter()
454 .map(|&v| v as f64)
455 .collect();
456 let ct: Vec<f64> = bin_spike_train(target, bin_size)
457 .iter()
458 .map(|&v| v as f64)
459 .collect();
460 let n = cs.len().min(ct.len());
461 if n <= 2 * order {
462 return 0.0;
463 }
464
465 let n_pts = n - order;
466 let y: Vec<f64> = ct[order..n].to_vec();
467
468 let r_cols = order;
470 let mut x_r = vec![0.0_f64; n_pts * r_cols];
471 for p in 0..n_pts {
472 for k in 0..order {
473 x_r[p * r_cols + k] = ct[order - k - 1 + p];
474 }
475 }
476 let sse_r = sse_ols(&x_r, &y, n_pts, r_cols);
477
478 let f_cols = 2 * order;
480 let mut x_f = vec![0.0_f64; n_pts * f_cols];
481 for p in 0..n_pts {
482 for k in 0..order {
483 x_f[p * f_cols + k] = ct[order - k - 1 + p];
484 x_f[p * f_cols + order + k] = cs[order - k - 1 + p];
485 }
486 }
487 let sse_f = sse_ols(&x_f, &y, n_pts, f_cols);
488
489 if sse_f <= 0.0 {
490 return 0.0;
491 }
492 (sse_r.max(1e-30) / sse_f.max(1e-30)).ln()
493}
494
495pub fn conditional_granger_causality(
498 source: &[i32],
499 target: &[i32],
500 condition: &[i32],
501 bin_size: usize,
502 order: usize,
503) -> f64 {
504 let cs: Vec<f64> = bin_spike_train(source, bin_size)
505 .iter()
506 .map(|&v| v as f64)
507 .collect();
508 let ct: Vec<f64> = bin_spike_train(target, bin_size)
509 .iter()
510 .map(|&v| v as f64)
511 .collect();
512 let cc: Vec<f64> = bin_spike_train(condition, bin_size)
513 .iter()
514 .map(|&v| v as f64)
515 .collect();
516 let n = cs.len().min(ct.len()).min(cc.len());
517 if n <= 2 * order {
518 return 0.0;
519 }
520
521 let n_pts = n - order;
522 let y: Vec<f64> = ct[order..n].to_vec();
523
524 let c_cols = 2 * order;
526 let mut x_c = vec![0.0_f64; n_pts * c_cols];
527 for p in 0..n_pts {
528 for k in 0..order {
529 x_c[p * c_cols + k] = ct[order - k - 1 + p];
530 x_c[p * c_cols + order + k] = cc[order - k - 1 + p];
531 }
532 }
533 let sse_c = sse_ols(&x_c, &y, n_pts, c_cols);
534
535 let f_cols = 3 * order;
537 let mut x_f = vec![0.0_f64; n_pts * f_cols];
538 for p in 0..n_pts {
539 for k in 0..order {
540 x_f[p * f_cols + k] = ct[order - k - 1 + p];
541 x_f[p * f_cols + order + k] = cc[order - k - 1 + p];
542 x_f[p * f_cols + 2 * order + k] = cs[order - k - 1 + p];
543 }
544 }
545 let sse_f = sse_ols(&x_f, &y, n_pts, f_cols);
546
547 if sse_f <= 0.0 {
548 return 0.0;
549 }
550 (sse_c.max(1e-30) / sse_f.max(1e-30)).ln()
551}
552
553pub fn spectral_granger_causality(
556 trains: &[&[i32]],
557 bin_size: usize,
558 order: usize,
559 n_freqs: usize,
560) -> (Vec<f64>, usize) {
561 let binned: Vec<Vec<f64>> = trains
562 .iter()
563 .map(|t| {
564 bin_spike_train(t, bin_size)
565 .iter()
566 .map(|&v| v as f64)
567 .collect()
568 })
569 .collect();
570 let d = binned.len();
571 let (beta, sigma) = var_coefficients(&binned, order);
572
573 let mut gc = vec![0.0_f64; d * d * n_freqs];
574
575 for fi in 0..n_freqs {
576 let f = fi as f64 / (2 * n_freqs) as f64; let mut a_f = vec![C64::zero(); d * d];
580 for i in 0..d {
581 a_f[i * d + i] = C64::one();
582 }
583 for k in 0..order {
584 let angle = -2.0 * PI * f * (k + 1) as f64;
585 let exp_val = C64::new(angle.cos(), angle.sin());
586 for i in 0..d {
587 for j in 0..d {
588 let coeff = beta[(k * d + j) * d + i]; a_f[i * d + j] -= C64::new(coeff, 0.0) * exp_val;
591 }
592 }
593 }
594
595 let det = cmat_det(&a_f, d);
596 if det.abs() < 1e-30 {
597 continue;
598 }
599 let h = match cmat_inv(&a_f, d) {
600 Some(inv) => inv,
601 None => continue,
602 };
603
604 let sigma_c: Vec<C64> = sigma.iter().map(|&v| C64::new(v, 0.0)).collect();
606 let h_conj_t = cmat_conj_t(&h, d);
607 let tmp = cmat_mul(&h, &sigma_c, d);
608 let s = cmat_mul(&tmp, &h_conj_t, d);
609
610 for i in 0..d {
611 for j in 0..d {
612 if i == j {
613 continue;
614 }
615 let s_ii = s[i * d + i].abs();
616 if s_ii > 1e-30 {
617 let h_ij_sq = h[i * d + j].norm_sq();
618 let reduced = s_ii - sigma[j * d + j] * h_ij_sq;
619 if reduced > 0.0 && reduced < s_ii {
620 gc[(i * d + j) * n_freqs + fi] = (s_ii / reduced).ln().max(0.0);
621 }
622 }
623 }
624 }
625 }
626 (gc, d)
627}
628
629pub fn partial_directed_coherence(
632 trains: &[&[i32]],
633 bin_size: usize,
634 order: usize,
635 n_freqs: usize,
636) -> (Vec<f64>, usize) {
637 let binned: Vec<Vec<f64>> = trains
638 .iter()
639 .map(|t| {
640 bin_spike_train(t, bin_size)
641 .iter()
642 .map(|&v| v as f64)
643 .collect()
644 })
645 .collect();
646 let d = binned.len();
647 let (beta, _) = var_coefficients(&binned, order);
648
649 let mut pdc = vec![0.0_f64; d * d * n_freqs];
650
651 for fi in 0..n_freqs {
652 let f = fi as f64 / (2 * n_freqs) as f64;
653
654 let mut a_f = vec![C64::zero(); d * d];
655 for i in 0..d {
656 a_f[i * d + i] = C64::one();
657 }
658 for k in 0..order {
659 let angle = -2.0 * PI * f * (k + 1) as f64;
660 let exp_val = C64::new(angle.cos(), angle.sin());
661 for i in 0..d {
662 for j in 0..d {
663 let coeff = beta[(k * d + j) * d + i];
664 a_f[i * d + j] -= C64::new(coeff, 0.0) * exp_val;
665 }
666 }
667 }
668
669 for j in 0..d {
670 let norm: f64 = (0..d).map(|i| a_f[i * d + j].norm_sq()).sum::<f64>().sqrt();
671 if norm > 0.0 {
672 for i in 0..d {
673 pdc[(i * d + j) * n_freqs + fi] = a_f[i * d + j].abs() / norm;
674 }
675 }
676 }
677 }
678 (pdc, d)
679}
680
681pub fn directed_transfer_function(
684 trains: &[&[i32]],
685 bin_size: usize,
686 order: usize,
687 n_freqs: usize,
688) -> (Vec<f64>, usize) {
689 let binned: Vec<Vec<f64>> = trains
690 .iter()
691 .map(|t| {
692 bin_spike_train(t, bin_size)
693 .iter()
694 .map(|&v| v as f64)
695 .collect()
696 })
697 .collect();
698 let d = binned.len();
699 let (beta, _sigma) = var_coefficients(&binned, order);
700
701 let mut dtf = vec![0.0_f64; d * d * n_freqs];
702
703 for fi in 0..n_freqs {
704 let f = fi as f64 / (2 * n_freqs) as f64;
705
706 let mut a_f = vec![C64::zero(); d * d];
707 for i in 0..d {
708 a_f[i * d + i] = C64::one();
709 }
710 for k in 0..order {
711 let angle = -2.0 * PI * f * (k + 1) as f64;
712 let exp_val = C64::new(angle.cos(), angle.sin());
713 for i in 0..d {
714 for j in 0..d {
715 let coeff = beta[(k * d + j) * d + i];
716 a_f[i * d + j] -= C64::new(coeff, 0.0) * exp_val;
717 }
718 }
719 }
720
721 let det = cmat_det(&a_f, d);
722 if det.abs() < 1e-30 {
723 continue;
724 }
725 let h = match cmat_inv(&a_f, d) {
726 Some(inv) => inv,
727 None => continue,
728 };
729
730 for i in 0..d {
731 let norm: f64 = (0..d).map(|j| h[i * d + j].norm_sq()).sum::<f64>().sqrt();
732 if norm > 0.0 {
733 for j in 0..d {
734 dtf[(i * d + j) * n_freqs + fi] = h[i * d + j].abs() / norm;
735 }
736 }
737 }
738 }
739 (dtf, d)
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745
746 fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
747 let mut t = vec![0i32; len];
748 for &s in spikes {
749 t[s] = 1;
750 }
751 t
752 }
753
754 #[test]
757 fn test_solve_linear_identity() {
758 let a = vec![1.0, 0.0, 0.0, 1.0];
760 let b = vec![3.0, 7.0];
761 let x = solve_linear(&a, &b, 2);
762 assert!((x[0] - 3.0).abs() < 1e-10);
763 assert!((x[1] - 7.0).abs() < 1e-10);
764 }
765
766 #[test]
767 fn test_solve_linear_2x2() {
768 let a = vec![2.0, 1.0, 1.0, 3.0];
770 let b = vec![5.0, 10.0];
771 let x = solve_linear(&a, &b, 2);
772 assert!((x[0] - 1.0).abs() < 1e-10);
773 assert!((x[1] - 3.0).abs() < 1e-10);
774 }
775
776 #[test]
777 fn test_cmat_det_2x2() {
778 let a = vec![
779 C64::new(1.0, 0.0),
780 C64::new(2.0, 0.0),
781 C64::new(3.0, 0.0),
782 C64::new(4.0, 0.0),
783 ];
784 let det = cmat_det(&a, 2);
785 assert!((det.re - (-2.0)).abs() < 1e-10);
786 assert!(det.im.abs() < 1e-10);
787 }
788
789 #[test]
790 fn test_cmat_inv_identity() {
791 let a = vec![C64::one(), C64::zero(), C64::zero(), C64::one()];
792 let inv = cmat_inv(&a, 2).unwrap();
793 assert!((inv[0].re - 1.0).abs() < 1e-10);
794 assert!((inv[3].re - 1.0).abs() < 1e-10);
795 assert!(inv[1].abs() < 1e-10);
796 assert!(inv[2].abs() < 1e-10);
797 }
798
799 #[test]
800 fn test_cmat_inv_roundtrip() {
801 let a = vec![
802 C64::new(2.0, 1.0),
803 C64::new(1.0, 0.0),
804 C64::new(0.0, 1.0),
805 C64::new(3.0, 0.0),
806 ];
807 let inv = cmat_inv(&a, 2).unwrap();
808 let prod = cmat_mul(&a, &inv, 2);
809 assert!((prod[0].re - 1.0).abs() < 1e-8);
811 assert!((prod[3].re - 1.0).abs() < 1e-8);
812 assert!(prod[1].abs() < 1e-8);
813 assert!(prod[2].abs() < 1e-8);
814 }
815
816 #[test]
819 fn test_gc_self_finite() {
820 let train = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
821 let gc = pairwise_granger_causality(&train, &train, 5, 3);
822 assert!(gc.is_finite(), "self GC should be finite, got {gc}");
824 assert!(gc >= 0.0, "GC should be non-negative, got {gc}");
825 }
826
827 #[test]
828 fn test_gc_non_negative_typical() {
829 let source = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
830 let target = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
831 let gc = pairwise_granger_causality(&source, &target, 5, 3);
832 assert!(gc.is_finite(), "GC should be finite, got {gc}");
834 }
835
836 #[test]
837 fn test_gc_too_short() {
838 let a = make_train(&[1], 10);
839 let b = make_train(&[2], 10);
840 let gc = pairwise_granger_causality(&a, &b, 5, 5);
841 assert_eq!(gc, 0.0, "too short → 0");
842 }
843
844 #[test]
847 fn test_cond_gc_finite() {
848 let source = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
849 let target = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
850 let cond = make_train(&[3, 13, 23, 33, 43, 53, 63, 73, 83, 93], 100);
851 let gc = conditional_granger_causality(&source, &target, &cond, 5, 3);
852 assert!(gc.is_finite(), "conditional GC should be finite");
853 }
854
855 #[test]
856 fn test_cond_gc_too_short() {
857 let a = make_train(&[1], 10);
858 let b = make_train(&[2], 10);
859 let c = make_train(&[3], 10);
860 assert_eq!(conditional_granger_causality(&a, &b, &c, 5, 5), 0.0);
861 }
862
863 #[test]
866 fn test_spectral_gc_shape() {
867 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
868 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
869 let trains: Vec<&[i32]> = vec![&t1, &t2];
870 let (gc, d) = spectral_granger_causality(&trains, 5, 3, 16);
871 assert_eq!(d, 2);
872 assert_eq!(gc.len(), 2 * 2 * 16);
873 }
874
875 #[test]
876 fn test_spectral_gc_diagonal_zero() {
877 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
878 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
879 let trains: Vec<&[i32]> = vec![&t1, &t2];
880 let (gc, _) = spectral_granger_causality(&trains, 5, 3, 16);
881 for fi in 0..16 {
883 assert_eq!(gc[fi], 0.0, "GC[0,0] should be 0");
884 assert_eq!(gc[3 * 16 + fi], 0.0, "GC[1,1] should be 0");
885 }
886 }
887
888 #[test]
889 fn test_spectral_gc_non_negative() {
890 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
891 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
892 let trains: Vec<&[i32]> = vec![&t1, &t2];
893 let (gc, _) = spectral_granger_causality(&trains, 5, 3, 16);
894 for &v in &gc {
895 assert!(v >= 0.0, "spectral GC must be non-negative, got {v}");
896 }
897 }
898
899 #[test]
902 fn test_pdc_shape() {
903 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
904 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
905 let trains: Vec<&[i32]> = vec![&t1, &t2];
906 let (pdc, d) = partial_directed_coherence(&trains, 5, 3, 16);
907 assert_eq!(d, 2);
908 assert_eq!(pdc.len(), 2 * 2 * 16);
909 }
910
911 #[test]
912 fn test_pdc_range() {
913 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
914 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
915 let trains: Vec<&[i32]> = vec![&t1, &t2];
916 let (pdc, _) = partial_directed_coherence(&trains, 5, 3, 16);
917 for &v in &pdc {
918 assert!(
919 (0.0..=1.0 + 1e-10).contains(&v),
920 "PDC should be in [0,1], got {v}"
921 );
922 }
923 }
924
925 #[test]
928 fn test_dtf_shape() {
929 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
930 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
931 let trains: Vec<&[i32]> = vec![&t1, &t2];
932 let (dtf, d) = directed_transfer_function(&trains, 5, 3, 16);
933 assert_eq!(d, 2);
934 assert_eq!(dtf.len(), 2 * 2 * 16);
935 }
936
937 #[test]
938 fn test_dtf_range() {
939 let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
940 let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
941 let trains: Vec<&[i32]> = vec![&t1, &t2];
942 let (dtf, _) = directed_transfer_function(&trains, 5, 3, 16);
943 for &v in &dtf {
944 assert!(
945 (0.0..=1.0 + 1e-10).contains(&v),
946 "DTF should be in [0,1], got {v}"
947 );
948 }
949 }
950
951 #[test]
954 fn test_var_too_short() {
955 let trains = vec![vec![1.0, 2.0]];
956 let (beta, sigma) = var_coefficients(&trains, 5);
957 assert!(beta.iter().all(|&v| v == 0.0), "too short → zero beta");
958 assert!((sigma[0] - 1.0).abs() < 1e-10, "identity sigma");
959 }
960}