Skip to content

Tutorial 77: Quantization-Aware Training

Train SNNs through quantization using straight-through estimators (STE). Closes the gap between full-precision training and fixed-point hardware. Includes ternary weight quantization (94% memory reduction).

The Problem

SNNs trained in float64 lose accuracy when deployed to fixed-point hardware. Post-training quantization drops 3-8% accuracy. QAT simulates quantization during training so the model learns to compensate.

Quantized SNN Layer

import numpy as np
from sc_neurocore.qat import QuantizedSNNLayer, quantize_aware_train_step

layer = QuantizedSNNLayer(
    n_inputs=784, n_neurons=128,
    weight_bits=8, threshold=1.0, tau_mem=20.0,
)

x = np.random.randn(784)
target = np.zeros(128); target[42] = 1.0
result = quantize_aware_train_step(layer, x, target, lr=0.01)
print(f"Loss: {result['loss']:.4f}")

hw_weights = layer.export_weights()  # already at 8-bit precision

Ternary Weights

Each weight is {-1, 0, +1}. 94% memory reduction:

from sc_neurocore.qat import TernaryWeights

ternary = TernaryWeights(threshold_ratio=0.7)
weights = np.random.randn(128, 784) * 0.1
t_weights = ternary.quantize(weights)
print(f"Sparsity: {ternary.sparsity(weights):.1%}")
Bits Memory vs Float32 Use Case
2 (ternary) 16x reduction Extreme edge
4 8x reduction FPGA LUT-based
8 4x reduction Standard ASIC/FPGA

API Reference

sc_neurocore.qat.quantize

Train SNNs through quantization using straight-through estimators.

Missing link between training and hardware deployment. No SNN library ships QAT as a reusable module.

Reference: QP-SNN (ICLR 2025), SpikeFit (EurIPS 2025)

TernaryWeights

Ternary weight quantization: {-1, 0, +1}.

94% memory reduction. Each weight is one of three values. Threshold-based: weights with |w| < threshold become 0.

Parameters

threshold_ratio : float Fraction of max(|w|) below which weights are zeroed.

Source code in src/sc_neurocore/qat/quantize.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class TernaryWeights:
    """Ternary weight quantization: {-1, 0, +1}.

    94% memory reduction. Each weight is one of three values.
    Threshold-based: weights with |w| < threshold become 0.

    Parameters
    ----------
    threshold_ratio : float
        Fraction of max(|w|) below which weights are zeroed.
    """

    def __init__(self, threshold_ratio: float = 0.7):
        self.threshold_ratio = threshold_ratio

    def quantize(self, weights: np.ndarray) -> np.ndarray:
        threshold = self.threshold_ratio * np.mean(np.abs(weights))
        ternary = np.zeros_like(weights)
        ternary[weights > threshold] = 1.0
        ternary[weights < -threshold] = -1.0
        return ternary

    def sparsity(self, weights: np.ndarray) -> float:
        t = self.quantize(weights)
        return float(np.mean(t == 0))

QuantizedSNNLayer dataclass

SNN layer with quantization-aware forward pass.

During training: weights quantized in forward, full-precision in backward (STE). At export: weights are already at target precision.

Parameters

n_inputs : int n_neurons : int weight_bits : int Target weight precision (2, 4, 8, 16). threshold : float tau_mem : float

Source code in src/sc_neurocore/qat/quantize.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@dataclass
class QuantizedSNNLayer:
    """SNN layer with quantization-aware forward pass.

    During training: weights quantized in forward, full-precision in backward (STE).
    At export: weights are already at target precision.

    Parameters
    ----------
    n_inputs : int
    n_neurons : int
    weight_bits : int
        Target weight precision (2, 4, 8, 16).
    threshold : float
    tau_mem : float
    """

    n_inputs: int
    n_neurons: int
    weight_bits: int = 8
    threshold: float = 1.0
    tau_mem: float = 20.0

    def __post_init__(self):
        rng = np.random.RandomState(42)
        self.W = rng.randn(self.n_neurons, self.n_inputs) * np.sqrt(2.0 / self.n_inputs)
        self._v = np.zeros(self.n_neurons)

    def forward(self, x: np.ndarray, dt: float = 1.0) -> np.ndarray:
        """Quantization-aware forward pass."""
        W_q = _ste_quantize(self.W, self.weight_bits)
        alpha = np.exp(-dt / self.tau_mem)
        current = W_q @ x
        self._v = alpha * self._v + (1 - alpha) * current
        spikes = (self._v >= self.threshold).astype(np.float64)
        self._v -= spikes * self.threshold
        return spikes

    def export_weights(self) -> np.ndarray:
        """Export quantized weights for hardware deployment."""
        return _ste_quantize(self.W, self.weight_bits)

    def reset(self):  # pragma: no cover
        self._v = np.zeros(self.n_neurons)

forward(x, dt=1.0)

Quantization-aware forward pass.

Source code in src/sc_neurocore/qat/quantize.py
91
92
93
94
95
96
97
98
99
def forward(self, x: np.ndarray, dt: float = 1.0) -> np.ndarray:
    """Quantization-aware forward pass."""
    W_q = _ste_quantize(self.W, self.weight_bits)
    alpha = np.exp(-dt / self.tau_mem)
    current = W_q @ x
    self._v = alpha * self._v + (1 - alpha) * current
    spikes = (self._v >= self.threshold).astype(np.float64)
    self._v -= spikes * self.threshold
    return spikes

export_weights()

Export quantized weights for hardware deployment.

Source code in src/sc_neurocore/qat/quantize.py
101
102
103
def export_weights(self) -> np.ndarray:
    """Export quantized weights for hardware deployment."""
    return _ste_quantize(self.W, self.weight_bits)

quantize_aware_train_step(layer, x, target, lr=0.01)

One QAT training step with STE.

Parameters

layer : QuantizedSNNLayer x : ndarray of shape (n_inputs,) target : ndarray of shape (n_neurons,) lr : float

Returns

dict with 'output', 'loss'

Source code in src/sc_neurocore/qat/quantize.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def quantize_aware_train_step(
    layer: QuantizedSNNLayer,
    x: np.ndarray,
    target: np.ndarray,
    lr: float = 0.01,
) -> dict:
    """One QAT training step with STE.

    Parameters
    ----------
    layer : QuantizedSNNLayer
    x : ndarray of shape (n_inputs,)
    target : ndarray of shape (n_neurons,)
    lr : float

    Returns
    -------
    dict with 'output', 'loss'
    """
    output = layer.forward(x)
    error = output - target
    loss = 0.5 * float(np.sum(error**2))

    # STE: gradient flows through quantization as if it weren't there
    grad_W = np.outer(error, x)
    layer.W -= lr * grad_W

    return {"output": output, "loss": loss}