Skip to content

Tutorial 78: Residual Blocks for Deep SNNs

Build 400+ layer deep spiking networks with two residual architectures: MS-ResNet (membrane shortcut) and SEW-ResNet (activation-before-addition).

The Problem

Deep SNNs (>10 layers) suffer from vanishing spikes. Standard residual connections fail because spike(f(x) + x) clips the identity mapping through the binary activation. MS-ResNet adds the shortcut to membrane potential instead, preserving gradient flow.

Membrane Shortcut Block (MS-ResNet)

import numpy as np
from sc_neurocore.residual import MembraneShortcutBlock

block = MembraneShortcutBlock(n_features=64, threshold=1.0, tau_mem=10.0)
x = (np.random.rand(64) > 0.5).astype(float)
spikes = block.forward(x)

SEW Block

spike(W@x) + x instead of spike(W@x + x):

from sc_neurocore.residual import SEWBlock

block = SEWBlock(n_features=64, threshold=1.0)
spikes = block.forward(x)

Deep SNN Stack

from sc_neurocore.residual import DeepSNNStack

model = DeepSNNStack(n_features=64, n_blocks=20, block_type="ms")
print(f"Depth: {model.depth} layers")  # 40
output = model.forward(x)
Block Residual Path Reference
MembraneShortcutBlock Input → membrane potential Hu 2024 (TNNLS)
SEWBlock spike(W@x) + x Fang 2021 (NeurIPS)

MS-ResNet: 482-layer SNN on CIFAR-10 — the deepest SNN published.

API Reference

sc_neurocore.residual.blocks

SNN residual blocks enabling 400+ layer deep spiking networks.

MembraneShortcutBlock: MS-ResNet (Hu 2024, TNNLS). Bypasses inter-block LIF neuron. Block dynamical isometry ensures gradient norm equality.

SEWBlock: activation-before-addition (Fang 2021, NeurIPS).

Reference: MS-ResNet trained 482-layer SNN on CIFAR-10

MembraneShortcutBlock

MS-ResNet residual block with membrane shortcut.

Skips the inter-block LIF neuron. Residual connection adds directly to membrane potential, not to spikes.

Parameters

n_features : int threshold : float tau_mem : float

Source code in src/sc_neurocore/residual/blocks.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class MembraneShortcutBlock:
    """MS-ResNet residual block with membrane shortcut.

    Skips the inter-block LIF neuron. Residual connection adds
    directly to membrane potential, not to spikes.

    Parameters
    ----------
    n_features : int
    threshold : float
    tau_mem : float
    """

    def __init__(
        self, n_features: int, threshold: float = 1.0, tau_mem: float = 10.0, seed: int = 42
    ):
        self.n_features = n_features
        self.threshold = threshold
        self.tau_mem = tau_mem
        rng = np.random.RandomState(seed)
        scale = np.sqrt(2.0 / n_features)
        self.W1 = rng.randn(n_features, n_features) * scale
        self.W2 = rng.randn(n_features, n_features) * scale
        self._v = np.zeros(n_features)

    def forward(self, x: np.ndarray) -> np.ndarray:
        """Forward pass: x -> W1 -> LIF -> W2 -> add residual -> LIF -> spikes."""
        alpha = np.exp(-1.0 / self.tau_mem)

        # First transform
        h = self.W1 @ x
        # LIF on hidden
        v1 = alpha * np.zeros(self.n_features) + (1 - alpha) * h
        s1 = (v1 >= self.threshold).astype(np.float64)

        # Second transform
        h2 = self.W2 @ s1

        # Membrane shortcut: add input directly to membrane (not spikes)
        self._v = alpha * self._v + (1 - alpha) * (h2 + x)
        spikes = (self._v >= self.threshold).astype(np.float64)
        self._v -= spikes * self.threshold
        return spikes

    def reset(self):
        self._v = np.zeros(self.n_features)

forward(x)

Forward pass: x -> W1 -> LIF -> W2 -> add residual -> LIF -> spikes.

Source code in src/sc_neurocore/residual/blocks.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def forward(self, x: np.ndarray) -> np.ndarray:
    """Forward pass: x -> W1 -> LIF -> W2 -> add residual -> LIF -> spikes."""
    alpha = np.exp(-1.0 / self.tau_mem)

    # First transform
    h = self.W1 @ x
    # LIF on hidden
    v1 = alpha * np.zeros(self.n_features) + (1 - alpha) * h
    s1 = (v1 >= self.threshold).astype(np.float64)

    # Second transform
    h2 = self.W2 @ s1

    # Membrane shortcut: add input directly to membrane (not spikes)
    self._v = alpha * self._v + (1 - alpha) * (h2 + x)
    spikes = (self._v >= self.threshold).astype(np.float64)
    self._v -= spikes * self.threshold
    return spikes

SEWBlock

SEW-ResNet block: activation-before-addition.

spike(W@x) + x instead of spike(W@x + x). Prevents identity mapping issues in spiking residual networks.

Parameters

n_features : int threshold : float

Source code in src/sc_neurocore/residual/blocks.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class SEWBlock:
    """SEW-ResNet block: activation-before-addition.

    spike(W@x) + x instead of spike(W@x + x).
    Prevents identity mapping issues in spiking residual networks.

    Parameters
    ----------
    n_features : int
    threshold : float
    """

    def __init__(self, n_features: int, threshold: float = 1.0, seed: int = 42):
        self.n_features = n_features
        self.threshold = threshold
        rng = np.random.RandomState(seed)
        self.W = rng.randn(n_features, n_features) * np.sqrt(2.0 / n_features)
        self._v = np.zeros(n_features)

    def forward(self, x: np.ndarray) -> np.ndarray:
        """Forward: spike(W@x) + x (element-wise, clamped to [0,1])."""
        h = self.W @ x
        self._v += h
        spikes = (self._v >= self.threshold).astype(np.float64)
        self._v -= spikes * self.threshold
        return np.clip(spikes + x, 0, 1)

    def reset(self):  # pragma: no cover
        self._v = np.zeros(self.n_features)

forward(x)

Forward: spike(W@x) + x (element-wise, clamped to [0,1]).

Source code in src/sc_neurocore/residual/blocks.py
91
92
93
94
95
96
97
def forward(self, x: np.ndarray) -> np.ndarray:
    """Forward: spike(W@x) + x (element-wise, clamped to [0,1])."""
    h = self.W @ x
    self._v += h
    spikes = (self._v >= self.threshold).astype(np.float64)
    self._v -= spikes * self.threshold
    return np.clip(spikes + x, 0, 1)

DeepSNNStack

Stack of residual blocks for building deep SNNs.

Parameters

n_features : int n_blocks : int block_type : str 'ms' for MembraneShortcut, 'sew' for SEW.

Source code in src/sc_neurocore/residual/blocks.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class DeepSNNStack:
    """Stack of residual blocks for building deep SNNs.

    Parameters
    ----------
    n_features : int
    n_blocks : int
    block_type : str
        'ms' for MembraneShortcut, 'sew' for SEW.
    """

    def __init__(self, n_features: int, n_blocks: int = 10, block_type: str = "ms"):
        self.blocks = []
        for i in range(n_blocks):
            if block_type == "ms":
                self.blocks.append(MembraneShortcutBlock(n_features, seed=42 + i))
            else:
                self.blocks.append(SEWBlock(n_features, seed=42 + i))

    def forward(self, x: np.ndarray) -> np.ndarray:
        h = x
        for block in self.blocks:
            h = block.forward(h)
        return h

    def reset(self):  # pragma: no cover
        for block in self.blocks:
            block.reset()

    @property
    def n_blocks(self) -> int:  # pragma: no cover
        return len(self.blocks)

    @property
    def depth(self) -> int:
        return len(self.blocks) * 2  # 2 weight layers per block