sc_neurocore_engine/synapses/
mod.rs1use crate::neuron::mask;
14
15#[derive(Clone, Copy, Debug)]
17pub struct StdpParams {
18 pub a_plus: i16,
19 pub a_minus: i16,
20 pub decay: i16,
21 pub w_min: i16,
22 pub w_max: i16,
23}
24
25#[derive(Clone, Debug)]
27pub struct StdpSynapse {
28 pub weight: i16,
30 pub trace_pre: i16,
32 pub trace_post: i16,
34 pub data_width: u32,
36 pub fraction: u32,
38}
39
40impl StdpSynapse {
41 pub fn new(initial_weight: i16, data_width: u32, fraction: u32) -> Self {
42 Self {
43 weight: initial_weight,
44 trace_pre: 0,
45 trace_post: 0,
46 data_width,
47 fraction,
48 }
49 }
50
51 pub fn step(&mut self, pre_spike: bool, post_spike: bool, params: &StdpParams) {
58 self.trace_pre = mask(
60 (self.trace_pre as i32 * params.decay as i32) >> self.fraction,
61 self.data_width,
62 );
63 self.trace_post = mask(
64 (self.trace_post as i32 * params.decay as i32) >> self.fraction,
65 self.data_width,
66 );
67
68 if pre_spike {
70 self.trace_pre = mask(
71 self.trace_pre as i32 + params.a_plus as i32,
72 self.data_width,
73 );
74 }
75 if post_spike {
76 self.trace_post = mask(
77 self.trace_post as i32 + params.a_minus as i32,
78 self.data_width,
79 );
80 }
81
82 if post_spike {
85 let dw = (self.trace_pre as i32 * params.a_plus.abs() as i32) >> self.fraction;
86 let new_w = (self.weight as i32 + dw).min(params.w_max as i32);
87 self.weight = mask(new_w, self.data_width);
88 } else if pre_spike {
89 let dw = (self.trace_post as i32 * params.a_minus.abs() as i32) >> self.fraction;
90 let new_w = (self.weight as i32 - dw).max(params.w_min as i32);
91 self.weight = mask(new_w, self.data_width);
92 }
93 }
94}
95
96#[derive(Clone, Debug)]
102pub struct RewardStdpSynapse {
103 pub weight: f64,
104 pub w_min: f64,
105 pub w_max: f64,
106 pub eligibility: f64,
107 pub trace_decay: f64,
108 pub anti_hebbian_scale: f64,
109 pub learning_rate: f64,
110}
111
112impl RewardStdpSynapse {
113 pub fn new(w: f64, w_min: f64, w_max: f64) -> Self {
114 Self {
115 weight: w,
116 w_min,
117 w_max,
118 eligibility: 0.0,
119 trace_decay: 0.95,
120 anti_hebbian_scale: 0.5,
121 learning_rate: 0.01,
122 }
123 }
124
125 pub fn step(&mut self, pre: bool, post: bool) {
127 if pre && post {
128 self.eligibility += 1.0;
129 } else if pre && !post {
130 self.eligibility -= self.anti_hebbian_scale;
131 }
132 self.eligibility *= self.trace_decay;
133 }
134
135 pub fn apply_reward(&mut self, reward: f64) {
137 let update = self.learning_rate * reward * self.eligibility;
138 self.weight = (self.weight + update).clamp(self.w_min, self.w_max);
139 }
140}
141
142#[derive(Clone, Debug)]
144pub struct StaticSynapse {
145 pub weight: f64,
146 pub is_excitatory: bool,
147 pub delay: u32,
148}
149
150impl StaticSynapse {
151 pub fn new(weight: f64, is_excitatory: bool) -> Self {
152 Self {
153 weight: weight.abs(),
154 is_excitatory,
155 delay: 0,
156 }
157 }
158
159 pub fn transmit(&self, pre_spike: bool) -> f64 {
161 if !pre_spike {
162 return 0.0;
163 }
164 if self.is_excitatory {
165 self.weight
166 } else {
167 -self.weight
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 fn default_params() -> StdpParams {
177 StdpParams {
178 a_plus: 64, a_minus: 48, decay: 230, w_min: 0,
182 w_max: 255,
183 }
184 }
185
186 #[test]
187 fn potentiation_increases_weight() {
188 let mut syn = StdpSynapse::new(128, 16, 8);
189 let params = default_params();
190 for _ in 0..5 {
192 syn.step(true, false, ¶ms);
193 }
194 let w_before = syn.weight;
195 syn.step(false, true, ¶ms);
197 assert!(syn.weight > w_before, "LTP must increase weight");
198 }
199
200 #[test]
201 fn depression_decreases_weight() {
202 let mut syn = StdpSynapse::new(128, 16, 8);
203 let params = default_params();
204 for _ in 0..5 {
206 syn.step(false, true, ¶ms);
207 }
208 let w_before = syn.weight;
209 syn.step(true, false, ¶ms);
211 assert!(syn.weight < w_before, "LTD must decrease weight");
212 }
213
214 #[test]
217 fn rstdp_positive_reward_potentiates() {
218 let mut syn = RewardStdpSynapse::new(0.5, 0.0, 1.0);
219 for _ in 0..10 {
220 syn.step(true, true);
221 }
222 let w_before = syn.weight;
223 syn.apply_reward(1.0);
224 assert!(syn.weight > w_before);
225 }
226
227 #[test]
228 fn rstdp_negative_reward_depresses() {
229 let mut syn = RewardStdpSynapse::new(0.5, 0.0, 1.0);
230 for _ in 0..10 {
231 syn.step(true, true);
232 }
233 let w_before = syn.weight;
234 syn.apply_reward(-1.0);
235 assert!(syn.weight < w_before);
236 }
237
238 #[test]
239 fn rstdp_weight_bounded() {
240 let mut syn = RewardStdpSynapse::new(0.5, 0.0, 1.0);
241 for _ in 0..100 {
242 syn.step(true, true);
243 syn.apply_reward(10.0);
244 }
245 assert!(syn.weight <= 1.0);
246 assert!(syn.weight >= 0.0);
247 }
248
249 #[test]
252 fn static_excitatory() {
253 let syn = StaticSynapse::new(0.5, true);
254 assert!((syn.transmit(true) - 0.5).abs() < 1e-12);
255 assert!((syn.transmit(false)).abs() < 1e-12);
256 }
257
258 #[test]
259 fn static_inhibitory() {
260 let syn = StaticSynapse::new(0.5, false);
261 assert!((syn.transmit(true) + 0.5).abs() < 1e-12);
262 }
263
264 #[test]
265 fn weight_stays_in_bounds() {
266 let mut syn = StdpSynapse::new(0, 16, 8);
267 let params = default_params();
268 for _ in 0..200 {
269 syn.step(true, false, ¶ms);
270 }
271 assert!(syn.weight >= params.w_min, "weight below w_min");
272 assert!(syn.weight <= params.w_max, "weight above w_max");
273
274 let mut syn2 = StdpSynapse::new(255, 16, 8);
275 for _ in 0..200 {
276 syn2.step(false, true, ¶ms);
277 }
278 assert!(syn2.weight >= params.w_min);
279 assert!(syn2.weight <= params.w_max);
280 }
281}