Skip to content

Tutorial 78: Residual Blocks for Deep SNNs

Build deep spiking networks (100+ layers) with two residual architectures that solve the vanishing spike problem: MS-ResNet (membrane shortcut) and SEW-ResNet (spike-element-wise addition).

The Problem

Deep SNNs (>10 layers) suffer from vanishing spikes. In standard networks, spike probability drops exponentially with depth — by layer 20, almost no spikes survive. Standard residual connections spike(f(x) + x) don't help because the binary spike function clips the identity mapping.

Two Solutions

MS-ResNet: Membrane Shortcut

Instead of adding the shortcut to the output (after spiking), add it to the membrane potential (before spiking):

Text Only
Standard residual:  spike(W@x + x)     ← identity lost through spike()
MS-ResNet:          spike(W@x) + x_mem  ← identity preserved in membrane
Python
import numpy as np
from sc_neurocore.residual import MembraneShortcutBlock

rng = np.random.default_rng(42)

block = MembraneShortcutBlock(
    n_features=64,
    threshold=1.0,
    tau_mem=10.0,  # membrane time constant
)

x = (rng.random(64) > 0.5).astype(np.float32)
spikes = block.forward(x)
print(f"Input spikes:  {x.sum():.0f} / {len(x)}")
print(f"Output spikes: {spikes.sum():.0f} / {len(spikes)}")
print(f"Membrane norm: {np.linalg.norm(block.membrane):.3f}")

The membrane shortcut preserves gradient flow — the identity path through membrane potential is continuous, so backpropagation through surrogate gradients works even at 400+ layers.

SEW-ResNet: Spike-Element-Wise

Add the shortcut after spiking, element-wise:

Text Only
SEW: spike(W@x) + x    ← addition in spike domain (can exceed 1)
Python
from sc_neurocore.residual import SEWBlock

block = SEWBlock(n_features=64, threshold=1.0)
spikes = block.forward(x)
print(f"Output: {spikes.sum():.0f} spikes")

SEW allows output values >1 (sum of current spike + shortcut spike). This breaks the strict binary constraint but preserves information flow. For hardware deployment, clamp output to {0, 1}.

Deep SNN Stack

Stack residual blocks for arbitrary depth:

Python
from sc_neurocore.residual import DeepSNNStack

# 20 residual blocks × 2 layers each = 40 layers deep
model = DeepSNNStack(
    n_features=64,
    n_blocks=20,
    block_type="ms",  # or "sew"
)

print(f"Depth: {model.depth} layers")  # 40
print(f"Parameters: {model.n_params:,}")

# Forward pass — spikes survive all 40 layers
output = model.forward(x)
print(f"Output spikes: {output.sum():.0f}")

Depth vs Accuracy

Measured on CIFAR-10 with surrogate gradient training:

Architecture Layers Accuracy Parameters
Plain SNN 10 89.2% 1.2M
Plain SNN 20 85.1% 2.4M
MS-ResNet 20 93.1% 2.4M
MS-ResNet 100 94.6% 12M
SEW-ResNet 20 92.8% 2.4M
SEW-ResNet 152 95.3% 18M

Published results, not our measurements. Our implementation follows the architecture exactly — accuracy should reproduce within ~0.5%.

When to Use Which

Block Pros Cons Best For
MS-ResNet Exact identity in membrane Slightly more compute FPGA (membrane is natural)
SEW-ResNet Simpler implementation Output can exceed 1 GPU training

FPGA Deployment

The membrane shortcut is natural for FPGA — the shortcut is just a wire from input to the membrane register's input MUX:

Text Only
Input ──┬── Weight multiply ── Membrane ── Spike ── Output
        │                        ↑
        └────────────────────────┘  (membrane shortcut = wire)

Cost: 1 additional MUX per neuron (~1 LUT).

References

  • Hu et al. (2024). "Membrane-Based Residual Spiking Neural Networks." IEEE TNNLS.
  • Fang et al. (2021). "Spike-Element-Wise Residual Connections." NeurIPS 2021.
  • He et al. (2016). "Deep Residual Learning for Image Recognition." CVPR 2016 (original ResNet).