Tutorial 78: Residual Blocks for Deep SNNs¶
Build deep spiking networks (100+ layers) with two residual architectures that solve the vanishing spike problem: MS-ResNet (membrane shortcut) and SEW-ResNet (spike-element-wise addition).
The Problem¶
Deep SNNs (>10 layers) suffer from vanishing spikes. In standard
networks, spike probability drops exponentially with depth — by layer
20, almost no spikes survive. Standard residual connections
spike(f(x) + x) don't help because the binary spike function clips
the identity mapping.
Two Solutions¶
MS-ResNet: Membrane Shortcut¶
Instead of adding the shortcut to the output (after spiking), add it to the membrane potential (before spiking):
Standard residual: spike(W@x + x) ← identity lost through spike()
MS-ResNet: spike(W@x) + x_mem ← identity preserved in membrane
import numpy as np
from sc_neurocore.residual import MembraneShortcutBlock
rng = np.random.default_rng(42)
block = MembraneShortcutBlock(
n_features=64,
threshold=1.0,
tau_mem=10.0, # membrane time constant
)
x = (rng.random(64) > 0.5).astype(np.float32)
spikes = block.forward(x)
print(f"Input spikes: {x.sum():.0f} / {len(x)}")
print(f"Output spikes: {spikes.sum():.0f} / {len(spikes)}")
print(f"Membrane norm: {np.linalg.norm(block.membrane):.3f}")
The membrane shortcut preserves gradient flow — the identity path through membrane potential is continuous, so backpropagation through surrogate gradients works even at 400+ layers.
SEW-ResNet: Spike-Element-Wise¶
Add the shortcut after spiking, element-wise:
SEW: spike(W@x) + x ← addition in spike domain (can exceed 1)
from sc_neurocore.residual import SEWBlock
block = SEWBlock(n_features=64, threshold=1.0)
spikes = block.forward(x)
print(f"Output: {spikes.sum():.0f} spikes")
SEW allows output values >1 (sum of current spike + shortcut spike). This breaks the strict binary constraint but preserves information flow. For hardware deployment, clamp output to {0, 1}.
Deep SNN Stack¶
Stack residual blocks for arbitrary depth:
from sc_neurocore.residual import DeepSNNStack
# 20 residual blocks × 2 layers each = 40 layers deep
model = DeepSNNStack(
n_features=64,
n_blocks=20,
block_type="ms", # or "sew"
)
print(f"Depth: {model.depth} layers") # 40
print(f"Parameters: {model.n_params:,}")
# Forward pass — spikes survive all 40 layers
output = model.forward(x)
print(f"Output spikes: {output.sum():.0f}")
Depth vs Accuracy¶
Measured on CIFAR-10 with surrogate gradient training:
| Architecture | Layers | Accuracy | Parameters |
|---|---|---|---|
| Plain SNN | 10 | 89.2% | 1.2M |
| Plain SNN | 20 | 85.1% | 2.4M |
| MS-ResNet | 20 | 93.1% | 2.4M |
| MS-ResNet | 100 | 94.6% | 12M |
| SEW-ResNet | 20 | 92.8% | 2.4M |
| SEW-ResNet | 152 | 95.3% | 18M |
Published results, not our measurements. Our implementation follows the architecture exactly — accuracy should reproduce within ~0.5%.
When to Use Which¶
| Block | Pros | Cons | Best For |
|---|---|---|---|
| MS-ResNet | Exact identity in membrane | Slightly more compute | FPGA (membrane is natural) |
| SEW-ResNet | Simpler implementation | Output can exceed 1 | GPU training |
FPGA Deployment¶
The membrane shortcut is natural for FPGA — the shortcut is just a wire from input to the membrane register's input MUX:
Input ──┬── Weight multiply ── Membrane ── Spike ── Output
│ ↑
└────────────────────────┘ (membrane shortcut = wire)
Cost: 1 additional MUX per neuron (~1 LUT).
References¶
- Hu et al. (2024). "Membrane-Based Residual Spiking Neural Networks." IEEE TNNLS.
- Fang et al. (2021). "Spike-Element-Wise Residual Connections." NeurIPS 2021.
- He et al. (2016). "Deep Residual Learning for Image Recognition." CVPR 2016 (original ResNet).