Skip to content

Tutorial 75: Spike Normalization

5 SNN-specific batch normalization variants. Standard BN fails in SNNs because spike activations are binary and statistics shift across timesteps. These normalizers handle temporal dynamics, threshold interaction, and inference re-parameterization (zero-overhead deployment).

No other SNN library ships these as reusable modules.

The Problem

Standard batch normalization assumes continuous activations with stable statistics. SNNs violate both assumptions: activations are binary spikes, and the distribution changes at every timestep (temporal covariate shift). Naively applying BN to SNNs degrades accuracy by 5-15% on CIFAR-10 (Zheng 2021).

Available Normalizers

Normalizer Key Idea Reference
ThresholdDependentBN Incorporates firing threshold into normalization Zheng 2021
PerTimestepBN Separate BN statistics per timestep Kim & Panda 2021
TemporalEffectiveBN Per-timestep scaling factor on top of BN Duan 2022 (NeurIPS)
MembranePotentialBN BN on membrane, folds into threshold at inference Guo 2023 (ICCV)
TemporalAccumulatedBN Normalizes accumulated membrane across time Jiang 2024 (ICLR)

Quick Start

import numpy as np
from sc_neurocore.spike_norm import (
    ThresholdDependentBN,
    PerTimestepBN,
    TemporalEffectiveBN,
    MembranePotentialBN,
    TemporalAccumulatedBN,
)

# Simulated batch of presynaptic currents: (batch=32, features=64)
rng = np.random.RandomState(42)
x = rng.randn(32, 64)

# tdBN: threshold-aware normalization
tdbn = ThresholdDependentBN(n_features=64, threshold=1.0)
x_norm = tdbn.forward(x, training=True)

# BNTT: different statistics per timestep
bntt = PerTimestepBN(n_features=64, T=10)
for t in range(10):
    x_t = rng.randn(32, 64)
    out_t = bntt.forward(x_t, t=t, training=True)

# MPBN: fuse into threshold at inference (zero overhead)
mpbn = MembranePotentialBN(n_features=64, threshold=1.0)
for _ in range(100):
    mpbn.forward(rng.randn(32, 64), training=True)
hw_thresholds = mpbn.fused_threshold()  # shape (64,)
# No BN computation at inference — threshold absorbs it

MPBN: Zero-Overhead Inference

MembranePotentialBN is recommended for hardware deployment. At inference, BN parameters fold into a per-neuron threshold:

new_threshold[i] = (V_th - beta[i]) * sqrt(var[i] + eps) / gamma[i] + mean[i]

Identical behavior to training BN with zero compute overhead.

API Reference

sc_neurocore.spike_norm.normalizers

5 SNN normalization variants. No framework ships these as reusable modules.

tdBN: threshold-dependent BN (Zheng 2021) BNTT: per-timestep BN (Kim & Panda 2021) TEBN: temporal effective BN (Duan 2022, NeurIPS) MPBN: membrane potential BN with inference re-parameterization (Guo 2023, ICCV) TAB: temporal accumulated BN (Jiang 2024, ICLR)

ThresholdDependentBN dataclass

tdBN: incorporates firing threshold into normalization.

BN(x) = gamma * (x - mean) / sqrt(var + eps) + beta where mean/var are computed across batch, adjusted by V_threshold.

Parameters

n_features : int threshold : float momentum : float

Source code in src/sc_neurocore/spike_norm/normalizers.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass
class ThresholdDependentBN:
    """tdBN: incorporates firing threshold into normalization.

    BN(x) = gamma * (x - mean) / sqrt(var + eps) + beta
    where mean/var are computed across batch, adjusted by V_threshold.

    Parameters
    ----------
    n_features : int
    threshold : float
    momentum : float
    """

    n_features: int
    threshold: float = 1.0
    momentum: float = 0.1
    eps: float = 1e-5

    def __post_init__(self):
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)

    def forward(self, x: np.ndarray, training: bool = True) -> np.ndarray:
        if training:
            mean = x.mean(axis=0)
            var = x.var(axis=0)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        return self.gamma * x_norm * self.threshold + self.beta

PerTimestepBN dataclass

BNTT: separate BN statistics per timestep.

Each timestep t has its own mean_t, var_t, gamma_t, beta_t.

Parameters

n_features : int T : int Number of timesteps.

Source code in src/sc_neurocore/spike_norm/normalizers.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@dataclass
class PerTimestepBN:
    """BNTT: separate BN statistics per timestep.

    Each timestep t has its own mean_t, var_t, gamma_t, beta_t.

    Parameters
    ----------
    n_features : int
    T : int
        Number of timesteps.
    """

    n_features: int
    T: int
    eps: float = 1e-5

    def __post_init__(self):
        self.gammas = [np.ones(self.n_features) for _ in range(self.T)]
        self.betas = [np.zeros(self.n_features) for _ in range(self.T)]
        self.running_means = [np.zeros(self.n_features) for _ in range(self.T)]
        self.running_vars = [np.ones(self.n_features) for _ in range(self.T)]

    def forward(self, x: np.ndarray, t: int, training: bool = True) -> np.ndarray:
        t_idx = min(t, self.T - 1)
        if training:
            mean = x.mean(axis=0)
            var = x.var(axis=0)
            self.running_means[t_idx] = 0.9 * self.running_means[t_idx] + 0.1 * mean
            self.running_vars[t_idx] = 0.9 * self.running_vars[t_idx] + 0.1 * var
        else:  # pragma: no cover
            mean = self.running_means[t_idx]
            var = self.running_vars[t_idx]
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        return self.gammas[t_idx] * x_norm + self.betas[t_idx]

TemporalEffectiveBN dataclass

TEBN: rescales presynaptic inputs per timestep.

Applies BN then per-timestep scaling factor lambda_t.

Parameters

n_features : int T : int

Source code in src/sc_neurocore/spike_norm/normalizers.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@dataclass
class TemporalEffectiveBN:
    """TEBN: rescales presynaptic inputs per timestep.

    Applies BN then per-timestep scaling factor lambda_t.

    Parameters
    ----------
    n_features : int
    T : int
    """

    n_features: int
    T: int
    eps: float = 1e-5

    def __post_init__(self):
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.lambdas = np.ones(self.T)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)

    def forward(self, x: np.ndarray, t: int, training: bool = True) -> np.ndarray:
        if training:
            mean = x.mean(axis=0)
            var = x.var(axis=0)
            self.running_mean = 0.9 * self.running_mean + 0.1 * mean
            self.running_var = 0.9 * self.running_var + 0.1 * var
        else:  # pragma: no cover
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        t_idx = min(t, self.T - 1)
        return self.lambdas[t_idx] * (self.gamma * x_norm + self.beta)

MembranePotentialBN dataclass

MPBN: BN on membrane potential before spike function.

At inference: fold BN into threshold (zero overhead). new_threshold = (V_th - beta) * sqrt(var + eps) / gamma + mean

Parameters

n_features : int threshold : float

Source code in src/sc_neurocore/spike_norm/normalizers.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@dataclass
class MembranePotentialBN:
    """MPBN: BN on membrane potential before spike function.

    At inference: fold BN into threshold (zero overhead).
    new_threshold = (V_th - beta) * sqrt(var + eps) / gamma + mean

    Parameters
    ----------
    n_features : int
    threshold : float
    """

    n_features: int
    threshold: float = 1.0
    momentum: float = 0.1
    eps: float = 1e-5

    def __post_init__(self):
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)

    def forward(self, membrane: np.ndarray, training: bool = True) -> np.ndarray:
        if training:
            mean = membrane.mean(axis=0) if membrane.ndim > 1 else membrane
            var = membrane.var(axis=0) if membrane.ndim > 1 else np.zeros_like(membrane)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            norm = (membrane - mean) / np.sqrt(var + self.eps)
            return self.gamma * norm + self.beta
        return membrane

    def fused_threshold(self) -> np.ndarray:
        """Compute per-neuron threshold that absorbs BN at inference.

        Returns ndarray of shape (n_features,) — use as per-neuron threshold
        instead of applying BN at inference (zero overhead).
        """
        return (self.threshold - self.beta) * np.sqrt(self.running_var + self.eps) / np.clip(
            self.gamma, 1e-8, None
        ) + self.running_mean

fused_threshold()

Compute per-neuron threshold that absorbs BN at inference.

Returns ndarray of shape (n_features,) — use as per-neuron threshold instead of applying BN at inference (zero overhead).

Source code in src/sc_neurocore/spike_norm/normalizers.py
170
171
172
173
174
175
176
177
178
def fused_threshold(self) -> np.ndarray:
    """Compute per-neuron threshold that absorbs BN at inference.

    Returns ndarray of shape (n_features,) — use as per-neuron threshold
    instead of applying BN at inference (zero overhead).
    """
    return (self.threshold - self.beta) * np.sqrt(self.running_var + self.eps) / np.clip(
        self.gamma, 1e-8, None
    ) + self.running_mean

TemporalAccumulatedBN dataclass

TAB: normalizes accumulated membrane potential.

Tracks running accumulated potential across timesteps. Addresses Temporal Covariate Shift directly.

Parameters

n_features : int

Source code in src/sc_neurocore/spike_norm/normalizers.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@dataclass
class TemporalAccumulatedBN:
    """TAB: normalizes accumulated membrane potential.

    Tracks running accumulated potential across timesteps.
    Addresses Temporal Covariate Shift directly.

    Parameters
    ----------
    n_features : int
    """

    n_features: int
    momentum: float = 0.1
    eps: float = 1e-5

    def __post_init__(self):
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)
        self._accumulated = np.zeros(self.n_features)

    def forward(self, x: np.ndarray, training: bool = True) -> np.ndarray:
        self._accumulated += x.mean(axis=0) if x.ndim > 1 else x
        if training:
            mean = self._accumulated
            # Variance estimated from current input
            var = x.var(axis=0) if x.ndim > 1 else np.zeros_like(x)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:  # pragma: no cover
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

    def reset(self):
        self._accumulated = np.zeros(self.n_features)