sc_neurocore_engine/grad/
surrogate.rs1use rayon::prelude::*;
14
15#[derive(Clone, Debug)]
16
17pub enum SurrogateType {
18 FastSigmoid { k: f32 },
20 SuperSpike { k: f32 },
22 ArcTan { k: f32 },
24 StraightThrough,
26 Triangular { width: f32 },
28 PiecewiseLinear { width: f32 },
30}
31
32impl SurrogateType {
33 pub fn grad(&self, x: f32) -> f32 {
35 match self {
36 Self::FastSigmoid { k } => {
37 let denom = 1.0 + k * x.abs();
39 1.0 / (2.0 * k * denom * denom)
40 }
41 Self::SuperSpike { k } => {
42 let denom = 1.0 + k * x.abs();
44 1.0 / (denom * denom)
45 }
46 Self::ArcTan { k } => 1.0 / (1.0 + (k * x).powi(2)),
47 Self::StraightThrough => {
48 if x.abs() < 0.5 {
49 1.0
50 } else {
51 0.0
52 }
53 }
54 Self::Triangular { width } => {
55 let v = 1.0 - x.abs() / width;
56 if v > 0.0 {
57 v / width
58 } else {
59 0.0
60 }
61 }
62 Self::PiecewiseLinear { width } => {
63 let v = 1.0 - x.abs() / width;
64 v.max(0.0)
65 }
66 }
67 }
68}
69
70pub struct SurrogateLif {
72 pub lif: crate::neuron::FixedPointLif,
73 pub surrogate: SurrogateType,
74 membrane_trace: Vec<(f32, f32)>,
75}
76
77impl SurrogateLif {
78 pub fn new(
80 data_width: u32,
81 fraction: u32,
82 v_rest: i16,
83 v_reset: i16,
84 v_threshold: i16,
85 refractory_period: i32,
86 surrogate: SurrogateType,
87 ) -> Self {
88 Self {
89 lif: crate::neuron::FixedPointLif::new(
90 data_width,
91 fraction,
92 v_rest,
93 v_reset,
94 v_threshold,
95 refractory_period,
96 ),
97 surrogate,
98 membrane_trace: Vec::new(),
99 }
100 }
101
102 pub fn forward(&mut self, leak_k: i16, gain_k: i16, i_t: i16, noise_in: i16) -> (i32, i16) {
104 let (spike, v_out) = self.lif.step(leak_k, gain_k, i_t, noise_in);
105 let scale = (1_u32 << self.lif.fraction) as f32;
106 let v_norm = (v_out as f32 - self.lif.v_threshold as f32) / scale;
107 self.membrane_trace.push((v_norm, spike as f32));
108 (spike, v_out)
109 }
110
111 pub fn backward(&mut self, grad_output: f32) -> Result<f32, String> {
113 let (v_norm, _spike) = self
114 .membrane_trace
115 .pop()
116 .ok_or_else(|| "backward() called without matching forward()".to_string())?;
117 Ok(grad_output * self.surrogate.grad(v_norm))
118 }
119
120 pub fn clear_trace(&mut self) {
122 self.membrane_trace.clear();
123 }
124
125 pub fn reset(&mut self) {
127 self.lif.reset();
128 self.clear_trace();
129 }
130
131 pub fn trace_len(&self) -> usize {
133 self.membrane_trace.len()
134 }
135}
136
137pub struct DifferentiableDenseLayer {
139 pub layer: crate::layer::DenseLayer,
140 pub surrogate: SurrogateType,
141 input_cache: Vec<f64>,
142 output_cache: Vec<f64>,
143}
144
145impl DifferentiableDenseLayer {
146 pub fn new(
148 n_inputs: usize,
149 n_neurons: usize,
150 length: usize,
151 seed: u64,
152 surrogate: SurrogateType,
153 ) -> Self {
154 Self {
155 layer: crate::layer::DenseLayer::new(n_inputs, n_neurons, length, seed),
156 surrogate,
157 input_cache: Vec::new(),
158 output_cache: Vec::new(),
159 }
160 }
161
162 pub fn forward(&mut self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
164 let out = self.layer.forward(input_values, seed)?;
165 self.input_cache = input_values.to_vec();
166 self.output_cache = out.clone();
167 Ok(out)
168 }
169
170 pub fn backward(&self, grad_output: &[f64]) -> Result<(Vec<f64>, Vec<Vec<f64>>), String> {
172 if self.input_cache.len() != self.layer.n_inputs {
173 return Err("backward() called before a valid forward() input cache.".to_string());
174 }
175 if self.output_cache.len() != self.layer.n_neurons {
176 return Err("backward() called before a valid forward() output cache.".to_string());
177 }
178 if grad_output.len() != self.layer.n_neurons {
179 return Err(format!(
180 "Expected grad_output length {}, got {}.",
181 self.layer.n_neurons,
182 grad_output.len()
183 ));
184 }
185
186 let mut grad_input = vec![0.0_f64; self.layer.n_inputs];
187 let mut grad_weights = vec![vec![0.0_f64; self.layer.n_inputs]; self.layer.n_neurons];
188
189 let local_grads: Vec<f64> = (0..self.layer.n_neurons)
191 .map(|j| {
192 let surr = self.surrogate.grad((self.output_cache[j] - 0.5) as f32) as f64;
193 grad_output[j] * surr
194 })
195 .collect();
196
197 grad_weights
199 .par_iter_mut()
200 .enumerate()
201 .for_each(|(j, row_grad_weights)| {
202 let local_grad = local_grads[j];
203 let mut chunks_gw = row_grad_weights.chunks_exact_mut(4);
204 let mut chunks_inp = self.input_cache.chunks_exact(4);
205 for (cgw, cinp) in chunks_gw.by_ref().zip(chunks_inp.by_ref()) {
206 cgw[0] = local_grad * cinp[0];
207 cgw[1] = local_grad * cinp[1];
208 cgw[2] = local_grad * cinp[2];
209 cgw[3] = local_grad * cinp[3];
210 }
211 for (gw, &inp) in chunks_gw
212 .into_remainder()
213 .iter_mut()
214 .zip(chunks_inp.remainder())
215 {
216 *gw = local_grad * inp;
217 }
218 });
219
220 for (j, &local_grad) in local_grads.iter().enumerate() {
222 let row_weights = &self.layer.weights[j];
223 let mut chunks_gi = grad_input.chunks_exact_mut(4);
224 let mut chunks_w = row_weights.chunks_exact(4);
225 for (cgi, cw) in chunks_gi.by_ref().zip(chunks_w.by_ref()) {
226 cgi[0] += local_grad * cw[0];
227 cgi[1] += local_grad * cw[1];
228 cgi[2] += local_grad * cw[2];
229 cgi[3] += local_grad * cw[3];
230 }
231 for (gi, &w) in chunks_gi
232 .into_remainder()
233 .iter_mut()
234 .zip(chunks_w.remainder())
235 {
236 *gi += local_grad * w;
237 }
238 }
239
240 Ok((grad_input, grad_weights))
241 }
242
243 pub fn update_weights(&mut self, weight_grads: &[Vec<f64>], lr: f64) {
245 if weight_grads.len() != self.layer.n_neurons {
246 return;
247 }
248 if weight_grads
249 .iter()
250 .any(|row| row.len() != self.layer.n_inputs)
251 {
252 return;
253 }
254
255 for (w_row, g_row) in self.layer.weights.iter_mut().zip(weight_grads.iter()) {
256 for (w, g) in w_row.iter_mut().zip(g_row.iter()) {
257 *w = (*w - lr * *g).clamp(0.0, 1.0);
258 }
259 }
260
261 self.layer.refresh_packed_weights();
262 }
263
264 pub fn clear_cache(&mut self) {
266 self.input_cache.clear();
267 self.output_cache.clear();
268 }
269}