Skip to content

Spike Normalization

5 SNN-specific batch normalization variants that handle temporal dynamics and threshold interaction.

Class Technique Reference
ThresholdDependentBN BN scaled by firing threshold Zheng 2021
PerTimestepBN Separate statistics per timestep Kim & Panda 2021
TemporalEffectiveBN Per-timestep scaling factor Duan 2022 (NeurIPS)
MembranePotentialBN BN on membrane, folds into threshold at inference Guo 2023 (ICCV)
TemporalAccumulatedBN Normalizes accumulated membrane Jiang 2024 (ICLR)

MembranePotentialBN.fused_threshold() returns per-neuron thresholds that absorb BN at inference — zero compute overhead on hardware.

Python
from sc_neurocore.spike_norm import (
    ThresholdDependentBN, PerTimestepBN, TemporalEffectiveBN,
    MembranePotentialBN, TemporalAccumulatedBN,
)

See Tutorial 75: Spike Normalization for usage examples.

sc_neurocore.spike_norm.normalizers

5 SNN normalization variants. No framework ships these as reusable modules.

tdBN: threshold-dependent BN (Zheng 2021) BNTT: per-timestep BN (Kim & Panda 2021) TEBN: temporal effective BN (Duan 2022, NeurIPS) MPBN: membrane potential BN with inference re-parameterization (Guo 2023, ICCV) TAB: temporal accumulated BN (Jiang 2024, ICLR)

ThresholdDependentBN dataclass

tdBN: incorporates firing threshold into normalization.

BN(x) = gamma * (x - mean) / sqrt(var + eps) + beta where mean/var are computed across batch, adjusted by V_threshold.

Parameters

n_features : int threshold : float momentum : float

Source code in src/sc_neurocore/spike_norm/normalizers.py
Python
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class ThresholdDependentBN:
    """tdBN: incorporates firing threshold into normalization.

    BN(x) = gamma * (x - mean) / sqrt(var + eps) + beta
    where mean/var are computed across batch, adjusted by V_threshold.

    Parameters
    ----------
    n_features : int
    threshold : float
    momentum : float
    """

    n_features: int
    threshold: float = 1.0
    momentum: float = 0.1
    eps: float = 1e-5

    def __post_init__(self) -> None:
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)

    def forward(self, x: np.ndarray[Any, Any], training: bool = True) -> np.ndarray[Any, Any]:
        if training:
            mean = x.mean(axis=0)
            var = x.var(axis=0)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        result: np.ndarray[Any, Any] = self.gamma * x_norm * self.threshold + self.beta
        return result

PerTimestepBN dataclass

BNTT: separate BN statistics per timestep.

Each timestep t has its own mean_t, var_t, gamma_t, beta_t.

Parameters

n_features : int T : int Number of timesteps.

Source code in src/sc_neurocore/spike_norm/normalizers.py
Python
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@dataclass
class PerTimestepBN:
    """BNTT: separate BN statistics per timestep.

    Each timestep t has its own mean_t, var_t, gamma_t, beta_t.

    Parameters
    ----------
    n_features : int
    T : int
        Number of timesteps.
    """

    n_features: int
    T: int
    eps: float = 1e-5

    def __post_init__(self) -> None:
        self.gammas = [np.ones(self.n_features) for _ in range(self.T)]
        self.betas = [np.zeros(self.n_features) for _ in range(self.T)]
        self.running_means = [np.zeros(self.n_features) for _ in range(self.T)]
        self.running_vars = [np.ones(self.n_features) for _ in range(self.T)]

    def forward(
        self, x: np.ndarray[Any, Any], t: int, training: bool = True
    ) -> np.ndarray[Any, Any]:
        t_idx = min(t, self.T - 1)
        if training:
            mean = x.mean(axis=0)
            var = x.var(axis=0)
            self.running_means[t_idx] = 0.9 * self.running_means[t_idx] + 0.1 * mean
            self.running_vars[t_idx] = 0.9 * self.running_vars[t_idx] + 0.1 * var
        else:  # pragma: no cover
            mean = self.running_means[t_idx]
            var = self.running_vars[t_idx]
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        result: np.ndarray[Any, Any] = self.gammas[t_idx] * x_norm + self.betas[t_idx]
        return result

TemporalEffectiveBN dataclass

TEBN: rescales presynaptic inputs per timestep.

Applies BN then per-timestep scaling factor lambda_t.

Parameters

n_features : int T : int

Source code in src/sc_neurocore/spike_norm/normalizers.py
Python
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@dataclass
class TemporalEffectiveBN:
    """TEBN: rescales presynaptic inputs per timestep.

    Applies BN then per-timestep scaling factor lambda_t.

    Parameters
    ----------
    n_features : int
    T : int
    """

    n_features: int
    T: int
    eps: float = 1e-5

    def __post_init__(self) -> None:
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.lambdas = np.ones(self.T)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)

    def forward(
        self, x: np.ndarray[Any, Any], t: int, training: bool = True
    ) -> np.ndarray[Any, Any]:
        if training:
            mean = x.mean(axis=0)
            var = x.var(axis=0)
            self.running_mean = 0.9 * self.running_mean + 0.1 * mean
            self.running_var = 0.9 * self.running_var + 0.1 * var
        else:  # pragma: no cover
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        t_idx = min(t, self.T - 1)
        result: np.ndarray[Any, Any] = self.lambdas[t_idx] * (self.gamma * x_norm + self.beta)
        return result

MembranePotentialBN dataclass

MPBN: BN on membrane potential before spike function.

At inference: fold BN into threshold (zero overhead). new_threshold = (V_th - beta) * sqrt(var + eps) / gamma + mean

Parameters

n_features : int threshold : float

Source code in src/sc_neurocore/spike_norm/normalizers.py
Python
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@dataclass
class MembranePotentialBN:
    """MPBN: BN on membrane potential before spike function.

    At inference: fold BN into threshold (zero overhead).
    new_threshold = (V_th - beta) * sqrt(var + eps) / gamma + mean

    Parameters
    ----------
    n_features : int
    threshold : float
    """

    n_features: int
    threshold: float = 1.0
    momentum: float = 0.1
    eps: float = 1e-5

    def __post_init__(self) -> None:
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)

    def forward(
        self, membrane: np.ndarray[Any, Any], training: bool = True
    ) -> np.ndarray[Any, Any]:
        if training:
            mean = membrane.mean(axis=0) if membrane.ndim > 1 else membrane
            var = membrane.var(axis=0) if membrane.ndim > 1 else np.zeros_like(membrane)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            norm = (membrane - mean) / np.sqrt(var + self.eps)
            result: np.ndarray[Any, Any] = self.gamma * norm + self.beta
            return result
        return membrane

    def fused_threshold(self) -> np.ndarray[Any, Any]:
        """Compute per-neuron threshold that absorbs BN at inference.

        Returns ndarray of shape (n_features,) — use as per-neuron threshold
        instead of applying BN at inference (zero overhead).
        """
        result: np.ndarray[Any, Any] = (self.threshold - self.beta) * np.sqrt(
            self.running_var + self.eps
        ) / np.clip(self.gamma, 1e-8, None) + self.running_mean
        return result

fused_threshold()

Compute per-neuron threshold that absorbs BN at inference.

Returns ndarray of shape (n_features,) — use as per-neuron threshold instead of applying BN at inference (zero overhead).

Source code in src/sc_neurocore/spike_norm/normalizers.py
Python
182
183
184
185
186
187
188
189
190
191
def fused_threshold(self) -> np.ndarray[Any, Any]:
    """Compute per-neuron threshold that absorbs BN at inference.

    Returns ndarray of shape (n_features,) — use as per-neuron threshold
    instead of applying BN at inference (zero overhead).
    """
    result: np.ndarray[Any, Any] = (self.threshold - self.beta) * np.sqrt(
        self.running_var + self.eps
    ) / np.clip(self.gamma, 1e-8, None) + self.running_mean
    return result

TemporalAccumulatedBN dataclass

TAB: normalizes accumulated membrane potential.

Tracks running accumulated potential across timesteps. Addresses Temporal Covariate Shift directly.

Parameters

n_features : int

Source code in src/sc_neurocore/spike_norm/normalizers.py
Python
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@dataclass
class TemporalAccumulatedBN:
    """TAB: normalizes accumulated membrane potential.

    Tracks running accumulated potential across timesteps.
    Addresses Temporal Covariate Shift directly.

    Parameters
    ----------
    n_features : int
    """

    n_features: int
    momentum: float = 0.1
    eps: float = 1e-5

    def __post_init__(self) -> None:
        self.gamma = np.ones(self.n_features)
        self.beta = np.zeros(self.n_features)
        self.running_mean = np.zeros(self.n_features)
        self.running_var = np.ones(self.n_features)
        self._accumulated = np.zeros(self.n_features)

    def forward(self, x: np.ndarray[Any, Any], training: bool = True) -> np.ndarray[Any, Any]:
        increment: np.ndarray[Any, Any] = x.mean(axis=0) if x.ndim > 1 else x
        self._accumulated = self._accumulated + increment
        if training:
            mean = self._accumulated
            # Variance estimated from current input
            var = x.var(axis=0) if x.ndim > 1 else np.zeros_like(x)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean  # type: ignore[assignment]
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:  # pragma: no cover
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        result: np.ndarray[Any, Any] = self.gamma * x_norm + self.beta
        return result

    def reset(self) -> None:
        self._accumulated = np.zeros(self.n_features)