sc_neurocore_engine/analysis/
stimulus.rs1pub fn spike_triggered_average(
14 stimulus: &[f64],
15 binary_train: &[i32],
16 window_steps: usize,
17) -> Vec<f64> {
18 let n = stimulus.len().min(binary_train.len());
19 let spike_idx: Vec<usize> = (window_steps..n).filter(|&i| binary_train[i] > 0).collect();
20 if spike_idx.is_empty() {
21 return vec![0.0; window_steps];
22 }
23 let mut avg = vec![0.0f64; window_steps];
24 for &t in &spike_idx {
25 for j in 0..window_steps {
26 avg[j] += stimulus[t - window_steps + j];
27 }
28 }
29 let count = spike_idx.len() as f64;
30 for v in &mut avg {
31 *v /= count;
32 }
33 avg
34}
35
36pub fn spike_triggered_covariance(
40 stimulus: &[f64],
41 binary_train: &[i32],
42 window_steps: usize,
43) -> Vec<f64> {
44 let n = stimulus.len().min(binary_train.len());
45 let spike_idx: Vec<usize> = (window_steps..n).filter(|&i| binary_train[i] > 0).collect();
46 if spike_idx.len() < 3 {
47 let mut eye = vec![0.0; window_steps * window_steps];
49 for i in 0..window_steps {
50 eye[i * window_steps + i] = 1.0;
51 }
52 return eye;
53 }
54 let m = spike_idx.len();
56 let w = window_steps;
57 let mut snippets = vec![0.0f64; m * w];
58 for (row, &t) in spike_idx.iter().enumerate() {
59 for j in 0..w {
60 snippets[row * w + j] = stimulus[t - w + j];
61 }
62 }
63 let mut mean = vec![0.0f64; w];
65 for row in 0..m {
66 for j in 0..w {
67 mean[j] += snippets[row * w + j];
68 }
69 }
70 for v in &mut mean {
71 *v /= m as f64;
72 }
73 for row in 0..m {
75 for j in 0..w {
76 snippets[row * w + j] -= mean[j];
77 }
78 }
79 let mut cov = vec![0.0f64; w * w];
81 for row in 0..m {
82 for i in 0..w {
83 let si = snippets[row * w + i];
84 for j in i..w {
85 let sj = snippets[row * w + j];
86 cov[i * w + j] += si * sj;
87 }
88 }
89 }
90 let denom = (m - 1) as f64;
91 for i in 0..w {
92 for j in i..w {
93 cov[i * w + j] /= denom;
94 cov[j * w + i] = cov[i * w + j];
95 }
96 }
97 cov
98}
99
100pub fn spatial_information(binary_train: &[i32], positions: &[f64], n_bins: usize, dt: f64) -> f64 {
104 let n = binary_train.len().min(positions.len());
105 if n < 10 {
106 return 0.0;
107 }
108 let pos = &positions[..n];
109 let pos_min = pos.iter().cloned().fold(f64::INFINITY, f64::min);
110 let pos_max = pos.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e-10;
111 let bin_width = (pos_max - pos_min) / n_bins as f64;
112
113 let mut occupancy = vec![0.0f64; n_bins];
114 let mut spike_counts = vec![0.0f64; n_bins];
115 for i in 0..n {
116 let k = ((pos[i] - pos_min) / bin_width).floor() as usize;
117 let k = k.min(n_bins - 1);
118 occupancy[k] += dt;
119 spike_counts[k] += binary_train[i] as f64;
120 }
121 let total_occ: f64 = occupancy.iter().sum();
122 if total_occ <= 0.0 {
123 return 0.0;
124 }
125 let total_spikes: f64 = spike_counts.iter().sum();
126 let mean_rate = total_spikes / (n as f64 * dt);
127 if mean_rate <= 0.0 {
128 return 0.0;
129 }
130 let mut si = 0.0;
131 for k in 0..n_bins {
132 let p_occ = occupancy[k] / total_occ;
133 let rate = if occupancy[k] > 0.0 {
134 spike_counts[k] / occupancy[k]
135 } else {
136 0.0
137 };
138 if rate > 0.0 && p_occ > 0.0 {
139 si += p_occ * rate / mean_rate * (rate / mean_rate).ln() / std::f64::consts::LN_2;
140 }
141 }
142 si.max(0.0)
143}
144
145pub fn place_field_detection(
150 binary_train: &[i32],
151 positions: &[f64],
152 n_bins: usize,
153 threshold_std: f64,
154 dt: f64,
155) -> Vec<(f64, f64)> {
156 let n = binary_train.len().min(positions.len());
157 if n < 10 {
158 return vec![];
159 }
160 let pos = &positions[..n];
161 let pos_min = pos.iter().cloned().fold(f64::INFINITY, f64::min);
162 let pos_max = pos.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e-10;
163 let bin_width = (pos_max - pos_min) / n_bins as f64;
164 let edges: Vec<f64> = (0..=n_bins)
165 .map(|k| pos_min + k as f64 * bin_width)
166 .collect();
167
168 let mut rates = vec![0.0f64; n_bins];
169 for k in 0..n_bins {
170 let mut occ = 0.0;
171 let mut spk = 0.0;
172 for i in 0..n {
173 if pos[i] >= edges[k] && pos[i] < edges[k + 1] {
174 occ += dt;
175 spk += binary_train[i] as f64;
176 }
177 }
178 rates[k] = if occ > 0.0 { spk / occ } else { 0.0 };
179 }
180
181 let mean_rate: f64 = rates.iter().sum::<f64>() / n_bins as f64;
182 let var: f64 = rates.iter().map(|&r| (r - mean_rate).powi(2)).sum::<f64>() / n_bins as f64;
183 let std_rate = var.sqrt();
184 let thresh = mean_rate + threshold_std * std_rate;
185
186 let mut fields = vec![];
187 let mut in_field = false;
188 let mut start = 0.0;
189 for k in 0..n_bins {
190 if rates[k] > thresh && !in_field {
191 in_field = true;
192 start = edges[k];
193 } else if rates[k] <= thresh && in_field {
194 in_field = false;
195 fields.push((start, edges[k]));
196 }
197 }
198 if in_field {
199 fields.push((start, edges[n_bins]));
200 }
201 fields
202}
203
204pub fn tuning_curve(
208 binary_train: &[i32],
209 stimulus_values: &[f64],
210 n_bins: usize,
211 dt: f64,
212) -> (Vec<f64>, Vec<f64>) {
213 let n = binary_train.len().min(stimulus_values.len());
214 if n < 5 {
215 return (vec![], vec![]);
216 }
217 let stim = &stimulus_values[..n];
218 let stim_min = stim.iter().cloned().fold(f64::INFINITY, f64::min);
219 let stim_max = stim.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e-10;
220 let bin_width = (stim_max - stim_min) / n_bins as f64;
221 let edges: Vec<f64> = (0..=n_bins)
222 .map(|k| stim_min + k as f64 * bin_width)
223 .collect();
224 let centres: Vec<f64> = (0..n_bins)
225 .map(|k| (edges[k] + edges[k + 1]) / 2.0)
226 .collect();
227
228 let mut rates = vec![0.0f64; n_bins];
229 for k in 0..n_bins {
230 let mut occ = 0.0;
231 let mut spk = 0.0;
232 for i in 0..n {
233 if stim[i] >= edges[k] && stim[i] < edges[k + 1] {
234 occ += dt;
235 spk += binary_train[i] as f64;
236 }
237 }
238 rates[k] = if occ > 0.0 { spk / occ } else { 0.0 };
239 }
240 (rates, centres)
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_sta_basic() {
249 let stim: Vec<f64> = (0..100).map(|i| (i as f64 * 0.1).sin()).collect();
250 let mut train = vec![0i32; 100];
251 train[50] = 1;
252 train[70] = 1;
253 let sta = spike_triggered_average(&stim, &train, 10);
254 assert_eq!(sta.len(), 10);
255 }
256
257 #[test]
258 fn test_sta_no_spikes() {
259 let stim = vec![1.0; 100];
260 let train = vec![0i32; 100];
261 let sta = spike_triggered_average(&stim, &train, 10);
262 assert_eq!(sta.len(), 10);
263 assert!(sta.iter().all(|&v| v == 0.0));
264 }
265
266 #[test]
267 fn test_sta_all_ones_stimulus() {
268 let stim = vec![1.0; 100];
269 let mut train = vec![0i32; 100];
270 train[30] = 1;
271 train[60] = 1;
272 let sta = spike_triggered_average(&stim, &train, 10);
273 assert!(sta.iter().all(|&v| (v - 1.0).abs() < 1e-12));
274 }
275
276 #[test]
277 fn test_stc_basic() {
278 let stim: Vec<f64> = (0..200).map(|i| (i as f64 * 0.05).sin()).collect();
279 let mut train = vec![0i32; 200];
280 for i in (50..200).step_by(20) {
281 train[i] = 1;
282 }
283 let cov = spike_triggered_covariance(&stim, &train, 10);
284 assert_eq!(cov.len(), 100); for i in 0..10 {
287 assert!(cov[i * 10 + i] >= 0.0);
288 }
289 }
290
291 #[test]
292 fn test_stc_few_spikes() {
293 let stim = vec![1.0; 100];
294 let train = vec![0i32; 100]; let cov = spike_triggered_covariance(&stim, &train, 5);
296 assert_eq!(cov.len(), 25);
297 for i in 0..5 {
299 assert!((cov[i * 5 + i] - 1.0).abs() < 1e-12);
300 }
301 }
302
303 #[test]
304 fn test_stc_symmetric() {
305 let stim: Vec<f64> = (0..200).map(|i| (i as f64 * 0.1).cos()).collect();
306 let mut train = vec![0i32; 200];
307 for i in (20..200).step_by(15) {
308 train[i] = 1;
309 }
310 let w = 8;
311 let cov = spike_triggered_covariance(&stim, &train, w);
312 for i in 0..w {
313 for j in 0..w {
314 assert!(
315 (cov[i * w + j] - cov[j * w + i]).abs() < 1e-12,
316 "Covariance not symmetric at ({i},{j})"
317 );
318 }
319 }
320 }
321
322 #[test]
323 fn test_spatial_information_basic() {
324 let mut train = vec![0i32; 200];
325 let positions: Vec<f64> = (0..200).map(|i| i as f64 / 200.0 * 10.0).collect();
326 for i in 0..50 {
328 if i % 2 == 0 {
329 train[i] = 1;
330 }
331 }
332 let si = spatial_information(&train, &positions, 20, 0.001);
333 assert!(si > 0.0, "Spatial info should be positive for place cell");
334 }
335
336 #[test]
337 fn test_spatial_information_uniform() {
338 let mut train = vec![0i32; 200];
340 let positions: Vec<f64> = (0..200).map(|i| i as f64).collect();
341 for i in (0..200).step_by(5) {
342 train[i] = 1;
343 }
344 let si = spatial_information(&train, &positions, 20, 0.001);
345 assert!(si < 0.5, "SI={si} too high for uniform firing");
347 }
348
349 #[test]
350 fn test_spatial_information_few_samples() {
351 assert_eq!(
352 spatial_information(&[0, 1, 0], &[1.0, 2.0, 3.0], 5, 0.001),
353 0.0
354 );
355 }
356
357 #[test]
358 fn test_place_field_detection() {
359 let mut train = vec![0i32; 1000];
360 let positions: Vec<f64> = (0..1000).map(|i| i as f64 / 1000.0 * 20.0).collect();
361 for i in 250..500 {
363 train[i] = 1; }
365 let fields = place_field_detection(&train, &positions, 50, 1.0, 0.001);
366 assert!(!fields.is_empty(), "Should detect at least one place field");
367 let (start, end) = fields[0];
369 assert!(
370 start < 12.0 && end > 4.0,
371 "Field ({start}, {end}) should be near 5-10"
372 );
373 }
374
375 #[test]
376 fn test_place_field_no_field() {
377 let mut train = vec![0i32; 200];
379 let positions: Vec<f64> = (0..200).map(|i| i as f64).collect();
380 for i in (0..200).step_by(10) {
381 train[i] = 1;
382 }
383 let fields = place_field_detection(&train, &positions, 50, 3.0, 0.001);
384 let _ = fields;
386 }
387
388 #[test]
389 fn test_tuning_curve_basic() {
390 let mut train = vec![0i32; 200];
391 let stim: Vec<f64> = (0..200)
392 .map(|i| (i as f64 / 200.0 * 360.0) % 360.0)
393 .collect();
394 for i in 90..110 {
396 train[i] = 1;
397 }
398 let (rates, centres) = tuning_curve(&train, &stim, 10, 0.001);
399 assert_eq!(rates.len(), 10);
400 assert_eq!(centres.len(), 10);
401 let peak_idx = rates
403 .iter()
404 .enumerate()
405 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
406 .unwrap()
407 .0;
408 assert!((4..=6).contains(&peak_idx));
409 }
410
411 #[test]
412 fn test_tuning_curve_few_samples() {
413 let (r, c) = tuning_curve(&[0, 1], &[1.0, 2.0], 5, 0.001);
414 assert!(r.is_empty());
415 assert!(c.is_empty());
416 }
417}