1use rayon::prelude::*;
10use rustfft::{num_complex::Complex, FftPlanner};
11
12use super::basic::{bin_spike_train, spike_times};
13
14pub fn cross_correlation(
17 train_a: &[i32],
18 train_b: &[i32],
19 max_lag_ms: f64,
20 dt: f64,
21) -> (Vec<f64>, Vec<f64>) {
22 let max_lag = (max_lag_ms / (dt * 1000.0)) as isize;
23 let n = train_a.len().min(train_b.len());
24 if n == 0 {
25 return (vec![], vec![]);
26 }
27
28 let mean_a: f64 = train_a[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
29 let mean_b: f64 = train_b[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
30 let a: Vec<f64> = train_a[..n].iter().map(|&v| v as f64 - mean_a).collect();
31 let b: Vec<f64> = train_b[..n].iter().map(|&v| v as f64 - mean_b).collect();
32
33 let norm = (a.iter().map(|x| x * x).sum::<f64>() * b.iter().map(|x| x * x).sum::<f64>()).sqrt();
34
35 let n_lags = (2 * max_lag + 1) as usize;
36 let mut cc = vec![0.0_f64; n_lags];
37 let mut lags_ms = Vec::with_capacity(n_lags);
38 for l in -max_lag..=max_lag {
39 lags_ms.push(l as f64 * dt * 1000.0);
40 }
41
42 if norm == 0.0 {
43 return (cc, lags_ms);
44 }
45
46 for (i, lag) in (-max_lag..=max_lag).enumerate() {
47 let sum = if lag >= 0 {
48 let l = lag as usize;
49 crate::simd::dot_f64_dispatch(&a[..n - l], &b[l..n])
50 } else {
51 let l = (-lag) as usize;
52 crate::simd::dot_f64_dispatch(&a[l..n], &b[..n - l])
53 };
54 cc[i] = sum / norm;
55 }
56
57 (cc, lags_ms)
58}
59
60pub fn pairwise_correlation(trains: &[&[i32]], dt: f64) -> Vec<Vec<f64>> {
62 let _ = dt;
63 let n = trains.len();
64 if n == 0 {
65 return vec![vec![]];
66 }
67 let min_len = trains.iter().map(|t| t.len()).min().unwrap_or(0);
68 if min_len == 0 {
69 return vec![vec![0.0; n]; n];
70 }
71
72 let mat: Vec<Vec<f64>> = trains
73 .iter()
74 .map(|t| t[..min_len].iter().map(|&v| v as f64).collect::<Vec<f64>>())
75 .collect();
76
77 let means: Vec<f64> = mat
78 .iter()
79 .map(|row| row.iter().sum::<f64>() / min_len as f64)
80 .collect();
81 let stds: Vec<f64> = mat
82 .iter()
83 .enumerate()
84 .map(|(i, row)| {
85 (row.iter().map(|v| (v - means[i]).powi(2)).sum::<f64>() / min_len as f64).sqrt()
86 })
87 .collect();
88
89 let mut corr = vec![vec![0.0_f64; n]; n];
90 for i in 0..n {
91 corr[i][i] = 1.0;
92 for j in (i + 1)..n {
93 if stds[i] > 0.0 && stds[j] > 0.0 {
94 let cov: f64 = (0..min_len)
95 .map(|k| (mat[i][k] - means[i]) * (mat[j][k] - means[j]))
96 .sum::<f64>()
97 / min_len as f64;
98 let r = cov / (stds[i] * stds[j]);
99 corr[i][j] = r;
100 corr[j][i] = r;
101 }
102 }
103 }
104 corr
105}
106
107pub fn event_synchronization(train_a: &[i32], train_b: &[i32], dt: f64, tau_ms: f64) -> f64 {
110 let ta = spike_times(train_a, dt);
111 let tb = spike_times(train_b, dt);
112 let na = ta.len();
113 let nb = tb.len();
114 if na == 0 || nb == 0 {
115 return 0.0;
116 }
117 let tau = tau_ms / 1000.0;
118 let mut count = 0_usize;
119 for &ti in &ta {
120 for &tj in &tb {
121 if (ti - tj).abs() < tau {
122 count += 1;
123 }
124 }
125 }
126 count as f64 / (na as f64 * nb as f64).sqrt()
127}
128
129pub fn spike_train_coherence(train_a: &[i32], train_b: &[i32], dt: f64) -> (Vec<f64>, Vec<f64>) {
132 let n = train_a.len().min(train_b.len());
133 if n < 2 {
134 return (vec![], vec![]);
135 }
136
137 let mean_a: f64 = train_a[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
138 let mean_b: f64 = train_b[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
139
140 let mut planner = FftPlanner::<f64>::new();
141 let fft = planner.plan_fft_forward(n);
142
143 let mut buf_a: Vec<Complex<f64>> = train_a[..n]
144 .iter()
145 .map(|&v| Complex::new(v as f64 - mean_a, 0.0))
146 .collect();
147 let mut buf_b: Vec<Complex<f64>> = train_b[..n]
148 .iter()
149 .map(|&v| Complex::new(v as f64 - mean_b, 0.0))
150 .collect();
151
152 fft.process(&mut buf_a);
153 fft.process(&mut buf_b);
154
155 let n_freqs = n / 2 + 1;
157 let mut coh = Vec::with_capacity(n_freqs);
158 let mut freqs = Vec::with_capacity(n_freqs);
159
160 for i in 0..n_freqs {
161 let fa = buf_a[i];
162 let fb = buf_b[i];
163 let pab = fa * fb.conj();
164 let paa = fa.norm_sqr();
165 let pbb = fb.norm_sqr();
166 let denom = paa * pbb;
167 if denom == 0.0 {
168 coh.push(0.0);
169 } else {
170 coh.push(pab.norm_sqr() / denom);
171 }
172 freqs.push(i as f64 / (n as f64 * dt));
173 }
174
175 (coh, freqs)
176}
177
178pub fn spike_time_tiling_coefficient(
180 train_a: &[i32],
181 train_b: &[i32],
182 dt: f64,
183 delta_ms: f64,
184) -> f64 {
185 let delta = delta_ms / 1000.0;
186 let ta = spike_times(train_a, dt);
187 let tb = spike_times(train_b, dt);
188 let duration = train_a.len().max(train_b.len()) as f64 * dt;
189
190 if ta.is_empty() || tb.is_empty() {
191 return 0.0;
192 }
193
194 let pa = coincidence_fraction(&ta, &tb, delta);
195 let pb = coincidence_fraction(&tb, &ta, delta);
196 let ta_frac = tile_fraction(&ta, delta, duration);
197 let tb_frac = tile_fraction(&tb, delta, duration);
198
199 0.5 * (sttc_term(pa, tb_frac) + sttc_term(pb, ta_frac))
200}
201
202fn tile_fraction(times: &[f64], delta: f64, duration: f64) -> f64 {
203 if times.is_empty() || duration <= 0.0 {
204 return 0.0;
205 }
206 let mut intervals: Vec<(f64, f64)> = times.iter().map(|&t| (t - delta, t + delta)).collect();
207 intervals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
208
209 let mut merged = vec![intervals[0]];
210 for &(lo, hi) in &intervals[1..] {
211 let last = merged.last_mut().unwrap();
212 if lo <= last.1 {
213 last.1 = last.1.max(hi);
214 } else {
215 merged.push((lo, hi));
216 }
217 }
218
219 let covered: f64 = merged
220 .iter()
221 .map(|&(lo, hi)| {
222 let lo_c = lo.max(0.0);
223 let hi_c = hi.min(duration);
224 if hi_c > lo_c {
225 hi_c - lo_c
226 } else {
227 0.0
228 }
229 })
230 .sum();
231
232 (covered / duration).min(1.0)
233}
234
235fn coincidence_fraction(times_ref: &[f64], times_target: &[f64], delta: f64) -> f64 {
236 if times_ref.is_empty() {
237 return 0.0;
238 }
239 let count = times_ref
240 .iter()
241 .filter(|&&t| times_target.iter().any(|&tt| (tt - t).abs() <= delta))
242 .count();
243 count as f64 / times_ref.len() as f64
244}
245
246fn sttc_term(p: f64, t: f64) -> f64 {
247 if (1.0 - t).abs() < 1e-15 {
248 return 0.0;
249 }
250 if (1.0 - p * t).abs() < 1e-15 {
251 return 0.0;
252 }
253 (p - t) / (1.0 - p * t)
254}
255
256pub fn covariance_matrix(trains: &[&[i32]], bin_size: usize) -> Vec<Vec<f64>> {
258 let binned: Vec<Vec<i64>> = trains
259 .iter()
260 .map(|t| bin_spike_train(t, bin_size))
261 .collect();
262 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
263 let n = trains.len();
264
265 if n == 0 || min_bins == 0 {
266 return vec![vec![]];
267 }
268
269 let mat: Vec<Vec<f64>> = binned
270 .iter()
271 .map(|b| b[..min_bins].iter().map(|&v| v as f64).collect())
272 .collect();
273 let means: Vec<f64> = mat
274 .iter()
275 .map(|row| row.iter().sum::<f64>() / min_bins as f64)
276 .collect();
277
278 if n == 1 {
279 let var = mat[0].iter().map(|v| (v - means[0]).powi(2)).sum::<f64>()
280 / (min_bins as f64 - 1.0).max(1.0);
281 return vec![vec![var]];
282 }
283
284 let ddof = (min_bins as f64 - 1.0).max(1.0);
285 let min_bins_f = min_bins as f64;
286 let mut cov = vec![vec![0.0_f64; n]; n];
287 cov.par_iter_mut().enumerate().for_each(|(i, row)| {
288 for j in i..n {
289 let dot = crate::simd::dot_f64_dispatch(&mat[i], &mat[j]);
290 row[j] = (dot - min_bins_f * means[i] * means[j]) / ddof;
291 }
292 });
293 for i in 0..n {
295 for j in (i + 1)..n {
296 cov[j][i] = cov[i][j];
297 }
298 }
299 cov
300}
301
302pub fn autocorrelation_time(binary_train: &[i32], dt: f64, max_lag_ms: f64) -> f64 {
305 let max_lag = (max_lag_ms / (dt * 1000.0)) as usize;
306 let n = binary_train.len();
307 let mean: f64 = binary_train.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
308 let x: Vec<f64> = binary_train.iter().map(|&v| v as f64 - mean).collect();
309 let var: f64 = x.iter().map(|v| v * v).sum();
310 if var == 0.0 {
311 return 0.0;
312 }
313 let mut tau = 0.0_f64;
314 for lag in 1..max_lag.min(n) {
315 let ac: f64 = (0..(n - lag)).map(|j| x[j] * x[j + lag]).sum::<f64>() / var;
316 if ac < 0.0 {
317 break;
318 }
319 tau += ac * dt;
320 }
321 tau
322}
323
324pub fn noise_correlation(trains: &[&[i32]], bin_size: usize) -> Vec<Vec<f64>> {
327 let binned: Vec<Vec<i64>> = trains
328 .iter()
329 .map(|t| bin_spike_train(t, bin_size))
330 .collect();
331 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
332 let n = trains.len();
333 if n == 0 || min_bins == 0 {
334 return vec![vec![]];
335 }
336
337 let mat: Vec<Vec<f64>> = binned
338 .iter()
339 .map(|b| b[..min_bins].iter().map(|&v| v as f64).collect())
340 .collect();
341
342 let bin_means: Vec<f64> = (0..min_bins)
344 .map(|k| mat.iter().map(|row| row[k]).sum::<f64>() / n as f64)
345 .collect();
346
347 let residuals: Vec<Vec<f64>> = mat
350 .iter()
351 .map(|row| {
352 row.iter()
353 .enumerate()
354 .map(|(k, &v)| v - bin_means[k])
355 .collect()
356 })
357 .collect();
358
359 let mut corr = vec![vec![0.0_f64; n]; n];
360 for i in 0..n {
361 corr[i][i] = 1.0;
362 let std_i = (residuals[i].iter().map(|v| v * v).sum::<f64>() / min_bins as f64).sqrt();
363 for j in (i + 1)..n {
364 let std_j = (residuals[j].iter().map(|v| v * v).sum::<f64>() / min_bins as f64).sqrt();
365 if std_i > 0.0 && std_j > 0.0 {
366 let r = residuals[i]
367 .iter()
368 .zip(residuals[j].iter())
369 .map(|(a, b)| a * b)
370 .sum::<f64>()
371 / min_bins as f64
372 / (std_i * std_j);
373 corr[i][j] = r;
374 corr[j][i] = r;
375 }
376 }
377 }
378 corr
379}
380
381pub fn signal_correlation(trains: &[&[i32]], bin_size: usize) -> Vec<Vec<f64>> {
384 let binned: Vec<Vec<i64>> = trains
385 .iter()
386 .map(|t| bin_spike_train(t, bin_size))
387 .collect();
388 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
389 let n = trains.len();
390 if n == 0 || min_bins == 0 {
391 return vec![vec![]];
392 }
393
394 let mat: Vec<Vec<f64>> = binned
395 .iter()
396 .map(|b| b[..min_bins].iter().map(|&v| v as f64).collect())
397 .collect();
398 let means: Vec<f64> = mat
399 .iter()
400 .map(|row| row.iter().sum::<f64>() / min_bins as f64)
401 .collect();
402 let stds: Vec<f64> = mat
403 .iter()
404 .enumerate()
405 .map(|(i, row)| {
406 (row.iter().map(|v| (v - means[i]).powi(2)).sum::<f64>() / min_bins as f64).sqrt()
407 })
408 .collect();
409
410 let mut corr = vec![vec![0.0_f64; n]; n];
411 for i in 0..n {
412 corr[i][i] = 1.0;
413 for j in (i + 1)..n {
414 if stds[i] > 0.0 && stds[j] > 0.0 {
415 let c: f64 = (0..min_bins)
416 .map(|k| (mat[i][k] - means[i]) * (mat[j][k] - means[j]))
417 .sum::<f64>()
418 / min_bins as f64;
419 let r = c / (stds[i] * stds[j]);
420 corr[i][j] = r;
421 corr[j][i] = r;
422 }
423 }
424 }
425 corr
426}
427
428pub fn spike_count_covariance(trains: &[&[i32]], window: usize) -> Vec<Vec<f64>> {
430 covariance_matrix(trains, window)
431}
432
433pub fn joint_psth(train_a: &[i32], train_b: &[i32], bin_size: usize) -> (Vec<f64>, usize) {
436 let ca_raw = bin_spike_train(train_a, bin_size);
437 let cb_raw = bin_spike_train(train_b, bin_size);
438 let n = ca_raw.len().min(cb_raw.len());
439 if n == 0 {
440 return (vec![], 0);
441 }
442 let mean_a = ca_raw[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
443 let mean_b = cb_raw[..n].iter().map(|&v| v as f64).sum::<f64>() / n as f64;
444 let ca: Vec<f64> = ca_raw[..n].iter().map(|&v| v as f64 - mean_a).collect();
445 let cb: Vec<f64> = cb_raw[..n].iter().map(|&v| v as f64 - mean_b).collect();
446
447 let mut result = Vec::with_capacity(n * n);
448 for &ai in &ca {
449 for &bj in &cb {
450 result.push(ai * bj / n as f64);
451 }
452 }
453 (result, n)
454}
455
456pub fn coincidence_index(train_a: &[i32], train_b: &[i32], dt: f64, delta_ms: f64) -> f64 {
459 let ta = spike_times(train_a, dt);
460 let tb = spike_times(train_b, dt);
461 if ta.is_empty() || tb.is_empty() {
462 return 0.0;
463 }
464 let delta = delta_ms / 1000.0;
465 let duration = train_a.len().max(train_b.len()) as f64 * dt;
466 let mut raw_coinc = 0_usize;
467 for &t in &ta {
468 if tb.iter().any(|&tt| (tt - t).abs() <= delta) {
469 raw_coinc += 1;
470 }
471 }
472 let expected = if duration > 0.0 {
473 2.0 * delta * ta.len() as f64 * tb.len() as f64 / duration
474 } else {
475 0.0
476 };
477 let norm = 0.5 * (ta.len() + tb.len()) as f64;
478 if norm <= expected {
479 return 0.0;
480 }
481 (raw_coinc as f64 - expected) / (norm - expected)
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
489 let mut t = vec![0i32; len];
490 for &s in spikes {
491 t[s] = 1;
492 }
493 t
494 }
495
496 #[test]
499 fn test_cross_correlation_identical() {
500 let train = make_train(&[10, 30, 50, 70, 90], 100);
501 let (cc, lags) = cross_correlation(&train, &train, 5.0, 0.001);
502 let zero_idx = lags.iter().position(|&l| l.abs() < 1e-10).unwrap();
504 assert!(
505 (cc[zero_idx] - 1.0).abs() < 1e-10,
506 "autocorrelation peak should be 1.0"
507 );
508 for i in 0..cc.len() / 2 {
510 assert!(
511 (cc[i] - cc[cc.len() - 1 - i]).abs() < 1e-10,
512 "autocorrelation should be symmetric"
513 );
514 }
515 }
516
517 #[test]
518 fn test_cross_correlation_shifted() {
519 let a = make_train(&[10, 30, 50], 100);
520 let b = make_train(&[12, 32, 52], 100);
521 let (cc, lags) = cross_correlation(&a, &b, 5.0, 0.001);
522 let peak_idx = cc
524 .iter()
525 .enumerate()
526 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
527 .unwrap()
528 .0;
529 assert!(
530 (lags[peak_idx] - 2.0).abs() < 1.5,
531 "peak lag should be near 2ms, got {}",
532 lags[peak_idx]
533 );
534 }
535
536 #[test]
537 fn test_cross_correlation_empty() {
538 let a = vec![0i32; 100];
539 let b = make_train(&[10, 50], 100);
540 let (cc, _) = cross_correlation(&a, &b, 5.0, 0.001);
541 assert!(
542 cc.iter().all(|&v| v == 0.0),
543 "zero train → zero correlation"
544 );
545 }
546
547 #[test]
550 fn test_pairwise_correlation_identity() {
551 let t1 = make_train(&[10, 30, 50], 100);
552 let t2 = make_train(&[10, 30, 50], 100);
553 let trains: Vec<&[i32]> = vec![&t1, &t2];
554 let corr = pairwise_correlation(&trains, 0.001);
555 assert!((corr[0][0] - 1.0).abs() < 1e-10);
556 assert!((corr[0][1] - 1.0).abs() < 1e-10);
557 assert!((corr[1][0] - 1.0).abs() < 1e-10);
558 }
559
560 #[test]
561 fn test_pairwise_correlation_anticorrelated() {
562 let t1 = make_train(&[0, 2, 4, 6, 8], 10);
563 let t2 = make_train(&[1, 3, 5, 7, 9], 10);
564 let trains: Vec<&[i32]> = vec![&t1, &t2];
565 let corr = pairwise_correlation(&trains, 0.001);
566 assert!(
567 corr[0][1] < 0.0,
568 "alternating trains should be negatively correlated"
569 );
570 }
571
572 #[test]
573 fn test_pairwise_correlation_empty() {
574 let corr = pairwise_correlation(&[], 0.001);
575 let expected: Vec<Vec<f64>> = vec![vec![]];
576 assert_eq!(corr, expected);
577 }
578
579 #[test]
582 fn test_event_sync_identical() {
583 let train = make_train(&[10, 30, 50, 70], 100);
584 let score = event_synchronization(&train, &train, 0.001, 5.0);
585 assert!(
587 (score - 1.0).abs() < 1e-10,
588 "identical trains: count=4, sqrt(16)=4, score=1.0, got {}",
589 score
590 );
591 }
592
593 #[test]
594 fn test_event_sync_no_overlap() {
595 let a = make_train(&[10], 100);
596 let b = make_train(&[90], 100);
597 let score = event_synchronization(&a, &b, 0.001, 2.0);
598 assert_eq!(score, 0.0, "far apart spikes → zero sync");
599 }
600
601 #[test]
602 fn test_event_sync_empty() {
603 let a = vec![0i32; 100];
604 let b = make_train(&[50], 100);
605 assert_eq!(event_synchronization(&a, &b, 0.001, 5.0), 0.0);
606 }
607
608 #[test]
611 fn test_coherence_identical() {
612 let train = make_train(&[10, 30, 50, 70, 90], 128);
613 let (coh, freqs) = spike_train_coherence(&train, &train, 0.001);
614 assert!(!coh.is_empty());
615 assert_eq!(coh.len(), freqs.len());
616 for (i, &c) in coh.iter().enumerate() {
618 if i == 0 {
619 continue; }
621 assert!(
622 (c - 1.0).abs() < 1e-8,
623 "self-coherence at freq idx {i} should be 1.0, got {c}"
624 );
625 }
626 }
627
628 #[test]
629 fn test_coherence_short() {
630 let a = vec![1i32];
631 let b = vec![0i32];
632 let (coh, _) = spike_train_coherence(&a, &b, 0.001);
633 assert!(coh.is_empty(), "n<2 → empty");
634 }
635
636 #[test]
639 fn test_sttc_identical() {
640 let train = make_train(&[10, 30, 50, 70, 90], 100);
641 let sttc = spike_time_tiling_coefficient(&train, &train, 0.001, 5.0);
642 assert!(sttc > 0.8, "identical trains → high STTC, got {sttc}");
643 }
644
645 #[test]
646 fn test_sttc_no_overlap() {
647 let a = make_train(&[5], 1000);
648 let b = make_train(&[995], 1000);
649 let sttc = spike_time_tiling_coefficient(&a, &b, 0.001, 1.0);
650 assert!(sttc < 0.1, "far apart spikes → low STTC, got {sttc}");
651 }
652
653 #[test]
654 fn test_sttc_empty() {
655 let a = vec![0i32; 100];
656 let b = make_train(&[50], 100);
657 assert_eq!(spike_time_tiling_coefficient(&a, &b, 0.001, 5.0), 0.0);
658 }
659
660 #[test]
663 fn test_covariance_identical() {
664 let train = make_train(&[0, 1, 5, 6, 10, 11, 15, 16, 20, 21], 25);
665 let trains: Vec<&[i32]> = vec![&train, &train];
666 let cov = covariance_matrix(&trains, 5);
667 assert!(
668 (cov[0][0] - cov[0][1]).abs() < 1e-10,
669 "identical trains → equal diagonal and off-diagonal"
670 );
671 }
672
673 #[test]
674 fn test_covariance_single() {
675 let train = make_train(&[0, 1, 2, 5, 6, 10, 11, 12, 13, 14], 20);
676 let trains: Vec<&[i32]> = vec![&train];
677 let cov = covariance_matrix(&trains, 5);
678 assert_eq!(cov.len(), 1);
679 assert!(cov[0][0] > 0.0, "non-constant train → positive variance");
680 }
681
682 #[test]
685 fn test_autocorr_time_bursty() {
686 let train = make_train(&[0, 1, 2, 10, 11, 12, 20, 21, 22, 30, 31, 32], 40);
688 let tau = autocorrelation_time(&train, 0.001, 50.0);
689 assert!(
690 tau > 0.0,
691 "bursty train should have positive autocorrelation time, got {tau}"
692 );
693 }
694
695 #[test]
696 fn test_autocorr_time_silent() {
697 let train = vec![0i32; 100];
698 assert_eq!(autocorrelation_time(&train, 0.001, 50.0), 0.0);
699 }
700
701 #[test]
704 fn test_noise_corr_identical() {
705 let t1 = make_train(&[5, 15, 25, 35, 45], 50);
706 let t2 = t1.clone();
707 let trains: Vec<&[i32]> = vec![&t1, &t2];
708 let corr = noise_correlation(&trains, 10);
709 assert!((corr[0][0] - 1.0).abs() < 1e-10);
710 }
712
713 #[test]
714 fn test_noise_corr_diagonal() {
715 let t1 = make_train(&[2, 12, 22], 30);
716 let t2 = make_train(&[7, 17, 27], 30);
717 let trains: Vec<&[i32]> = vec![&t1, &t2];
718 let corr = noise_correlation(&trains, 10);
719 assert!((corr[0][0] - 1.0).abs() < 1e-10);
720 assert!((corr[1][1] - 1.0).abs() < 1e-10);
721 }
722
723 #[test]
726 fn test_signal_corr_identical() {
727 let t1 = make_train(&[5, 10, 11, 12], 30);
729 let t2 = t1.clone();
730 let trains: Vec<&[i32]> = vec![&t1, &t2];
731 let corr = signal_correlation(&trains, 10);
732 assert!(
733 (corr[0][1] - 1.0).abs() < 1e-10,
734 "identical trains → r=1.0, got {}",
735 corr[0][1]
736 );
737 }
738
739 #[test]
742 fn test_spike_count_cov_delegates() {
743 let t1 = make_train(&[0, 1, 5, 6, 10, 11], 15);
744 let trains: Vec<&[i32]> = vec![&t1];
745 let cov1 = covariance_matrix(&trains, 5);
746 let cov2 = spike_count_covariance(&trains, 5);
747 assert_eq!(cov1, cov2);
748 }
749
750 #[test]
753 fn test_joint_psth_shape() {
754 let a = make_train(&[0, 1, 5, 6, 10, 11, 15, 16, 20, 21], 25);
755 let b = make_train(&[2, 3, 7, 8, 12, 13, 17, 18, 22, 23], 25);
756 let (result, n) = joint_psth(&a, &b, 5);
757 assert_eq!(n, 5);
758 assert_eq!(result.len(), 25);
759 }
760
761 #[test]
762 fn test_joint_psth_symmetry() {
763 let train = make_train(&[0, 1, 5, 6, 10, 11, 15, 16, 20, 21], 25);
764 let (result, n) = joint_psth(&train, &train, 5);
765 for i in 0..n {
767 for j in 0..n {
768 assert!(
769 (result[i * n + j] - result[j * n + i]).abs() < 1e-10,
770 "JPSTH of identical trains should be symmetric"
771 );
772 }
773 }
774 }
775
776 #[test]
779 fn test_coincidence_index_identical() {
780 let train = make_train(&[10, 30, 50, 70, 90], 100);
781 let ci = coincidence_index(&train, &train, 0.001, 2.0);
782 assert!(
783 ci > 0.5,
784 "identical trains → high coincidence index, got {ci}"
785 );
786 }
787
788 #[test]
789 fn test_coincidence_index_no_overlap() {
790 let a = make_train(&[5], 1000);
791 let b = make_train(&[995], 1000);
792 let ci = coincidence_index(&a, &b, 0.001, 1.0);
793 assert!(ci <= 0.0, "far apart → zero or negative kappa, got {ci}");
794 }
795
796 #[test]
797 fn test_coincidence_index_empty() {
798 let a = vec![0i32; 100];
799 let b = make_train(&[50], 100);
800 assert_eq!(coincidence_index(&a, &b, 0.001, 2.0), 0.0);
801 }
802
803 #[test]
806 fn test_tile_fraction_single_spike() {
807 let times = vec![0.05];
808 let frac = tile_fraction(×, 0.005, 0.1);
809 assert!((frac - 0.1).abs() < 1e-10);
811 }
812
813 #[test]
814 fn test_tile_fraction_overlapping() {
815 let times = vec![0.05, 0.052];
816 let frac = tile_fraction(×, 0.005, 0.1);
817 assert!((frac - 0.12).abs() < 1e-10);
819 }
820
821 #[test]
822 fn test_sttc_term_edge_cases() {
823 assert_eq!(sttc_term(0.5, 1.0), 0.0); assert_eq!(sttc_term(0.0, 0.0), 0.0); }
826}