1use std::collections::HashSet;
13
14#[derive(Debug, Clone)]
16pub struct SpadePattern {
17 pub neurons: Vec<usize>,
19 pub lags: Vec<i64>,
21 pub count: usize,
23 pub p_value: f64,
25}
26
27fn build_binary_matrix(trains: &[&[i32]], bin_steps: usize, n_bins: usize) -> Vec<Vec<u8>> {
31 trains
32 .iter()
33 .map(|t| {
34 let mut row = vec![0u8; n_bins];
35 for b in 0..n_bins {
36 let start = b * bin_steps;
37 let end = ((b + 1) * bin_steps).min(t.len());
38 if start < t.len() && t[start..end].iter().any(|&v| v > 0) {
39 row[b] = 1;
40 }
41 }
42 row
43 })
44 .collect()
45}
46
47fn find_frequent_itemsets(
49 binary_matrix: &[Vec<u8>],
50 min_support: usize,
51 max_size: usize,
52) -> Vec<(Vec<usize>, usize)> {
53 let n_neurons = binary_matrix.len();
54 let n_bins = if n_neurons > 0 {
55 binary_matrix[0].len()
56 } else {
57 return vec![];
58 };
59
60 let mut freq: Vec<(Vec<usize>, usize)> = Vec::new();
61 let mut candidates_k: Vec<Vec<usize>> = Vec::new();
62
63 for nid in 0..n_neurons {
65 let cnt: usize = binary_matrix[nid].iter().map(|&v| v as usize).sum();
66 if cnt >= min_support {
67 let s = vec![nid];
68 freq.push((s.clone(), cnt));
69 candidates_k.push(s);
70 }
71 }
72
73 for k in 2..=max_size {
75 if candidates_k.len() < 2 {
76 break;
77 }
78 let mut new_candidates: HashSet<Vec<usize>> = HashSet::new();
79 let prev = &candidates_k;
80 for i in 0..prev.len() {
81 for j in i + 1..prev.len() {
82 let mut union: Vec<usize> = prev[i].clone();
83 for &v in &prev[j] {
84 if !union.contains(&v) {
85 union.push(v);
86 }
87 }
88 union.sort();
89 if union.len() == k {
90 new_candidates.insert(union);
91 }
92 }
93 }
94
95 candidates_k = Vec::new();
96 for s in new_candidates {
97 let mut cnt = 0usize;
99 for b in 0..n_bins {
100 if s.iter().all(|&nid| binary_matrix[nid][b] > 0) {
101 cnt += 1;
102 }
103 }
104 if cnt >= min_support {
105 freq.push((s.clone(), cnt));
106 candidates_k.push(s);
107 }
108 }
109 }
110
111 freq
112}
113
114fn extend_to_spatiotemporal(
116 trains: &[&[i32]],
117 itemsets: &[(Vec<usize>, usize)],
118 bin_steps: usize,
119 n_bins: usize,
120 max_lag_bins: usize,
121) -> Vec<(Vec<usize>, Vec<i64>, usize)> {
122 let mut patterns = Vec::new();
123
124 for (neurons, _sync_count) in itemsets {
125 if neurons.len() < 2 {
126 continue;
127 }
128 let ref_id = neurons[0];
129
130 let mut ref_bins = vec![0u8; n_bins];
132 for b in 0..n_bins {
133 let start = b * bin_steps;
134 let end = ((b + 1) * bin_steps).min(trains[ref_id].len());
135 if start < trains[ref_id].len() && trains[ref_id][start..end].iter().any(|&v| v > 0) {
136 ref_bins[b] = 1;
137 }
138 }
139
140 let mut best_lags: Vec<(usize, i64)> = vec![(ref_id, 0)];
141 let mut coincidence = ref_bins.clone();
142
143 for &nid in &neurons[1..] {
144 let mut best_lag: i64 = 0;
145 let mut best_overlap = 0usize;
146
147 for lag in 0..=max_lag_bins {
148 let mut shifted = vec![0u8; n_bins];
149 for b in 0..n_bins {
150 let src_b = b as i64 - lag as i64;
151 if src_b >= 0 && (src_b as usize) < n_bins {
152 let start = src_b as usize * bin_steps;
153 let end = ((src_b as usize + 1) * bin_steps).min(trains[nid].len());
154 if start < trains[nid].len()
155 && trains[nid][start..end].iter().any(|&v| v > 0)
156 {
157 shifted[b] = 1;
158 }
159 }
160 }
161 let overlap: usize = coincidence
162 .iter()
163 .zip(shifted.iter())
164 .map(|(&a, &b)| (a & b) as usize)
165 .sum();
166 if overlap > best_overlap {
167 best_overlap = overlap;
168 best_lag = lag as i64;
169 }
170 }
171
172 best_lags.push((nid, best_lag));
173
174 let mut nbins_best = vec![0u8; n_bins];
176 for b in 0..n_bins {
177 let src_b = b as i64 - best_lag;
178 if src_b >= 0 && (src_b as usize) < n_bins {
179 let start = src_b as usize * bin_steps;
180 let end = ((src_b as usize + 1) * bin_steps).min(trains[nid].len());
181 if start < trains[nid].len() && trains[nid][start..end].iter().any(|&v| v > 0) {
182 nbins_best[b] = 1;
183 }
184 }
185 }
186 for b in 0..n_bins {
187 coincidence[b] &= nbins_best[b];
188 }
189 }
190
191 let best_count: usize = coincidence.iter().map(|&v| v as usize).sum();
192 if best_count > 0 {
193 let neuron_list: Vec<usize> = best_lags.iter().map(|&(n, _)| n).collect();
194 let lag_list: Vec<i64> = best_lags.iter().map(|&(_, l)| l).collect();
195 patterns.push((neuron_list, lag_list, best_count));
196 }
197 }
198
199 patterns
200}
201
202pub fn spade_detect(
206 trains: &[&[i32]],
207 bin_ms: f64,
208 dt: f64,
209 min_support: usize,
210 max_pattern_size: usize,
211 n_surrogates: usize,
212 alpha: f64,
213 seed: u64,
214) -> Vec<SpadePattern> {
215 let n_neurons = trains.len();
216 if n_neurons < 2 {
217 return vec![];
218 }
219 let bin_steps = (bin_ms / (dt * 1000.0)).round().max(1.0) as usize;
220 let duration = trains.iter().map(|t| t.len()).max().unwrap_or(0);
221 let n_bins = duration / bin_steps;
222 if n_bins == 0 {
223 return vec![];
224 }
225
226 let binary_matrix = build_binary_matrix(trains, bin_steps, n_bins);
227 let itemsets = find_frequent_itemsets(&binary_matrix, min_support, max_pattern_size);
228 if itemsets.is_empty() {
229 return vec![];
230 }
231
232 let patterns = extend_to_spatiotemporal(trains, &itemsets, bin_steps, n_bins, 10);
233 if patterns.is_empty() {
234 return vec![];
235 }
236
237 let mut rng = seed;
239 let mut results = Vec::new();
240
241 for (neuron_list, lag_list, count) in &patterns {
242 let mut surr_counts = vec![0usize; n_surrogates];
243
244 for s in 0..n_surrogates {
245 let surr_trains: Vec<Vec<i32>> = (0..n_neurons)
247 .map(|i| {
248 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
249 let shift = (rng % (bin_steps as u64 * 10 + 1)) as i64 - (bin_steps as i64 * 5);
250 let n = trains[i].len();
251 if n == 0 {
252 return vec![];
253 }
254 let mut shifted = vec![0i32; n];
255 for j in 0..n {
256 let src = ((j as i64 - shift).rem_euclid(n as i64)) as usize;
257 shifted[j] = trains[i][src];
258 }
259 shifted
260 })
261 .collect();
262
263 let surr_binary = build_binary_matrix(
264 &surr_trains.iter().map(|v| v.as_slice()).collect::<Vec<_>>(),
265 bin_steps,
266 n_bins,
267 );
268
269 let mut coincidence = vec![1u8; n_bins];
271 for (idx, (&nid, &lag)) in neuron_list.iter().zip(lag_list.iter()).enumerate() {
272 let _ = idx;
273 for b in 0..n_bins {
274 let src_b = b as i64 - lag;
275 if src_b >= 0 && (src_b as usize) < n_bins {
276 coincidence[b] &= surr_binary[nid][src_b as usize];
277 } else {
278 coincidence[b] = 0;
279 }
280 }
281 }
282 surr_counts[s] = coincidence.iter().map(|&v| v as usize).sum();
283 }
284
285 let p_value = (surr_counts.iter().filter(|&&c| c >= *count).count() + 1) as f64
286 / (n_surrogates + 1) as f64;
287 if p_value <= alpha {
288 results.push(SpadePattern {
289 neurons: neuron_list.clone(),
290 lags: lag_list.clone(),
291 count: *count,
292 p_value,
293 });
294 }
295 }
296
297 results
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 fn make_correlated_trains() -> Vec<Vec<i32>> {
305 let n = 500;
307 let mut trains = vec![vec![0i32; n]; 3];
308 for i in (0..n).step_by(20) {
310 trains[0][i] = 1;
311 trains[1][i] = 1;
312 if i + 2 < n {
313 trains[2][i + 2] = 1; }
315 }
316 let mut rng = 42u64;
318 for t in &mut trains {
319 for j in 0..n {
320 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
321 if rng.is_multiple_of(50) && t[j] == 0 {
322 t[j] = 1;
323 }
324 }
325 }
326 trains
327 }
328
329 #[test]
330 fn test_spade_detects_patterns() {
331 let trains = make_correlated_trains();
332 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
333 let results = spade_detect(&refs, 5.0, 0.001, 3, 3, 50, 0.05, 42);
334 assert!(
336 !results.is_empty(),
337 "SPADE should detect synchronous patterns"
338 );
339 }
340
341 #[test]
342 fn test_spade_pattern_fields() {
343 let trains = make_correlated_trains();
344 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
345 let results = spade_detect(&refs, 5.0, 0.001, 3, 3, 50, 0.05, 42);
346 for pat in &results {
347 assert!(pat.neurons.len() >= 2);
348 assert_eq!(pat.neurons.len(), pat.lags.len());
349 assert!(pat.count > 0);
350 assert!(pat.p_value <= 0.05);
351 assert!(pat.p_value > 0.0);
352 }
353 }
354
355 #[test]
356 fn test_spade_empty() {
357 let results = spade_detect(&[], 5.0, 0.001, 3, 3, 50, 0.05, 42);
358 assert!(results.is_empty());
359 }
360
361 #[test]
362 fn test_spade_single_neuron() {
363 let train = vec![1, 0, 1, 0, 1];
364 let results = spade_detect(&[&train], 5.0, 0.001, 1, 3, 50, 0.05, 42);
365 assert!(results.is_empty());
366 }
367
368 #[test]
369 fn test_spade_no_pattern() {
370 let n = 200;
372 let mut trains = Vec::new();
373 let mut rng = 42u64;
374 for _ in 0..3 {
375 let mut t = vec![0i32; n];
376 for j in 0..n {
377 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
378 if rng.is_multiple_of(20) {
379 t[j] = 1;
380 }
381 }
382 trains.push(t);
383 }
384 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
385 let results = spade_detect(&refs, 5.0, 0.001, 5, 3, 100, 0.01, 42);
386 let _ = results;
388 }
389
390 #[test]
391 fn test_find_frequent_itemsets() {
392 let matrix = vec![
393 vec![1, 1, 0, 1, 1, 0, 1, 0, 1, 1],
394 vec![1, 1, 0, 1, 1, 0, 1, 0, 1, 1],
395 vec![0, 0, 1, 0, 0, 1, 0, 1, 0, 0],
396 ];
397 let itemsets = find_frequent_itemsets(&matrix, 3, 3);
398 let has_pair = itemsets
400 .iter()
401 .any(|(s, _)| s.len() == 2 && s.contains(&0) && s.contains(&1));
402 assert!(has_pair, "Should find {{0,1}} as frequent pair");
403 }
404
405 #[test]
406 fn test_build_binary_matrix() {
407 let train = vec![0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
408 let mat = build_binary_matrix(&[&train], 5, 2);
409 assert_eq!(mat.len(), 1);
410 assert_eq!(mat[0], vec![1, 1]); }
412
413 #[test]
414 fn test_spade_deterministic() {
415 let trains = make_correlated_trains();
416 let refs: Vec<&[i32]> = trains.iter().map(|t| t.as_slice()).collect();
417 let r1 = spade_detect(&refs, 5.0, 0.001, 3, 3, 30, 0.05, 42);
418 let r2 = spade_detect(&refs, 5.0, 0.001, 3, 3, 30, 0.05, 42);
419 assert_eq!(r1.len(), r2.len());
420 for (a, b) in r1.iter().zip(r2.iter()) {
421 assert_eq!(a.neurons, b.neurons);
422 assert_eq!(a.count, b.count);
423 assert!((a.p_value - b.p_value).abs() < 1e-12);
424 }
425 }
426}