1use crate::neuron::mask;
15
16#[derive(Clone, Copy, Debug)]
18pub struct StdpParams {
19 pub a_plus: i16,
20 pub a_minus: i16,
21 pub decay: i16,
22 pub w_min: i16,
23 pub w_max: i16,
24}
25
26#[derive(Clone, Debug)]
28pub struct StdpSynapse {
29 pub weight: i16,
31 pub trace_pre: i16,
33 pub trace_post: i16,
35 pub data_width: u32,
37 pub fraction: u32,
39}
40
41impl StdpSynapse {
42 pub fn new(initial_weight: i16, data_width: u32, fraction: u32) -> Self {
43 Self {
44 weight: initial_weight,
45 trace_pre: 0,
46 trace_post: 0,
47 data_width,
48 fraction,
49 }
50 }
51
52 pub fn step(&mut self, pre_spike: bool, post_spike: bool, params: &StdpParams) {
59 self.trace_pre = mask(
61 (self.trace_pre as i32 * params.decay as i32) >> self.fraction,
62 self.data_width,
63 );
64 self.trace_post = mask(
65 (self.trace_post as i32 * params.decay as i32) >> self.fraction,
66 self.data_width,
67 );
68
69 if pre_spike {
71 self.trace_pre = mask(
72 self.trace_pre as i32 + params.a_plus as i32,
73 self.data_width,
74 );
75 }
76 if post_spike {
77 self.trace_post = mask(
78 self.trace_post as i32 + params.a_minus as i32,
79 self.data_width,
80 );
81 }
82
83 if post_spike {
86 let dw = (self.trace_pre as i32 * params.a_plus.abs() as i32) >> self.fraction;
87 let new_w = (self.weight as i32 + dw).min(params.w_max as i32);
88 self.weight = mask(new_w, self.data_width);
89 } else if pre_spike {
90 let dw = (self.trace_post as i32 * params.a_minus.abs() as i32) >> self.fraction;
91 let new_w = (self.weight as i32 - dw).max(params.w_min as i32);
92 self.weight = mask(new_w, self.data_width);
93 }
94 }
95}
96
97#[derive(Clone, Debug)]
103pub struct RewardStdpSynapse {
104 pub weight: f64,
105 pub w_min: f64,
106 pub w_max: f64,
107 pub eligibility: f64,
108 pub trace_decay: f64,
109 pub anti_hebbian_scale: f64,
110 pub learning_rate: f64,
111}
112
113impl RewardStdpSynapse {
114 pub fn new(w: f64, w_min: f64, w_max: f64) -> Self {
115 Self {
116 weight: w,
117 w_min,
118 w_max,
119 eligibility: 0.0,
120 trace_decay: 0.95,
121 anti_hebbian_scale: 0.5,
122 learning_rate: 0.01,
123 }
124 }
125
126 pub fn step(&mut self, pre: bool, post: bool) {
128 if pre && post {
129 self.eligibility += 1.0;
130 } else if pre && !post {
131 self.eligibility -= self.anti_hebbian_scale;
132 }
133 self.eligibility *= self.trace_decay;
134 }
135
136 pub fn apply_reward(&mut self, reward: f64) {
138 let update = self.learning_rate * reward * self.eligibility;
139 self.weight = (self.weight + update).clamp(self.w_min, self.w_max);
140 }
141}
142
143#[derive(Clone, Debug)]
145pub struct StaticSynapse {
146 pub weight: f64,
147 pub is_excitatory: bool,
148 pub delay: u32,
149}
150
151impl StaticSynapse {
152 pub fn new(weight: f64, is_excitatory: bool) -> Self {
153 Self {
154 weight: weight.abs(),
155 is_excitatory,
156 delay: 0,
157 }
158 }
159
160 pub fn transmit(&self, pre_spike: bool) -> f64 {
162 if !pre_spike {
163 return 0.0;
164 }
165 if self.is_excitatory {
166 self.weight
167 } else {
168 -self.weight
169 }
170 }
171}
172
173#[derive(Clone, Debug)]
185pub struct TripletStdpSynapse {
186 pub weight: f64,
187 pub w_min: f64,
188 pub w_max: f64,
189 pub r1: f64,
191 pub o1: f64,
192 pub r2: f64,
194 pub o2: f64,
195 pub tau_plus: f64,
197 pub tau_minus: f64,
198 pub tau_x: f64,
199 pub tau_y: f64,
200 pub a2_plus: f64,
202 pub a2_minus: f64,
203 pub a3_plus: f64,
204 pub a3_minus: f64,
205 pub dt: f64,
206}
207
208impl TripletStdpSynapse {
209 pub fn new(weight: f64, w_min: f64, w_max: f64) -> Self {
210 Self {
211 weight,
212 w_min,
213 w_max,
214 r1: 0.0,
215 o1: 0.0,
216 r2: 0.0,
217 o2: 0.0,
218 tau_plus: 16.8,
219 tau_minus: 33.7,
220 tau_x: 101.0,
221 tau_y: 125.0,
222 a2_plus: 0.005,
223 a2_minus: 0.007,
224 a3_plus: 0.006,
225 a3_minus: 0.002,
226 dt: 1.0,
227 }
228 }
229
230 pub fn step(&mut self, pre_spike: bool, post_spike: bool) {
231 self.r1 *= (-self.dt / self.tau_plus).exp();
233 self.o1 *= (-self.dt / self.tau_minus).exp();
234 self.r2 *= (-self.dt / self.tau_x).exp();
235 self.o2 *= (-self.dt / self.tau_y).exp();
236
237 if pre_spike {
238 let dw_minus = -(self.a2_minus + self.a3_minus * self.r2) * self.o1;
240 self.weight = (self.weight + dw_minus).clamp(self.w_min, self.w_max);
241 self.r1 += 1.0;
242 self.r2 += 1.0;
243 }
244
245 if post_spike {
246 let dw_plus = (self.a2_plus + self.a3_plus * self.o2) * self.r1;
248 self.weight = (self.weight + dw_plus).clamp(self.w_min, self.w_max);
249 self.o1 += 1.0;
250 self.o2 += 1.0;
251 }
252 }
253}
254
255#[derive(Clone, Debug)]
267pub struct ShortTermPlasticitySynapse {
268 pub x: f64,
269 pub u: f64,
270 pub u_base: f64,
271 pub tau_d: f64,
272 pub tau_f: f64,
273 pub amplitude: f64,
274 pub dt: f64,
275}
276
277impl ShortTermPlasticitySynapse {
278 pub fn new_depressing() -> Self {
280 Self {
281 x: 1.0,
282 u: 0.5,
283 u_base: 0.5,
284 tau_d: 200.0,
285 tau_f: 20.0,
286 amplitude: 1.0,
287 dt: 1.0,
288 }
289 }
290
291 pub fn new_facilitating() -> Self {
293 Self {
294 x: 1.0,
295 u: 0.1,
296 u_base: 0.1,
297 tau_d: 50.0,
298 tau_f: 500.0,
299 amplitude: 1.0,
300 dt: 1.0,
301 }
302 }
303
304 pub fn step(&mut self, pre_spike: bool) -> f64 {
306 self.x += (1.0 - self.x) / self.tau_d * self.dt;
308 self.u += (self.u_base - self.u) / self.tau_f * self.dt;
309
310 if pre_spike {
311 self.u += self.u_base * (1.0 - self.u);
313 let psc = self.amplitude * self.u * self.x;
315 self.x -= self.u * self.x;
317 self.x = self.x.max(0.0);
318 psc
319 } else {
320 0.0
321 }
322 }
323
324 pub fn reset(&mut self) {
325 self.x = 1.0;
326 self.u = self.u_base;
327 }
328}
329
330#[derive(Clone, Debug)]
342pub struct DopamineStdpSynapse {
343 pub weight: f64,
344 pub w_min: f64,
345 pub w_max: f64,
346 pub eligibility: f64,
347 pub dopamine: f64,
348 pub trace_pre: f64,
349 pub trace_post: f64,
350 pub tau_e: f64,
351 pub tau_da: f64,
352 pub tau_pre: f64,
353 pub tau_post: f64,
354 pub a_plus: f64,
355 pub a_minus: f64,
356 pub lr: f64,
357 pub dt: f64,
358}
359
360impl DopamineStdpSynapse {
361 pub fn new(weight: f64, w_min: f64, w_max: f64) -> Self {
362 Self {
363 weight,
364 w_min,
365 w_max,
366 eligibility: 0.0,
367 dopamine: 0.0,
368 trace_pre: 0.0,
369 trace_post: 0.0,
370 tau_e: 1000.0,
371 tau_da: 200.0,
372 tau_pre: 20.0,
373 tau_post: 20.0,
374 a_plus: 1.0,
375 a_minus: -1.0,
376 lr: 0.001,
377 dt: 1.0,
378 }
379 }
380
381 pub fn step(&mut self, pre_spike: bool, post_spike: bool, reward: f64) {
382 self.trace_pre *= (-self.dt / self.tau_pre).exp();
384 self.trace_post *= (-self.dt / self.tau_post).exp();
385 self.eligibility *= (-self.dt / self.tau_e).exp();
386 self.dopamine += (-self.dopamine / self.tau_da + reward) * self.dt;
387
388 if pre_spike {
389 self.eligibility += self.a_minus * self.trace_post;
391 self.trace_pre += 1.0;
392 }
393 if post_spike {
394 self.eligibility += self.a_plus * self.trace_pre;
396 self.trace_post += 1.0;
397 }
398
399 let dw = self.lr * self.dopamine * self.eligibility * self.dt;
401 self.weight = (self.weight + dw).clamp(self.w_min, self.w_max);
402 }
403
404 pub fn reset(&mut self) {
405 self.eligibility = 0.0;
406 self.dopamine = 0.0;
407 self.trace_pre = 0.0;
408 self.trace_post = 0.0;
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 fn default_params() -> StdpParams {
417 StdpParams {
418 a_plus: 64, a_minus: 48, decay: 230, w_min: 0,
422 w_max: 255,
423 }
424 }
425
426 #[test]
427 fn potentiation_increases_weight() {
428 let mut syn = StdpSynapse::new(128, 16, 8);
429 let params = default_params();
430 for _ in 0..5 {
432 syn.step(true, false, ¶ms);
433 }
434 let w_before = syn.weight;
435 syn.step(false, true, ¶ms);
437 assert!(syn.weight > w_before, "LTP must increase weight");
438 }
439
440 #[test]
441 fn depression_decreases_weight() {
442 let mut syn = StdpSynapse::new(128, 16, 8);
443 let params = default_params();
444 for _ in 0..5 {
446 syn.step(false, true, ¶ms);
447 }
448 let w_before = syn.weight;
449 syn.step(true, false, ¶ms);
451 assert!(syn.weight < w_before, "LTD must decrease weight");
452 }
453
454 #[test]
457 fn rstdp_positive_reward_potentiates() {
458 let mut syn = RewardStdpSynapse::new(0.5, 0.0, 1.0);
459 for _ in 0..10 {
460 syn.step(true, true);
461 }
462 let w_before = syn.weight;
463 syn.apply_reward(1.0);
464 assert!(syn.weight > w_before);
465 }
466
467 #[test]
468 fn rstdp_negative_reward_depresses() {
469 let mut syn = RewardStdpSynapse::new(0.5, 0.0, 1.0);
470 for _ in 0..10 {
471 syn.step(true, true);
472 }
473 let w_before = syn.weight;
474 syn.apply_reward(-1.0);
475 assert!(syn.weight < w_before);
476 }
477
478 #[test]
479 fn rstdp_weight_bounded() {
480 let mut syn = RewardStdpSynapse::new(0.5, 0.0, 1.0);
481 for _ in 0..100 {
482 syn.step(true, true);
483 syn.apply_reward(10.0);
484 }
485 assert!(syn.weight <= 1.0);
486 assert!(syn.weight >= 0.0);
487 }
488
489 #[test]
492 fn static_excitatory() {
493 let syn = StaticSynapse::new(0.5, true);
494 assert!((syn.transmit(true) - 0.5).abs() < 1e-12);
495 assert!((syn.transmit(false)).abs() < 1e-12);
496 }
497
498 #[test]
499 fn static_inhibitory() {
500 let syn = StaticSynapse::new(0.5, false);
501 assert!((syn.transmit(true) + 0.5).abs() < 1e-12);
502 }
503
504 #[test]
505 fn weight_stays_in_bounds() {
506 let mut syn = StdpSynapse::new(0, 16, 8);
507 let params = default_params();
508 for _ in 0..200 {
509 syn.step(true, false, ¶ms);
510 }
511 assert!(syn.weight >= params.w_min, "weight below w_min");
512 assert!(syn.weight <= params.w_max, "weight above w_max");
513
514 let mut syn2 = StdpSynapse::new(255, 16, 8);
515 for _ in 0..200 {
516 syn2.step(false, true, ¶ms);
517 }
518 assert!(syn2.weight >= params.w_min);
519 assert!(syn2.weight <= params.w_max);
520 }
521
522 #[test]
525 fn triplet_ltp_pre_then_post() {
526 let mut syn = TripletStdpSynapse::new(0.5, 0.0, 1.0);
527 syn.step(true, false); syn.step(false, true); assert!(syn.weight > 0.5, "Pre-then-post must potentiate");
530 }
531
532 #[test]
533 fn triplet_ltd_post_then_pre() {
534 let mut syn = TripletStdpSynapse::new(0.5, 0.0, 1.0);
535 syn.step(false, true); syn.step(true, false); assert!(syn.weight < 0.5, "Post-then-pre must depress");
538 }
539
540 #[test]
541 fn triplet_bounded() {
542 let mut syn = TripletStdpSynapse::new(0.5, 0.0, 1.0);
543 for _ in 0..1000 {
544 syn.step(true, true);
545 }
546 assert!(syn.weight >= 0.0 && syn.weight <= 1.0);
547 }
548
549 #[test]
550 fn triplet_slow_trace_enhances() {
551 let mut syn1 = TripletStdpSynapse::new(0.5, 0.0, 1.0);
553 let mut syn2 = TripletStdpSynapse::new(0.5, 0.0, 1.0);
554 for _ in 0..5 {
556 syn2.step(false, true);
557 }
558 for _ in 0..200 {
560 syn2.step(false, false);
561 }
562 syn1.weight = 0.5;
564 syn2.weight = 0.5;
565 syn1.step(true, false);
567 syn1.step(false, true);
568 syn2.step(true, false);
569 syn2.step(false, true);
570 assert!(
572 syn2.weight >= syn1.weight,
573 "Triplet o2 trace should enhance LTP: syn2={:.6} >= syn1={:.6}",
574 syn2.weight,
575 syn1.weight
576 );
577 }
578
579 #[test]
582 fn stp_depressing_decreases_psc() {
583 let mut syn = ShortTermPlasticitySynapse::new_depressing();
584 let psc1 = syn.step(true);
585 let psc2 = syn.step(true);
586 assert!(
587 psc2 < psc1,
588 "Depression: 2nd PSC < 1st: {psc2:.4} < {psc1:.4}"
589 );
590 }
591
592 #[test]
593 fn stp_facilitating_increases_psc() {
594 let mut syn = ShortTermPlasticitySynapse::new_facilitating();
595 let psc1 = syn.step(true);
596 let psc2 = syn.step(true);
597 assert!(
598 psc2 > psc1,
599 "Facilitation: 2nd PSC > 1st: {psc2:.4} > {psc1:.4}"
600 );
601 }
602
603 #[test]
604 fn stp_recovers_after_silence() {
605 let mut syn = ShortTermPlasticitySynapse::new_depressing();
606 syn.step(true);
607 syn.step(true);
608 let depleted = syn.step(true);
609 for _ in 0..500 {
611 syn.step(false);
612 }
613 let recovered = syn.step(true);
614 assert!(
615 recovered > depleted,
616 "Recovery: {recovered:.4} > {depleted:.4}"
617 );
618 }
619
620 #[test]
621 fn stp_no_spike_no_current() {
622 let mut syn = ShortTermPlasticitySynapse::new_depressing();
623 assert_eq!(syn.step(false), 0.0);
624 }
625
626 #[test]
629 fn da_stdp_reward_potentiates() {
630 let mut syn = DopamineStdpSynapse::new(0.5, 0.0, 1.0);
631 for _ in 0..20 {
633 syn.step(true, false, 0.0);
634 syn.step(false, true, 0.0);
635 }
636 let w_before = syn.weight;
637 for _ in 0..100 {
639 syn.step(false, false, 1.0);
640 }
641 assert!(
642 syn.weight > w_before,
643 "Reward should potentiate: {:.4} > {:.4}",
644 syn.weight,
645 w_before
646 );
647 }
648
649 #[test]
650 fn da_stdp_no_reward_no_change() {
651 let mut syn = DopamineStdpSynapse::new(0.5, 0.0, 1.0);
652 for _ in 0..100 {
654 syn.step(true, false, 0.0);
655 syn.step(false, true, 0.0);
656 }
657 assert!(
659 (syn.weight - 0.5).abs() < 0.01,
660 "Without reward, weight should stay near initial: {:.4}",
661 syn.weight
662 );
663 }
664
665 #[test]
666 fn da_stdp_bounded() {
667 let mut syn = DopamineStdpSynapse::new(0.5, 0.0, 1.0);
668 for _ in 0..1000 {
669 syn.step(true, true, 10.0);
670 }
671 assert!(syn.weight >= 0.0 && syn.weight <= 1.0);
672 }
673}