1use std::collections::HashMap;
10
11use super::basic::bin_spike_train;
12
13fn digamma(mut x: f64) -> f64 {
15 let mut result = 0.0;
17 while x < 6.0 {
18 result -= 1.0 / x;
19 x += 1.0;
20 }
21 let inv_x = 1.0 / x;
23 let inv_x2 = inv_x * inv_x;
24 result += x.ln()
25 - 0.5 * inv_x
26 - inv_x2 * (1.0 / 12.0 - inv_x2 * (1.0 / 120.0 - inv_x2 * (1.0 / 252.0 - inv_x2 / 240.0)));
27 result
28}
29
30fn entropy_from_counts(counts: &[usize], total: usize) -> f64 {
31 if total == 0 {
32 return 0.0;
33 }
34 let n_inv = 1.0 / total as f64;
35 let mut h = 0.0_f64;
36
37 let mut chunks = counts.chunks_exact(4);
38 for chunk in chunks.by_ref() {
39 for &c in chunk {
40 if c > 0 {
41 let p = c as f64 * n_inv;
42 h -= p * (p + 1e-30).log2();
43 }
44 }
45 }
46 for &c in chunks.remainder() {
47 if c > 0 {
48 let p = c as f64 * n_inv;
49 h -= p * (p + 1e-30).log2();
50 }
51 }
52 h
53}
54
55fn count_values(data: &[i64]) -> Vec<usize> {
56 let mut map: HashMap<i64, usize> = HashMap::new();
57 for &v in data {
58 *map.entry(v).or_insert(0) += 1;
59 }
60 map.into_values().collect()
61}
62
63pub fn mutual_information(train_a: &[i32], train_b: &[i32], bin_size: usize) -> f64 {
65 let ca = bin_spike_train(train_a, bin_size);
66 let cb = bin_spike_train(train_b, bin_size);
67 let n = ca.len().min(cb.len());
68 if n == 0 {
69 return 0.0;
70 }
71
72 let ha = entropy_from_counts(&count_values(&ca[..n]), n);
73 let hb = entropy_from_counts(&count_values(&cb[..n]), n);
74
75 let max_b = cb[..n].iter().copied().max().unwrap_or(0);
77 let joint: Vec<i64> = (0..n).map(|i| ca[i] * (max_b + 1) + cb[i]).collect();
78 let hab = entropy_from_counts(&count_values(&joint), n);
79
80 (ha + hb - hab).max(0.0)
81}
82
83pub fn transfer_entropy(source: &[i32], target: &[i32], bin_size: usize, lag: usize) -> f64 {
85 let cs = bin_spike_train(source, bin_size);
86 let ct = bin_spike_train(target, bin_size);
87 let n = cs.len().min(ct.len());
88 if n <= lag {
89 return 0.0;
90 }
91
92 let t_past = &ct[..n - lag];
93 let t_future = &ct[lag..n];
94 let s_past = &cs[..n - lag];
95 let n_pts = t_past.len();
96
97 let max_tp = t_past.iter().copied().max().unwrap_or(0) + 1;
99 let joint_ft: Vec<i64> = (0..n_pts)
100 .map(|i| t_future[i] * max_tp + t_past[i])
101 .collect();
102 let h_ft = entropy_from_counts(&count_values(&joint_ft), n_pts);
103 let h_tp = entropy_from_counts(&count_values(t_past), n_pts);
104 let h1 = h_ft - h_tp;
105
106 let max_sp = s_past.iter().copied().max().unwrap_or(0) + 1;
108 let past_joint: Vec<i64> = (0..n_pts).map(|i| t_past[i] * max_sp + s_past[i]).collect();
109 let max_pj = past_joint.iter().copied().max().unwrap_or(0) + 1;
110 let joint_fts: Vec<i64> = (0..n_pts)
111 .map(|i| t_future[i] * max_pj + past_joint[i])
112 .collect();
113 let h_fts = entropy_from_counts(&count_values(&joint_fts), n_pts);
114 let h_ps = entropy_from_counts(&count_values(&past_joint), n_pts);
115 let h2 = h_fts - h_ps;
116
117 (h1 - h2).max(0.0)
118}
119
120pub fn spike_train_entropy(binary_train: &[i32], bin_size: usize, word_length: usize) -> f64 {
122 let binned: Vec<i64> = bin_spike_train(binary_train, bin_size)
123 .iter()
124 .map(|&v| if v > 0 { 1_i64 } else { 0_i64 })
125 .collect();
126 let n = binned.len();
127 if n < word_length {
128 return f64::NAN;
129 }
130 let n_words = n - word_length + 1;
131 let mut words = Vec::with_capacity(n_words);
132 for i in 0..n_words {
133 let mut w = 0_i64;
134 for j in 0..word_length {
135 w = w * 2 + binned[i + j];
136 }
137 words.push(w);
138 }
139 entropy_from_counts(&count_values(&words), n_words)
140}
141
142pub fn noise_entropy(
144 binary_train: &[i32],
145 n_trials: usize,
146 bin_size: usize,
147 word_length: usize,
148) -> f64 {
149 let n = binary_train.len();
150 let trial_len = n / n_trials;
151 if trial_len < bin_size * word_length {
152 return f64::NAN;
153 }
154 let mut sum = 0.0_f64;
155 let mut count = 0_usize;
156 for t in 0..n_trials {
157 let start = t * trial_len;
158 let end = start + trial_len;
159 let h = spike_train_entropy(&binary_train[start..end], bin_size, word_length);
160 if !h.is_nan() {
161 sum += h;
162 count += 1;
163 }
164 }
165 if count == 0 {
166 return f64::NAN;
167 }
168 sum / count as f64
169}
170
171pub fn stimulus_specific_information(spike_counts: &[f64], stimulus_ids: &[i64]) -> f64 {
173 let n_total = spike_counts.len().min(stimulus_ids.len());
174 if n_total == 0 {
175 return 0.0;
176 }
177
178 let overall_mean: f64 = spike_counts[..n_total].iter().sum::<f64>() / n_total as f64;
179 if overall_mean <= 0.0 {
180 return 0.0;
181 }
182
183 let mut groups: HashMap<i64, Vec<f64>> = HashMap::new();
185 for i in 0..n_total {
186 groups
187 .entry(stimulus_ids[i])
188 .or_default()
189 .push(spike_counts[i]);
190 }
191
192 let mut ssi = 0.0_f64;
193 for counts in groups.values() {
194 let n_s = counts.len() as f64;
195 let p_s = n_s / n_total as f64;
196 let mean_s: f64 = counts.iter().sum::<f64>() / n_s;
197 if mean_s > 0.0 {
198 ssi += p_s * mean_s * (mean_s / overall_mean).log2() / overall_mean;
199 }
200 }
201 ssi.max(0.0)
202}
203
204pub fn kozachenko_leonenko_mi(x: &[f64], y: &[f64], k: usize) -> f64 {
207 let n = x.len().min(y.len());
208 if n < k + 1 {
209 return 0.0;
210 }
211
212 let psi_k = digamma(k as f64);
213 let psi_n = digamma(n as f64);
214
215 let mut nx_sum = 0.0_f64;
216 let mut ny_sum = 0.0_f64;
217
218 for i in 0..n {
219 let mut dists: Vec<f64> = (0..n)
221 .filter(|&j| j != i)
222 .map(|j| (x[i] - x[j]).abs().max((y[i] - y[j]).abs()))
223 .collect();
224 dists.sort_by(|a, b| a.partial_cmp(b).unwrap());
225 let eps = dists[k - 1];
226
227 let nx = (0..n)
229 .filter(|&j| j != i && (x[i] - x[j]).abs() < eps)
230 .count();
231 let ny = (0..n)
232 .filter(|&j| j != i && (y[i] - y[j]).abs() < eps)
233 .count();
234
235 nx_sum += digamma((nx + 1) as f64);
236 ny_sum += digamma((ny + 1) as f64);
237 }
238
239 (psi_k + psi_n - nx_sum / n as f64 - ny_sum / n as f64).max(0.0)
240}
241
242pub fn time_rescaling_ks_test(
245 times: &[f64],
246 rate_func: fn(f64) -> f64,
247 t_start: f64,
248 t_end: f64,
249) -> (f64, bool) {
250 let mut sorted: Vec<f64> = times
251 .iter()
252 .copied()
253 .filter(|&t| t >= t_start && t <= t_end)
254 .collect();
255 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
256 let n = sorted.len();
257 if n < 5 {
258 return (1.0, false);
259 }
260
261 let n_quad = 20_usize;
262 let mut rescaled = Vec::with_capacity(n);
263 for i in 0..n {
264 let lo = if i == 0 { t_start } else { sorted[i - 1] };
265 let hi = sorted[i];
266 let step = (hi - lo) / (n_quad - 1).max(1) as f64;
268 let mut integral = 0.0_f64;
269 for q in 0..n_quad {
270 let t = lo + q as f64 * step;
271 let w = if q == 0 || q == n_quad - 1 { 0.5 } else { 1.0 };
272 integral += w * rate_func(t) * step;
273 }
274 rescaled.push(integral);
275 }
276
277 let mut transformed: Vec<f64> = rescaled.iter().map(|&r| 1.0 - (-r).exp()).collect();
278 transformed.sort_by(|a, b| a.partial_cmp(b).unwrap());
279
280 let mut ks = 0.0_f64;
281 for i in 0..n {
282 let ecdf = (i + 1) as f64 / n as f64;
283 ks = ks.max((ecdf - transformed[i]).abs());
284 }
285
286 let critical_95 = 1.36 / (n as f64).sqrt();
287 (ks, ks < critical_95)
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
295 let mut t = vec![0i32; len];
296 for &s in spikes {
297 t[s] = 1;
298 }
299 t
300 }
301
302 #[test]
305 fn test_digamma_known_values() {
306 assert!((digamma(1.0) - (-0.5772156649)).abs() < 1e-8);
308 assert!((digamma(2.0) - 0.4227843351).abs() < 1e-8);
310 assert!((digamma(0.5) - (-1.9635100260)).abs() < 1e-7);
312 }
313
314 #[test]
317 fn test_mi_identical() {
318 let train = make_train(&[0, 1, 2, 10, 11, 12, 20, 21, 22], 30);
319 let mi = mutual_information(&train, &train, 5);
320 assert!(mi > 0.0, "identical trains → positive MI, got {mi}");
321 }
322
323 #[test]
324 fn test_mi_non_negative() {
325 let a = make_train(&[5, 15, 25], 30);
326 let b = make_train(&[0, 10, 20], 30);
327 let mi = mutual_information(&a, &b, 5);
328 assert!(mi >= 0.0, "MI must be non-negative");
329 }
330
331 #[test]
332 fn test_mi_zero_constant() {
333 let a = vec![0i32; 50];
335 let b = vec![0i32; 50];
336 let mi = mutual_information(&a, &b, 10);
337 assert!(mi.abs() < 1e-10, "constant trains → MI ≈ 0, got {mi}");
338 }
339
340 #[test]
343 fn test_te_non_negative() {
344 let source = make_train(&[5, 15, 25, 35, 45], 50);
345 let target = make_train(&[7, 17, 27, 37, 47], 50);
346 let te = transfer_entropy(&source, &target, 5, 1);
347 assert!(te >= 0.0, "TE must be non-negative");
348 }
349
350 #[test]
351 fn test_te_short_returns_zero() {
352 let source = make_train(&[1], 5);
353 let target = make_train(&[2], 5);
354 let te = transfer_entropy(&source, &target, 5, 10);
355 assert_eq!(te, 0.0, "n <= lag → 0");
356 }
357
358 #[test]
359 fn test_te_self_zero() {
360 let train = make_train(&[5, 15, 25, 35, 45], 50);
362 let te = transfer_entropy(&train, &train, 5, 1);
363 assert!(te < 1e-10, "TE(X→X) should be ~0, got {te}");
364 }
365
366 #[test]
369 fn test_entropy_constant() {
370 let train = vec![0i32; 100];
371 let h = spike_train_entropy(&train, 10, 4);
372 assert!(h.abs() < 1e-10, "constant → entropy 0, got {h}");
373 }
374
375 #[test]
376 fn test_entropy_all_ones_binary() {
377 let train = vec![1i32; 100];
379 let h = spike_train_entropy(&train, 10, 4);
380 assert!(h.abs() < 1e-10, "uniform → entropy 0, got {h}");
381 }
382
383 #[test]
384 fn test_entropy_non_negative() {
385 let train = make_train(&[5, 15, 25, 45, 55, 85], 100);
386 let h = spike_train_entropy(&train, 10, 4);
387 assert!(h >= 0.0 || h.is_nan(), "entropy must be non-negative");
388 }
389
390 #[test]
391 fn test_entropy_short_nan() {
392 let train = make_train(&[0, 1], 5);
393 let h = spike_train_entropy(&train, 10, 4);
394 assert!(h.is_nan(), "too short → NaN");
395 }
396
397 #[test]
400 fn test_noise_entropy_constant() {
401 let train = vec![0i32; 500];
402 let h = noise_entropy(&train, 5, 10, 4);
403 assert!(h.abs() < 1e-10 || h.is_nan(), "constant → 0 or NaN");
404 }
405
406 #[test]
407 fn test_noise_entropy_too_short() {
408 let train = vec![0i32; 10];
409 let h = noise_entropy(&train, 10, 10, 4);
410 assert!(h.is_nan(), "too short → NaN");
411 }
412
413 #[test]
416 fn test_ssi_uniform() {
417 let counts = vec![5.0, 5.0, 5.0, 5.0];
419 let stim = vec![0_i64, 1, 0, 1];
420 let ssi = stimulus_specific_information(&counts, &stim);
421 assert!(ssi.abs() < 1e-10, "uniform response → SSI 0, got {ssi}");
422 }
423
424 #[test]
425 fn test_ssi_selective() {
426 let counts = vec![10.0, 1.0, 10.0, 1.0];
428 let stim = vec![0_i64, 1, 0, 1];
429 let ssi = stimulus_specific_information(&counts, &stim);
430 assert!(ssi > 0.0, "selective response → positive SSI, got {ssi}");
431 }
432
433 #[test]
434 fn test_ssi_empty() {
435 let ssi = stimulus_specific_information(&[], &[]);
436 assert_eq!(ssi, 0.0);
437 }
438
439 #[test]
442 fn test_kl_mi_identical() {
443 let x: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
444 let y = x.clone();
445 let mi = kozachenko_leonenko_mi(&x, &y, 3);
446 assert!(mi > 0.0, "identical signals → positive MI, got {mi}");
447 }
448
449 #[test]
450 fn test_kl_mi_independent() {
451 let x: Vec<f64> = (0..100).map(|i| (i % 7) as f64).collect();
453 let y: Vec<f64> = (0..100).map(|i| (i % 11) as f64).collect();
454 let mi = kozachenko_leonenko_mi(&x, &y, 3);
455 assert!(mi < 1.0, "roughly independent → low MI, got {mi}");
457 }
458
459 #[test]
460 fn test_kl_mi_too_few() {
461 let x = vec![1.0, 2.0];
462 let y = vec![3.0, 4.0];
463 assert_eq!(kozachenko_leonenko_mi(&x, &y, 3), 0.0, "n < k+1 → 0");
464 }
465
466 #[test]
469 fn test_ks_constant_rate() {
470 fn rate(_t: f64) -> f64 {
471 100.0
472 }
473 let times: Vec<f64> = (0..50).map(|i| i as f64 * 0.02).collect();
475 let (ks, _passes) = time_rescaling_ks_test(×, rate, 0.0, 1.0);
476 assert!((0.0..=1.0).contains(&ks), "KS stat in [0,1], got {ks}");
477 }
478
479 #[test]
480 fn test_ks_too_few_spikes() {
481 fn rate(_t: f64) -> f64 {
482 100.0
483 }
484 let (ks, passes) = time_rescaling_ks_test(&[0.5], rate, 0.0, 1.0);
485 assert_eq!(ks, 1.0);
486 assert!(!passes);
487 }
488
489 #[test]
492 fn test_entropy_single_symbol() {
493 let h = entropy_from_counts(&[10], 10);
494 assert!(h.abs() < 1e-10, "single symbol → entropy 0");
495 }
496
497 #[test]
498 fn test_entropy_uniform_two() {
499 let h = entropy_from_counts(&[5, 5], 10);
501 assert!((h - 1.0).abs() < 1e-10, "uniform binary → 1 bit, got {h}");
502 }
503
504 #[test]
505 fn test_entropy_uniform_four() {
506 let h = entropy_from_counts(&[25, 25, 25, 25], 100);
508 assert!((h - 2.0).abs() < 1e-10, "uniform 4-ary → 2 bits, got {h}");
509 }
510}