Skip to main content

sc_neurocore_engine/analysis/
causality.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Granger causality and directed connectivity measures
8
9use rayon::prelude::*;
10use std::f64::consts::PI;
11
12use super::basic::bin_spike_train;
13
14// ── Small-matrix linear algebra (real) ──────────────────────────────
15
16/// Solve A x = b via Gaussian elimination with partial pivoting.
17/// A is n×n (row-major flat), b is n-length. Returns x.
18fn solve_linear(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
19    let mut aug = vec![0.0_f64; n * (n + 1)];
20    for i in 0..n {
21        for j in 0..n {
22            aug[i * (n + 1) + j] = a[i * n + j];
23        }
24        aug[i * (n + 1) + n] = b[i];
25    }
26    let stride = n + 1;
27
28    for col in 0..n {
29        // Partial pivoting
30        let mut max_row = col;
31        let mut max_val = aug[col * stride + col].abs();
32        for row in (col + 1)..n {
33            let v = aug[row * stride + col].abs();
34            if v > max_val {
35                max_val = v;
36                max_row = row;
37            }
38        }
39        if max_row != col {
40            for j in 0..stride {
41                aug.swap(col * stride + j, max_row * stride + j);
42            }
43        }
44        let pivot = aug[col * stride + col];
45        if pivot.abs() < 1e-30 {
46            continue;
47        }
48        for row in (col + 1)..n {
49            let factor = aug[row * stride + col] / pivot;
50            let mut j = col;
51            let r_off = row * stride;
52            let c_off = col * stride;
53            while j + 3 < stride {
54                aug[r_off + j] -= factor * aug[c_off + j];
55                aug[r_off + j + 1] -= factor * aug[c_off + j + 1];
56                aug[r_off + j + 2] -= factor * aug[c_off + j + 2];
57                aug[r_off + j + 3] -= factor * aug[c_off + j + 3];
58                j += 4;
59            }
60            while j < stride {
61                aug[r_off + j] -= factor * aug[c_off + j];
62                j += 1;
63            }
64        }
65    }
66
67    // Back substitution
68    let mut x = vec![0.0_f64; n];
69    for i in (0..n).rev() {
70        let mut sum = aug[i * stride + n];
71        for j in (i + 1)..n {
72            sum -= aug[i * stride + j] * x[j];
73        }
74        let diag = aug[i * stride + i];
75        x[i] = if diag.abs() > 1e-30 { sum / diag } else { 0.0 };
76    }
77    x
78}
79
80/// Solve A X = B where A is n×n and B is n×m. Returns X (n×m, row-major).
81fn solve_matrix(a: &[f64], b: &[f64], n: usize, m: usize) -> Vec<f64> {
82    let result = vec![0.0_f64; n * m];
83    (0..m).into_par_iter().for_each(|col| {
84        let rhs: Vec<f64> = (0..n).map(|i| b[i * m + col]).collect();
85        let x = solve_linear(a, &rhs, n);
86        // SAFETY: Each thread writes to unique indices based on col.
87        unsafe {
88            let ptr = result.as_ptr() as *mut f64;
89            for i in 0..n {
90                *ptr.add(i * m + col) = x[i];
91            }
92        }
93    });
94    result
95}
96
97// ── Small-matrix linear algebra (complex) ───────────────────────────
98
99#[derive(Clone, Copy)]
100struct C64 {
101    re: f64,
102    im: f64,
103}
104
105impl C64 {
106    fn new(re: f64, im: f64) -> Self {
107        Self { re, im }
108    }
109    fn zero() -> Self {
110        Self { re: 0.0, im: 0.0 }
111    }
112    fn one() -> Self {
113        Self { re: 1.0, im: 0.0 }
114    }
115    fn norm_sq(self) -> f64 {
116        self.re * self.re + self.im * self.im
117    }
118    fn abs(self) -> f64 {
119        self.norm_sq().sqrt()
120    }
121    fn conj(self) -> Self {
122        Self {
123            re: self.re,
124            im: -self.im,
125        }
126    }
127}
128
129impl std::ops::Add for C64 {
130    type Output = Self;
131    fn add(self, rhs: Self) -> Self {
132        Self {
133            re: self.re + rhs.re,
134            im: self.im + rhs.im,
135        }
136    }
137}
138
139impl std::ops::Sub for C64 {
140    type Output = Self;
141    fn sub(self, rhs: Self) -> Self {
142        Self {
143            re: self.re - rhs.re,
144            im: self.im - rhs.im,
145        }
146    }
147}
148
149impl std::ops::Mul for C64 {
150    type Output = Self;
151    fn mul(self, rhs: Self) -> Self {
152        Self {
153            re: self.re * rhs.re - self.im * rhs.im,
154            im: self.re * rhs.im + self.im * rhs.re,
155        }
156    }
157}
158
159impl std::ops::Mul<f64> for C64 {
160    type Output = Self;
161    fn mul(self, rhs: f64) -> Self {
162        Self {
163            re: self.re * rhs,
164            im: self.im * rhs,
165        }
166    }
167}
168
169impl std::ops::AddAssign for C64 {
170    fn add_assign(&mut self, rhs: Self) {
171        self.re += rhs.re;
172        self.im += rhs.im;
173    }
174}
175
176impl std::ops::SubAssign for C64 {
177    fn sub_assign(&mut self, rhs: Self) {
178        self.re -= rhs.re;
179        self.im -= rhs.im;
180    }
181}
182
183/// Complex matrix multiply: C = A * B, all d×d row-major.
184fn cmat_mul(a: &[C64], b: &[C64], d: usize) -> Vec<C64> {
185    let mut c = vec![C64::zero(); d * d];
186    for i in 0..d {
187        for j in 0..d {
188            let mut s = C64::zero();
189            for k in 0..d {
190                s += a[i * d + k] * b[k * d + j];
191            }
192            c[i * d + j] = s;
193        }
194    }
195    c
196}
197
198/// Complex matrix inverse via Gauss-Jordan, d×d row-major.
199fn cmat_inv(a: &[C64], d: usize) -> Option<Vec<C64>> {
200    let mut aug = vec![C64::zero(); d * 2 * d];
201    for i in 0..d {
202        for j in 0..d {
203            aug[i * 2 * d + j] = a[i * d + j];
204        }
205        aug[i * 2 * d + d + i] = C64::one();
206    }
207    let w = 2 * d;
208    for col in 0..d {
209        // Pivot
210        let mut max_row = col;
211        let mut max_val = aug[col * w + col].abs();
212        for row in (col + 1)..d {
213            let v = aug[row * w + col].abs();
214            if v > max_val {
215                max_val = v;
216                max_row = row;
217            }
218        }
219        if max_val < 1e-30 {
220            return None;
221        }
222        if max_row != col {
223            for j in 0..w {
224                aug.swap(col * w + j, max_row * w + j);
225            }
226        }
227        let pivot = aug[col * w + col];
228        let inv_pivot = pivot.conj() * (1.0 / pivot.norm_sq());
229        for j in 0..w {
230            aug[col * w + j] = aug[col * w + j] * inv_pivot;
231        }
232        for row in 0..d {
233            if row == col {
234                continue;
235            }
236            let factor = aug[row * w + col];
237            for j in 0..w {
238                let sub = factor * aug[col * w + j];
239                aug[row * w + j] -= sub;
240            }
241        }
242    }
243    let mut result = vec![C64::zero(); d * d];
244    for i in 0..d {
245        for j in 0..d {
246            result[i * d + j] = aug[i * w + d + j];
247        }
248    }
249    Some(result)
250}
251
252/// Complex matrix determinant, d×d.
253fn cmat_det(a: &[C64], d: usize) -> C64 {
254    if d == 1 {
255        return a[0];
256    }
257    if d == 2 {
258        return a[0] * a[3] - a[1] * a[2];
259    }
260    // LU-based via Gaussian elimination
261    let mut m = a.to_vec();
262    let mut det = C64::one();
263    for col in 0..d {
264        let mut max_row = col;
265        let mut max_val = m[col * d + col].abs();
266        for row in (col + 1)..d {
267            let v = m[row * d + col].abs();
268            if v > max_val {
269                max_val = v;
270                max_row = row;
271            }
272        }
273        if max_val < 1e-30 {
274            return C64::zero();
275        }
276        if max_row != col {
277            for j in 0..d {
278                m.swap(col * d + j, max_row * d + j);
279            }
280            det = det * (-1.0);
281        }
282        det = det * m[col * d + col];
283        let pivot = m[col * d + col];
284        let inv_pivot = pivot.conj() * (1.0 / pivot.norm_sq());
285        for row in (col + 1)..d {
286            let factor = m[row * d + col] * inv_pivot;
287            for j in col..d {
288                let sub = factor * m[col * d + j];
289                m[row * d + j] -= sub;
290            }
291        }
292    }
293    det
294}
295
296/// Conjugate transpose of d×d complex matrix.
297fn cmat_conj_t(a: &[C64], d: usize) -> Vec<C64> {
298    let mut r = vec![C64::zero(); d * d];
299    for i in 0..d {
300        for j in 0..d {
301            r[j * d + i] = a[i * d + j].conj();
302        }
303    }
304    r
305}
306
307// ── VAR model ───────────────────────────────────────────────────────
308
309/// Fit VAR(order) model. Returns (beta [order*d × d, row-major], sigma [d×d, row-major]).
310fn var_coefficients(trains_binned: &[Vec<f64>], order: usize) -> (Vec<f64>, Vec<f64>) {
311    let d = trains_binned.len();
312    let t = if d > 0 { trains_binned[0].len() } else { 0 };
313    if t <= order + 1 || d == 0 {
314        return (vec![0.0; order * d * d], identity_flat(d));
315    }
316    let n_pts = t - order;
317    let x_cols = order * d;
318
319    // Build y_cols: (d × n_pts) column-major
320    let mut y_cols = vec![vec![0.0_f64; n_pts]; d];
321    for ch in 0..d {
322        for i in 0..n_pts {
323            y_cols[ch][i] = trains_binned[ch][order + i];
324        }
325    }
326
327    // Build x_cols_data: (x_cols × n_pts) column-major
328    let mut x_cols_data = vec![vec![0.0_f64; n_pts]; x_cols];
329    for i in 0..n_pts {
330        for k in 0..order {
331            for ch in 0..d {
332                x_cols_data[k * d + ch][i] = trains_binned[ch][order - k - 1 + i];
333            }
334        }
335    }
336
337    // X^T X + reg
338    let mut xtx = vec![0.0_f64; x_cols * x_cols];
339    xtx.par_chunks_exact_mut(x_cols)
340        .enumerate()
341        .for_each(|(i, row)| {
342            for j in 0..=i {
343                let dot = crate::simd::dot_f64_dispatch(&x_cols_data[i], &x_cols_data[j]);
344                row[j] = dot + if i == j { 1e-8 } else { 0.0 };
345            }
346        });
347    // Mirror the matrix (serial, small overhead)
348    for i in 0..x_cols {
349        for j in (i + 1)..x_cols {
350            xtx[i * x_cols + j] = xtx[j * x_cols + i];
351        }
352    }
353
354    // X^T Y
355    let mut xty = vec![0.0_f64; x_cols * d];
356    xty.par_chunks_exact_mut(d)
357        .enumerate()
358        .for_each(|(i, row)| {
359            for j in 0..d {
360                row[j] = crate::simd::dot_f64_dispatch(&x_cols_data[i], &y_cols[j]);
361            }
362        });
363
364    // beta = (X^T X)^{-1} X^T Y
365    let beta = solve_matrix(&xtx, &xty, x_cols, d);
366
367    // Residuals Sigma = (1/N) (Y - X beta)^T (Y - X beta)
368    let mut sigma = vec![0.0_f64; d * d];
369    let n_norm = n_pts.max(1) as f64;
370
371    // Precompute residuals: res_cols = y_cols - X_cols * beta (parallel)
372    let res_cols: Vec<Vec<f64>> = (0..d)
373        .into_par_iter()
374        .map(|j| {
375            let mut res = vec![0.0_f64; n_pts];
376            for p in 0..n_pts {
377                let mut r = y_cols[j][p];
378                for c in 0..x_cols {
379                    r -= x_cols_data[c][p] * beta[c * d + j];
380                }
381                res[p] = r;
382            }
383            res
384        })
385        .collect();
386
387    for i in 0..d {
388        for j in 0..=i {
389            let dot = crate::simd::dot_f64_dispatch(&res_cols[i], &res_cols[j]);
390            let val = dot / n_norm;
391            sigma[i * d + j] = val;
392            sigma[j * d + i] = val;
393        }
394    }
395
396    (beta, sigma)
397}
398
399fn identity_flat(d: usize) -> Vec<f64> {
400    let mut m = vec![0.0_f64; d * d];
401    for i in 0..d {
402        m[i * d + i] = 1.0;
403    }
404    m
405}
406
407/// Sum of squared errors for OLS regression: SSE = ||y - X beta||^2.
408fn sse_ols(x: &[f64], y: &[f64], n_pts: usize, x_cols: usize) -> f64 {
409    // X^T X + reg
410    let mut xtx = vec![0.0_f64; x_cols * x_cols];
411    for i in 0..x_cols {
412        for j in 0..x_cols {
413            let mut s = 0.0;
414            for p in 0..n_pts {
415                s += x[p * x_cols + i] * x[p * x_cols + j];
416            }
417            xtx[i * x_cols + j] = s + if i == j { 1e-8 } else { 0.0 };
418        }
419    }
420    // X^T y
421    let mut xty = vec![0.0_f64; x_cols];
422    for i in 0..x_cols {
423        let mut s = 0.0;
424        for p in 0..n_pts {
425            s += x[p * x_cols + i] * y[p];
426        }
427        xty[i] = s;
428    }
429    let beta = solve_linear(&xtx, &xty, x_cols);
430    let mut sse = 0.0_f64;
431    for p in 0..n_pts {
432        let mut pred = 0.0;
433        for c in 0..x_cols {
434            pred += x[p * x_cols + c] * beta[c];
435        }
436        let r = y[p] - pred;
437        sse += r * r;
438    }
439    sse
440}
441
442// ── Public API ──────────────────────────────────────────────────────
443
444/// Pairwise Granger causality (Granger 1969).
445/// Returns log-likelihood ratio. Positive = source Granger-causes target.
446pub fn pairwise_granger_causality(
447    source: &[i32],
448    target: &[i32],
449    bin_size: usize,
450    order: usize,
451) -> f64 {
452    let cs: Vec<f64> = bin_spike_train(source, bin_size)
453        .iter()
454        .map(|&v| v as f64)
455        .collect();
456    let ct: Vec<f64> = bin_spike_train(target, bin_size)
457        .iter()
458        .map(|&v| v as f64)
459        .collect();
460    let n = cs.len().min(ct.len());
461    if n <= 2 * order {
462        return 0.0;
463    }
464
465    let n_pts = n - order;
466    let y: Vec<f64> = ct[order..n].to_vec();
467
468    // Restricted model: target past only
469    let r_cols = order;
470    let mut x_r = vec![0.0_f64; n_pts * r_cols];
471    for p in 0..n_pts {
472        for k in 0..order {
473            x_r[p * r_cols + k] = ct[order - k - 1 + p];
474        }
475    }
476    let sse_r = sse_ols(&x_r, &y, n_pts, r_cols);
477
478    // Full model: target past + source past
479    let f_cols = 2 * order;
480    let mut x_f = vec![0.0_f64; n_pts * f_cols];
481    for p in 0..n_pts {
482        for k in 0..order {
483            x_f[p * f_cols + k] = ct[order - k - 1 + p];
484            x_f[p * f_cols + order + k] = cs[order - k - 1 + p];
485        }
486    }
487    let sse_f = sse_ols(&x_f, &y, n_pts, f_cols);
488
489    if sse_f <= 0.0 {
490        return 0.0;
491    }
492    (sse_r.max(1e-30) / sse_f.max(1e-30)).ln()
493}
494
495/// Conditional Granger causality (Geweke 1984).
496/// Tests if source Granger-causes target controlling for condition.
497pub fn conditional_granger_causality(
498    source: &[i32],
499    target: &[i32],
500    condition: &[i32],
501    bin_size: usize,
502    order: usize,
503) -> f64 {
504    let cs: Vec<f64> = bin_spike_train(source, bin_size)
505        .iter()
506        .map(|&v| v as f64)
507        .collect();
508    let ct: Vec<f64> = bin_spike_train(target, bin_size)
509        .iter()
510        .map(|&v| v as f64)
511        .collect();
512    let cc: Vec<f64> = bin_spike_train(condition, bin_size)
513        .iter()
514        .map(|&v| v as f64)
515        .collect();
516    let n = cs.len().min(ct.len()).min(cc.len());
517    if n <= 2 * order {
518        return 0.0;
519    }
520
521    let n_pts = n - order;
522    let y: Vec<f64> = ct[order..n].to_vec();
523
524    // Conditioned model: target + condition past
525    let c_cols = 2 * order;
526    let mut x_c = vec![0.0_f64; n_pts * c_cols];
527    for p in 0..n_pts {
528        for k in 0..order {
529            x_c[p * c_cols + k] = ct[order - k - 1 + p];
530            x_c[p * c_cols + order + k] = cc[order - k - 1 + p];
531        }
532    }
533    let sse_c = sse_ols(&x_c, &y, n_pts, c_cols);
534
535    // Full model: target + condition + source past
536    let f_cols = 3 * order;
537    let mut x_f = vec![0.0_f64; n_pts * f_cols];
538    for p in 0..n_pts {
539        for k in 0..order {
540            x_f[p * f_cols + k] = ct[order - k - 1 + p];
541            x_f[p * f_cols + order + k] = cc[order - k - 1 + p];
542            x_f[p * f_cols + 2 * order + k] = cs[order - k - 1 + p];
543        }
544    }
545    let sse_f = sse_ols(&x_f, &y, n_pts, f_cols);
546
547    if sse_f <= 0.0 {
548        return 0.0;
549    }
550    (sse_c.max(1e-30) / sse_f.max(1e-30)).ln()
551}
552
553/// Spectral Granger causality (Geweke 1982).
554/// Returns (d × d × n_freqs) as flat Vec, row-major in [i][j][f] order.
555pub fn spectral_granger_causality(
556    trains: &[&[i32]],
557    bin_size: usize,
558    order: usize,
559    n_freqs: usize,
560) -> (Vec<f64>, usize) {
561    let binned: Vec<Vec<f64>> = trains
562        .iter()
563        .map(|t| {
564            bin_spike_train(t, bin_size)
565                .iter()
566                .map(|&v| v as f64)
567                .collect()
568        })
569        .collect();
570    let d = binned.len();
571    let (beta, sigma) = var_coefficients(&binned, order);
572
573    let mut gc = vec![0.0_f64; d * d * n_freqs];
574
575    for fi in 0..n_freqs {
576        let f = fi as f64 / (2 * n_freqs) as f64; // [0, 0.5)
577
578        // A(f) = I - sum_k coeff_k * exp(-2πi f (k+1))
579        let mut a_f = vec![C64::zero(); d * d];
580        for i in 0..d {
581            a_f[i * d + i] = C64::one();
582        }
583        for k in 0..order {
584            let angle = -2.0 * PI * f * (k + 1) as f64;
585            let exp_val = C64::new(angle.cos(), angle.sin());
586            for i in 0..d {
587                for j in 0..d {
588                    // beta is (order*d × d), block k is rows [k*d..(k+1)*d], transposed
589                    let coeff = beta[(k * d + j) * d + i]; // beta[k*d+j, i] → coeff_block.T[i,j]
590                    a_f[i * d + j] -= C64::new(coeff, 0.0) * exp_val;
591                }
592            }
593        }
594
595        let det = cmat_det(&a_f, d);
596        if det.abs() < 1e-30 {
597            continue;
598        }
599        let h = match cmat_inv(&a_f, d) {
600            Some(inv) => inv,
601            None => continue,
602        };
603
604        // S = H Σ H*
605        let sigma_c: Vec<C64> = sigma.iter().map(|&v| C64::new(v, 0.0)).collect();
606        let h_conj_t = cmat_conj_t(&h, d);
607        let tmp = cmat_mul(&h, &sigma_c, d);
608        let s = cmat_mul(&tmp, &h_conj_t, d);
609
610        for i in 0..d {
611            for j in 0..d {
612                if i == j {
613                    continue;
614                }
615                let s_ii = s[i * d + i].abs();
616                if s_ii > 1e-30 {
617                    let h_ij_sq = h[i * d + j].norm_sq();
618                    let reduced = s_ii - sigma[j * d + j] * h_ij_sq;
619                    if reduced > 0.0 && reduced < s_ii {
620                        gc[(i * d + j) * n_freqs + fi] = (s_ii / reduced).ln().max(0.0);
621                    }
622                }
623            }
624        }
625    }
626    (gc, d)
627}
628
629/// Partial directed coherence (Baccala & Sameshima 2001).
630/// Returns (d × d × n_freqs) flat Vec.
631pub fn partial_directed_coherence(
632    trains: &[&[i32]],
633    bin_size: usize,
634    order: usize,
635    n_freqs: usize,
636) -> (Vec<f64>, usize) {
637    let binned: Vec<Vec<f64>> = trains
638        .iter()
639        .map(|t| {
640            bin_spike_train(t, bin_size)
641                .iter()
642                .map(|&v| v as f64)
643                .collect()
644        })
645        .collect();
646    let d = binned.len();
647    let (beta, _) = var_coefficients(&binned, order);
648
649    let mut pdc = vec![0.0_f64; d * d * n_freqs];
650
651    for fi in 0..n_freqs {
652        let f = fi as f64 / (2 * n_freqs) as f64;
653
654        let mut a_f = vec![C64::zero(); d * d];
655        for i in 0..d {
656            a_f[i * d + i] = C64::one();
657        }
658        for k in 0..order {
659            let angle = -2.0 * PI * f * (k + 1) as f64;
660            let exp_val = C64::new(angle.cos(), angle.sin());
661            for i in 0..d {
662                for j in 0..d {
663                    let coeff = beta[(k * d + j) * d + i];
664                    a_f[i * d + j] -= C64::new(coeff, 0.0) * exp_val;
665                }
666            }
667        }
668
669        for j in 0..d {
670            let norm: f64 = (0..d).map(|i| a_f[i * d + j].norm_sq()).sum::<f64>().sqrt();
671            if norm > 0.0 {
672                for i in 0..d {
673                    pdc[(i * d + j) * n_freqs + fi] = a_f[i * d + j].abs() / norm;
674                }
675            }
676        }
677    }
678    (pdc, d)
679}
680
681/// Directed transfer function (Kaminski & Blinowska 1991).
682/// Returns (d × d × n_freqs) flat Vec.
683pub fn directed_transfer_function(
684    trains: &[&[i32]],
685    bin_size: usize,
686    order: usize,
687    n_freqs: usize,
688) -> (Vec<f64>, usize) {
689    let binned: Vec<Vec<f64>> = trains
690        .iter()
691        .map(|t| {
692            bin_spike_train(t, bin_size)
693                .iter()
694                .map(|&v| v as f64)
695                .collect()
696        })
697        .collect();
698    let d = binned.len();
699    let (beta, _sigma) = var_coefficients(&binned, order);
700
701    let mut dtf = vec![0.0_f64; d * d * n_freqs];
702
703    for fi in 0..n_freqs {
704        let f = fi as f64 / (2 * n_freqs) as f64;
705
706        let mut a_f = vec![C64::zero(); d * d];
707        for i in 0..d {
708            a_f[i * d + i] = C64::one();
709        }
710        for k in 0..order {
711            let angle = -2.0 * PI * f * (k + 1) as f64;
712            let exp_val = C64::new(angle.cos(), angle.sin());
713            for i in 0..d {
714                for j in 0..d {
715                    let coeff = beta[(k * d + j) * d + i];
716                    a_f[i * d + j] -= C64::new(coeff, 0.0) * exp_val;
717                }
718            }
719        }
720
721        let det = cmat_det(&a_f, d);
722        if det.abs() < 1e-30 {
723            continue;
724        }
725        let h = match cmat_inv(&a_f, d) {
726            Some(inv) => inv,
727            None => continue,
728        };
729
730        for i in 0..d {
731            let norm: f64 = (0..d).map(|j| h[i * d + j].norm_sq()).sum::<f64>().sqrt();
732            if norm > 0.0 {
733                for j in 0..d {
734                    dtf[(i * d + j) * n_freqs + fi] = h[i * d + j].abs() / norm;
735                }
736            }
737        }
738    }
739    (dtf, d)
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745
746    fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
747        let mut t = vec![0i32; len];
748        for &s in spikes {
749            t[s] = 1;
750        }
751        t
752    }
753
754    // ── linear algebra helpers ──────────────────────────────────────
755
756    #[test]
757    fn test_solve_linear_identity() {
758        // I x = b → x = b
759        let a = vec![1.0, 0.0, 0.0, 1.0];
760        let b = vec![3.0, 7.0];
761        let x = solve_linear(&a, &b, 2);
762        assert!((x[0] - 3.0).abs() < 1e-10);
763        assert!((x[1] - 7.0).abs() < 1e-10);
764    }
765
766    #[test]
767    fn test_solve_linear_2x2() {
768        // [2 1; 1 3] x = [5; 10] → x = [1, 3]
769        let a = vec![2.0, 1.0, 1.0, 3.0];
770        let b = vec![5.0, 10.0];
771        let x = solve_linear(&a, &b, 2);
772        assert!((x[0] - 1.0).abs() < 1e-10);
773        assert!((x[1] - 3.0).abs() < 1e-10);
774    }
775
776    #[test]
777    fn test_cmat_det_2x2() {
778        let a = vec![
779            C64::new(1.0, 0.0),
780            C64::new(2.0, 0.0),
781            C64::new(3.0, 0.0),
782            C64::new(4.0, 0.0),
783        ];
784        let det = cmat_det(&a, 2);
785        assert!((det.re - (-2.0)).abs() < 1e-10);
786        assert!(det.im.abs() < 1e-10);
787    }
788
789    #[test]
790    fn test_cmat_inv_identity() {
791        let a = vec![C64::one(), C64::zero(), C64::zero(), C64::one()];
792        let inv = cmat_inv(&a, 2).unwrap();
793        assert!((inv[0].re - 1.0).abs() < 1e-10);
794        assert!((inv[3].re - 1.0).abs() < 1e-10);
795        assert!(inv[1].abs() < 1e-10);
796        assert!(inv[2].abs() < 1e-10);
797    }
798
799    #[test]
800    fn test_cmat_inv_roundtrip() {
801        let a = vec![
802            C64::new(2.0, 1.0),
803            C64::new(1.0, 0.0),
804            C64::new(0.0, 1.0),
805            C64::new(3.0, 0.0),
806        ];
807        let inv = cmat_inv(&a, 2).unwrap();
808        let prod = cmat_mul(&a, &inv, 2);
809        // Should be identity
810        assert!((prod[0].re - 1.0).abs() < 1e-8);
811        assert!((prod[3].re - 1.0).abs() < 1e-8);
812        assert!(prod[1].abs() < 1e-8);
813        assert!(prod[2].abs() < 1e-8);
814    }
815
816    // ── pairwise_granger_causality ──────────────────────────────────
817
818    #[test]
819    fn test_gc_self_finite() {
820        let train = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
821        let gc = pairwise_granger_causality(&train, &train, 5, 3);
822        // When source == target, duplicate regressors can reduce SSE via regularisation
823        assert!(gc.is_finite(), "self GC should be finite, got {gc}");
824        assert!(gc >= 0.0, "GC should be non-negative, got {gc}");
825    }
826
827    #[test]
828    fn test_gc_non_negative_typical() {
829        let source = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
830        let target = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
831        let gc = pairwise_granger_causality(&source, &target, 5, 3);
832        // Just check it returns a finite value
833        assert!(gc.is_finite(), "GC should be finite, got {gc}");
834    }
835
836    #[test]
837    fn test_gc_too_short() {
838        let a = make_train(&[1], 10);
839        let b = make_train(&[2], 10);
840        let gc = pairwise_granger_causality(&a, &b, 5, 5);
841        assert_eq!(gc, 0.0, "too short → 0");
842    }
843
844    // ── conditional_granger_causality ────────────────────────────────
845
846    #[test]
847    fn test_cond_gc_finite() {
848        let source = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
849        let target = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
850        let cond = make_train(&[3, 13, 23, 33, 43, 53, 63, 73, 83, 93], 100);
851        let gc = conditional_granger_causality(&source, &target, &cond, 5, 3);
852        assert!(gc.is_finite(), "conditional GC should be finite");
853    }
854
855    #[test]
856    fn test_cond_gc_too_short() {
857        let a = make_train(&[1], 10);
858        let b = make_train(&[2], 10);
859        let c = make_train(&[3], 10);
860        assert_eq!(conditional_granger_causality(&a, &b, &c, 5, 5), 0.0);
861    }
862
863    // ── spectral_granger_causality ──────────────────────────────────
864
865    #[test]
866    fn test_spectral_gc_shape() {
867        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
868        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
869        let trains: Vec<&[i32]> = vec![&t1, &t2];
870        let (gc, d) = spectral_granger_causality(&trains, 5, 3, 16);
871        assert_eq!(d, 2);
872        assert_eq!(gc.len(), 2 * 2 * 16);
873    }
874
875    #[test]
876    fn test_spectral_gc_diagonal_zero() {
877        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
878        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
879        let trains: Vec<&[i32]> = vec![&t1, &t2];
880        let (gc, _) = spectral_granger_causality(&trains, 5, 3, 16);
881        // Diagonal entries (i==j) should be 0
882        for fi in 0..16 {
883            assert_eq!(gc[fi], 0.0, "GC[0,0] should be 0");
884            assert_eq!(gc[3 * 16 + fi], 0.0, "GC[1,1] should be 0");
885        }
886    }
887
888    #[test]
889    fn test_spectral_gc_non_negative() {
890        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
891        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
892        let trains: Vec<&[i32]> = vec![&t1, &t2];
893        let (gc, _) = spectral_granger_causality(&trains, 5, 3, 16);
894        for &v in &gc {
895            assert!(v >= 0.0, "spectral GC must be non-negative, got {v}");
896        }
897    }
898
899    // ── partial_directed_coherence ──────────────────────────────────
900
901    #[test]
902    fn test_pdc_shape() {
903        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
904        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
905        let trains: Vec<&[i32]> = vec![&t1, &t2];
906        let (pdc, d) = partial_directed_coherence(&trains, 5, 3, 16);
907        assert_eq!(d, 2);
908        assert_eq!(pdc.len(), 2 * 2 * 16);
909    }
910
911    #[test]
912    fn test_pdc_range() {
913        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
914        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
915        let trains: Vec<&[i32]> = vec![&t1, &t2];
916        let (pdc, _) = partial_directed_coherence(&trains, 5, 3, 16);
917        for &v in &pdc {
918            assert!(
919                (0.0..=1.0 + 1e-10).contains(&v),
920                "PDC should be in [0,1], got {v}"
921            );
922        }
923    }
924
925    // ── directed_transfer_function ──────────────────────────────────
926
927    #[test]
928    fn test_dtf_shape() {
929        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
930        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
931        let trains: Vec<&[i32]> = vec![&t1, &t2];
932        let (dtf, d) = directed_transfer_function(&trains, 5, 3, 16);
933        assert_eq!(d, 2);
934        assert_eq!(dtf.len(), 2 * 2 * 16);
935    }
936
937    #[test]
938    fn test_dtf_range() {
939        let t1 = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
940        let t2 = make_train(&[7, 17, 27, 37, 47, 57, 67, 77, 87, 97], 100);
941        let trains: Vec<&[i32]> = vec![&t1, &t2];
942        let (dtf, _) = directed_transfer_function(&trains, 5, 3, 16);
943        for &v in &dtf {
944            assert!(
945                (0.0..=1.0 + 1e-10).contains(&v),
946                "DTF should be in [0,1], got {v}"
947            );
948        }
949    }
950
951    // ── var_coefficients ────────────────────────────────────────────
952
953    #[test]
954    fn test_var_too_short() {
955        let trains = vec![vec![1.0, 2.0]];
956        let (beta, sigma) = var_coefficients(&trains, 5);
957        assert!(beta.iter().all(|&v| v == 0.0), "too short → zero beta");
958        assert!((sigma[0] - 1.0).abs() < 1e-10, "identity sigma");
959    }
960}