Spike Normalization
5 SNN-specific batch normalization variants that handle temporal dynamics and threshold interaction.
| Class |
Technique |
Reference |
ThresholdDependentBN |
BN scaled by firing threshold |
Zheng 2021 |
PerTimestepBN |
Separate statistics per timestep |
Kim & Panda 2021 |
TemporalEffectiveBN |
Per-timestep scaling factor |
Duan 2022 (NeurIPS) |
MembranePotentialBN |
BN on membrane, folds into threshold at inference |
Guo 2023 (ICCV) |
TemporalAccumulatedBN |
Normalizes accumulated membrane |
Jiang 2024 (ICLR) |
MembranePotentialBN.fused_threshold() returns per-neuron thresholds that absorb BN at inference — zero compute overhead on hardware.
from sc_neurocore.spike_norm import (
ThresholdDependentBN, PerTimestepBN, TemporalEffectiveBN,
MembranePotentialBN, TemporalAccumulatedBN,
)
See Tutorial 75: Spike Normalization for usage examples.
sc_neurocore.spike_norm.normalizers
5 SNN normalization variants. No framework ships these as reusable modules.
tdBN: threshold-dependent BN (Zheng 2021)
BNTT: per-timestep BN (Kim & Panda 2021)
TEBN: temporal effective BN (Duan 2022, NeurIPS)
MPBN: membrane potential BN with inference re-parameterization (Guo 2023, ICCV)
TAB: temporal accumulated BN (Jiang 2024, ICLR)
ThresholdDependentBN
dataclass
tdBN: incorporates firing threshold into normalization.
BN(x) = gamma * (x - mean) / sqrt(var + eps) + beta
where mean/var are computed across batch, adjusted by V_threshold.
Parameters
n_features : int
threshold : float
momentum : float
Source code in src/sc_neurocore/spike_norm/normalizers.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 | @dataclass
class ThresholdDependentBN:
"""tdBN: incorporates firing threshold into normalization.
BN(x) = gamma * (x - mean) / sqrt(var + eps) + beta
where mean/var are computed across batch, adjusted by V_threshold.
Parameters
----------
n_features : int
threshold : float
momentum : float
"""
n_features: int
threshold: float = 1.0
momentum: float = 0.1
eps: float = 1e-5
def __post_init__(self):
self.gamma = np.ones(self.n_features)
self.beta = np.zeros(self.n_features)
self.running_mean = np.zeros(self.n_features)
self.running_var = np.ones(self.n_features)
def forward(self, x: np.ndarray, training: bool = True) -> np.ndarray:
if training:
mean = x.mean(axis=0)
var = x.var(axis=0)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
mean = self.running_mean
var = self.running_var
x_norm = (x - mean) / np.sqrt(var + self.eps)
return self.gamma * x_norm * self.threshold + self.beta
|
PerTimestepBN
dataclass
BNTT: separate BN statistics per timestep.
Each timestep t has its own mean_t, var_t, gamma_t, beta_t.
Parameters
n_features : int
T : int
Number of timesteps.
Source code in src/sc_neurocore/spike_norm/normalizers.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96 | @dataclass
class PerTimestepBN:
"""BNTT: separate BN statistics per timestep.
Each timestep t has its own mean_t, var_t, gamma_t, beta_t.
Parameters
----------
n_features : int
T : int
Number of timesteps.
"""
n_features: int
T: int
eps: float = 1e-5
def __post_init__(self):
self.gammas = [np.ones(self.n_features) for _ in range(self.T)]
self.betas = [np.zeros(self.n_features) for _ in range(self.T)]
self.running_means = [np.zeros(self.n_features) for _ in range(self.T)]
self.running_vars = [np.ones(self.n_features) for _ in range(self.T)]
def forward(self, x: np.ndarray, t: int, training: bool = True) -> np.ndarray:
t_idx = min(t, self.T - 1)
if training:
mean = x.mean(axis=0)
var = x.var(axis=0)
self.running_means[t_idx] = 0.9 * self.running_means[t_idx] + 0.1 * mean
self.running_vars[t_idx] = 0.9 * self.running_vars[t_idx] + 0.1 * var
else: # pragma: no cover
mean = self.running_means[t_idx]
var = self.running_vars[t_idx]
x_norm = (x - mean) / np.sqrt(var + self.eps)
return self.gammas[t_idx] * x_norm + self.betas[t_idx]
|
TemporalEffectiveBN
dataclass
TEBN: rescales presynaptic inputs per timestep.
Applies BN then per-timestep scaling factor lambda_t.
Parameters
n_features : int
T : int
Source code in src/sc_neurocore/spike_norm/normalizers.py
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133 | @dataclass
class TemporalEffectiveBN:
"""TEBN: rescales presynaptic inputs per timestep.
Applies BN then per-timestep scaling factor lambda_t.
Parameters
----------
n_features : int
T : int
"""
n_features: int
T: int
eps: float = 1e-5
def __post_init__(self):
self.gamma = np.ones(self.n_features)
self.beta = np.zeros(self.n_features)
self.lambdas = np.ones(self.T)
self.running_mean = np.zeros(self.n_features)
self.running_var = np.ones(self.n_features)
def forward(self, x: np.ndarray, t: int, training: bool = True) -> np.ndarray:
if training:
mean = x.mean(axis=0)
var = x.var(axis=0)
self.running_mean = 0.9 * self.running_mean + 0.1 * mean
self.running_var = 0.9 * self.running_var + 0.1 * var
else: # pragma: no cover
mean = self.running_mean
var = self.running_var
x_norm = (x - mean) / np.sqrt(var + self.eps)
t_idx = min(t, self.T - 1)
return self.lambdas[t_idx] * (self.gamma * x_norm + self.beta)
|
MembranePotentialBN
dataclass
MPBN: BN on membrane potential before spike function.
At inference: fold BN into threshold (zero overhead).
new_threshold = (V_th - beta) * sqrt(var + eps) / gamma + mean
Parameters
n_features : int
threshold : float
Source code in src/sc_neurocore/spike_norm/normalizers.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178 | @dataclass
class MembranePotentialBN:
"""MPBN: BN on membrane potential before spike function.
At inference: fold BN into threshold (zero overhead).
new_threshold = (V_th - beta) * sqrt(var + eps) / gamma + mean
Parameters
----------
n_features : int
threshold : float
"""
n_features: int
threshold: float = 1.0
momentum: float = 0.1
eps: float = 1e-5
def __post_init__(self):
self.gamma = np.ones(self.n_features)
self.beta = np.zeros(self.n_features)
self.running_mean = np.zeros(self.n_features)
self.running_var = np.ones(self.n_features)
def forward(self, membrane: np.ndarray, training: bool = True) -> np.ndarray:
if training:
mean = membrane.mean(axis=0) if membrane.ndim > 1 else membrane
var = membrane.var(axis=0) if membrane.ndim > 1 else np.zeros_like(membrane)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
norm = (membrane - mean) / np.sqrt(var + self.eps)
return self.gamma * norm + self.beta
return membrane
def fused_threshold(self) -> np.ndarray:
"""Compute per-neuron threshold that absorbs BN at inference.
Returns ndarray of shape (n_features,) — use as per-neuron threshold
instead of applying BN at inference (zero overhead).
"""
return (self.threshold - self.beta) * np.sqrt(self.running_var + self.eps) / np.clip(
self.gamma, 1e-8, None
) + self.running_mean
|
fused_threshold()
Compute per-neuron threshold that absorbs BN at inference.
Returns ndarray of shape (n_features,) — use as per-neuron threshold
instead of applying BN at inference (zero overhead).
Source code in src/sc_neurocore/spike_norm/normalizers.py
170
171
172
173
174
175
176
177
178 | def fused_threshold(self) -> np.ndarray:
"""Compute per-neuron threshold that absorbs BN at inference.
Returns ndarray of shape (n_features,) — use as per-neuron threshold
instead of applying BN at inference (zero overhead).
"""
return (self.threshold - self.beta) * np.sqrt(self.running_var + self.eps) / np.clip(
self.gamma, 1e-8, None
) + self.running_mean
|
TemporalAccumulatedBN
dataclass
TAB: normalizes accumulated membrane potential.
Tracks running accumulated potential across timesteps.
Addresses Temporal Covariate Shift directly.
Parameters
n_features : int
Source code in src/sc_neurocore/spike_norm/normalizers.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219 | @dataclass
class TemporalAccumulatedBN:
"""TAB: normalizes accumulated membrane potential.
Tracks running accumulated potential across timesteps.
Addresses Temporal Covariate Shift directly.
Parameters
----------
n_features : int
"""
n_features: int
momentum: float = 0.1
eps: float = 1e-5
def __post_init__(self):
self.gamma = np.ones(self.n_features)
self.beta = np.zeros(self.n_features)
self.running_mean = np.zeros(self.n_features)
self.running_var = np.ones(self.n_features)
self._accumulated = np.zeros(self.n_features)
def forward(self, x: np.ndarray, training: bool = True) -> np.ndarray:
self._accumulated += x.mean(axis=0) if x.ndim > 1 else x
if training:
mean = self._accumulated
# Variance estimated from current input
var = x.var(axis=0) if x.ndim > 1 else np.zeros_like(x)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else: # pragma: no cover
mean = self.running_mean
var = self.running_var
x_norm = (x - mean) / np.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
def reset(self):
self._accumulated = np.zeros(self.n_features)
|