Skip to content

Quantization-Aware Training — STE for Hardware Deployment

Train SNNs through quantization using straight-through estimators (STE). The missing link between training and FPGA deployment: weights are quantized in the forward pass but maintain full precision in the backward pass.

How STE Works

Standard quantization is non-differentiable (rounding has zero gradient almost everywhere). The straight-through estimator passes the gradient through quantization as if it weren't there:

  • Forward: W_q = round(W / scale) * scale (quantized)
  • Backward: ∂L/∂W = ∂L/∂W_q (identity, as if no quantization)

This trains weights to be robust to their own quantization noise. At export time, weights are already at target precision.

Components

  • QuantizedSNNLayer — SNN layer with quantization-aware forward pass.
Parameter Default Meaning
n_inputs (required) Input dimension
n_neurons (required) Output dimension
weight_bits 8 Target weight precision (2, 4, 8, 16)
threshold 1.0 LIF spike threshold
tau_mem 20.0 Membrane time constant
  • TernaryWeights — Ternary quantization: {-1, 0, +1}. 94% memory reduction. Weights with |w| < threshold_ratio * mean(|w|) become 0.
  • quantize_aware_train_step — One QAT training step with STE gradient flow. Returns {'output', 'loss'}.
  • _ste_quantize — Core quantization function. Supports symmetric and asymmetric modes.

Usage

Python
from sc_neurocore.qat import QuantizedSNNLayer, quantize_aware_train_step, TernaryWeights
import numpy as np

# Create QAT layer
layer = QuantizedSNNLayer(n_inputs=784, n_neurons=128, weight_bits=8)

# Training loop with STE
for epoch in range(100):
    result = quantize_aware_train_step(layer, x_train, y_target, lr=0.01)
    print(f"Loss: {result['loss']:.4f}")

# Export hardware-ready weights (already quantized to 8-bit)
hw_weights = layer.export_weights()

# Ternary quantization for extreme compression
tw = TernaryWeights(threshold_ratio=0.7)
ternary = tw.quantize(layer.W)
print(f"Sparsity: {tw.sparsity(layer.W):.1%}")  # ~50-70% zeros

References: QP-SNN (ICLR 2025), SpikeFit (EurIPS 2025).

See Tutorial 77: QAT.

sc_neurocore.qat.quantize

Train SNNs through quantization using straight-through estimators.

Missing link between training and hardware deployment. No SNN library ships QAT as a reusable module.

Reference: QP-SNN (ICLR 2025), SpikeFit (EurIPS 2025)

QuantizedSNNLayer dataclass

SNN layer with quantization-aware forward pass.

During training: weights quantized in forward, full-precision in backward (STE). At export: weights are already at target precision.

Parameters

n_inputs : int n_neurons : int weight_bits : int Target weight precision (2, 4, 8, 16). threshold : float tau_mem : float

Source code in src/sc_neurocore/qat/quantize.py
Python
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@dataclass
class QuantizedSNNLayer:
    """SNN layer with quantization-aware forward pass.

    During training: weights quantized in forward, full-precision in backward (STE).
    At export: weights are already at target precision.

    Parameters
    ----------
    n_inputs : int
    n_neurons : int
    weight_bits : int
        Target weight precision (2, 4, 8, 16).
    threshold : float
    tau_mem : float
    """

    n_inputs: int
    n_neurons: int
    weight_bits: int = 8
    threshold: float = 1.0
    tau_mem: float = 20.0

    def __post_init__(self) -> None:
        rng = np.random.RandomState(42)
        self.W = rng.randn(self.n_neurons, self.n_inputs) * np.sqrt(2.0 / self.n_inputs)
        self._v = np.zeros(self.n_neurons)

    def forward(self, x: np.ndarray, dt: float = 1.0) -> np.ndarray:
        """Quantization-aware forward pass."""
        W_q = _ste_quantize(self.W, self.weight_bits)
        alpha = np.exp(-dt / self.tau_mem)
        current = W_q @ x
        self._v = alpha * self._v + (1 - alpha) * current
        spikes = (self._v >= self.threshold).astype(np.float64)
        self._v -= spikes * self.threshold
        return spikes

    def export_weights(self) -> np.ndarray:
        """Export quantized weights for hardware deployment."""
        return _ste_quantize(self.W, self.weight_bits)

    def reset(self) -> None:  # pragma: no cover
        self._v = np.zeros(self.n_neurons)

forward(x, dt=1.0)

Quantization-aware forward pass.

Source code in src/sc_neurocore/qat/quantize.py
Python
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(self, x: np.ndarray, dt: float = 1.0) -> np.ndarray:
    """Quantization-aware forward pass."""
    W_q = _ste_quantize(self.W, self.weight_bits)
    alpha = np.exp(-dt / self.tau_mem)
    current = W_q @ x
    self._v = alpha * self._v + (1 - alpha) * current
    spikes = (self._v >= self.threshold).astype(np.float64)
    self._v -= spikes * self.threshold
    return spikes

export_weights()

Export quantized weights for hardware deployment.

Source code in src/sc_neurocore/qat/quantize.py
Python
102
103
104
def export_weights(self) -> np.ndarray:
    """Export quantized weights for hardware deployment."""
    return _ste_quantize(self.W, self.weight_bits)

TernaryWeights

Ternary weight quantization: {-1, 0, +1}.

94% memory reduction. Each weight is one of three values. Threshold-based: weights with |w| < threshold become 0.

Parameters

threshold_ratio : float Fraction of max(|w|) below which weights are zeroed.

Source code in src/sc_neurocore/qat/quantize.py
Python
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class TernaryWeights:
    """Ternary weight quantization: {-1, 0, +1}.

    94% memory reduction. Each weight is one of three values.
    Threshold-based: weights with |w| < threshold become 0.

    Parameters
    ----------
    threshold_ratio : float
        Fraction of max(|w|) below which weights are zeroed.
    """

    def __init__(self, threshold_ratio: float = 0.7):
        self.threshold_ratio = threshold_ratio

    def quantize(self, weights: np.ndarray) -> np.ndarray:
        threshold = self.threshold_ratio * np.mean(np.abs(weights))
        ternary = np.zeros_like(weights)
        ternary[weights > threshold] = 1.0
        ternary[weights < -threshold] = -1.0
        return ternary

    def sparsity(self, weights: np.ndarray) -> float:
        t = self.quantize(weights)
        return float(np.mean(t == 0))

quantize_aware_train_step(layer, x, target, lr=0.01)

One QAT training step with STE.

Parameters

layer : QuantizedSNNLayer x : ndarray of shape (n_inputs,) target : ndarray of shape (n_neurons,) lr : float

Returns

dict with 'output', 'loss'

Source code in src/sc_neurocore/qat/quantize.py
Python
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def quantize_aware_train_step(
    layer: QuantizedSNNLayer,
    x: np.ndarray,
    target: np.ndarray,
    lr: float = 0.01,
) -> dict[str, object]:
    """One QAT training step with STE.

    Parameters
    ----------
    layer : QuantizedSNNLayer
    x : ndarray of shape (n_inputs,)
    target : ndarray of shape (n_neurons,)
    lr : float

    Returns
    -------
    dict with 'output', 'loss'
    """
    output = layer.forward(x)
    error = output - target
    loss = 0.5 * float(np.sum(error**2))

    # STE: gradient flows through quantization as if it weren't there
    grad_W = np.outer(error, x)
    layer.W -= lr * grad_W

    return {"output": output, "loss": loss}