1use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Gamma, Poisson};
12
13pub fn surrogate_isi_shuffle(binary_train: &[i32], seed: u64) -> Vec<i32> {
15 let spike_idx: Vec<usize> = binary_train
16 .iter()
17 .enumerate()
18 .filter(|(_, &v)| v > 0)
19 .map(|(i, _)| i)
20 .collect();
21 if spike_idx.len() < 3 {
22 return binary_train.to_vec();
23 }
24 let mut intervals: Vec<usize> = spike_idx.windows(2).map(|w| w[1] - w[0]).collect();
25 let mut rng = ChaCha8Rng::seed_from_u64(seed);
26 intervals.shuffle(&mut rng);
27
28 let mut out = vec![0i32; binary_train.len()];
29 let mut idx = spike_idx[0];
30 out[idx] = 1;
31 for &gap in &intervals {
32 idx += gap;
33 if idx < out.len() {
34 out[idx] = 1;
35 }
36 }
37 out
38}
39
40pub fn surrogate_dither(binary_train: &[i32], dither_ms: f64, dt: f64, seed: u64) -> Vec<i32> {
42 let mut rng = ChaCha8Rng::seed_from_u64(seed);
43 let dither_steps = (dither_ms / (dt * 1000.0)) as i64;
44 let n = binary_train.len();
45 let mut out = vec![0i32; n];
46 for (i, &v) in binary_train.iter().enumerate() {
47 if v > 0 {
48 let jitter = rng.random_range(-dither_steps..=dither_steps);
49 let new_idx = (i as i64 + jitter).clamp(0, n as i64 - 1) as usize;
50 out[new_idx] = 1;
51 }
52 }
53 out
54}
55
56pub fn surrogate_trial_shuffle(n_trials: usize, seed: u64) -> Vec<usize> {
59 let mut indices: Vec<usize> = (0..n_trials).collect();
60 let mut rng = ChaCha8Rng::seed_from_u64(seed);
61 indices.shuffle(&mut rng);
62 indices
63}
64
65pub fn homogeneous_poisson(rate_hz: f64, duration_s: f64, dt: f64, seed: u64) -> Vec<f64> {
67 let mut rng = ChaCha8Rng::seed_from_u64(seed);
68 let n = (duration_s / dt) as usize;
69 let threshold = rate_hz * dt;
70 (0..n)
71 .map(|_| {
72 if rng.random::<f64>() < threshold {
73 1.0
74 } else {
75 0.0
76 }
77 })
78 .collect()
79}
80
81pub fn inhomogeneous_poisson(
83 rate_func: fn(f64) -> f64,
84 duration_s: f64,
85 dt: f64,
86 seed: u64,
87) -> Vec<f64> {
88 let mut rng = ChaCha8Rng::seed_from_u64(seed);
89 let n = (duration_s / dt) as usize;
90 let rates: Vec<f64> = (0..n).map(|i| rate_func(i as f64 * dt)).collect();
91 let max_rate = rates.iter().copied().fold(0.0_f64, f64::max);
92 if max_rate <= 0.0 {
93 return vec![0.0; n];
94 }
95 let threshold = max_rate * dt;
96 (0..n)
97 .map(|i| {
98 let candidate = rng.random::<f64>() < threshold;
99 let accept = rng.random::<f64>() < rates[i] / max_rate;
100 if candidate && accept {
101 1.0
102 } else {
103 0.0
104 }
105 })
106 .collect()
107}
108
109pub fn gamma_process(rate_hz: f64, shape: f64, duration_s: f64, dt: f64, seed: u64) -> Vec<f64> {
112 let n = (duration_s / dt) as usize;
113 let mut train = vec![0.0_f64; n];
114 if rate_hz <= 0.0 || shape <= 0.0 {
115 return train;
116 }
117 let scale = 1.0 / (rate_hz * shape);
118 let gamma = Gamma::new(shape, scale).unwrap();
119 let mut rng = ChaCha8Rng::seed_from_u64(seed);
120 let mut t = 0.0_f64;
121 loop {
122 let interval: f64 = rng.sample(gamma);
123 t += interval;
124 let idx = (t / dt) as usize;
125 if idx >= n {
126 break;
127 }
128 train[idx] = 1.0;
129 }
130 train
131}
132
133pub fn compound_poisson_process(
136 rate_hz: f64,
137 burst_mean: f64,
138 duration_s: f64,
139 dt: f64,
140 seed: u64,
141) -> Vec<f64> {
142 let mut rng = ChaCha8Rng::seed_from_u64(seed);
143 let n = (duration_s / dt) as usize;
144 let mut train = vec![0.0_f64; n];
145 let threshold = rate_hz * dt;
146 let poisson = Poisson::new(burst_mean.max(1e-10)).unwrap();
147 for i in 0..n {
148 if rng.random::<f64>() < threshold {
149 let n_spikes: usize = rng.sample(poisson) as usize;
150 for s in 0..n_spikes {
151 let offset = i + s;
152 if offset < n {
153 train[offset] = 1.0;
154 }
155 }
156 }
157 }
158 train
159}
160
161pub fn surrogate_joint_isi(binary_train: &[i32], seed: u64) -> Vec<i32> {
163 let spike_idx: Vec<usize> = binary_train
164 .iter()
165 .enumerate()
166 .filter(|(_, &v)| v > 0)
167 .map(|(i, _)| i)
168 .collect();
169 if spike_idx.len() < 4 {
170 return binary_train.to_vec();
171 }
172 let mut intervals: Vec<usize> = spike_idx.windows(2).map(|w| w[1] - w[0]).collect();
173 let mut rng = ChaCha8Rng::seed_from_u64(seed);
174 let ni = intervals.len();
175 for _ in 0..(2 * ni) {
176 let i = rng.random_range(0..ni);
177 let j = rng.random_range(0..ni);
178 if i != j {
179 intervals.swap(i, j);
180 }
181 }
182 let mut out = vec![0i32; binary_train.len()];
183 let mut pos = spike_idx[0];
184 out[pos] = 1;
185 for &gap in &intervals {
186 pos += gap;
187 if pos < out.len() {
188 out[pos] = 1;
189 }
190 }
191 out
192}
193
194pub fn surrogate_bin_shuffling(binary_train: &[i32], bin_size: usize, seed: u64) -> Vec<i32> {
196 let mut out = binary_train.to_vec();
197 let mut rng = ChaCha8Rng::seed_from_u64(seed);
198 let n = out.len();
199 let mut start = 0;
200 while start < n {
201 let end = (start + bin_size).min(n);
202 out[start..end].shuffle(&mut rng);
203 start = end;
204 }
205 out
206}
207
208pub fn surrogate_spike_train_shifting(
210 binary_train: &[i32],
211 max_shift: usize,
212 seed: u64,
213) -> Vec<i32> {
214 let n = binary_train.len();
215 if n == 0 {
216 return vec![];
217 }
218 let mut rng = ChaCha8Rng::seed_from_u64(seed);
219 let shift = rng.random_range(0..=(2 * max_shift)) as i64 - max_shift as i64;
220 let mut out = vec![0i32; n];
221 for i in 0..n {
222 let new_idx = ((i as i64 + shift).rem_euclid(n as i64)) as usize;
223 out[new_idx] = binary_train[i];
224 }
225 out
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn make_train(spikes: &[usize], len: usize) -> Vec<i32> {
233 let mut t = vec![0i32; len];
234 for &s in spikes {
235 t[s] = 1;
236 }
237 t
238 }
239
240 fn spike_count(train: &[i32]) -> i64 {
241 train.iter().map(|&v| v as i64).sum()
242 }
243
244 fn spike_count_f64(train: &[f64]) -> i64 {
245 train.iter().filter(|&&v| v > 0.5).count() as i64
246 }
247
248 #[test]
251 fn test_isi_shuffle_preserves_count() {
252 let train = make_train(&[5, 15, 25, 35, 45, 55, 65, 75, 85, 95], 100);
253 let surr = surrogate_isi_shuffle(&train, 42);
254 assert_eq!(spike_count(&surr), spike_count(&train));
255 }
256
257 #[test]
258 fn test_isi_shuffle_deterministic() {
259 let train = make_train(&[5, 15, 25, 35, 45], 100);
260 let s1 = surrogate_isi_shuffle(&train, 42);
261 let s2 = surrogate_isi_shuffle(&train, 42);
262 assert_eq!(s1, s2, "same seed → same result");
263 }
264
265 #[test]
266 fn test_isi_shuffle_few_spikes() {
267 let train = make_train(&[50], 100);
268 let surr = surrogate_isi_shuffle(&train, 0);
269 assert_eq!(surr, train, "too few spikes → unchanged");
270 }
271
272 #[test]
275 fn test_dither_preserves_count_approx() {
276 let train = make_train(&[10, 30, 50, 70, 90], 100);
277 let surr = surrogate_dither(&train, 2.0, 0.001, 42);
278 assert!(spike_count(&surr) > 0);
280 assert!(spike_count(&surr) <= spike_count(&train));
281 }
282
283 #[test]
284 fn test_dither_deterministic() {
285 let train = make_train(&[10, 50, 90], 100);
286 let s1 = surrogate_dither(&train, 3.0, 0.001, 7);
287 let s2 = surrogate_dither(&train, 3.0, 0.001, 7);
288 assert_eq!(s1, s2);
289 }
290
291 #[test]
294 fn test_trial_shuffle_permutation() {
295 let perm = surrogate_trial_shuffle(5, 42);
296 assert_eq!(perm.len(), 5);
297 let mut sorted = perm.clone();
298 sorted.sort();
299 assert_eq!(sorted, vec![0, 1, 2, 3, 4], "should be a permutation");
300 }
301
302 #[test]
305 fn test_poisson_rate() {
306 let train = homogeneous_poisson(100.0, 1.0, 0.001, 42);
307 assert_eq!(train.len(), 1000);
308 let count = spike_count_f64(&train);
309 assert!(
311 count > 50 && count < 200,
312 "expected ~100 spikes, got {count}"
313 );
314 }
315
316 #[test]
317 fn test_poisson_deterministic() {
318 let t1 = homogeneous_poisson(50.0, 0.5, 0.001, 99);
319 let t2 = homogeneous_poisson(50.0, 0.5, 0.001, 99);
320 assert_eq!(t1, t2);
321 }
322
323 #[test]
324 fn test_poisson_zero_rate() {
325 let train = homogeneous_poisson(0.0, 1.0, 0.001, 0);
326 assert_eq!(spike_count_f64(&train), 0);
327 }
328
329 #[test]
332 fn test_inhom_poisson_constant_matches_homogeneous() {
333 fn rate(_t: f64) -> f64 {
334 50.0
335 }
336 let train = inhomogeneous_poisson(rate, 1.0, 0.001, 42);
337 assert_eq!(train.len(), 1000);
338 let count = spike_count_f64(&train);
339 assert!(
340 count > 10 && count < 150,
341 "~50 spikes expected, got {count}"
342 );
343 }
344
345 #[test]
348 fn test_gamma_poisson_like() {
349 let train = gamma_process(100.0, 1.0, 1.0, 0.001, 42);
350 assert_eq!(train.len(), 1000);
351 let count = spike_count_f64(&train);
352 assert!(count > 30 && count < 200, "shape=1 ≈ Poisson, got {count}");
353 }
354
355 #[test]
356 fn test_gamma_regular() {
357 let train = gamma_process(50.0, 5.0, 1.0, 0.001, 42);
359 let count = spike_count_f64(&train);
360 assert!(count > 10, "should produce spikes, got {count}");
361 }
362
363 #[test]
364 fn test_gamma_zero_rate() {
365 let train = gamma_process(0.0, 1.0, 1.0, 0.001, 0);
366 assert_eq!(spike_count_f64(&train), 0);
367 }
368
369 #[test]
372 fn test_cpp_produces_spikes() {
373 let train = compound_poisson_process(50.0, 3.0, 1.0, 0.001, 42);
374 assert_eq!(train.len(), 1000);
375 let count = spike_count_f64(&train);
376 assert!(count > 10, "should produce bursts, got {count}");
377 }
378
379 #[test]
380 fn test_cpp_deterministic() {
381 let t1 = compound_poisson_process(30.0, 2.0, 0.5, 0.001, 7);
382 let t2 = compound_poisson_process(30.0, 2.0, 0.5, 0.001, 7);
383 assert_eq!(t1, t2);
384 }
385
386 #[test]
389 fn test_joint_isi_preserves_count() {
390 let train = make_train(&[5, 15, 25, 35, 45, 55, 65, 75], 100);
391 let surr = surrogate_joint_isi(&train, 42);
392 assert_eq!(spike_count(&surr), spike_count(&train));
393 }
394
395 #[test]
396 fn test_joint_isi_few_spikes() {
397 let train = make_train(&[10, 50], 100);
398 let surr = surrogate_joint_isi(&train, 0);
399 assert_eq!(surr, train, "< 4 spikes → unchanged");
400 }
401
402 #[test]
405 fn test_bin_shuffle_preserves_count() {
406 let train = make_train(&[0, 1, 2, 15, 16, 30, 31, 32, 33, 45], 50);
407 let surr = surrogate_bin_shuffling(&train, 10, 42);
408 assert_eq!(spike_count(&surr), spike_count(&train));
409 }
410
411 #[test]
412 fn test_bin_shuffle_deterministic() {
413 let train = make_train(&[3, 7, 13, 27], 30);
414 let s1 = surrogate_bin_shuffling(&train, 10, 42);
415 let s2 = surrogate_bin_shuffling(&train, 10, 42);
416 assert_eq!(s1, s2);
417 }
418
419 #[test]
422 fn test_shift_preserves_count() {
423 let train = make_train(&[10, 30, 50, 70, 90], 100);
424 let surr = surrogate_spike_train_shifting(&train, 20, 42);
425 assert_eq!(spike_count(&surr), spike_count(&train));
426 }
427
428 #[test]
429 fn test_shift_circular() {
430 let train = make_train(&[0, 99], 100);
431 let surr = surrogate_spike_train_shifting(&train, 50, 42);
432 assert_eq!(spike_count(&surr), 2, "circular shift preserves all spikes");
433 }
434
435 #[test]
436 fn test_shift_empty() {
437 assert!(surrogate_spike_train_shifting(&[], 10, 0).is_empty());
438 }
439}