SNN Training with Surrogate Gradients¶
Spiking neurons fire when membrane voltage crosses a threshold — a Heaviside step function, whose derivative is zero everywhere except at the discontinuity (Dirac delta). Standard backpropagation cannot pass gradients through this.
Surrogate gradient methods replace the Heaviside derivative with a smooth approximation during the backward pass while keeping the hard threshold in the forward pass. The network learns with real spikes; only the gradient signal is relaxed.
Reference: Neftci, Mostafa & Zenke, "Surrogate Gradient Learning in Spiking Neural Networks", IEEE Signal Processing Magazine 36(6), 2019.
Requires: pip install sc-neurocore[training] (installs PyTorch)
flowchart LR
subgraph FWD["Forward Pass"]
A["v = β·v + I"] --> B{"v > θ?"}
B -->|Yes| C["spike = 1<br/>v -= θ"]
B -->|No| D["spike = 0"]
end
subgraph BWD["Backward Pass"]
E["∂L/∂spike"] --> F["× surrogate'(v-θ)"]
F --> G["∂L/∂v → ∂L/∂W"]
end
FWD -.->|"Heaviside replaced<br/>by smooth approx"| BWD
style FWD fill:#e8f5e9
style BWD fill:#fff3e0
Surrogate functions¶
SC-NeuroCore provides seven surrogates. The three most commonly used:
| Function | Backward gradient | Default shape param | Citation |
|---|---|---|---|
atan_surrogate |
α / (2(1 + (παx/2)²)) | α=2.0 | Fang et al. 2021 |
fast_sigmoid |
slope / (1 + slope·|x|)² | slope=25.0 | Zenke & Ganguli 2018 |
superspike |
1 / (1 + β·|x|)² | β=10.0 | Zenke & Vogels 2021 |
All accept x = v - threshold (pre-shifted). Forward pass returns (x > 0).float().
from sc_neurocore.training import atan_surrogate, fast_sigmoid, superspike
Also available: sigmoid_surrogate, straight_through (STE), triangular.
Neuron model: LIFCell¶
from sc_neurocore.training import LIFCell
lif = LIFCell(beta=0.9, threshold=1.0, surrogate_fn=atan_surrogate)
Single-step dynamics:
v[t] = beta * v[t-1] + I[t]
spike[t] = H(v[t] - threshold) # Heaviside in forward, surrogate in backward
v[t] -= spike[t] * threshold # subtract reset
beta controls membrane leak (higher = longer memory). Set learn_beta=True
or learn_threshold=True to make them trainable parameters.
All neuron cells¶
| Cell | States | Key feature | Citation |
|---|---|---|---|
LIFCell |
v | Leaky integrate-and-fire (default) | Lapicque 1907 |
IFCell |
v | No leak (beta=1), simplest model | — |
SynapticCell |
i_syn, v | Dual-exponential synaptic current | — |
ALIFCell |
v, a | Adaptive threshold (spike-frequency adaptation) | Bellec et al. 2020 |
ExpIFCell |
v | Exponential upstroke near threshold | Fourcaud-Trocmé et al. 2003 |
AdExCell |
v, w | Adaptive exponential IF (tonic/bursting/adapting) | Brette & Gerstner 2005 |
LapicqueCell |
v | Explicit RC parameters (tau, R, dt) | Lapicque 1907 |
AlphaCell |
i_exc, i_inh, v | Separate excitatory/inhibitory synapses | Rall 1967 |
SecondOrderLIFCell |
a, v | Inertial acceleration term | Dayan & Abbott 2001 |
RecurrentLIFCell |
v, spike_prev | Trainable recurrent weights | — |
All cells support learn_beta=True and learn_threshold=True where
applicable. Import any cell from sc_neurocore.training.
SpikingNet: full classifier¶
SpikingNet stacks [Linear → LIFCell] layers and accumulates output spikes
over T timesteps.
from sc_neurocore.training import SpikingNet
model = SpikingNet(
n_input=784, # flattened image
n_hidden=128, # hidden layer width
n_output=10, # digit classes
n_layers=2, # hidden layers
beta=0.9,
surrogate_fn=atan_surrogate,
)
Forward signature: x: (T, batch, n_input) → (spike_counts, membrane_acc).
Classification uses spike_counts.argmax(dim=1).
Spike encoding¶
Convert pixel intensities [0,1] to spike trains of shape (T, *batch):
from sc_neurocore.training import rate_encode, latency_encode, delta_encode
| Encoder | Method | When to use |
|---|---|---|
rate_encode(x, T) |
Poisson: P(spike) = x each timestep | Static images, general purpose |
latency_encode(x, T, tau=5) |
Time-to-first-spike: strong input → early spike | Low-latency inference |
delta_encode(x, threshold=0.1) |
Spike on temporal change | Temporal/event data |
Loss functions¶
from sc_neurocore.training import spike_count_loss, membrane_loss, spike_rate_loss
spike_count_loss(counts, targets)— cross-entropy on spike counts (default)membrane_loss(mem_acc, targets)— cross-entropy on accumulated membranespike_rate_loss(counts, targets, T, target_rate=0.8)— MSE on firing rates
Regularizers to encourage sparse firing:
from sc_neurocore.training import spike_l1_loss, spike_l2_loss
Training loop¶
train_epoch and evaluate handle the timestep expansion and accumulation:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sc_neurocore.training import (
SpikingNet, atan_surrogate, auto_device,
train_epoch, evaluate, spike_count_loss,
)
digits = load_digits()
X = torch.tensor(digits.data / 16.0, dtype=torch.float32) # normalize to [0,1]
y = torch.tensor(digits.target, dtype=torch.long)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
train_loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(X_te, y_te), batch_size=64)
device = auto_device()
model = SpikingNet(n_input=64, n_hidden=128, n_output=10, beta=0.9).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
T = 25
for epoch in range(1, 11):
loss, acc = train_epoch(model, train_loader, optimizer, T, device=device)
val_loss, val_acc = evaluate(model, test_loader, T, device=device)
print(f"Epoch {epoch:2d} train {acc:.1%} val {val_acc:.1%}")
Expected: ~95% validation accuracy within 10 epochs on sklearn digits (8x8, 10 classes).
train_epoch internally reshapes (batch, features) → (T, batch, features) by
repeating the input across timesteps. For rate-coded inputs, encode before the
DataLoader and pass (T, batch, features) tensors directly.
Switching surrogate functions¶
model_fs = SpikingNet(n_input=64, n_hidden=128, n_output=10,
surrogate_fn=fast_sigmoid)
model_ss = SpikingNet(n_input=64, n_hidden=128, n_output=10,
surrogate_fn=superspike)
Effect on training: atan_surrogate (wider gradient window) converges smoothly
on most tasks. fast_sigmoid trains faster on deep networks. superspike gives
sharper gradients near threshold — better for temporal coding but can be
unstable with high learning rates.
Bridge to stochastic computing: to_sc_weights()¶
flowchart LR
A["Trained SpikingNet<br/>float32 weights"] --> B["to_sc_weights()<br/>min-max → [0,1]"]
B --> C["SCDenseLayer<br/>bitstream simulation"]
B --> D["equation_compiler<br/>→ Q8.8 Verilog"]
D --> E["Yosys + nextpnr<br/>→ FPGA bitstream"]
C --> F["Verify:<br/>float ≈ SC ≈ RTL"]
E --> F
style A fill:#fff3e0
style B fill:#e8f5e9
style E fill:#fce4ec
After training, export weight matrices normalized to [0,1] for SC bitstream deployment:
sc_weights = model.to_sc_weights()
for i, w in enumerate(sc_weights):
print(f"Layer {i}: {tuple(w.shape)}, range [{w.min():.4f}, {w.max():.4f}]")
Each weight matrix is min-max scaled per layer. These [0,1] values map directly
to bitstream probabilities in SCDenseLayer and the Verilog RTL dot-product
units.
Float vs SC inference comparison¶
Run the trained float model, then simulate the same weights through the SC pipeline:
import numpy as np
from sc_neurocore.layers.sc_dense_layer import SCDenseLayer
_, float_acc = evaluate(model, test_loader, T, device=device)
print(f"Float accuracy: {float_acc:.1%}")
sc_w = sc_weights[0].cpu().numpy() # first layer weights [n_hidden, n_input]
sample = X_te[0].numpy()
layer = SCDenseLayer(
n_neurons=sc_w.shape[0],
x_inputs=sample.tolist(),
weight_values=sc_w[0].tolist(), # one neuron's weights
x_min=0.0, x_max=1.0,
w_min=0.0, w_max=1.0,
length=2048,
base_seed=42,
)
layer.run(T=500)
print(f"SC firing rates: {layer.summary()['avg_firing_rate_hz']:.1f} Hz")
The SC simulation uses bitstream arithmetic (XNOR + popcount) for the
dot-product and StochasticLIFNeuron for integration. Bitstream length controls
the precision/latency tradeoff: 2048 bits ≈ 11-bit effective precision.
Full end-to-end SC inference across all layers requires composing multiple
SCDenseLayer instances and routing output spike trains as input bitstreams to
the next layer. The to_sc_weights() bridge ensures weight compatibility.
Full MNIST example¶
For 28x28 MNIST with torchvision (requires pip install torchvision):
python examples/mnist_surrogate/train.py --epochs 10 --device cuda
This trains a 784→128→128→10 SpikingNet and reports ~95% test accuracy.
ConvSpikingNet is available for higher accuracy on full MNIST (Conv2d → LIF →
pooling architecture).
Summary¶
| Component | API |
|---|---|
| Surrogates | atan_surrogate, fast_sigmoid, superspike, sigmoid_surrogate, straight_through, triangular |
| Neurons | LIFCell, IFCell, ALIFCell, SynapticCell, RecurrentLIFCell, ExpIFCell, AdExCell, LapicqueCell, AlphaCell, SecondOrderLIFCell |
| Networks | SpikingNet, ConvSpikingNet |
| Encoding | rate_encode, latency_encode, delta_encode |
| Losses | spike_count_loss, membrane_loss, spike_rate_loss, spike_l1_loss, spike_l2_loss |
| Training | train_epoch, evaluate, auto_device |
| SC bridge | model.to_sc_weights() → SCDenseLayer / Verilog RTL |