Skip to main content

sc_neurocore_engine/
neuron.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Neuron Models
8
9//! # Neuron Models
10//!
11//! Fixed-point LIF and Izhikevich neuron models for the v3 engine.
12
13/// Mask and sign-interpret an integer to `width` bits (branchless).
14///
15/// `width` must be in 1..=32. Values outside this range trigger a
16/// debug assertion failure (release builds silently produce garbage).
17#[inline]
18pub fn mask(value: i32, width: u32) -> i16 {
19    assert!(
20        width > 0 && width <= 32,
21        "mask width must be 1..=32, got {width}"
22    );
23    let m = (1_i64 << width) - 1;
24    let v = (value as i64) & m;
25    let shift = 64 - width;
26    ((v << shift) >> shift) as i16
27}
28
29/// Fixed-point leaky-integrate-and-fire neuron state and parameters.
30#[derive(Clone, Debug)]
31pub struct FixedPointLif {
32    /// Membrane potential.
33    pub v: i16,
34    /// Refractory counter in simulation steps.
35    pub refractory_counter: i32,
36    /// Arithmetic data width.
37    pub data_width: u32,
38    /// Fraction bits for fixed-point scaling.
39    pub fraction: u32,
40    /// Resting potential.
41    pub v_rest: i16,
42    /// Reset potential after spike.
43    pub v_reset: i16,
44    /// Spike threshold.
45    pub v_threshold: i16,
46    /// Refractory period length in steps.
47    pub refractory_period: i32,
48}
49
50impl FixedPointLif {
51    /// Construct a fixed-point LIF neuron.
52    pub fn new(
53        data_width: u32,
54        fraction: u32,
55        v_rest: i16,
56        v_reset: i16,
57        v_threshold: i16,
58        refractory_period: i32,
59    ) -> Self {
60        Self {
61            v: v_rest,
62            refractory_counter: 0,
63            data_width,
64            fraction,
65            v_rest,
66            v_reset,
67            v_threshold,
68            refractory_period,
69        }
70    }
71
72    /// Advance one simulation step.
73    ///
74    /// Returns `(spike, membrane_voltage)`.
75    #[allow(non_snake_case)]
76    pub fn step(&mut self, leak_k: i16, gain_k: i16, i_t: i16, noise_in: i16) -> (i32, i16) {
77        let w = self.data_width;
78
79        // Refractory: check previous step's counter before any fire logic.
80        if self.refractory_counter > 0 {
81            self.refractory_counter -= 1;
82            self.v = self.v_rest;
83            return (0, mask(self.v_rest as i32, w));
84        }
85
86        let diff = mask((self.v_rest as i32) - (self.v as i32), 2 * w) as i32;
87        let dv_leak = mask((diff * (leak_k as i32)) >> self.fraction, self.data_width);
88        let dv_in = mask(
89            ((i_t as i32) * (gain_k as i32)) >> self.fraction,
90            self.data_width,
91        );
92
93        let v_next = mask(
94            (self.v as i32) + (dv_leak as i32) + (dv_in as i32) + (noise_in as i32),
95            self.data_width,
96        );
97
98        if v_next >= self.v_threshold {
99            self.v = self.v_reset;
100            self.refractory_counter = self.refractory_period;
101            (1, mask(self.v_reset as i32, w))
102        } else {
103            self.v = v_next;
104            (0, mask(v_next as i32, w))
105        }
106    }
107
108    /// Reset internal state to resting potential.
109    pub fn reset(&mut self) {
110        self.v = self.v_rest;
111        self.refractory_counter = 0;
112    }
113}
114
115/// Izhikevich neuron (floating-point).
116///
117/// Standard model from IEEE TNN 14(6), 2003:
118///   v' = 0.04*v² + 5*v + 140 - u + I
119///   u' = a*(b*v - u)
120///   if v >= 30: v ← c, u ← u + d
121#[derive(Clone, Debug)]
122pub struct Izhikevich {
123    pub v: f64,
124    pub u: f64,
125    pub a: f64,
126    pub b: f64,
127    pub c: f64,
128    pub d: f64,
129    pub dt: f64,
130}
131
132impl Izhikevich {
133    /// Regular spiking defaults: a=0.02, b=0.2, c=-65, d=8, dt=1.0.
134    pub fn new(a: f64, b: f64, c: f64, d: f64, dt: f64) -> Self {
135        Self {
136            v: c,
137            u: b * c,
138            a,
139            b,
140            c,
141            d,
142            dt,
143        }
144    }
145
146    /// Regular spiking preset.
147    pub fn regular_spiking() -> Self {
148        Self::new(0.02, 0.2, -65.0, 8.0, 1.0)
149    }
150
151    /// Advance one step. Returns 1 on spike, 0 otherwise.
152    pub fn step(&mut self, current: f64) -> i32 {
153        // Two half-steps for numerical stability on 0.04v² term.
154        let half = self.dt * 0.5;
155        for _ in 0..2 {
156            let dv = (0.04 * self.v * self.v + 5.0 * self.v + 140.0 - self.u + current) * half;
157            let du = (self.a * (self.b * self.v - self.u)) * half;
158            self.v += dv;
159            self.u += du;
160        }
161
162        if self.v >= 30.0 {
163            self.v = self.c;
164            self.u += self.d;
165            1
166        } else {
167            0
168        }
169    }
170
171    /// Reset to initial state.
172    pub fn reset(&mut self) {
173        self.v = self.c;
174        self.u = self.b * self.c;
175    }
176}
177
178/// Sliding-window bitstream probability estimator.
179///
180/// Mirrors Python's `BitstreamAverager`.
181#[derive(Clone, Debug)]
182pub struct BitstreamAverager {
183    buffer: Vec<u8>,
184    index: usize,
185    filled: bool,
186    running_sum: u64,
187}
188
189impl BitstreamAverager {
190    pub fn new(window: usize) -> Self {
191        assert!(window > 0, "window must be > 0");
192        Self {
193            buffer: vec![0; window],
194            index: 0,
195            filled: false,
196            running_sum: 0,
197        }
198    }
199
200    pub fn push(&mut self, bit: u8) {
201        debug_assert!(bit <= 1, "bit must be 0 or 1");
202        let old = self.buffer[self.index];
203        self.buffer[self.index] = bit;
204
205        if self.filled {
206            self.running_sum = self.running_sum - old as u64 + bit as u64;
207        } else {
208            self.running_sum += bit as u64;
209        }
210
211        self.index += 1;
212        if self.index == self.buffer.len() {
213            self.index = 0;
214            self.filled = true;
215        }
216    }
217
218    pub fn estimate(&self) -> f64 {
219        if !self.filled {
220            if self.index == 0 {
221                return 0.0;
222            }
223            return self.running_sum as f64 / self.index as f64;
224        }
225        self.running_sum as f64 / self.buffer.len() as f64
226    }
227
228    pub fn reset(&mut self) {
229        self.buffer.fill(0);
230        self.index = 0;
231        self.filled = false;
232        self.running_sum = 0;
233    }
234
235    pub fn window(&self) -> usize {
236        self.buffer.len()
237    }
238}
239
240/// Homeostatic LIF neuron with adaptive threshold.
241///
242/// Threshold adapts via EMA of spike rate toward a target setpoint.
243/// Turrigiano, Cold Spring Harb Perspect Biol 4:a005736, 2012.
244#[derive(Clone, Debug)]
245pub struct HomeostaticLif {
246    pub v: f64,
247    pub v_threshold: f64,
248    pub v_rest: f64,
249    pub v_reset: f64,
250    pub rate_trace: f64,
251    pub target_rate: f64,
252    pub adaptation_rate: f64,
253    pub trace_decay: f64,
254    initial_threshold: f64,
255}
256
257impl HomeostaticLif {
258    pub fn new(target_rate: f64, adaptation_rate: f64, trace_decay: f64) -> Self {
259        Self {
260            v: 0.0,
261            v_threshold: 1.0,
262            v_rest: 0.0,
263            v_reset: 0.0,
264            rate_trace: 0.0,
265            target_rate,
266            adaptation_rate,
267            trace_decay,
268            initial_threshold: 1.0,
269        }
270    }
271
272    pub fn with_defaults() -> Self {
273        Self::new(0.1, 0.01, 0.95)
274    }
275
276    /// LIF step with threshold adaptation. Returns 1 on spike.
277    pub fn step(&mut self, current: f64) -> i32 {
278        // Leak-integrate
279        let tau = 20.0;
280        self.v += (-(self.v - self.v_rest) + current) / tau;
281
282        let spike = if self.v >= self.v_threshold {
283            self.v = self.v_reset;
284            1
285        } else {
286            0
287        };
288
289        // EMA spike rate tracking
290        self.rate_trace =
291            self.rate_trace * self.trace_decay + spike as f64 * (1.0 - self.trace_decay);
292
293        // Threshold adaptation
294        let error = self.rate_trace - self.target_rate;
295        self.v_threshold += self.adaptation_rate * error;
296        self.v_threshold = self.v_threshold.clamp(0.1, self.initial_threshold * 10.0);
297
298        spike
299    }
300
301    pub fn reset(&mut self) {
302        self.v = self.v_rest;
303        self.rate_trace = 0.0;
304        self.v_threshold = self.initial_threshold;
305    }
306}
307
308/// XOR-nonlinearity dendritic neuron.
309///
310/// Koch, Biophysics of Computation, 1999, Ch. 12.
311/// Output = 1 if (d1 + d2 - 2*d1*d2) > threshold.
312#[derive(Clone, Debug)]
313pub struct DendriticNeuron {
314    pub threshold: f64,
315    last_current: f64,
316}
317
318impl DendriticNeuron {
319    pub fn new(threshold: f64) -> Self {
320        Self {
321            threshold,
322            last_current: 0.0,
323        }
324    }
325
326    pub fn with_defaults() -> Self {
327        Self::new(0.5)
328    }
329
330    pub fn step(&mut self, input_a: f64, input_b: f64) -> i32 {
331        self.last_current = input_a + input_b - 2.0 * input_a * input_b;
332        if self.last_current > self.threshold {
333            1
334        } else {
335            0
336        }
337    }
338
339    pub fn reset(&mut self) {
340        self.last_current = 0.0;
341    }
342}
343
344/// Adaptive Exponential IF neuron. Brette & Gerstner 2005.
345/// PyO3 wrapper: `pyo3_neurons::PyAdExNeuron`
346#[derive(Clone, Debug)]
347pub struct AdExNeuron {
348    pub v: f64,
349    pub w: f64,
350    pub v_rest: f64,
351    pub v_reset: f64,
352    pub v_threshold: f64,
353    pub v_rh: f64,
354    pub delta_t: f64,
355    pub tau: f64,
356    pub tau_w: f64,
357    pub a: f64,
358    pub b: f64,
359    pub c_m: f64,
360    pub dt: f64,
361}
362
363impl Default for AdExNeuron {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369impl AdExNeuron {
370    pub fn new() -> Self {
371        Self {
372            v: -65.0,
373            w: 0.0,
374            v_rest: -65.0,
375            v_reset: -68.0,
376            v_threshold: -50.0,
377            v_rh: -55.0,
378            delta_t: 2.0,
379            tau: 20.0,
380            tau_w: 100.0,
381            a: 0.5,
382            b: 7.0,
383            c_m: 200.0,
384            dt: 0.1,
385        }
386    }
387
388    pub fn step(&mut self, current: f64) -> i32 {
389        // Brette & Gerstner 2005: C dV/dt = -g_L(V-E_L) + g_L ΔT exp((V-V_T)/ΔT) - w + I
390        let exp_arg = ((self.v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
391        let exp_term = self.delta_t * exp_arg.exp();
392        let dv = ((-(self.v - self.v_rest) + exp_term) / self.tau + (-self.w + current) / self.c_m)
393            * self.dt;
394        let dw = (self.a * (self.v - self.v_rest) - self.w) / self.tau_w * self.dt;
395        self.v += dv;
396        self.w += dw;
397
398        if self.v >= self.v_threshold {
399            self.v = self.v_reset;
400            self.w += self.b;
401            1
402        } else {
403            0
404        }
405    }
406
407    pub fn reset(&mut self) {
408        self.v = self.v_rest;
409        self.w = 0.0;
410    }
411}
412
413/// Exponential IF (no adaptation). Fourcaud-Trocmé et al. 2003.
414#[derive(Clone, Debug)]
415pub struct ExpIfNeuron {
416    pub v: f64,
417    pub v_rest: f64,
418    pub v_reset: f64,
419    pub v_threshold: f64,
420    pub v_rh: f64,
421    pub delta_t: f64,
422    pub tau: f64,
423    pub dt: f64,
424    /// Precomputed 1.0 / delta_t.
425    pub inv_delta_t: f64,
426    /// Precomputed dt / tau.
427    pub dt_div_tau: f64,
428}
429
430impl Default for ExpIfNeuron {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436impl ExpIfNeuron {
437    pub fn new() -> Self {
438        Self {
439            v: -65.0,
440            v_rest: -65.0,
441            v_reset: -68.0,
442            v_threshold: -50.0,
443            v_rh: -55.0,
444            delta_t: 2.0,
445            tau: 20.0,
446            dt: 0.1,
447            inv_delta_t: 1.0 / 2.0,
448            dt_div_tau: 0.1 / 20.0,
449        }
450    }
451
452    pub fn step(&mut self, current: f64) -> i32 {
453        let exp_arg = ((self.v - self.v_rh) * self.inv_delta_t).clamp(-20.0, 20.0);
454        let exp_term = self.delta_t * exp_arg.exp();
455        let dv = (-(self.v - self.v_rest) + exp_term + current) * self.dt_div_tau;
456        self.v += dv;
457
458        if self.v >= self.v_threshold {
459            self.v = self.v_reset;
460            1
461        } else {
462            0
463        }
464    }
465
466    pub fn reset(&mut self) {
467        self.v = self.v_rest;
468    }
469}
470
471/// Lapicque 1907 — classical RC integrate-and-fire.
472#[derive(Clone, Debug)]
473pub struct LapicqueNeuron {
474    pub v: f64,
475    pub v_rest: f64,
476    pub v_reset: f64,
477    pub v_threshold: f64,
478    pub tau: f64,
479    pub resistance: f64,
480    pub dt: f64,
481}
482
483impl LapicqueNeuron {
484    pub fn new(tau: f64, resistance: f64, threshold: f64, dt: f64) -> Self {
485        Self {
486            v: 0.0,
487            v_rest: 0.0,
488            v_reset: 0.0,
489            v_threshold: threshold,
490            tau,
491            resistance,
492            dt,
493        }
494    }
495
496    pub fn step(&mut self, current: f64) -> i32 {
497        let dv = (-(self.v - self.v_rest) + self.resistance * current) / self.tau * self.dt;
498        self.v += dv;
499
500        if self.v >= self.v_threshold {
501            self.v = self.v_reset;
502            1
503        } else {
504            0
505        }
506    }
507
508    pub fn reset(&mut self) {
509        self.v = self.v_rest;
510    }
511}
512
513#[cfg(test)]
514mod tests {
515
516    #[test]
517    fn test_exp_if_optimisation_parity() {
518        let mut n = ExpIfNeuron::new();
519        n.v = -60.0;
520        let current = 10.0;
521
522        // Manual calculation of original formula
523        let exp_arg = ((-60.0_f64 - (-55.0)) / 2.0).clamp(-20.0, 20.0);
524        let exp_term = 2.0 * exp_arg.exp();
525        let expected_dv = (-(-60.0 - (-65.0)) + exp_term + current) / 20.0 * 0.1;
526
527        n.step(current);
528        let got_dv = n.v - (-60.0); // Simple check since we only did one step
529
530        // Use a small epsilon for float parity
531        assert!(
532            (got_dv - expected_dv).abs() < 1e-15,
533            "Logic mismatch in ExpIfNeuron: got {}, expected {}",
534            got_dv,
535            expected_dv
536        );
537    }
538
539    use super::{
540        mask, AdExNeuron, BitstreamAverager, DendriticNeuron, ExpIfNeuron, FixedPointLif,
541        HomeostaticLif, Izhikevich, LapicqueNeuron,
542    };
543
544    #[test]
545    fn mask_branchless_matches_original() {
546        for &width in &[16_u32, 32] {
547            for value in [
548                -32768_i32,
549                -1,
550                0,
551                1,
552                32767,
553                65535,
554                -65536,
555                i16::MAX as i32,
556                i16::MIN as i32,
557            ] {
558                let result = mask(value, width);
559
560                let m = (1_i64 << width) - 1;
561                let mut v = (value as i64) & m;
562                if v >= (1_i64 << (width - 1)) {
563                    v -= 1_i64 << width;
564                }
565                let expected = if width >= 32 {
566                    v as i32 as i16
567                } else {
568                    v as i16
569                };
570
571                assert_eq!(
572                    result, expected,
573                    "mask({value}, {width}): got {result}, expected {expected}"
574                );
575            }
576        }
577    }
578
579    #[test]
580    fn lif_fires_with_refractory_period() {
581        // Q8.8: threshold=1.0 → 256, matching Python default
582        let mut n = FixedPointLif::new(16, 8, 0, 0, 256, 2);
583        let mut spikes = Vec::new();
584        for _ in 0..30 {
585            let (s, _) = n.step(1, 256, 50, 0);
586            spikes.push(s);
587        }
588        let total: i32 = spikes.iter().sum();
589        assert!(total > 0, "neuron must fire with refractory_period=2");
590        // Refractory gap: after a spike, next 2 steps must be silent.
591        for (i, &s) in spikes.iter().enumerate() {
592            if s == 1 && i + 2 < spikes.len() {
593                assert_eq!(spikes[i + 1], 0, "step {} should be refractory", i + 1);
594                assert_eq!(spikes[i + 2], 0, "step {} should be refractory", i + 2);
595            }
596        }
597    }
598
599    #[test]
600    fn lif_fires_without_refractory() {
601        let mut n = FixedPointLif::new(16, 8, 0, 0, 256, 0);
602        let mut total = 0;
603        for _ in 0..20 {
604            let (s, _) = n.step(1, 256, 50, 0);
605            total += s;
606        }
607        assert!(total > 0, "neuron must fire with refractory_period=0");
608    }
609
610    // ── Izhikevich tests ──────────────────────────────────────────
611
612    #[test]
613    fn izhikevich_regular_spiking_fires() {
614        let mut n = Izhikevich::regular_spiking();
615        let mut total = 0;
616        for _ in 0..100 {
617            total += n.step(10.0);
618        }
619        assert!(total > 0, "RS neuron must fire with I=10");
620    }
621
622    #[test]
623    fn izhikevich_no_spike_without_input() {
624        let mut n = Izhikevich::regular_spiking();
625        let mut total = 0;
626        for _ in 0..100 {
627            total += n.step(0.0);
628        }
629        assert_eq!(total, 0, "no spikes without input");
630    }
631
632    #[test]
633    fn izhikevich_reset_clears_state() {
634        let mut n = Izhikevich::regular_spiking();
635        for _ in 0..50 {
636            n.step(10.0);
637        }
638        n.reset();
639        assert_eq!(n.v, n.c);
640        assert!((n.u - n.b * n.c).abs() < 1e-12);
641    }
642
643    #[test]
644    fn izhikevich_chattering_fires_more() {
645        // Chattering: a=0.02, b=0.2, c=-50, d=2
646        let mut ch = Izhikevich::new(0.02, 0.2, -50.0, 2.0, 1.0);
647        let mut rs = Izhikevich::regular_spiking();
648        let mut ch_spikes = 0;
649        let mut rs_spikes = 0;
650        for _ in 0..200 {
651            ch_spikes += ch.step(10.0);
652            rs_spikes += rs.step(10.0);
653        }
654        assert!(
655            ch_spikes > rs_spikes,
656            "chattering ({ch_spikes}) should fire more than RS ({rs_spikes})"
657        );
658    }
659
660    // ── BitstreamAverager tests ───────────────────────────────────
661
662    #[test]
663    fn averager_all_ones() {
664        let mut avg = BitstreamAverager::new(100);
665        for _ in 0..100 {
666            avg.push(1);
667        }
668        assert!((avg.estimate() - 1.0).abs() < 1e-12);
669    }
670
671    #[test]
672    fn averager_all_zeros() {
673        let mut avg = BitstreamAverager::new(50);
674        for _ in 0..50 {
675            avg.push(0);
676        }
677        assert!(avg.estimate().abs() < 1e-12);
678    }
679
680    #[test]
681    fn averager_half() {
682        let mut avg = BitstreamAverager::new(100);
683        for i in 0..100 {
684            avg.push((i % 2) as u8);
685        }
686        assert!((avg.estimate() - 0.5).abs() < 1e-12);
687    }
688
689    #[test]
690    fn averager_sliding_window() {
691        let mut avg = BitstreamAverager::new(4);
692        // Fill: [1, 1, 0, 0] → 0.5
693        for &b in &[1_u8, 1, 0, 0] {
694            avg.push(b);
695        }
696        assert!((avg.estimate() - 0.5).abs() < 1e-12);
697        // Push 1 → [1, 1, 0, 1] (oldest 1 replaced by 1) → wait
698        // Actually buffer is circular: index=0, push 1 replaces buffer[0]=1 with 1 → still 0.5
699        avg.push(1);
700        // Buffer: [1, 1, 0, 0] → index wraps to 0, push 1 at index 0: [1, 1, 0, 0] → [1, 1, 0, 0] no wait
701        // filled=true after first wrap. push(1) at index 0: old=1, new=1, sum stays 2 → 0.5
702        assert!((avg.estimate() - 0.5).abs() < 1e-12);
703        // Push 1 at index 1: old=1, new=1 → still 0.5
704        avg.push(1);
705        assert!((avg.estimate() - 0.5).abs() < 1e-12);
706        // Push 1 at index 2: old=0, new=1 → sum=3 → 0.75
707        avg.push(1);
708        assert!((avg.estimate() - 0.75).abs() < 1e-12);
709    }
710
711    #[test]
712    fn averager_partial_fill() {
713        let mut avg = BitstreamAverager::new(100);
714        avg.push(1);
715        avg.push(0);
716        assert!((avg.estimate() - 0.5).abs() < 1e-12);
717    }
718
719    #[test]
720    fn averager_empty_returns_zero() {
721        let avg = BitstreamAverager::new(10);
722        assert!(avg.estimate().abs() < 1e-12);
723    }
724
725    // ── HomeostaticLif tests ──────────────────────────────────────
726
727    #[test]
728    fn homeostatic_fires_with_strong_input() {
729        let mut n = HomeostaticLif::with_defaults();
730        let mut total = 0;
731        for _ in 0..200 {
732            total += n.step(25.0);
733        }
734        assert!(total > 0, "must fire with strong input");
735    }
736
737    #[test]
738    fn homeostatic_threshold_adapts() {
739        let mut n = HomeostaticLif::with_defaults();
740        let initial = n.v_threshold;
741        for _ in 0..500 {
742            n.step(25.0);
743        }
744        assert!(
745            (n.v_threshold - initial).abs() > 1e-6,
746            "threshold must adapt"
747        );
748    }
749
750    #[test]
751    fn homeostatic_no_fire_without_input() {
752        let mut n = HomeostaticLif::with_defaults();
753        let mut total = 0;
754        for _ in 0..100 {
755            total += n.step(0.0);
756        }
757        assert_eq!(total, 0);
758    }
759
760    #[test]
761    fn homeostatic_threshold_bounded() {
762        let mut n = HomeostaticLif::with_defaults();
763        for _ in 0..10000 {
764            n.step(50.0);
765        }
766        assert!(n.v_threshold >= 0.1);
767        assert!(n.v_threshold <= 10.0);
768    }
769
770    // ── DendriticNeuron tests ─────────────────────────────────────
771
772    #[test]
773    fn dendritic_xor_truth_table() {
774        let mut n = DendriticNeuron::new(0.5);
775        assert_eq!(n.step(0.0, 0.0), 0); // 0+0-0 = 0
776        assert_eq!(n.step(1.0, 0.0), 1); // 1+0-0 = 1
777        assert_eq!(n.step(0.0, 1.0), 1); // 0+1-0 = 1
778        assert_eq!(n.step(1.0, 1.0), 0); // 1+1-2 = 0
779    }
780
781    #[test]
782    fn dendritic_subthreshold() {
783        let mut n = DendriticNeuron::new(0.5);
784        assert_eq!(n.step(0.2, 0.1), 0);
785    }
786
787    #[test]
788    fn dendritic_reset() {
789        let mut n = DendriticNeuron::with_defaults();
790        n.step(1.0, 0.0);
791        n.reset();
792        assert!((n.last_current).abs() < 1e-12);
793    }
794
795    #[test]
796    fn averager_reset() {
797        let mut avg = BitstreamAverager::new(10);
798        for _ in 0..10 {
799            avg.push(1);
800        }
801        avg.reset();
802        assert!(avg.estimate().abs() < 1e-12);
803    }
804
805    // ── AdEx tests ────────────────────────────────────────────────
806
807    #[test]
808    fn adex_fires_with_input() {
809        let mut n = AdExNeuron::new();
810        let mut total = 0;
811        for _ in 0..2000 {
812            total += n.step(500.0);
813        }
814        assert!(total > 0, "AdEx must fire with strong input");
815    }
816
817    #[test]
818    fn adex_adaptation_reduces_rate() {
819        let mut n = AdExNeuron::new();
820        let first_100: i32 = (0..1000).map(|_| n.step(400.0)).sum();
821        let next_100: i32 = (0..1000).map(|_| n.step(400.0)).sum();
822        // Adaptation should reduce firing over time (w grows)
823        assert!(
824            next_100 <= first_100 + 5,
825            "adaptation should not increase rate: first={first_100}, next={next_100}"
826        );
827    }
828
829    // ── ExpIF tests ───────────────────────────────────────────────
830
831    #[test]
832    fn expif_fires() {
833        let mut n = ExpIfNeuron::new();
834        let mut total = 0;
835        for _ in 0..2000 {
836            total += n.step(500.0);
837        }
838        assert!(total > 0, "ExpIF must fire");
839    }
840
841    #[test]
842    fn expif_no_fire_without_input() {
843        let mut n = ExpIfNeuron::new();
844        let total: i32 = (0..500).map(|_| n.step(0.0)).sum();
845        assert_eq!(total, 0);
846    }
847
848    // ── Lapicque tests ────────────────────────────────────────────
849
850    #[test]
851    fn lapicque_fires() {
852        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
853        let mut total = 0;
854        for _ in 0..200 {
855            total += n.step(5.0);
856        }
857        assert!(total > 0, "Lapicque must fire with sustained input");
858    }
859
860    #[test]
861    fn lapicque_reset() {
862        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
863        for _ in 0..50 {
864            n.step(5.0);
865        }
866        n.reset();
867        assert!((n.v).abs() < 1e-12);
868    }
869
870    // ── AdEx coverage tests ────────────────────────────────────────
871
872    #[test]
873    fn adex_no_fire_without_input() {
874        let mut n = AdExNeuron::new();
875        let total: i32 = (0..1000).map(|_| n.step(0.0)).sum();
876        assert_eq!(total, 0);
877    }
878
879    #[test]
880    fn adex_negative_current_no_fire() {
881        let mut n = AdExNeuron::new();
882        let total: i32 = (0..500).map(|_| n.step(-100.0)).sum();
883        assert_eq!(total, 0, "negative current must not cause spikes");
884    }
885
886    #[test]
887    fn adex_reset_roundtrip() {
888        let mut n = AdExNeuron::new();
889        for _ in 0..200 {
890            n.step(500.0);
891        }
892        assert!(n.w > 0.0, "w must grow during spiking");
893        n.reset();
894        assert_eq!(n.v, n.v_rest);
895        assert_eq!(n.w, 0.0);
896        // Post-reset: should behave identically to fresh
897        let mut fresh = AdExNeuron::new();
898        let r1: i32 = (0..100).map(|_| n.step(500.0)).sum();
899        let r2: i32 = (0..100).map(|_| fresh.step(500.0)).sum();
900        assert_eq!(r1, r2, "reset neuron must match fresh neuron");
901    }
902
903    #[test]
904    fn adex_voltage_bounded() {
905        let mut n = AdExNeuron::new();
906        for _ in 0..5000 {
907            n.step(1000.0);
908        }
909        assert!(n.v.is_finite(), "voltage must stay finite");
910        assert!(n.w.is_finite(), "adaptation must stay finite");
911    }
912
913    #[test]
914    fn adex_pipeline_sustained_spiking() {
915        let mut n = AdExNeuron::new();
916        let spikes: i32 = (0..10000).map(|_| n.step(500.0)).sum();
917        assert!(
918            spikes > 100,
919            "sustained input should produce many spikes: got {spikes}"
920        );
921        assert!(n.v.is_finite());
922    }
923
924    #[test]
925    fn adex_performance_10k_steps() {
926        let mut n = AdExNeuron::new();
927        let start = std::time::Instant::now();
928        for _ in 0..10_000 {
929            n.step(500.0);
930        }
931        let elapsed = start.elapsed();
932        assert!(
933            elapsed.as_millis() < 50,
934            "10k steps took too long: {:?}",
935            elapsed
936        );
937    }
938
939    // ── ExpIF coverage tests ───────────────────────────────────────
940
941    #[test]
942    fn expif_negative_current_no_fire() {
943        let mut n = ExpIfNeuron::new();
944        let total: i32 = (0..500).map(|_| n.step(-100.0)).sum();
945        assert_eq!(total, 0);
946    }
947
948    #[test]
949    fn expif_reset_roundtrip() {
950        let mut n = ExpIfNeuron::new();
951        for _ in 0..200 {
952            n.step(500.0);
953        }
954        n.reset();
955        assert_eq!(n.v, n.v_rest);
956        let mut fresh = ExpIfNeuron::new();
957        let r1: i32 = (0..100).map(|_| n.step(500.0)).sum();
958        let r2: i32 = (0..100).map(|_| fresh.step(500.0)).sum();
959        assert_eq!(r1, r2);
960    }
961
962    #[test]
963    fn expif_voltage_bounded() {
964        let mut n = ExpIfNeuron::new();
965        for _ in 0..5000 {
966            n.step(1000.0);
967        }
968        assert!(n.v.is_finite());
969    }
970
971    #[test]
972    fn expif_fires_more_than_adex() {
973        // ExpIF has no adaptation, should fire at least as much as AdEx
974        let mut eif = ExpIfNeuron::new();
975        let mut adex = AdExNeuron::new();
976        let eif_spikes: i32 = (0..5000).map(|_| eif.step(500.0)).sum();
977        let adex_spikes: i32 = (0..5000).map(|_| adex.step(500.0)).sum();
978        assert!(
979            eif_spikes >= adex_spikes,
980            "ExpIF ({eif_spikes}) should fire >= AdEx ({adex_spikes}) due to no adaptation"
981        );
982    }
983
984    #[test]
985    fn expif_performance_10k_steps() {
986        let mut n = ExpIfNeuron::new();
987        let start = std::time::Instant::now();
988        for _ in 0..10_000 {
989            n.step(500.0);
990        }
991        let elapsed = start.elapsed();
992        assert!(
993            elapsed.as_millis() < 50,
994            "10k steps took too long: {:?}",
995            elapsed
996        );
997    }
998
999    // ── Lapicque coverage tests ────────────────────────────────────
1000
1001    #[test]
1002    fn lapicque_no_fire_without_input() {
1003        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1004        let total: i32 = (0..500).map(|_| n.step(0.0)).sum();
1005        assert_eq!(total, 0);
1006    }
1007
1008    #[test]
1009    fn lapicque_negative_current_no_fire() {
1010        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1011        let total: i32 = (0..500).map(|_| n.step(-5.0)).sum();
1012        assert_eq!(total, 0);
1013    }
1014
1015    #[test]
1016    fn lapicque_reset_roundtrip() {
1017        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1018        for _ in 0..100 {
1019            n.step(5.0);
1020        }
1021        n.reset();
1022        assert_eq!(n.v, n.v_rest);
1023        let mut fresh = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1024        let r1: i32 = (0..100).map(|_| n.step(5.0)).sum();
1025        let r2: i32 = (0..100).map(|_| fresh.step(5.0)).sum();
1026        assert_eq!(r1, r2);
1027    }
1028
1029    #[test]
1030    fn lapicque_voltage_bounded() {
1031        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1032        for _ in 0..5000 {
1033            n.step(100.0);
1034        }
1035        assert!(n.v.is_finite());
1036    }
1037
1038    #[test]
1039    fn lapicque_higher_resistance_fires_faster() {
1040        let mut lo = LapicqueNeuron::new(20.0, 0.5, 1.0, 1.0);
1041        let mut hi = LapicqueNeuron::new(20.0, 2.0, 1.0, 1.0);
1042        let lo_spikes: i32 = (0..200).map(|_| lo.step(1.0)).sum();
1043        let hi_spikes: i32 = (0..200).map(|_| hi.step(1.0)).sum();
1044        assert!(
1045            hi_spikes >= lo_spikes,
1046            "higher R ({hi_spikes}) should fire >= lower R ({lo_spikes})"
1047        );
1048    }
1049
1050    #[test]
1051    fn lapicque_performance_10k_steps() {
1052        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1053        let start = std::time::Instant::now();
1054        for _ in 0..10_000 {
1055            n.step(5.0);
1056        }
1057        let elapsed = start.elapsed();
1058        assert!(
1059            elapsed.as_millis() < 50,
1060            "10k steps took too long: {:?}",
1061            elapsed
1062        );
1063    }
1064
1065    #[test]
1066    fn lapicque_pipeline_sustained_spiking() {
1067        let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
1068        let spikes: i32 = (0..10000).map(|_| n.step(5.0)).sum();
1069        assert!(
1070            spikes > 100,
1071            "sustained input should produce many spikes: got {spikes}"
1072        );
1073    }
1074}