Skip to main content

sc_neurocore_engine/grad/
surrogate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Surrogate Gradient Components
8
9//! # Surrogate Gradient Components
10//!
11//! Differentiable wrappers around SC forward operators.
12
13use rayon::prelude::*;
14
15#[derive(Clone, Debug)]
16
17pub enum SurrogateType {
18    /// Fast Sigmoid: d/dx = 1 / (2k * (1 + k|x|)^2)
19    FastSigmoid { k: f32 },
20    /// SuperSpike: d/dx = 1 / (1 + k|x|)^2
21    SuperSpike { k: f32 },
22    /// ArcTan: d/dx = 1 / (1 + (kx)^2)
23    ArcTan { k: f32 },
24    /// Straight-through estimator.
25    StraightThrough,
26    /// Triangular pulse: max(0, 1 - |x|/width) / width
27    Triangular { width: f32 },
28    /// Piecewise linear: max(0, 1 - |x|/width)
29    PiecewiseLinear { width: f32 },
30}
31
32impl SurrogateType {
33    /// Evaluate surrogate derivative at membrane offset `x`.
34    pub fn grad(&self, x: f32) -> f32 {
35        match self {
36            Self::FastSigmoid { k } => {
37                // Zenke & Vogels 2021 normalization includes 1/(2k).
38                let denom = 1.0 + k * x.abs();
39                1.0 / (2.0 * k * denom * denom)
40            }
41            Self::SuperSpike { k } => {
42                // Zenke & Ganguli 2018 unnormalized surrogate.
43                let denom = 1.0 + k * x.abs();
44                1.0 / (denom * denom)
45            }
46            Self::ArcTan { k } => 1.0 / (1.0 + (k * x).powi(2)),
47            Self::StraightThrough => {
48                if x.abs() < 0.5 {
49                    1.0
50                } else {
51                    0.0
52                }
53            }
54            Self::Triangular { width } => {
55                let v = 1.0 - x.abs() / width;
56                if v > 0.0 {
57                    v / width
58                } else {
59                    0.0
60                }
61            }
62            Self::PiecewiseLinear { width } => {
63                let v = 1.0 - x.abs() / width;
64                v.max(0.0)
65            }
66        }
67    }
68}
69
70/// LIF neuron with surrogate gradient support.
71pub struct SurrogateLif {
72    pub lif: crate::neuron::FixedPointLif,
73    pub surrogate: SurrogateType,
74    membrane_trace: Vec<(f32, f32)>,
75}
76
77impl SurrogateLif {
78    /// Construct a surrogate-enabled fixed-point LIF neuron.
79    pub fn new(
80        data_width: u32,
81        fraction: u32,
82        v_rest: i16,
83        v_reset: i16,
84        v_threshold: i16,
85        refractory_period: i32,
86        surrogate: SurrogateType,
87    ) -> Self {
88        Self {
89            lif: crate::neuron::FixedPointLif::new(
90                data_width,
91                fraction,
92                v_rest,
93                v_reset,
94                v_threshold,
95                refractory_period,
96            ),
97            surrogate,
98            membrane_trace: Vec::new(),
99        }
100    }
101
102    /// Forward LIF step while caching trace for backward pass.
103    pub fn forward(&mut self, leak_k: i16, gain_k: i16, i_t: i16, noise_in: i16) -> (i32, i16) {
104        let (spike, v_out) = self.lif.step(leak_k, gain_k, i_t, noise_in);
105        let scale = (1_u32 << self.lif.fraction) as f32;
106        let v_norm = (v_out as f32 - self.lif.v_threshold as f32) / scale;
107        self.membrane_trace.push((v_norm, spike as f32));
108        (spike, v_out)
109    }
110
111    /// Backward pass through last cached membrane value.
112    pub fn backward(&mut self, grad_output: f32) -> Result<f32, String> {
113        let (v_norm, _spike) = self
114            .membrane_trace
115            .pop()
116            .ok_or_else(|| "backward() called without matching forward()".to_string())?;
117        Ok(grad_output * self.surrogate.grad(v_norm))
118    }
119
120    /// Clear stored membrane trace.
121    pub fn clear_trace(&mut self) {
122        self.membrane_trace.clear();
123    }
124
125    /// Reset neuron state and clear trace.
126    pub fn reset(&mut self) {
127        self.lif.reset();
128        self.clear_trace();
129    }
130
131    /// Number of cached forward steps.
132    pub fn trace_len(&self) -> usize {
133        self.membrane_trace.len()
134    }
135}
136
137/// Dense SC layer with surrogate gradient backward pass.
138pub struct DifferentiableDenseLayer {
139    pub layer: crate::layer::DenseLayer,
140    pub surrogate: SurrogateType,
141    input_cache: Vec<f64>,
142    output_cache: Vec<f64>,
143}
144
145impl DifferentiableDenseLayer {
146    /// Construct a differentiable dense SC layer.
147    pub fn new(
148        n_inputs: usize,
149        n_neurons: usize,
150        length: usize,
151        seed: u64,
152        surrogate: SurrogateType,
153    ) -> Self {
154        Self {
155            layer: crate::layer::DenseLayer::new(n_inputs, n_neurons, length, seed),
156            surrogate,
157            input_cache: Vec::new(),
158            output_cache: Vec::new(),
159        }
160    }
161
162    /// Forward pass and cache activations for backward pass.
163    pub fn forward(&mut self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
164        let out = self.layer.forward(input_values, seed)?;
165        self.input_cache = input_values.to_vec();
166        self.output_cache = out.clone();
167        Ok(out)
168    }
169
170    /// Backward pass producing input and weight gradients.
171    pub fn backward(&self, grad_output: &[f64]) -> Result<(Vec<f64>, Vec<Vec<f64>>), String> {
172        if self.input_cache.len() != self.layer.n_inputs {
173            return Err("backward() called before a valid forward() input cache.".to_string());
174        }
175        if self.output_cache.len() != self.layer.n_neurons {
176            return Err("backward() called before a valid forward() output cache.".to_string());
177        }
178        if grad_output.len() != self.layer.n_neurons {
179            return Err(format!(
180                "Expected grad_output length {}, got {}.",
181                self.layer.n_neurons,
182                grad_output.len()
183            ));
184        }
185
186        let mut grad_input = vec![0.0_f64; self.layer.n_inputs];
187        let mut grad_weights = vec![vec![0.0_f64; self.layer.n_inputs]; self.layer.n_neurons];
188
189        // Compute surr and local_grad for all neurons
190        let local_grads: Vec<f64> = (0..self.layer.n_neurons)
191            .map(|j| {
192                let surr = self.surrogate.grad((self.output_cache[j] - 0.5) as f32) as f64;
193                grad_output[j] * surr
194            })
195            .collect();
196
197        // Parallel weight gradient computation
198        grad_weights
199            .par_iter_mut()
200            .enumerate()
201            .for_each(|(j, row_grad_weights)| {
202                let local_grad = local_grads[j];
203                let mut chunks_gw = row_grad_weights.chunks_exact_mut(4);
204                let mut chunks_inp = self.input_cache.chunks_exact(4);
205                for (cgw, cinp) in chunks_gw.by_ref().zip(chunks_inp.by_ref()) {
206                    cgw[0] = local_grad * cinp[0];
207                    cgw[1] = local_grad * cinp[1];
208                    cgw[2] = local_grad * cinp[2];
209                    cgw[3] = local_grad * cinp[3];
210                }
211                for (gw, &inp) in chunks_gw
212                    .into_remainder()
213                    .iter_mut()
214                    .zip(chunks_inp.remainder())
215                {
216                    *gw = local_grad * inp;
217                }
218            });
219
220        // Serial input gradient accumulation
221        for (j, &local_grad) in local_grads.iter().enumerate() {
222            let row_weights = &self.layer.weights[j];
223            let mut chunks_gi = grad_input.chunks_exact_mut(4);
224            let mut chunks_w = row_weights.chunks_exact(4);
225            for (cgi, cw) in chunks_gi.by_ref().zip(chunks_w.by_ref()) {
226                cgi[0] += local_grad * cw[0];
227                cgi[1] += local_grad * cw[1];
228                cgi[2] += local_grad * cw[2];
229                cgi[3] += local_grad * cw[3];
230            }
231            for (gi, &w) in chunks_gi
232                .into_remainder()
233                .iter_mut()
234                .zip(chunks_w.remainder())
235            {
236                *gi += local_grad * w;
237            }
238        }
239
240        Ok((grad_input, grad_weights))
241    }
242
243    /// Apply gradient descent update and clamp weights to `[0, 1]`.
244    pub fn update_weights(&mut self, weight_grads: &[Vec<f64>], lr: f64) {
245        if weight_grads.len() != self.layer.n_neurons {
246            return;
247        }
248        if weight_grads
249            .iter()
250            .any(|row| row.len() != self.layer.n_inputs)
251        {
252            return;
253        }
254
255        for (w_row, g_row) in self.layer.weights.iter_mut().zip(weight_grads.iter()) {
256            for (w, g) in w_row.iter_mut().zip(g_row.iter()) {
257                *w = (*w - lr * *g).clamp(0.0, 1.0);
258            }
259        }
260
261        self.layer.refresh_packed_weights();
262    }
263
264    /// Clear cached forward tensors.
265    pub fn clear_cache(&mut self) {
266        self.input_cache.clear();
267        self.output_cache.clear();
268    }
269}