1#[inline]
17pub fn mask(value: i32, width: u32) -> i16 {
18 assert!(
19 width > 0 && width <= 32,
20 "mask width must be 1..=32, got {width}"
21 );
22 let m = (1_i64 << width) - 1;
23 let v = (value as i64) & m;
24 let shift = 64 - width;
25 ((v << shift) >> shift) as i16
26}
27
28#[derive(Clone, Debug)]
30pub struct FixedPointLif {
31 pub v: i16,
33 pub refractory_counter: i32,
35 pub data_width: u32,
37 pub fraction: u32,
39 pub v_rest: i16,
41 pub v_reset: i16,
43 pub v_threshold: i16,
45 pub refractory_period: i32,
47}
48
49impl FixedPointLif {
50 pub fn new(
52 data_width: u32,
53 fraction: u32,
54 v_rest: i16,
55 v_reset: i16,
56 v_threshold: i16,
57 refractory_period: i32,
58 ) -> Self {
59 Self {
60 v: v_rest,
61 refractory_counter: 0,
62 data_width,
63 fraction,
64 v_rest,
65 v_reset,
66 v_threshold,
67 refractory_period,
68 }
69 }
70
71 pub fn step(&mut self, leak_k: i16, gain_k: i16, i_t: i16, noise_in: i16) -> (i32, i16) {
75 let w = self.data_width;
76
77 if self.refractory_counter > 0 {
79 self.refractory_counter -= 1;
80 self.v = self.v_rest;
81 return (0, mask(self.v_rest as i32, w));
82 }
83
84 let diff = mask((self.v_rest as i32) - (self.v as i32), 2 * w) as i32;
85 let dv_leak = mask((diff * (leak_k as i32)) >> self.fraction, self.data_width);
86 let dv_in = mask(
87 ((i_t as i32) * (gain_k as i32)) >> self.fraction,
88 self.data_width,
89 );
90
91 let v_next = mask(
92 (self.v as i32) + (dv_leak as i32) + (dv_in as i32) + (noise_in as i32),
93 self.data_width,
94 );
95
96 if v_next >= self.v_threshold {
97 self.v = self.v_reset;
98 self.refractory_counter = self.refractory_period;
99 (1, mask(self.v_reset as i32, w))
100 } else {
101 self.v = v_next;
102 (0, mask(v_next as i32, w))
103 }
104 }
105
106 pub fn reset(&mut self) {
108 self.v = self.v_rest;
109 self.refractory_counter = 0;
110 }
111}
112
113#[derive(Clone, Debug)]
120pub struct Izhikevich {
121 pub v: f64,
122 pub u: f64,
123 pub a: f64,
124 pub b: f64,
125 pub c: f64,
126 pub d: f64,
127 pub dt: f64,
128}
129
130impl Izhikevich {
131 pub fn new(a: f64, b: f64, c: f64, d: f64, dt: f64) -> Self {
133 Self {
134 v: c,
135 u: b * c,
136 a,
137 b,
138 c,
139 d,
140 dt,
141 }
142 }
143
144 pub fn regular_spiking() -> Self {
146 Self::new(0.02, 0.2, -65.0, 8.0, 1.0)
147 }
148
149 pub fn step(&mut self, current: f64) -> i32 {
151 let half = self.dt * 0.5;
153 for _ in 0..2 {
154 let dv = (0.04 * self.v * self.v + 5.0 * self.v + 140.0 - self.u + current) * half;
155 let du = (self.a * (self.b * self.v - self.u)) * half;
156 self.v += dv;
157 self.u += du;
158 }
159
160 if self.v >= 30.0 {
161 self.v = self.c;
162 self.u += self.d;
163 1
164 } else {
165 0
166 }
167 }
168
169 pub fn reset(&mut self) {
171 self.v = self.c;
172 self.u = self.b * self.c;
173 }
174}
175
176#[derive(Clone, Debug)]
180pub struct BitstreamAverager {
181 buffer: Vec<u8>,
182 index: usize,
183 filled: bool,
184 running_sum: u64,
185}
186
187impl BitstreamAverager {
188 pub fn new(window: usize) -> Self {
189 assert!(window > 0, "window must be > 0");
190 Self {
191 buffer: vec![0; window],
192 index: 0,
193 filled: false,
194 running_sum: 0,
195 }
196 }
197
198 pub fn push(&mut self, bit: u8) {
199 debug_assert!(bit <= 1, "bit must be 0 or 1");
200 let old = self.buffer[self.index];
201 self.buffer[self.index] = bit;
202
203 if self.filled {
204 self.running_sum = self.running_sum - old as u64 + bit as u64;
205 } else {
206 self.running_sum += bit as u64;
207 }
208
209 self.index += 1;
210 if self.index == self.buffer.len() {
211 self.index = 0;
212 self.filled = true;
213 }
214 }
215
216 pub fn estimate(&self) -> f64 {
217 if !self.filled {
218 if self.index == 0 {
219 return 0.0;
220 }
221 return self.running_sum as f64 / self.index as f64;
222 }
223 self.running_sum as f64 / self.buffer.len() as f64
224 }
225
226 pub fn reset(&mut self) {
227 self.buffer.fill(0);
228 self.index = 0;
229 self.filled = false;
230 self.running_sum = 0;
231 }
232
233 pub fn window(&self) -> usize {
234 self.buffer.len()
235 }
236}
237
238#[derive(Clone, Debug)]
243pub struct HomeostaticLif {
244 pub v: f64,
245 pub v_threshold: f64,
246 pub v_rest: f64,
247 pub v_reset: f64,
248 pub rate_trace: f64,
249 pub target_rate: f64,
250 pub adaptation_rate: f64,
251 pub trace_decay: f64,
252 initial_threshold: f64,
253}
254
255impl HomeostaticLif {
256 pub fn new(target_rate: f64, adaptation_rate: f64, trace_decay: f64) -> Self {
257 Self {
258 v: 0.0,
259 v_threshold: 1.0,
260 v_rest: 0.0,
261 v_reset: 0.0,
262 rate_trace: 0.0,
263 target_rate,
264 adaptation_rate,
265 trace_decay,
266 initial_threshold: 1.0,
267 }
268 }
269
270 pub fn with_defaults() -> Self {
271 Self::new(0.1, 0.01, 0.95)
272 }
273
274 pub fn step(&mut self, current: f64) -> i32 {
276 let tau = 20.0;
278 self.v += (-(self.v - self.v_rest) + current) / tau;
279
280 let spike = if self.v >= self.v_threshold {
281 self.v = self.v_reset;
282 1
283 } else {
284 0
285 };
286
287 self.rate_trace =
289 self.rate_trace * self.trace_decay + spike as f64 * (1.0 - self.trace_decay);
290
291 let error = self.rate_trace - self.target_rate;
293 self.v_threshold += self.adaptation_rate * error;
294 self.v_threshold = self.v_threshold.clamp(0.1, self.initial_threshold * 10.0);
295
296 spike
297 }
298
299 pub fn reset(&mut self) {
300 self.v = self.v_rest;
301 self.rate_trace = 0.0;
302 self.v_threshold = self.initial_threshold;
303 }
304}
305
306#[derive(Clone, Debug)]
311pub struct DendriticNeuron {
312 pub threshold: f64,
313 last_current: f64,
314}
315
316impl DendriticNeuron {
317 pub fn new(threshold: f64) -> Self {
318 Self {
319 threshold,
320 last_current: 0.0,
321 }
322 }
323
324 pub fn with_defaults() -> Self {
325 Self::new(0.5)
326 }
327
328 pub fn step(&mut self, input_a: f64, input_b: f64) -> i32 {
329 self.last_current = input_a + input_b - 2.0 * input_a * input_b;
330 if self.last_current > self.threshold {
331 1
332 } else {
333 0
334 }
335 }
336
337 pub fn reset(&mut self) {
338 self.last_current = 0.0;
339 }
340}
341
342#[derive(Clone, Debug)]
344pub struct AdExNeuron {
345 pub v: f64,
346 pub w: f64,
347 pub v_rest: f64,
348 pub v_reset: f64,
349 pub v_threshold: f64,
350 pub v_rh: f64,
351 pub delta_t: f64,
352 pub tau: f64,
353 pub tau_w: f64,
354 pub a: f64,
355 pub b: f64,
356 pub dt: f64,
357}
358
359impl Default for AdExNeuron {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365impl AdExNeuron {
366 pub fn new() -> Self {
367 Self {
368 v: -65.0,
369 w: 0.0,
370 v_rest: -65.0,
371 v_reset: -68.0,
372 v_threshold: -50.0,
373 v_rh: -55.0,
374 delta_t: 2.0,
375 tau: 20.0,
376 tau_w: 100.0,
377 a: 0.5,
378 b: 7.0,
379 dt: 0.1,
380 }
381 }
382
383 pub fn step(&mut self, current: f64) -> i32 {
384 let exp_arg = ((self.v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
385 let exp_term = self.delta_t * exp_arg.exp();
386 let dv = (-(self.v - self.v_rest) + exp_term - self.w + current) / self.tau * self.dt;
387 let dw = (self.a * (self.v - self.v_rest) - self.w) / self.tau_w * self.dt;
388 self.v += dv;
389 self.w += dw;
390
391 if self.v >= self.v_threshold {
392 self.v = self.v_reset;
393 self.w += self.b;
394 1
395 } else {
396 0
397 }
398 }
399
400 pub fn reset(&mut self) {
401 self.v = self.v_rest;
402 self.w = 0.0;
403 }
404}
405
406#[derive(Clone, Debug)]
408pub struct ExpIfNeuron {
409 pub v: f64,
410 pub v_rest: f64,
411 pub v_reset: f64,
412 pub v_threshold: f64,
413 pub v_rh: f64,
414 pub delta_t: f64,
415 pub tau: f64,
416 pub dt: f64,
417}
418
419impl Default for ExpIfNeuron {
420 fn default() -> Self {
421 Self::new()
422 }
423}
424
425impl ExpIfNeuron {
426 pub fn new() -> Self {
427 Self {
428 v: -65.0,
429 v_rest: -65.0,
430 v_reset: -68.0,
431 v_threshold: -50.0,
432 v_rh: -55.0,
433 delta_t: 2.0,
434 tau: 20.0,
435 dt: 0.1,
436 }
437 }
438
439 pub fn step(&mut self, current: f64) -> i32 {
440 let exp_arg = ((self.v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
441 let exp_term = self.delta_t * exp_arg.exp();
442 let dv = (-(self.v - self.v_rest) + exp_term + current) / self.tau * self.dt;
443 self.v += dv;
444
445 if self.v >= self.v_threshold {
446 self.v = self.v_reset;
447 1
448 } else {
449 0
450 }
451 }
452
453 pub fn reset(&mut self) {
454 self.v = self.v_rest;
455 }
456}
457
458#[derive(Clone, Debug)]
460pub struct LapicqueNeuron {
461 pub v: f64,
462 pub v_rest: f64,
463 pub v_reset: f64,
464 pub v_threshold: f64,
465 pub tau: f64,
466 pub resistance: f64,
467 pub dt: f64,
468}
469
470impl LapicqueNeuron {
471 pub fn new(tau: f64, resistance: f64, threshold: f64, dt: f64) -> Self {
472 Self {
473 v: 0.0,
474 v_rest: 0.0,
475 v_reset: 0.0,
476 v_threshold: threshold,
477 tau,
478 resistance,
479 dt,
480 }
481 }
482
483 pub fn step(&mut self, current: f64) -> i32 {
484 let dv = (-(self.v - self.v_rest) + self.resistance * current) / self.tau * self.dt;
485 self.v += dv;
486
487 if self.v >= self.v_threshold {
488 self.v = self.v_reset;
489 1
490 } else {
491 0
492 }
493 }
494
495 pub fn reset(&mut self) {
496 self.v = self.v_rest;
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::{
503 mask, AdExNeuron, BitstreamAverager, DendriticNeuron, ExpIfNeuron, FixedPointLif,
504 HomeostaticLif, Izhikevich, LapicqueNeuron,
505 };
506
507 #[test]
508 fn mask_branchless_matches_original() {
509 for &width in &[16_u32, 32] {
510 for value in [
511 -32768_i32,
512 -1,
513 0,
514 1,
515 32767,
516 65535,
517 -65536,
518 i16::MAX as i32,
519 i16::MIN as i32,
520 ] {
521 let result = mask(value, width);
522
523 let m = (1_i64 << width) - 1;
524 let mut v = (value as i64) & m;
525 if v >= (1_i64 << (width - 1)) {
526 v -= 1_i64 << width;
527 }
528 let expected = if width >= 32 {
529 v as i32 as i16
530 } else {
531 v as i16
532 };
533
534 assert_eq!(
535 result, expected,
536 "mask({value}, {width}): got {result}, expected {expected}"
537 );
538 }
539 }
540 }
541
542 #[test]
543 fn lif_fires_with_refractory_period() {
544 let mut n = FixedPointLif::new(16, 8, 0, 0, 256, 2);
546 let mut spikes = Vec::new();
547 for _ in 0..30 {
548 let (s, _) = n.step(1, 256, 50, 0);
549 spikes.push(s);
550 }
551 let total: i32 = spikes.iter().sum();
552 assert!(total > 0, "neuron must fire with refractory_period=2");
553 for (i, &s) in spikes.iter().enumerate() {
555 if s == 1 && i + 2 < spikes.len() {
556 assert_eq!(spikes[i + 1], 0, "step {} should be refractory", i + 1);
557 assert_eq!(spikes[i + 2], 0, "step {} should be refractory", i + 2);
558 }
559 }
560 }
561
562 #[test]
563 fn lif_fires_without_refractory() {
564 let mut n = FixedPointLif::new(16, 8, 0, 0, 256, 0);
565 let mut total = 0;
566 for _ in 0..20 {
567 let (s, _) = n.step(1, 256, 50, 0);
568 total += s;
569 }
570 assert!(total > 0, "neuron must fire with refractory_period=0");
571 }
572
573 #[test]
576 fn izhikevich_regular_spiking_fires() {
577 let mut n = Izhikevich::regular_spiking();
578 let mut total = 0;
579 for _ in 0..100 {
580 total += n.step(10.0);
581 }
582 assert!(total > 0, "RS neuron must fire with I=10");
583 }
584
585 #[test]
586 fn izhikevich_no_spike_without_input() {
587 let mut n = Izhikevich::regular_spiking();
588 let mut total = 0;
589 for _ in 0..100 {
590 total += n.step(0.0);
591 }
592 assert_eq!(total, 0, "no spikes without input");
593 }
594
595 #[test]
596 fn izhikevich_reset_clears_state() {
597 let mut n = Izhikevich::regular_spiking();
598 for _ in 0..50 {
599 n.step(10.0);
600 }
601 n.reset();
602 assert_eq!(n.v, n.c);
603 assert!((n.u - n.b * n.c).abs() < 1e-12);
604 }
605
606 #[test]
607 fn izhikevich_chattering_fires_more() {
608 let mut ch = Izhikevich::new(0.02, 0.2, -50.0, 2.0, 1.0);
610 let mut rs = Izhikevich::regular_spiking();
611 let mut ch_spikes = 0;
612 let mut rs_spikes = 0;
613 for _ in 0..200 {
614 ch_spikes += ch.step(10.0);
615 rs_spikes += rs.step(10.0);
616 }
617 assert!(
618 ch_spikes > rs_spikes,
619 "chattering ({ch_spikes}) should fire more than RS ({rs_spikes})"
620 );
621 }
622
623 #[test]
626 fn averager_all_ones() {
627 let mut avg = BitstreamAverager::new(100);
628 for _ in 0..100 {
629 avg.push(1);
630 }
631 assert!((avg.estimate() - 1.0).abs() < 1e-12);
632 }
633
634 #[test]
635 fn averager_all_zeros() {
636 let mut avg = BitstreamAverager::new(50);
637 for _ in 0..50 {
638 avg.push(0);
639 }
640 assert!(avg.estimate().abs() < 1e-12);
641 }
642
643 #[test]
644 fn averager_half() {
645 let mut avg = BitstreamAverager::new(100);
646 for i in 0..100 {
647 avg.push((i % 2) as u8);
648 }
649 assert!((avg.estimate() - 0.5).abs() < 1e-12);
650 }
651
652 #[test]
653 fn averager_sliding_window() {
654 let mut avg = BitstreamAverager::new(4);
655 for &b in &[1_u8, 1, 0, 0] {
657 avg.push(b);
658 }
659 assert!((avg.estimate() - 0.5).abs() < 1e-12);
660 avg.push(1);
663 assert!((avg.estimate() - 0.5).abs() < 1e-12);
666 avg.push(1);
668 assert!((avg.estimate() - 0.5).abs() < 1e-12);
669 avg.push(1);
671 assert!((avg.estimate() - 0.75).abs() < 1e-12);
672 }
673
674 #[test]
675 fn averager_partial_fill() {
676 let mut avg = BitstreamAverager::new(100);
677 avg.push(1);
678 avg.push(0);
679 assert!((avg.estimate() - 0.5).abs() < 1e-12);
680 }
681
682 #[test]
683 fn averager_empty_returns_zero() {
684 let avg = BitstreamAverager::new(10);
685 assert!(avg.estimate().abs() < 1e-12);
686 }
687
688 #[test]
691 fn homeostatic_fires_with_strong_input() {
692 let mut n = HomeostaticLif::with_defaults();
693 let mut total = 0;
694 for _ in 0..200 {
695 total += n.step(25.0);
696 }
697 assert!(total > 0, "must fire with strong input");
698 }
699
700 #[test]
701 fn homeostatic_threshold_adapts() {
702 let mut n = HomeostaticLif::with_defaults();
703 let initial = n.v_threshold;
704 for _ in 0..500 {
705 n.step(25.0);
706 }
707 assert!(
708 (n.v_threshold - initial).abs() > 1e-6,
709 "threshold must adapt"
710 );
711 }
712
713 #[test]
714 fn homeostatic_no_fire_without_input() {
715 let mut n = HomeostaticLif::with_defaults();
716 let mut total = 0;
717 for _ in 0..100 {
718 total += n.step(0.0);
719 }
720 assert_eq!(total, 0);
721 }
722
723 #[test]
724 fn homeostatic_threshold_bounded() {
725 let mut n = HomeostaticLif::with_defaults();
726 for _ in 0..10000 {
727 n.step(50.0);
728 }
729 assert!(n.v_threshold >= 0.1);
730 assert!(n.v_threshold <= 10.0);
731 }
732
733 #[test]
736 fn dendritic_xor_truth_table() {
737 let mut n = DendriticNeuron::new(0.5);
738 assert_eq!(n.step(0.0, 0.0), 0); assert_eq!(n.step(1.0, 0.0), 1); assert_eq!(n.step(0.0, 1.0), 1); assert_eq!(n.step(1.0, 1.0), 0); }
743
744 #[test]
745 fn dendritic_subthreshold() {
746 let mut n = DendriticNeuron::new(0.5);
747 assert_eq!(n.step(0.2, 0.1), 0);
748 }
749
750 #[test]
751 fn dendritic_reset() {
752 let mut n = DendriticNeuron::with_defaults();
753 n.step(1.0, 0.0);
754 n.reset();
755 assert!((n.last_current).abs() < 1e-12);
756 }
757
758 #[test]
759 fn averager_reset() {
760 let mut avg = BitstreamAverager::new(10);
761 for _ in 0..10 {
762 avg.push(1);
763 }
764 avg.reset();
765 assert!(avg.estimate().abs() < 1e-12);
766 }
767
768 #[test]
771 fn adex_fires_with_input() {
772 let mut n = AdExNeuron::new();
773 let mut total = 0;
774 for _ in 0..2000 {
775 total += n.step(500.0);
776 }
777 assert!(total > 0, "AdEx must fire with strong input");
778 }
779
780 #[test]
781 fn adex_adaptation_reduces_rate() {
782 let mut n = AdExNeuron::new();
783 let first_100: i32 = (0..1000).map(|_| n.step(400.0)).sum();
784 let next_100: i32 = (0..1000).map(|_| n.step(400.0)).sum();
785 assert!(
787 next_100 <= first_100 + 5,
788 "adaptation should not increase rate: first={first_100}, next={next_100}"
789 );
790 }
791
792 #[test]
795 fn expif_fires() {
796 let mut n = ExpIfNeuron::new();
797 let mut total = 0;
798 for _ in 0..2000 {
799 total += n.step(500.0);
800 }
801 assert!(total > 0, "ExpIF must fire");
802 }
803
804 #[test]
805 fn expif_no_fire_without_input() {
806 let mut n = ExpIfNeuron::new();
807 let total: i32 = (0..500).map(|_| n.step(0.0)).sum();
808 assert_eq!(total, 0);
809 }
810
811 #[test]
814 fn lapicque_fires() {
815 let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
816 let mut total = 0;
817 for _ in 0..200 {
818 total += n.step(5.0);
819 }
820 assert!(total > 0, "Lapicque must fire with sustained input");
821 }
822
823 #[test]
824 fn lapicque_reset() {
825 let mut n = LapicqueNeuron::new(20.0, 1.0, 1.0, 1.0);
826 for _ in 0..50 {
827 n.step(5.0);
828 }
829 n.reset();
830 assert!((n.v).abs() < 1e-12);
831 }
832}