Skip to main content

sc_neurocore_engine/analysis/
correlation.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Cross-correlation, synchrony, and covariance measures
8
9use rayon::prelude::*;
10use rustfft::{num_complex::Complex, FftPlanner};
11
12use super::basic::{bin_spike_train, spike_times};
13
14/// Cross-correlogram between two binary spike trains.
15/// Returns (correlation, lags_ms).
16pub fn cross_correlation(
17    train_a: &[i32],
18    train_b: &[i32],
19    max_lag_ms: f64,
20    dt: f64,
21) -> (Vec<f64>, Vec<f64>) {
22    let max_lag = (max_lag_ms / (dt * 1000.0)) as isize;
23    let n = train_a.len().min(train_b.len());
24    if n == 0 {
25        return (vec![], vec![]);
26    }
27
28    let mean_a: f64 = train_a[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
29    let mean_b: f64 = train_b[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
30    let a: Vec<f64> = train_a[..n].iter().map(|&v| v as f64 - mean_a).collect();
31    let b: Vec<f64> = train_b[..n].iter().map(|&v| v as f64 - mean_b).collect();
32
33    let norm = (a.iter().map(|x| x * x).sum::<f64>() * b.iter().map(|x| x * x).sum::<f64>()).sqrt();
34
35    let n_lags = (2 * max_lag + 1) as usize;
36    let mut cc = vec![0.0_f64; n_lags];
37    let mut lags_ms = Vec::with_capacity(n_lags);
38    for l in -max_lag..=max_lag {
39        lags_ms.push(l as f64 * dt * 1000.0);
40    }
41
42    if norm == 0.0 {
43        return (cc, lags_ms);
44    }
45
46    for (i, lag) in (-max_lag..=max_lag).enumerate() {
47        let sum = if lag >= 0 {
48            let l = lag as usize;
49            crate::simd::dot_f64_dispatch(&a[..n - l], &b[l..n])
50        } else {
51            let l = (-lag) as usize;
52            crate::simd::dot_f64_dispatch(&a[l..n], &b[..n - l])
53        };
54        cc[i] = sum / norm;
55    }
56
57    (cc, lags_ms)
58}
59
60/// Pairwise Pearson correlation matrix across neurons.
61pub fn pairwise_correlation(trains: &[&[i32]], dt: f64) -> Vec<Vec<f64>> {
62    let _ = dt;
63    let n = trains.len();
64    if n == 0 {
65        return vec![vec![]];
66    }
67    let min_len = trains.iter().map(|t| t.len()).min().unwrap_or(0);
68    if min_len == 0 {
69        return vec![vec![0.0; n]; n];
70    }
71
72    let mat: Vec<Vec<f64>> = trains
73        .iter()
74        .map(|t| t[..min_len].iter().map(|&v| v as f64).collect::<Vec<f64>>())
75        .collect();
76
77    let means: Vec<f64> = mat
78        .iter()
79        .map(|row| row.iter().sum::<f64>() / min_len as f64)
80        .collect();
81    let stds: Vec<f64> = mat
82        .iter()
83        .enumerate()
84        .map(|(i, row)| {
85            (row.iter().map(|v| (v - means[i]).powi(2)).sum::<f64>() / min_len as f64).sqrt()
86        })
87        .collect();
88
89    let mut corr = vec![vec![0.0_f64; n]; n];
90    for i in 0..n {
91        corr[i][i] = 1.0;
92        for j in (i + 1)..n {
93            if stds[i] > 0.0 && stds[j] > 0.0 {
94                let cov: f64 = (0..min_len)
95                    .map(|k| (mat[i][k] - means[i]) * (mat[j][k] - means[j]))
96                    .sum::<f64>()
97                    / min_len as f64;
98                let r = cov / (stds[i] * stds[j]);
99                corr[i][j] = r;
100                corr[j][i] = r;
101            }
102        }
103    }
104    corr
105}
106
107/// Event synchronisation (Quian Quiroga et al. 2002).
108/// Returns synchrony score in [0, 1].
109pub fn event_synchronization(train_a: &[i32], train_b: &[i32], dt: f64, tau_ms: f64) -> f64 {
110    let ta = spike_times(train_a, dt);
111    let tb = spike_times(train_b, dt);
112    let na = ta.len();
113    let nb = tb.len();
114    if na == 0 || nb == 0 {
115        return 0.0;
116    }
117    let tau = tau_ms / 1000.0;
118    let mut count = 0_usize;
119    for &ti in &ta {
120        for &tj in &tb {
121            if (ti - tj).abs() < tau {
122                count += 1;
123            }
124        }
125    }
126    count as f64 / (na as f64 * nb as f64).sqrt()
127}
128
129/// Magnitude-squared coherence between two binary spike trains.
130/// Returns (coherence, freqs_hz).
131pub fn spike_train_coherence(train_a: &[i32], train_b: &[i32], dt: f64) -> (Vec<f64>, Vec<f64>) {
132    let n = train_a.len().min(train_b.len());
133    if n < 2 {
134        return (vec![], vec![]);
135    }
136
137    let mean_a: f64 = train_a[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
138    let mean_b: f64 = train_b[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
139
140    let mut planner = FftPlanner::<f64>::new();
141    let fft = planner.plan_fft_forward(n);
142
143    let mut buf_a: Vec<Complex<f64>> = train_a[..n]
144        .iter()
145        .map(|&v| Complex::new(v as f64 - mean_a, 0.0))
146        .collect();
147    let mut buf_b: Vec<Complex<f64>> = train_b[..n]
148        .iter()
149        .map(|&v| Complex::new(v as f64 - mean_b, 0.0))
150        .collect();
151
152    fft.process(&mut buf_a);
153    fft.process(&mut buf_b);
154
155    // Only positive frequencies (rfft equivalent): indices 0..=n/2
156    let n_freqs = n / 2 + 1;
157    let mut coh = Vec::with_capacity(n_freqs);
158    let mut freqs = Vec::with_capacity(n_freqs);
159
160    for i in 0..n_freqs {
161        let fa = buf_a[i];
162        let fb = buf_b[i];
163        let pab = fa * fb.conj();
164        let paa = fa.norm_sqr();
165        let pbb = fb.norm_sqr();
166        let denom = paa * pbb;
167        if denom == 0.0 {
168            coh.push(0.0);
169        } else {
170            coh.push(pab.norm_sqr() / denom);
171        }
172        freqs.push(i as f64 / (n as f64 * dt));
173    }
174
175    (coh, freqs)
176}
177
178/// Spike Time Tiling Coefficient (Cutts & Eglen 2014).
179pub fn spike_time_tiling_coefficient(
180    train_a: &[i32],
181    train_b: &[i32],
182    dt: f64,
183    delta_ms: f64,
184) -> f64 {
185    let delta = delta_ms / 1000.0;
186    let ta = spike_times(train_a, dt);
187    let tb = spike_times(train_b, dt);
188    let duration = train_a.len().max(train_b.len()) as f64 * dt;
189
190    if ta.is_empty() || tb.is_empty() {
191        return 0.0;
192    }
193
194    let pa = coincidence_fraction(&ta, &tb, delta);
195    let pb = coincidence_fraction(&tb, &ta, delta);
196    let ta_frac = tile_fraction(&ta, delta, duration);
197    let tb_frac = tile_fraction(&tb, delta, duration);
198
199    0.5 * (sttc_term(pa, tb_frac) + sttc_term(pb, ta_frac))
200}
201
202fn tile_fraction(times: &[f64], delta: f64, duration: f64) -> f64 {
203    if times.is_empty() || duration <= 0.0 {
204        return 0.0;
205    }
206    let mut intervals: Vec<(f64, f64)> = times.iter().map(|&t| (t - delta, t + delta)).collect();
207    intervals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
208
209    let mut merged = vec![intervals[0]];
210    for &(lo, hi) in &intervals[1..] {
211        let last = merged.last_mut().unwrap();
212        if lo <= last.1 {
213            last.1 = last.1.max(hi);
214        } else {
215            merged.push((lo, hi));
216        }
217    }
218
219    let covered: f64 = merged
220        .iter()
221        .map(|&(lo, hi)| {
222            let lo_c = lo.max(0.0);
223            let hi_c = hi.min(duration);
224            if hi_c > lo_c {
225                hi_c - lo_c
226            } else {
227                0.0
228            }
229        })
230        .sum();
231
232    (covered / duration).min(1.0)
233}
234
235fn coincidence_fraction(times_ref: &[f64], times_target: &[f64], delta: f64) -> f64 {
236    if times_ref.is_empty() {
237        return 0.0;
238    }
239    let count = times_ref
240        .iter()
241        .filter(|&&t| times_target.iter().any(|&tt| (tt - t).abs() <= delta))
242        .count();
243    count as f64 / times_ref.len() as f64
244}
245
246fn sttc_term(p: f64, t: f64) -> f64 {
247    if (1.0 - t).abs() < 1e-15 {
248        return 0.0;
249    }
250    if (1.0 - p * t).abs() < 1e-15 {
251        return 0.0;
252    }
253    (p - t) / (1.0 - p * t)
254}
255
256/// Spike count covariance matrix (de la Rocha et al. 2007).
257pub fn covariance_matrix(trains: &[&[i32]], bin_size: usize) -> Vec<Vec<f64>> {
258    let binned: Vec<Vec<i64>> = trains
259        .iter()
260        .map(|t| bin_spike_train(t, bin_size))
261        .collect();
262    let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
263    let n = trains.len();
264
265    if n == 0 || min_bins == 0 {
266        return vec![vec![]];
267    }
268
269    let mat: Vec<Vec<f64>> = binned
270        .iter()
271        .map(|b| b[..min_bins].iter().map(|&v| v as f64).collect())
272        .collect();
273    let means: Vec<f64> = mat
274        .iter()
275        .map(|row| row.iter().sum::<f64>() / min_bins as f64)
276        .collect();
277
278    if n == 1 {
279        let var = mat[0].iter().map(|v| (v - means[0]).powi(2)).sum::<f64>()
280            / (min_bins as f64 - 1.0).max(1.0);
281        return vec![vec![var]];
282    }
283
284    let ddof = (min_bins as f64 - 1.0).max(1.0);
285    let min_bins_f = min_bins as f64;
286    let mut cov = vec![vec![0.0_f64; n]; n];
287    cov.par_iter_mut().enumerate().for_each(|(i, row)| {
288        for j in i..n {
289            let dot = crate::simd::dot_f64_dispatch(&mat[i], &mat[j]);
290            row[j] = (dot - min_bins_f * means[i] * means[j]) / ddof;
291        }
292    });
293    // Mirror
294    for i in 0..n {
295        for j in (i + 1)..n {
296            cov[j][i] = cov[i][j];
297        }
298    }
299    cov
300}
301
302/// Autocorrelation time (seconds).
303/// Integral of normalised autocorrelation until first zero crossing.
304pub fn autocorrelation_time(binary_train: &[i32], dt: f64, max_lag_ms: f64) -> f64 {
305    let max_lag = (max_lag_ms / (dt * 1000.0)) as usize;
306    let n = binary_train.len();
307    let mean: f64 = binary_train.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
308    let x: Vec<f64> = binary_train.iter().map(|&v| v as f64 - mean).collect();
309    let var: f64 = x.iter().map(|v| v * v).sum();
310    if var == 0.0 {
311        return 0.0;
312    }
313    let mut tau = 0.0_f64;
314    for lag in 1..max_lag.min(n) {
315        let ac: f64 = (0..(n - lag)).map(|j| x[j] * x[j + lag]).sum::<f64>() / var;
316        if ac < 0.0 {
317            break;
318        }
319        tau += ac * dt;
320    }
321    tau
322}
323
324/// Noise correlation (Averbeck & Lee 2006).
325/// Residuals after subtracting mean across neurons.
326pub fn noise_correlation(trains: &[&[i32]], bin_size: usize) -> Vec<Vec<f64>> {
327    let binned: Vec<Vec<i64>> = trains
328        .iter()
329        .map(|t| bin_spike_train(t, bin_size))
330        .collect();
331    let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
332    let n = trains.len();
333    if n == 0 || min_bins == 0 {
334        return vec![vec![]];
335    }
336
337    let mat: Vec<Vec<f64>> = binned
338        .iter()
339        .map(|b| b[..min_bins].iter().map(|&v| v as f64).collect())
340        .collect();
341
342    // Mean across neurons for each time bin
343    let bin_means: Vec<f64> = (0..min_bins)
344        .map(|k| mat.iter().map(|row| row[k]).sum::<f64>() / n as f64)
345        .collect();
346
347    // Residuals = mat - mean across time (axis=0)
348    // Python: residuals = mat - mat.mean(axis=0, keepdims=True)
349    let residuals: Vec<Vec<f64>> = mat
350        .iter()
351        .map(|row| {
352            row.iter()
353                .enumerate()
354                .map(|(k, &v)| v - bin_means[k])
355                .collect()
356        })
357        .collect();
358
359    let mut corr = vec![vec![0.0_f64; n]; n];
360    for i in 0..n {
361        corr[i][i] = 1.0;
362        let std_i = (residuals[i].iter().map(|v| v * v).sum::<f64>() / min_bins as f64).sqrt();
363        for j in (i + 1)..n {
364            let std_j = (residuals[j].iter().map(|v| v * v).sum::<f64>() / min_bins as f64).sqrt();
365            if std_i > 0.0 && std_j > 0.0 {
366                let r = residuals[i]
367                    .iter()
368                    .zip(residuals[j].iter())
369                    .map(|(a, b)| a * b)
370                    .sum::<f64>()
371                    / min_bins as f64
372                    / (std_i * std_j);
373                corr[i][j] = r;
374                corr[j][i] = r;
375            }
376        }
377    }
378    corr
379}
380
381/// Signal correlation (tuning similarity).
382/// Pearson correlation of mean responses.
383pub fn signal_correlation(trains: &[&[i32]], bin_size: usize) -> Vec<Vec<f64>> {
384    let binned: Vec<Vec<i64>> = trains
385        .iter()
386        .map(|t| bin_spike_train(t, bin_size))
387        .collect();
388    let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
389    let n = trains.len();
390    if n == 0 || min_bins == 0 {
391        return vec![vec![]];
392    }
393
394    let mat: Vec<Vec<f64>> = binned
395        .iter()
396        .map(|b| b[..min_bins].iter().map(|&v| v as f64).collect())
397        .collect();
398    let means: Vec<f64> = mat
399        .iter()
400        .map(|row| row.iter().sum::<f64>() / min_bins as f64)
401        .collect();
402    let stds: Vec<f64> = mat
403        .iter()
404        .enumerate()
405        .map(|(i, row)| {
406            (row.iter().map(|v| (v - means[i]).powi(2)).sum::<f64>() / min_bins as f64).sqrt()
407        })
408        .collect();
409
410    let mut corr = vec![vec![0.0_f64; n]; n];
411    for i in 0..n {
412        corr[i][i] = 1.0;
413        for j in (i + 1)..n {
414            if stds[i] > 0.0 && stds[j] > 0.0 {
415                let c: f64 = (0..min_bins)
416                    .map(|k| (mat[i][k] - means[i]) * (mat[j][k] - means[j]))
417                    .sum::<f64>()
418                    / min_bins as f64;
419                let r = c / (stds[i] * stds[j]);
420                corr[i][j] = r;
421                corr[j][i] = r;
422            }
423        }
424    }
425    corr
426}
427
428/// Windowed spike count covariance (Kohn & Smith 2005).
429pub fn spike_count_covariance(trains: &[&[i32]], window: usize) -> Vec<Vec<f64>> {
430    covariance_matrix(trains, window)
431}
432
433/// Joint PSTH matrix (Aertsen et al. 1989).
434/// Returns flattened n×n outer product of mean-subtracted binned counts.
435pub fn joint_psth(train_a: &[i32], train_b: &[i32], bin_size: usize) -> (Vec<f64>, usize) {
436    let ca_raw = bin_spike_train(train_a, bin_size);
437    let cb_raw = bin_spike_train(train_b, bin_size);
438    let n = ca_raw.len().min(cb_raw.len());
439    if n == 0 {
440        return (vec![], 0);
441    }
442    let mean_a = ca_raw[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
443    let mean_b = cb_raw[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
444    let ca: Vec<f64> = ca_raw[..n].iter().map(|&v| v as f64 - mean_a).collect();
445    let cb: Vec<f64> = cb_raw[..n].iter().map(|&v| v as f64 - mean_b).collect();
446
447    let mut result = Vec::with_capacity(n * n);
448    for &ai in &ca {
449        for &bj in &cb {
450            result.push(ai * bj / n as f64);
451        }
452    }
453    (result, n)
454}
455
456/// Coincidence index / kappa (Joris et al. 2006).
457/// Corrects raw coincidence count for expected coincidences from rate.
458pub fn coincidence_index(train_a: &[i32], train_b: &[i32], dt: f64, delta_ms: f64) -> f64 {
459    let ta = spike_times(train_a, dt);
460    let tb = spike_times(train_b, dt);
461    if ta.is_empty() || tb.is_empty() {
462        return 0.0;
463    }
464    let delta = delta_ms / 1000.0;
465    let duration = train_a.len().max(train_b.len()) as f64 * dt;
466    let mut raw_coinc = 0_usize;
467    for &t in &ta {
468        if tb.iter().any(|&tt| (tt - t).abs() <= delta) {
469            raw_coinc += 1;
470        }
471    }
472    let expected = if duration > 0.0 {
473        2.0 * delta * ta.len() as f64 * tb.len() as f64 / duration
474    } else {
475        0.0
476    };
477    let norm = 0.5 * (ta.len() + tb.len()) as f64;
478    if norm <= expected {
479        return 0.0;
480    }
481    (raw_coinc as f64 - expected) / (norm - expected)
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
489        let mut t = vec![0i32; len];
490        for &s in spikes {
491            t[s] = 1;
492        }
493        t
494    }
495
496    // ── cross_correlation ───────────────────────────────────────────
497
498    #[test]
499    fn test_cross_correlation_identical() {
500        let train = make_train(&[10, 30, 50, 70, 90], 100);
501        let (cc, lags) = cross_correlation(&train, &train, 5.0, 0.001);
502        // Peak at lag=0
503        let zero_idx = lags.iter().position(|&l| l.abs() < 1e-10).unwrap();
504        assert!(
505            (cc[zero_idx] - 1.0).abs() < 1e-10,
506            "autocorrelation peak should be 1.0"
507        );
508        // Symmetric
509        for i in 0..cc.len() / 2 {
510            assert!(
511                (cc[i] - cc[cc.len() - 1 - i]).abs() < 1e-10,
512                "autocorrelation should be symmetric"
513            );
514        }
515    }
516
517    #[test]
518    fn test_cross_correlation_shifted() {
519        let a = make_train(&[10, 30, 50], 100);
520        let b = make_train(&[12, 32, 52], 100);
521        let (cc, lags) = cross_correlation(&a, &b, 5.0, 0.001);
522        // Peak should be near lag=+2ms (b lags a by 2 steps = 2ms)
523        let peak_idx = cc
524            .iter()
525            .enumerate()
526            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
527            .unwrap()
528            .0;
529        assert!(
530            (lags[peak_idx] - 2.0).abs() < 1.5,
531            "peak lag should be near 2ms, got {}",
532            lags[peak_idx]
533        );
534    }
535
536    #[test]
537    fn test_cross_correlation_empty() {
538        let a = vec![0i32; 100];
539        let b = make_train(&[10, 50], 100);
540        let (cc, _) = cross_correlation(&a, &b, 5.0, 0.001);
541        assert!(
542            cc.iter().all(|&v| v == 0.0),
543            "zero train → zero correlation"
544        );
545    }
546
547    // ── pairwise_correlation ────────────────────────────────────────
548
549    #[test]
550    fn test_pairwise_correlation_identity() {
551        let t1 = make_train(&[10, 30, 50], 100);
552        let t2 = make_train(&[10, 30, 50], 100);
553        let trains: Vec<&[i32]> = vec![&t1, &t2];
554        let corr = pairwise_correlation(&trains, 0.001);
555        assert!((corr[0][0] - 1.0).abs() < 1e-10);
556        assert!((corr[0][1] - 1.0).abs() < 1e-10);
557        assert!((corr[1][0] - 1.0).abs() < 1e-10);
558    }
559
560    #[test]
561    fn test_pairwise_correlation_anticorrelated() {
562        let t1 = make_train(&[0, 2, 4, 6, 8], 10);
563        let t2 = make_train(&[1, 3, 5, 7, 9], 10);
564        let trains: Vec<&[i32]> = vec![&t1, &t2];
565        let corr = pairwise_correlation(&trains, 0.001);
566        assert!(
567            corr[0][1] < 0.0,
568            "alternating trains should be negatively correlated"
569        );
570    }
571
572    #[test]
573    fn test_pairwise_correlation_empty() {
574        let corr = pairwise_correlation(&[], 0.001);
575        let expected: Vec<Vec<f64>> = vec![vec![]];
576        assert_eq!(corr, expected);
577    }
578
579    // ── event_synchronization ───────────────────────────────────────
580
581    #[test]
582    fn test_event_sync_identical() {
583        let train = make_train(&[10, 30, 50, 70], 100);
584        let score = event_synchronization(&train, &train, 0.001, 5.0);
585        // Only self-matches within tau: each spike matches itself → count=4, sqrt(4*4)=4 → 1.0
586        assert!(
587            (score - 1.0).abs() < 1e-10,
588            "identical trains: count=4, sqrt(16)=4, score=1.0, got {}",
589            score
590        );
591    }
592
593    #[test]
594    fn test_event_sync_no_overlap() {
595        let a = make_train(&[10], 100);
596        let b = make_train(&[90], 100);
597        let score = event_synchronization(&a, &b, 0.001, 2.0);
598        assert_eq!(score, 0.0, "far apart spikes → zero sync");
599    }
600
601    #[test]
602    fn test_event_sync_empty() {
603        let a = vec![0i32; 100];
604        let b = make_train(&[50], 100);
605        assert_eq!(event_synchronization(&a, &b, 0.001, 5.0), 0.0);
606    }
607
608    // ── spike_train_coherence ───────────────────────────────────────
609
610    #[test]
611    fn test_coherence_identical() {
612        let train = make_train(&[10, 30, 50, 70, 90], 128);
613        let (coh, freqs) = spike_train_coherence(&train, &train, 0.001);
614        assert!(!coh.is_empty());
615        assert_eq!(coh.len(), freqs.len());
616        // Self-coherence should be 1.0 at non-DC frequencies (DC is 0/0 after mean subtraction)
617        for (i, &c) in coh.iter().enumerate() {
618            if i == 0 {
619                continue; // DC bin is zero after mean subtraction
620            }
621            assert!(
622                (c - 1.0).abs() < 1e-8,
623                "self-coherence at freq idx {i} should be 1.0, got {c}"
624            );
625        }
626    }
627
628    #[test]
629    fn test_coherence_short() {
630        let a = vec![1i32];
631        let b = vec![0i32];
632        let (coh, _) = spike_train_coherence(&a, &b, 0.001);
633        assert!(coh.is_empty(), "n<2 → empty");
634    }
635
636    // ── spike_time_tiling_coefficient ───────────────────────────────
637
638    #[test]
639    fn test_sttc_identical() {
640        let train = make_train(&[10, 30, 50, 70, 90], 100);
641        let sttc = spike_time_tiling_coefficient(&train, &train, 0.001, 5.0);
642        assert!(sttc > 0.8, "identical trains → high STTC, got {sttc}");
643    }
644
645    #[test]
646    fn test_sttc_no_overlap() {
647        let a = make_train(&[5], 1000);
648        let b = make_train(&[995], 1000);
649        let sttc = spike_time_tiling_coefficient(&a, &b, 0.001, 1.0);
650        assert!(sttc < 0.1, "far apart spikes → low STTC, got {sttc}");
651    }
652
653    #[test]
654    fn test_sttc_empty() {
655        let a = vec![0i32; 100];
656        let b = make_train(&[50], 100);
657        assert_eq!(spike_time_tiling_coefficient(&a, &b, 0.001, 5.0), 0.0);
658    }
659
660    // ── covariance_matrix ───────────────────────────────────────────
661
662    #[test]
663    fn test_covariance_identical() {
664        let train = make_train(&[0, 1, 5, 6, 10, 11, 15, 16, 20, 21], 25);
665        let trains: Vec<&[i32]> = vec![&train, &train];
666        let cov = covariance_matrix(&trains, 5);
667        assert!(
668            (cov[0][0] - cov[0][1]).abs() < 1e-10,
669            "identical trains → equal diagonal and off-diagonal"
670        );
671    }
672
673    #[test]
674    fn test_covariance_single() {
675        let train = make_train(&[0, 1, 2, 5, 6, 10, 11, 12, 13, 14], 20);
676        let trains: Vec<&[i32]> = vec![&train];
677        let cov = covariance_matrix(&trains, 5);
678        assert_eq!(cov.len(), 1);
679        assert!(cov[0][0] > 0.0, "non-constant train → positive variance");
680    }
681
682    // ── autocorrelation_time ────────────────────────────────────────
683
684    #[test]
685    fn test_autocorr_time_bursty() {
686        // Burst pattern: consecutive spikes have positive lag-1 autocorrelation
687        let train = make_train(&[0, 1, 2, 10, 11, 12, 20, 21, 22, 30, 31, 32], 40);
688        let tau = autocorrelation_time(&train, 0.001, 50.0);
689        assert!(
690            tau > 0.0,
691            "bursty train should have positive autocorrelation time, got {tau}"
692        );
693    }
694
695    #[test]
696    fn test_autocorr_time_silent() {
697        let train = vec![0i32; 100];
698        assert_eq!(autocorrelation_time(&train, 0.001, 50.0), 0.0);
699    }
700
701    // ── noise_correlation ───────────────────────────────────────────
702
703    #[test]
704    fn test_noise_corr_identical() {
705        let t1 = make_train(&[5, 15, 25, 35, 45], 50);
706        let t2 = t1.clone();
707        let trains: Vec<&[i32]> = vec![&t1, &t2];
708        let corr = noise_correlation(&trains, 10);
709        assert!((corr[0][0] - 1.0).abs() < 1e-10);
710        // Identical trains: residuals are zero → correlation undefined but set to 1 if std>0
711    }
712
713    #[test]
714    fn test_noise_corr_diagonal() {
715        let t1 = make_train(&[2, 12, 22], 30);
716        let t2 = make_train(&[7, 17, 27], 30);
717        let trains: Vec<&[i32]> = vec![&t1, &t2];
718        let corr = noise_correlation(&trains, 10);
719        assert!((corr[0][0] - 1.0).abs() < 1e-10);
720        assert!((corr[1][1] - 1.0).abs() < 1e-10);
721    }
722
723    // ── signal_correlation ──────────────────────────────────────────
724
725    #[test]
726    fn test_signal_corr_identical() {
727        // bins: [1, 3, 0] — non-constant so std > 0
728        let t1 = make_train(&[5, 10, 11, 12], 30);
729        let t2 = t1.clone();
730        let trains: Vec<&[i32]> = vec![&t1, &t2];
731        let corr = signal_correlation(&trains, 10);
732        assert!(
733            (corr[0][1] - 1.0).abs() < 1e-10,
734            "identical trains → r=1.0, got {}",
735            corr[0][1]
736        );
737    }
738
739    // ── spike_count_covariance ──────────────────────────────────────
740
741    #[test]
742    fn test_spike_count_cov_delegates() {
743        let t1 = make_train(&[0, 1, 5, 6, 10, 11], 15);
744        let trains: Vec<&[i32]> = vec![&t1];
745        let cov1 = covariance_matrix(&trains, 5);
746        let cov2 = spike_count_covariance(&trains, 5);
747        assert_eq!(cov1, cov2);
748    }
749
750    // ── joint_psth ──────────────────────────────────────────────────
751
752    #[test]
753    fn test_joint_psth_shape() {
754        let a = make_train(&[0, 1, 5, 6, 10, 11, 15, 16, 20, 21], 25);
755        let b = make_train(&[2, 3, 7, 8, 12, 13, 17, 18, 22, 23], 25);
756        let (result, n) = joint_psth(&a, &b, 5);
757        assert_eq!(n, 5);
758        assert_eq!(result.len(), 25);
759    }
760
761    #[test]
762    fn test_joint_psth_symmetry() {
763        let train = make_train(&[0, 1, 5, 6, 10, 11, 15, 16, 20, 21], 25);
764        let (result, n) = joint_psth(&train, &train, 5);
765        // Outer product of x with itself is symmetric
766        for i in 0..n {
767            for j in 0..n {
768                assert!(
769                    (result[i * n + j] - result[j * n + i]).abs() < 1e-10,
770                    "JPSTH of identical trains should be symmetric"
771                );
772            }
773        }
774    }
775
776    // ── coincidence_index ───────────────────────────────────────────
777
778    #[test]
779    fn test_coincidence_index_identical() {
780        let train = make_train(&[10, 30, 50, 70, 90], 100);
781        let ci = coincidence_index(&train, &train, 0.001, 2.0);
782        assert!(
783            ci > 0.5,
784            "identical trains → high coincidence index, got {ci}"
785        );
786    }
787
788    #[test]
789    fn test_coincidence_index_no_overlap() {
790        let a = make_train(&[5], 1000);
791        let b = make_train(&[995], 1000);
792        let ci = coincidence_index(&a, &b, 0.001, 1.0);
793        assert!(ci <= 0.0, "far apart → zero or negative kappa, got {ci}");
794    }
795
796    #[test]
797    fn test_coincidence_index_empty() {
798        let a = vec![0i32; 100];
799        let b = make_train(&[50], 100);
800        assert_eq!(coincidence_index(&a, &b, 0.001, 2.0), 0.0);
801    }
802
803    // ── helpers ─────────────────────────────────────────────────────
804
805    #[test]
806    fn test_tile_fraction_single_spike() {
807        let times = vec![0.05];
808        let frac = tile_fraction(&times, 0.005, 0.1);
809        // Window: [0.045, 0.055] → 0.01 / 0.1 = 0.1
810        assert!((frac - 0.1).abs() < 1e-10);
811    }
812
813    #[test]
814    fn test_tile_fraction_overlapping() {
815        let times = vec![0.05, 0.052];
816        let frac = tile_fraction(&times, 0.005, 0.1);
817        // Windows [0.045, 0.055] and [0.047, 0.057] merge to [0.045, 0.057] → 0.012
818        assert!((frac - 0.12).abs() < 1e-10);
819    }
820
821    #[test]
822    fn test_sttc_term_edge_cases() {
823        assert_eq!(sttc_term(0.5, 1.0), 0.0); // t=1
824        assert_eq!(sttc_term(0.0, 0.0), 0.0); // p=t=0
825    }
826}