Skip to main content

sc_neurocore_engine/grad/
surrogate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SC-NeuroCore — Surrogate Gradient Components
7
8//! # Surrogate Gradient Components
9//!
10//! Differentiable wrappers around SC forward operators.
11
12#[derive(Clone, Debug)]
13pub enum SurrogateType {
14    /// Fast Sigmoid: d/dx = 1 / (2k * (1 + k|x|)^2)
15    FastSigmoid { k: f32 },
16    /// SuperSpike: d/dx = 1 / (1 + k|x|)^2
17    SuperSpike { k: f32 },
18    /// ArcTan: d/dx = 1 / (1 + (kx)^2)
19    ArcTan { k: f32 },
20    /// Straight-through estimator.
21    StraightThrough,
22    /// Triangular pulse: max(0, 1 - |x|/width) / width
23    Triangular { width: f32 },
24    /// Piecewise linear: max(0, 1 - |x|/width)
25    PiecewiseLinear { width: f32 },
26}
27
28impl SurrogateType {
29    /// Evaluate surrogate derivative at membrane offset `x`.
30    pub fn grad(&self, x: f32) -> f32 {
31        match self {
32            Self::FastSigmoid { k } => {
33                // Zenke & Vogels 2021 normalization includes 1/(2k).
34                let denom = 1.0 + k * x.abs();
35                1.0 / (2.0 * k * denom * denom)
36            }
37            Self::SuperSpike { k } => {
38                // Zenke & Ganguli 2018 unnormalized surrogate.
39                let denom = 1.0 + k * x.abs();
40                1.0 / (denom * denom)
41            }
42            Self::ArcTan { k } => 1.0 / (1.0 + (k * x).powi(2)),
43            Self::StraightThrough => {
44                if x.abs() < 0.5 {
45                    1.0
46                } else {
47                    0.0
48                }
49            }
50            Self::Triangular { width } => {
51                let v = 1.0 - x.abs() / width;
52                if v > 0.0 {
53                    v / width
54                } else {
55                    0.0
56                }
57            }
58            Self::PiecewiseLinear { width } => {
59                let v = 1.0 - x.abs() / width;
60                v.max(0.0)
61            }
62        }
63    }
64}
65
66/// LIF neuron with surrogate gradient support.
67pub struct SurrogateLif {
68    pub lif: crate::neuron::FixedPointLif,
69    pub surrogate: SurrogateType,
70    membrane_trace: Vec<(f32, f32)>,
71}
72
73impl SurrogateLif {
74    /// Construct a surrogate-enabled fixed-point LIF neuron.
75    pub fn new(
76        data_width: u32,
77        fraction: u32,
78        v_rest: i16,
79        v_reset: i16,
80        v_threshold: i16,
81        refractory_period: i32,
82        surrogate: SurrogateType,
83    ) -> Self {
84        Self {
85            lif: crate::neuron::FixedPointLif::new(
86                data_width,
87                fraction,
88                v_rest,
89                v_reset,
90                v_threshold,
91                refractory_period,
92            ),
93            surrogate,
94            membrane_trace: Vec::new(),
95        }
96    }
97
98    /// Forward LIF step while caching trace for backward pass.
99    pub fn forward(&mut self, leak_k: i16, gain_k: i16, i_t: i16, noise_in: i16) -> (i32, i16) {
100        let (spike, v_out) = self.lif.step(leak_k, gain_k, i_t, noise_in);
101        let scale = (1_u32 << self.lif.fraction) as f32;
102        let v_norm = (v_out as f32 - self.lif.v_threshold as f32) / scale;
103        self.membrane_trace.push((v_norm, spike as f32));
104        (spike, v_out)
105    }
106
107    /// Backward pass through last cached membrane value.
108    pub fn backward(&mut self, grad_output: f32) -> Result<f32, String> {
109        let (v_norm, _spike) = self
110            .membrane_trace
111            .pop()
112            .ok_or_else(|| "backward() called without matching forward()".to_string())?;
113        Ok(grad_output * self.surrogate.grad(v_norm))
114    }
115
116    /// Clear stored membrane trace.
117    pub fn clear_trace(&mut self) {
118        self.membrane_trace.clear();
119    }
120
121    /// Reset neuron state and clear trace.
122    pub fn reset(&mut self) {
123        self.lif.reset();
124        self.clear_trace();
125    }
126
127    /// Number of cached forward steps.
128    pub fn trace_len(&self) -> usize {
129        self.membrane_trace.len()
130    }
131}
132
133/// Dense SC layer with surrogate gradient backward pass.
134pub struct DifferentiableDenseLayer {
135    pub layer: crate::layer::DenseLayer,
136    pub surrogate: SurrogateType,
137    input_cache: Vec<f64>,
138    output_cache: Vec<f64>,
139}
140
141impl DifferentiableDenseLayer {
142    /// Construct a differentiable dense SC layer.
143    pub fn new(
144        n_inputs: usize,
145        n_neurons: usize,
146        length: usize,
147        seed: u64,
148        surrogate: SurrogateType,
149    ) -> Self {
150        Self {
151            layer: crate::layer::DenseLayer::new(n_inputs, n_neurons, length, seed),
152            surrogate,
153            input_cache: Vec::new(),
154            output_cache: Vec::new(),
155        }
156    }
157
158    /// Forward pass and cache activations for backward pass.
159    pub fn forward(&mut self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
160        let out = self.layer.forward(input_values, seed)?;
161        self.input_cache = input_values.to_vec();
162        self.output_cache = out.clone();
163        Ok(out)
164    }
165
166    /// Backward pass producing input and weight gradients.
167    pub fn backward(&self, grad_output: &[f64]) -> Result<(Vec<f64>, Vec<Vec<f64>>), String> {
168        if self.input_cache.len() != self.layer.n_inputs {
169            return Err("backward() called before a valid forward() input cache.".to_string());
170        }
171        if self.output_cache.len() != self.layer.n_neurons {
172            return Err("backward() called before a valid forward() output cache.".to_string());
173        }
174        if grad_output.len() != self.layer.n_neurons {
175            return Err(format!(
176                "Expected grad_output length {}, got {}.",
177                self.layer.n_neurons,
178                grad_output.len()
179            ));
180        }
181
182        let mut grad_input = vec![0.0_f64; self.layer.n_inputs];
183        let mut grad_weights = vec![vec![0.0_f64; self.layer.n_inputs]; self.layer.n_neurons];
184
185        for j in 0..self.layer.n_neurons {
186            // Gate upstream gradient through the surrogate derivative at the
187            // output activation (treating output as a soft proxy for the spike
188            // non-linearity). Centre around 0.5 (midpoint of [0,1] output range).
189            let surr = self.surrogate.grad((self.output_cache[j] - 0.5) as f32) as f64;
190            let local_grad = grad_output[j] * surr;
191            for i in 0..self.layer.n_inputs {
192                grad_weights[j][i] = local_grad * self.input_cache[i];
193                grad_input[i] += local_grad * self.layer.weights[j][i];
194            }
195        }
196
197        Ok((grad_input, grad_weights))
198    }
199
200    /// Apply gradient descent update and clamp weights to `[0, 1]`.
201    pub fn update_weights(&mut self, weight_grads: &[Vec<f64>], lr: f64) {
202        if weight_grads.len() != self.layer.n_neurons {
203            return;
204        }
205        if weight_grads
206            .iter()
207            .any(|row| row.len() != self.layer.n_inputs)
208        {
209            return;
210        }
211
212        for (w_row, g_row) in self.layer.weights.iter_mut().zip(weight_grads.iter()) {
213            for (w, g) in w_row.iter_mut().zip(g_row.iter()) {
214                *w = (*w - lr * *g).clamp(0.0, 1.0);
215            }
216        }
217
218        self.layer.refresh_packed_weights();
219    }
220
221    /// Clear cached forward tensors.
222    pub fn clear_cache(&mut self) {
223        self.input_cache.clear();
224        self.output_cache.clear();
225    }
226}