1use super::basic::bin_spike_train;
10use super::correlation::cross_correlation;
11use nalgebra::{DMatrix, SymmetricEigen};
12
13pub fn functional_connectivity(trains: &[&[i32]], max_lag_ms: f64, dt: f64) -> Vec<f64> {
16 let n = trains.len();
17 let mut mat = vec![0.0_f64; n * n];
18 for i in 0..n {
19 mat[i * n + i] = 1.0;
20 for j in (i + 1)..n {
21 let (cc, _) = cross_correlation(trains[i], trains[j], max_lag_ms, dt);
22 let peak = cc.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
23 mat[i * n + j] = peak;
24 mat[j * n + i] = peak;
25 }
26 }
27 mat
28}
29
30pub fn unitary_events(trains: &[&[i32]], bin_size: usize, alpha: f64) -> Vec<usize> {
34 let n_trains = trains.len();
35 if n_trains < 2 {
36 return vec![];
37 }
38 let binned: Vec<Vec<i64>> = trains
39 .iter()
40 .map(|t| bin_spike_train(t, bin_size))
41 .collect();
42 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
43 if min_bins == 0 {
44 return vec![];
45 }
46
47 let active: Vec<Vec<bool>> = binned
49 .iter()
50 .map(|b| b[..min_bins].iter().map(|&v| v > 0).collect())
51 .collect();
52
53 let rates: Vec<f64> = active
55 .iter()
56 .map(|row| row.iter().filter(|&&v| v).count() as f64 / min_bins as f64)
57 .collect();
58
59 let expected_rate: f64 = rates.iter().product::<f64>().powi(n_trains as i32);
60
61 let mut significant = Vec::new();
62 for k in 0..min_bins {
63 let all_active = (0..n_trains).all(|i| active[i][k]);
64 if all_active && expected_rate < alpha {
65 significant.push(k);
66 }
67 }
68 significant
69}
70
71pub fn cell_assembly_detection(
74 trains: &[&[i32]],
75 bin_size: usize,
76 threshold: f64,
77) -> Vec<Vec<usize>> {
78 let n = trains.len();
79 if n < 3 {
80 return vec![];
81 }
82 let binned: Vec<Vec<f64>> = trains
83 .iter()
84 .map(|t| {
85 bin_spike_train(t, bin_size)
86 .iter()
87 .map(|&v| v as f64)
88 .collect()
89 })
90 .collect();
91 let min_bins = binned.iter().map(|b| b.len()).min().unwrap_or(0);
92 if min_bins < 2 {
93 return vec![];
94 }
95
96 let mut mat: Vec<Vec<f64>> = binned.iter().map(|b| b[..min_bins].to_vec()).collect();
98 for row in &mut mat {
99 let mean = row.iter().sum::<f64>() / min_bins as f64;
100 let std = (row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / min_bins as f64)
101 .sqrt()
102 .max(1e-30);
103 for v in row.iter_mut() {
104 *v = (*v - mean) / std;
105 }
106 }
107
108 let mut corr = vec![0.0_f64; n * n];
110 for i in 0..n {
111 for j in i..n {
112 let mut s = 0.0;
113 for k in 0..min_bins {
114 s += mat[i][k] * mat[j][k];
115 }
116 let c = s / min_bins as f64;
117 corr[i * n + j] = c;
118 corr[j * n + i] = c;
119 }
120 }
121
122 let (eigvals, eigvecs) = symmetric_eigen(&corr, n);
124
125 let q = n as f64 / min_bins as f64;
127 let mp_upper = (1.0 + q.sqrt()).powi(2);
128
129 let thresh_scaled = threshold / (n as f64).sqrt();
130 let mut assemblies = Vec::new();
131 for i in 0..n {
132 if eigvals[i] > mp_upper {
133 let members: Vec<usize> = (0..n)
134 .filter(|&j| eigvecs[j * n + i].abs() > thresh_scaled)
135 .collect();
136 if members.len() >= 2 {
137 assemblies.push(members);
138 }
139 }
140 }
141 assemblies
142}
143
144pub fn synfire_chain_detection(
147 trains: &[&[i32]],
148 dt: f64,
149 max_delay_ms: f64,
150 min_chain_length: usize,
151) -> Vec<Vec<usize>> {
152 let n = trains.len();
153 if n < min_chain_length {
154 return vec![];
155 }
156
157 let mut peak_lags = vec![0.0_f64; n * n];
159 for i in 0..n {
160 for j in 0..n {
161 if i == j {
162 continue;
163 }
164 let (cc, lags) = cross_correlation(trains[i], trains[j], max_delay_ms, dt);
165 if !cc.is_empty() {
166 let peak_idx = cc
167 .iter()
168 .enumerate()
169 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
170 .map(|(idx, _)| idx)
171 .unwrap_or(0);
172 peak_lags[i * n + j] = lags[peak_idx];
173 }
174 }
175 }
176
177 let mut chains = Vec::new();
178 let mut visited = vec![false; n];
179
180 for start in 0..n {
181 if visited[start] {
182 continue;
183 }
184 let mut chain = vec![start];
185 let mut current = start;
186 for _ in 0..n {
187 let mut candidates: Vec<(f64, usize)> = Vec::new();
188 for j in 0..n {
189 if chain.contains(&j) {
190 continue;
191 }
192 let lag = peak_lags[current * n + j];
193 if lag > 0.0 && lag <= max_delay_ms {
194 candidates.push((lag, j));
195 }
196 }
197 if candidates.is_empty() {
198 break;
199 }
200 candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
201 let nxt = candidates[0].1;
202 chain.push(nxt);
203 current = nxt;
204 }
205 if chain.len() >= min_chain_length {
206 for &idx in &chain {
207 visited[idx] = true;
208 }
209 chains.push(chain);
210 }
211 }
212 chains
213}
214
215fn symmetric_eigen(a: &[f64], n: usize) -> (Vec<f64>, Vec<f64>) {
222 let se = SymmetricEigen::new(DMatrix::<f64>::from_row_slice(n, n, a));
223 let mut idx: Vec<usize> = (0..n).collect();
224 idx.sort_by(|&i, &j| se.eigenvalues[j].partial_cmp(&se.eigenvalues[i]).unwrap());
225
226 let vals: Vec<f64> = idx.iter().map(|&i| se.eigenvalues[i]).collect();
227 let mut vecs = vec![0.0f64; n * n];
228 for (new_col, &old_col) in idx.iter().enumerate() {
229 let mut pivot = 0usize;
230 let mut max_abs = 0.0f64;
231 for r in 0..n {
232 let v = se.eigenvectors[(r, old_col)].abs();
233 if v > max_abs {
234 max_abs = v;
235 pivot = r;
236 }
237 }
238 let sign = if se.eigenvectors[(pivot, old_col)] < 0.0 {
239 -1.0
240 } else {
241 1.0
242 };
243 for r in 0..n {
244 vecs[r * n + new_col] = sign * se.eigenvectors[(r, old_col)];
245 }
246 }
247 (vals, vecs)
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
255 let mut t = vec![0i32; len];
256 for &s in spikes {
257 t[s] = 1;
258 }
259 t
260 }
261
262 #[test]
265 fn test_fc_diagonal_one() {
266 let t1 = make_train(&[10, 30, 50], 100);
267 let t2 = make_train(&[20, 40, 60], 100);
268 let trains: Vec<&[i32]> = vec![&t1, &t2];
269 let mat = functional_connectivity(&trains, 20.0, 0.001);
270 assert!((mat[0] - 1.0).abs() < 1e-10, "diagonal should be 1.0");
271 assert!((mat[3] - 1.0).abs() < 1e-10);
272 }
273
274 #[test]
275 fn test_fc_symmetric() {
276 let t1 = make_train(&[10, 30, 50], 100);
277 let t2 = make_train(&[12, 32, 52], 100);
278 let trains: Vec<&[i32]> = vec![&t1, &t2];
279 let mat = functional_connectivity(&trains, 20.0, 0.001);
280 assert!((mat[1] - mat[2]).abs() < 1e-10, "should be symmetric");
281 }
282
283 #[test]
284 fn test_fc_identical_high() {
285 let t = make_train(&[10, 30, 50, 70, 90], 100);
286 let trains: Vec<&[i32]> = vec![&t, &t];
287 let mat = functional_connectivity(&trains, 20.0, 0.001);
288 assert!(
289 mat[1] > 0.9,
290 "identical trains → high connectivity, got {}",
291 mat[1]
292 );
293 }
294
295 #[test]
298 fn test_ue_coincident() {
299 let t1 = make_train(&[5, 55], 200);
302 let t2 = make_train(&[5, 55], 200);
303 let trains: Vec<&[i32]> = vec![&t1, &t2];
304 let ue = unitary_events(&trains, 10, 0.05);
305 assert!(
306 !ue.is_empty(),
307 "sparse coincident spikes → significant bins"
308 );
309 }
310
311 #[test]
312 fn test_ue_single_train() {
313 let t = make_train(&[5, 15], 50);
314 let trains: Vec<&[i32]> = vec![&t];
315 assert!(
316 unitary_events(&trains, 5, 0.05).is_empty(),
317 "need ≥2 trains"
318 );
319 }
320
321 #[test]
322 fn test_ue_empty() {
323 let t1 = vec![0i32; 50];
324 let t2 = vec![0i32; 50];
325 let trains: Vec<&[i32]> = vec![&t1, &t2];
326 assert!(
327 unitary_events(&trains, 5, 0.05).is_empty(),
328 "no spikes → no events"
329 );
330 }
331
332 #[test]
335 fn test_assembly_too_few_neurons() {
336 let t1 = make_train(&[5, 15], 50);
337 let t2 = make_train(&[5, 15], 50);
338 let trains: Vec<&[i32]> = vec![&t1, &t2];
339 assert!(
340 cell_assembly_detection(&trains, 5, 2.0).is_empty(),
341 "need ≥3 neurons"
342 );
343 }
344
345 #[test]
346 fn test_assembly_correlated_group() {
347 let sync = make_train(&[0, 1, 10, 11, 20, 21, 30, 31, 40, 41], 50);
349 let indep = make_train(&[3, 7, 13, 17, 23, 27, 33, 37, 43, 47], 50);
350 let trains: Vec<&[i32]> = vec![&sync, &sync, &sync, &indep];
351 let assemblies = cell_assembly_detection(&trains, 5, 1.0);
352 for asm in &assemblies {
355 for &idx in asm {
356 assert!(idx < 4, "index out of bounds");
357 }
358 }
359 }
360
361 #[test]
364 fn test_synfire_sequential() {
365 let t0 = make_train(&[10, 30, 50, 70, 90], 100);
367 let t1 = make_train(&[15, 35, 55, 75, 95], 100);
368 let t2 = make_train(&[20, 40, 60, 80], 100);
369 let trains: Vec<&[i32]> = vec![&t0, &t1, &t2];
370 let chains = synfire_chain_detection(&trains, 0.001, 10.0, 3);
371 if !chains.is_empty() {
373 assert!(chains[0].len() >= 3, "chain should have ≥3 neurons");
374 }
375 }
376
377 #[test]
378 fn test_synfire_too_few() {
379 let t = make_train(&[10, 30], 50);
380 let trains: Vec<&[i32]> = vec![&t, &t];
381 assert!(
382 synfire_chain_detection(&trains, 0.001, 10.0, 3).is_empty(),
383 "need ≥ min_chain_length neurons"
384 );
385 }
386
387 #[test]
390 fn test_symmetric_eigen_identity() {
391 let a = vec![1.0, 0.0, 0.0, 1.0];
392 let (vals, _) = symmetric_eigen(&a, 2);
393 assert!((vals[0] - 1.0).abs() < 1e-10);
394 assert!((vals[1] - 1.0).abs() < 1e-10);
395 }
396
397 #[test]
398 fn test_symmetric_eigen_descending() {
399 let a = vec![3.0, 0.0, 0.0, 7.0];
400 let (vals, _) = symmetric_eigen(&a, 2);
401 assert!((vals[0] - 7.0).abs() < 1e-10);
403 assert!((vals[1] - 3.0).abs() < 1e-10);
404 }
405
406 #[test]
407 fn test_symmetric_eigen_known() {
408 let a = vec![2.0, 1.0, 1.0, 2.0];
410 let (vals, _) = symmetric_eigen(&a, 2);
411 assert!(
412 (vals[0] - 3.0).abs() < 1e-8,
413 "eigenvalue 3, got {}",
414 vals[0]
415 );
416 assert!(
417 (vals[1] - 1.0).abs() < 1e-8,
418 "eigenvalue 1, got {}",
419 vals[1]
420 );
421 }
422
423 #[test]
424 fn test_symmetric_eigen_eigenvectors_orthogonal() {
425 let a = vec![2.0, 1.0, 1.0, 2.0];
426 let (_, v) = symmetric_eigen(&a, 2);
427 let dot: f64 = (0..2).map(|i| v[i * 2] * v[i * 2 + 1]).sum();
429 assert!(
430 dot.abs() < 1e-8,
431 "eigenvectors should be orthogonal, dot={dot}"
432 );
433 }
434
435 #[test]
436 fn test_symmetric_eigen_sign_canonical() {
437 let a = vec![2.0, 1.0, 1.0, 2.0];
439 let (_, v) = symmetric_eigen(&a, 2);
440 for c in 0..2 {
441 let pivot = if v[c].abs() >= v[2 + c].abs() { 0 } else { 1 };
442 assert!(v[pivot * 2 + c] > 0.0, "column {c} not sign-canonical");
443 }
444 }
445}