Skip to main content

sc_neurocore_engine/analysis/
stimulus.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Spike-triggered analysis and receptive field estimation
8
9/// Spike-triggered average (STA) of a stimulus signal.
10///
11/// Returns the average stimulus snippet of length `window_steps` preceding
12/// each spike.
13pub fn spike_triggered_average(
14    stimulus: &[f64],
15    binary_train: &[i32],
16    window_steps: usize,
17) -> Vec<f64> {
18    let n = stimulus.len().min(binary_train.len());
19    let spike_idx: Vec<usize> = (window_steps..n).filter(|&i| binary_train[i] > 0).collect();
20    if spike_idx.is_empty() {
21        return vec![0.0; window_steps];
22    }
23    let mut avg = vec![0.0f64; window_steps];
24    for &t in &spike_idx {
25        for j in 0..window_steps {
26            avg[j] += stimulus[t - window_steps + j];
27        }
28    }
29    let count = spike_idx.len() as f64;
30    for v in &mut avg {
31        *v /= count;
32    }
33    avg
34}
35
36/// Spike-triggered covariance (STC). Schwartz et al. 2006.
37///
38/// Returns flattened covariance matrix `[window_steps x window_steps]`.
39pub fn spike_triggered_covariance(
40    stimulus: &[f64],
41    binary_train: &[i32],
42    window_steps: usize,
43) -> Vec<f64> {
44    let n = stimulus.len().min(binary_train.len());
45    let spike_idx: Vec<usize> = (window_steps..n).filter(|&i| binary_train[i] > 0).collect();
46    if spike_idx.len() < 3 {
47        // Return identity
48        let mut eye = vec![0.0; window_steps * window_steps];
49        for i in 0..window_steps {
50            eye[i * window_steps + i] = 1.0;
51        }
52        return eye;
53    }
54    // Collect snippets
55    let m = spike_idx.len();
56    let w = window_steps;
57    let mut snippets = vec![0.0f64; m * w];
58    for (row, &t) in spike_idx.iter().enumerate() {
59        for j in 0..w {
60            snippets[row * w + j] = stimulus[t - w + j];
61        }
62    }
63    // Mean
64    let mut mean = vec![0.0f64; w];
65    for row in 0..m {
66        for j in 0..w {
67            mean[j] += snippets[row * w + j];
68        }
69    }
70    for v in &mut mean {
71        *v /= m as f64;
72    }
73    // Centre
74    for row in 0..m {
75        for j in 0..w {
76            snippets[row * w + j] -= mean[j];
77        }
78    }
79    // Covariance: S^T S / (m - 1)
80    let mut cov = vec![0.0f64; w * w];
81    for row in 0..m {
82        for i in 0..w {
83            let si = snippets[row * w + i];
84            for j in i..w {
85                let sj = snippets[row * w + j];
86                cov[i * w + j] += si * sj;
87            }
88        }
89    }
90    let denom = (m - 1) as f64;
91    for i in 0..w {
92        for j in i..w {
93            cov[i * w + j] /= denom;
94            cov[j * w + i] = cov[i * w + j];
95        }
96    }
97    cov
98}
99
100/// Spatial information (bits/spike). Skaggs et al. 1993.
101///
102/// `positions`: 1D position values (same length as `binary_train`).
103pub fn spatial_information(binary_train: &[i32], positions: &[f64], n_bins: usize, dt: f64) -> f64 {
104    let n = binary_train.len().min(positions.len());
105    if n < 10 {
106        return 0.0;
107    }
108    let pos = &positions[..n];
109    let pos_min = pos.iter().cloned().fold(f64::INFINITY, f64::min);
110    let pos_max = pos.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e-10;
111    let bin_width = (pos_max - pos_min) / n_bins as f64;
112
113    let mut occupancy = vec![0.0f64; n_bins];
114    let mut spike_counts = vec![0.0f64; n_bins];
115    for i in 0..n {
116        let k = ((pos[i] - pos_min) / bin_width).floor() as usize;
117        let k = k.min(n_bins - 1);
118        occupancy[k] += dt;
119        spike_counts[k] += binary_train[i] as f64;
120    }
121    let total_occ: f64 = occupancy.iter().sum();
122    if total_occ <= 0.0 {
123        return 0.0;
124    }
125    let total_spikes: f64 = spike_counts.iter().sum();
126    let mean_rate = total_spikes / (n as f64 * dt);
127    if mean_rate <= 0.0 {
128        return 0.0;
129    }
130    let mut si = 0.0;
131    for k in 0..n_bins {
132        let p_occ = occupancy[k] / total_occ;
133        let rate = if occupancy[k] > 0.0 {
134            spike_counts[k] / occupancy[k]
135        } else {
136            0.0
137        };
138        if rate > 0.0 && p_occ > 0.0 {
139            si += p_occ * rate / mean_rate * (rate / mean_rate).ln() / std::f64::consts::LN_2;
140        }
141    }
142    si.max(0.0)
143}
144
145/// Detect place fields as contiguous bins with rate > mean + threshold_std * std.
146/// O'Keefe & Dostrovsky 1971.
147///
148/// Returns list of `(field_start_pos, field_end_pos)`.
149pub fn place_field_detection(
150    binary_train: &[i32],
151    positions: &[f64],
152    n_bins: usize,
153    threshold_std: f64,
154    dt: f64,
155) -> Vec<(f64, f64)> {
156    let n = binary_train.len().min(positions.len());
157    if n < 10 {
158        return vec![];
159    }
160    let pos = &positions[..n];
161    let pos_min = pos.iter().cloned().fold(f64::INFINITY, f64::min);
162    let pos_max = pos.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e-10;
163    let bin_width = (pos_max - pos_min) / n_bins as f64;
164    let edges: Vec<f64> = (0..=n_bins)
165        .map(|k| pos_min + k as f64 * bin_width)
166        .collect();
167
168    let mut rates = vec![0.0f64; n_bins];
169    for k in 0..n_bins {
170        let mut occ = 0.0;
171        let mut spk = 0.0;
172        for i in 0..n {
173            if pos[i] >= edges[k] && pos[i] < edges[k + 1] {
174                occ += dt;
175                spk += binary_train[i] as f64;
176            }
177        }
178        rates[k] = if occ > 0.0 { spk / occ } else { 0.0 };
179    }
180
181    let mean_rate: f64 = rates.iter().sum::<f64>() / n_bins as f64;
182    let var: f64 = rates.iter().map(|&r| (r - mean_rate).powi(2)).sum::<f64>() / n_bins as f64;
183    let std_rate = var.sqrt();
184    let thresh = mean_rate + threshold_std * std_rate;
185
186    let mut fields = vec![];
187    let mut in_field = false;
188    let mut start = 0.0;
189    for k in 0..n_bins {
190        if rates[k] > thresh && !in_field {
191            in_field = true;
192            start = edges[k];
193        } else if rates[k] <= thresh && in_field {
194            in_field = false;
195            fields.push((start, edges[k]));
196        }
197    }
198    if in_field {
199        fields.push((start, edges[n_bins]));
200    }
201    fields
202}
203
204/// Tuning curve: mean firing rate vs stimulus value. Dayan & Abbott 2001.
205///
206/// Returns `(mean_rates, bin_centres)`.
207pub fn tuning_curve(
208    binary_train: &[i32],
209    stimulus_values: &[f64],
210    n_bins: usize,
211    dt: f64,
212) -> (Vec<f64>, Vec<f64>) {
213    let n = binary_train.len().min(stimulus_values.len());
214    if n < 5 {
215        return (vec![], vec![]);
216    }
217    let stim = &stimulus_values[..n];
218    let stim_min = stim.iter().cloned().fold(f64::INFINITY, f64::min);
219    let stim_max = stim.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e-10;
220    let bin_width = (stim_max - stim_min) / n_bins as f64;
221    let edges: Vec<f64> = (0..=n_bins)
222        .map(|k| stim_min + k as f64 * bin_width)
223        .collect();
224    let centres: Vec<f64> = (0..n_bins)
225        .map(|k| (edges[k] + edges[k + 1]) / 2.0)
226        .collect();
227
228    let mut rates = vec![0.0f64; n_bins];
229    for k in 0..n_bins {
230        let mut occ = 0.0;
231        let mut spk = 0.0;
232        for i in 0..n {
233            if stim[i] >= edges[k] && stim[i] < edges[k + 1] {
234                occ += dt;
235                spk += binary_train[i] as f64;
236            }
237        }
238        rates[k] = if occ > 0.0 { spk / occ } else { 0.0 };
239    }
240    (rates, centres)
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_sta_basic() {
249        let stim: Vec<f64> = (0..100).map(|i| (i as f64 * 0.1).sin()).collect();
250        let mut train = vec![0i32; 100];
251        train[50] = 1;
252        train[70] = 1;
253        let sta = spike_triggered_average(&stim, &train, 10);
254        assert_eq!(sta.len(), 10);
255    }
256
257    #[test]
258    fn test_sta_no_spikes() {
259        let stim = vec![1.0; 100];
260        let train = vec![0i32; 100];
261        let sta = spike_triggered_average(&stim, &train, 10);
262        assert_eq!(sta.len(), 10);
263        assert!(sta.iter().all(|&v| v == 0.0));
264    }
265
266    #[test]
267    fn test_sta_all_ones_stimulus() {
268        let stim = vec![1.0; 100];
269        let mut train = vec![0i32; 100];
270        train[30] = 1;
271        train[60] = 1;
272        let sta = spike_triggered_average(&stim, &train, 10);
273        assert!(sta.iter().all(|&v| (v - 1.0).abs() < 1e-12));
274    }
275
276    #[test]
277    fn test_stc_basic() {
278        let stim: Vec<f64> = (0..200).map(|i| (i as f64 * 0.05).sin()).collect();
279        let mut train = vec![0i32; 200];
280        for i in (50..200).step_by(20) {
281            train[i] = 1;
282        }
283        let cov = spike_triggered_covariance(&stim, &train, 10);
284        assert_eq!(cov.len(), 100); // 10x10
285                                    // Diagonal should be non-negative
286        for i in 0..10 {
287            assert!(cov[i * 10 + i] >= 0.0);
288        }
289    }
290
291    #[test]
292    fn test_stc_few_spikes() {
293        let stim = vec![1.0; 100];
294        let train = vec![0i32; 100]; // no spikes -> identity
295        let cov = spike_triggered_covariance(&stim, &train, 5);
296        assert_eq!(cov.len(), 25);
297        // Should be identity
298        for i in 0..5 {
299            assert!((cov[i * 5 + i] - 1.0).abs() < 1e-12);
300        }
301    }
302
303    #[test]
304    fn test_stc_symmetric() {
305        let stim: Vec<f64> = (0..200).map(|i| (i as f64 * 0.1).cos()).collect();
306        let mut train = vec![0i32; 200];
307        for i in (20..200).step_by(15) {
308            train[i] = 1;
309        }
310        let w = 8;
311        let cov = spike_triggered_covariance(&stim, &train, w);
312        for i in 0..w {
313            for j in 0..w {
314                assert!(
315                    (cov[i * w + j] - cov[j * w + i]).abs() < 1e-12,
316                    "Covariance not symmetric at ({i},{j})"
317                );
318            }
319        }
320    }
321
322    #[test]
323    fn test_spatial_information_basic() {
324        let mut train = vec![0i32; 200];
325        let positions: Vec<f64> = (0..200).map(|i| i as f64 / 200.0 * 10.0).collect();
326        // Place field: high firing in first quarter
327        for i in 0..50 {
328            if i % 2 == 0 {
329                train[i] = 1;
330            }
331        }
332        let si = spatial_information(&train, &positions, 20, 0.001);
333        assert!(si > 0.0, "Spatial info should be positive for place cell");
334    }
335
336    #[test]
337    fn test_spatial_information_uniform() {
338        // Uniform firing -> low spatial info
339        let mut train = vec![0i32; 200];
340        let positions: Vec<f64> = (0..200).map(|i| i as f64).collect();
341        for i in (0..200).step_by(5) {
342            train[i] = 1;
343        }
344        let si = spatial_information(&train, &positions, 20, 0.001);
345        // Should be near zero for uniform
346        assert!(si < 0.5, "SI={si} too high for uniform firing");
347    }
348
349    #[test]
350    fn test_spatial_information_few_samples() {
351        assert_eq!(
352            spatial_information(&[0, 1, 0], &[1.0, 2.0, 3.0], 5, 0.001),
353            0.0
354        );
355    }
356
357    #[test]
358    fn test_place_field_detection() {
359        let mut train = vec![0i32; 1000];
360        let positions: Vec<f64> = (0..1000).map(|i| i as f64 / 1000.0 * 20.0).collect();
361        // Create dense place field at positions 5-10 (indices 250-500)
362        for i in 250..500 {
363            train[i] = 1; // every step fires
364        }
365        let fields = place_field_detection(&train, &positions, 50, 1.0, 0.001);
366        assert!(!fields.is_empty(), "Should detect at least one place field");
367        // Field should overlap the 5-10 range
368        let (start, end) = fields[0];
369        assert!(
370            start < 12.0 && end > 4.0,
371            "Field ({start}, {end}) should be near 5-10"
372        );
373    }
374
375    #[test]
376    fn test_place_field_no_field() {
377        // Uniform firing -> no fields
378        let mut train = vec![0i32; 200];
379        let positions: Vec<f64> = (0..200).map(|i| i as f64).collect();
380        for i in (0..200).step_by(10) {
381            train[i] = 1;
382        }
383        let fields = place_field_detection(&train, &positions, 50, 3.0, 0.001);
384        // May or may not detect spurious fields, but shouldn't crash
385        let _ = fields;
386    }
387
388    #[test]
389    fn test_tuning_curve_basic() {
390        let mut train = vec![0i32; 200];
391        let stim: Vec<f64> = (0..200)
392            .map(|i| (i as f64 / 200.0 * 360.0) % 360.0)
393            .collect();
394        // Tuned to ~180 degrees
395        for i in 90..110 {
396            train[i] = 1;
397        }
398        let (rates, centres) = tuning_curve(&train, &stim, 10, 0.001);
399        assert_eq!(rates.len(), 10);
400        assert_eq!(centres.len(), 10);
401        // Peak should be in the middle bins
402        let peak_idx = rates
403            .iter()
404            .enumerate()
405            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
406            .unwrap()
407            .0;
408        assert!((4..=6).contains(&peak_idx));
409    }
410
411    #[test]
412    fn test_tuning_curve_few_samples() {
413        let (r, c) = tuning_curve(&[0, 1], &[1.0, 2.0], 5, 0.001);
414        assert!(r.is_empty());
415        assert!(c.is_empty());
416    }
417}