sc_neurocore_engine/grad/
surrogate.rs1#[derive(Clone, Debug)]
13pub enum SurrogateType {
14 FastSigmoid { k: f32 },
16 SuperSpike { k: f32 },
18 ArcTan { k: f32 },
20 StraightThrough,
22 Triangular { width: f32 },
24 PiecewiseLinear { width: f32 },
26}
27
28impl SurrogateType {
29 pub fn grad(&self, x: f32) -> f32 {
31 match self {
32 Self::FastSigmoid { k } => {
33 let denom = 1.0 + k * x.abs();
35 1.0 / (2.0 * k * denom * denom)
36 }
37 Self::SuperSpike { k } => {
38 let denom = 1.0 + k * x.abs();
40 1.0 / (denom * denom)
41 }
42 Self::ArcTan { k } => 1.0 / (1.0 + (k * x).powi(2)),
43 Self::StraightThrough => {
44 if x.abs() < 0.5 {
45 1.0
46 } else {
47 0.0
48 }
49 }
50 Self::Triangular { width } => {
51 let v = 1.0 - x.abs() / width;
52 if v > 0.0 {
53 v / width
54 } else {
55 0.0
56 }
57 }
58 Self::PiecewiseLinear { width } => {
59 let v = 1.0 - x.abs() / width;
60 v.max(0.0)
61 }
62 }
63 }
64}
65
66pub struct SurrogateLif {
68 pub lif: crate::neuron::FixedPointLif,
69 pub surrogate: SurrogateType,
70 membrane_trace: Vec<(f32, f32)>,
71}
72
73impl SurrogateLif {
74 pub fn new(
76 data_width: u32,
77 fraction: u32,
78 v_rest: i16,
79 v_reset: i16,
80 v_threshold: i16,
81 refractory_period: i32,
82 surrogate: SurrogateType,
83 ) -> Self {
84 Self {
85 lif: crate::neuron::FixedPointLif::new(
86 data_width,
87 fraction,
88 v_rest,
89 v_reset,
90 v_threshold,
91 refractory_period,
92 ),
93 surrogate,
94 membrane_trace: Vec::new(),
95 }
96 }
97
98 pub fn forward(&mut self, leak_k: i16, gain_k: i16, i_t: i16, noise_in: i16) -> (i32, i16) {
100 let (spike, v_out) = self.lif.step(leak_k, gain_k, i_t, noise_in);
101 let scale = (1_u32 << self.lif.fraction) as f32;
102 let v_norm = (v_out as f32 - self.lif.v_threshold as f32) / scale;
103 self.membrane_trace.push((v_norm, spike as f32));
104 (spike, v_out)
105 }
106
107 pub fn backward(&mut self, grad_output: f32) -> Result<f32, String> {
109 let (v_norm, _spike) = self
110 .membrane_trace
111 .pop()
112 .ok_or_else(|| "backward() called without matching forward()".to_string())?;
113 Ok(grad_output * self.surrogate.grad(v_norm))
114 }
115
116 pub fn clear_trace(&mut self) {
118 self.membrane_trace.clear();
119 }
120
121 pub fn reset(&mut self) {
123 self.lif.reset();
124 self.clear_trace();
125 }
126
127 pub fn trace_len(&self) -> usize {
129 self.membrane_trace.len()
130 }
131}
132
133pub struct DifferentiableDenseLayer {
135 pub layer: crate::layer::DenseLayer,
136 pub surrogate: SurrogateType,
137 input_cache: Vec<f64>,
138 output_cache: Vec<f64>,
139}
140
141impl DifferentiableDenseLayer {
142 pub fn new(
144 n_inputs: usize,
145 n_neurons: usize,
146 length: usize,
147 seed: u64,
148 surrogate: SurrogateType,
149 ) -> Self {
150 Self {
151 layer: crate::layer::DenseLayer::new(n_inputs, n_neurons, length, seed),
152 surrogate,
153 input_cache: Vec::new(),
154 output_cache: Vec::new(),
155 }
156 }
157
158 pub fn forward(&mut self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
160 let out = self.layer.forward(input_values, seed)?;
161 self.input_cache = input_values.to_vec();
162 self.output_cache = out.clone();
163 Ok(out)
164 }
165
166 pub fn backward(&self, grad_output: &[f64]) -> Result<(Vec<f64>, Vec<Vec<f64>>), String> {
168 if self.input_cache.len() != self.layer.n_inputs {
169 return Err("backward() called before a valid forward() input cache.".to_string());
170 }
171 if self.output_cache.len() != self.layer.n_neurons {
172 return Err("backward() called before a valid forward() output cache.".to_string());
173 }
174 if grad_output.len() != self.layer.n_neurons {
175 return Err(format!(
176 "Expected grad_output length {}, got {}.",
177 self.layer.n_neurons,
178 grad_output.len()
179 ));
180 }
181
182 let mut grad_input = vec![0.0_f64; self.layer.n_inputs];
183 let mut grad_weights = vec![vec![0.0_f64; self.layer.n_inputs]; self.layer.n_neurons];
184
185 for j in 0..self.layer.n_neurons {
186 let surr = self.surrogate.grad((self.output_cache[j] - 0.5) as f32) as f64;
190 let local_grad = grad_output[j] * surr;
191 for i in 0..self.layer.n_inputs {
192 grad_weights[j][i] = local_grad * self.input_cache[i];
193 grad_input[i] += local_grad * self.layer.weights[j][i];
194 }
195 }
196
197 Ok((grad_input, grad_weights))
198 }
199
200 pub fn update_weights(&mut self, weight_grads: &[Vec<f64>], lr: f64) {
202 if weight_grads.len() != self.layer.n_neurons {
203 return;
204 }
205 if weight_grads
206 .iter()
207 .any(|row| row.len() != self.layer.n_inputs)
208 {
209 return;
210 }
211
212 for (w_row, g_row) in self.layer.weights.iter_mut().zip(weight_grads.iter()) {
213 for (w, g) in w_row.iter_mut().zip(g_row.iter()) {
214 *w = (*w - lr * *g).clamp(0.0, 1.0);
215 }
216 }
217
218 self.layer.refresh_packed_weights();
219 }
220
221 pub fn clear_cache(&mut self) {
223 self.input_cache.clear();
224 self.output_cache.clear();
225 }
226}