Skip to content

ANN-to-SNN Conversion

Convert trained PyTorch ANNs to rate-coded spiking neural networks.

Converter

sc_neurocore.conversion.ann_to_snn

Convert trained PyTorch ANNs to rate-coded spiking neural networks.

The conversion replaces ReLU activations with IF (integrate-and-fire) neurons and uses weight/threshold normalization to preserve accuracy. Rate coding: ANN activation a maps to spike rate a/theta over T steps.

Pipeline
  1. Extract weights and biases from PyTorch Sequential model
  2. Compute per-layer activation statistics (max, percentile)
  3. Normalize weights so that max activation = threshold
  4. Build an SNN with IF neurons that reproduces the ANN output as spike counts over T timesteps

Reference: Diehl et al. 2015 — "Fast-classifying, high-accuracy spiking deep networks through weight and threshold balancing"

ConvertedSNN dataclass

Rate-coded SNN converted from an ANN.

Attributes

weights : list of ndarray Per-layer weight matrices. biases : list of ndarray or None Per-layer biases (None if absent). thresholds : list of float Per-layer firing thresholds after normalization. T : int Number of simulation timesteps. n_layers : int Number of layers.

Source code in src/sc_neurocore/conversion/ann_to_snn.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@dataclass
class ConvertedSNN:
    """Rate-coded SNN converted from an ANN.

    Attributes
    ----------
    weights : list of ndarray
        Per-layer weight matrices.
    biases : list of ndarray or None
        Per-layer biases (None if absent).
    thresholds : list of float
        Per-layer firing thresholds after normalization.
    T : int
        Number of simulation timesteps.
    n_layers : int
        Number of layers.
    """

    weights: list[np.ndarray]
    biases: list[np.ndarray | None]
    thresholds: list[float]
    T: int
    n_layers: int = field(init=False)

    def __post_init__(self):
        self.n_layers = len(self.weights)

    def run(self, x: np.ndarray) -> np.ndarray:
        """Run the converted SNN for T timesteps on input x.

        Parameters
        ----------
        x : ndarray of shape (n_input,) or (batch, n_input)
            Input values in [0, 1]. Converted to Poisson spike trains.

        Returns
        -------
        ndarray of shape (n_output,) or (batch, n_output)
            Output spike counts over T timesteps (unnormalized).
        """
        squeeze = x.ndim == 1
        if squeeze:
            x = x[np.newaxis]

        batch = x.shape[0]
        rng = np.random.RandomState(42)

        # Initialize membrane voltages
        voltages = [np.zeros((batch, w.shape[0])) for w in self.weights]
        spike_counts = np.zeros((batch, self.weights[-1].shape[0]))

        for t in range(self.T):
            # Rate-code input: spike with probability proportional to x
            input_spikes = (rng.random(x.shape) < x).astype(np.float64)

            layer_input = input_spikes
            for i, (w, b, theta) in enumerate(zip(self.weights, self.biases, self.thresholds)):
                current = layer_input @ w.T
                if b is not None:
                    current += b / self.T
                voltages[i] += current
                spikes = (voltages[i] >= theta).astype(np.float64)
                voltages[i] -= spikes * theta
                layer_input = spikes

                if i == self.n_layers - 1:
                    spike_counts += spikes

        if squeeze:
            spike_counts = spike_counts[0]
        return spike_counts

    def classify(self, x: np.ndarray) -> np.ndarray:
        """Run SNN and return predicted class indices."""
        counts = self.run(x)
        return np.argmax(counts, axis=-1)

run(x)

Run the converted SNN for T timesteps on input x.

Parameters

x : ndarray of shape (n_input,) or (batch, n_input) Input values in [0, 1]. Converted to Poisson spike trains.

Returns

ndarray of shape (n_output,) or (batch, n_output) Output spike counts over T timesteps (unnormalized).

Source code in src/sc_neurocore/conversion/ann_to_snn.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def run(self, x: np.ndarray) -> np.ndarray:
    """Run the converted SNN for T timesteps on input x.

    Parameters
    ----------
    x : ndarray of shape (n_input,) or (batch, n_input)
        Input values in [0, 1]. Converted to Poisson spike trains.

    Returns
    -------
    ndarray of shape (n_output,) or (batch, n_output)
        Output spike counts over T timesteps (unnormalized).
    """
    squeeze = x.ndim == 1
    if squeeze:
        x = x[np.newaxis]

    batch = x.shape[0]
    rng = np.random.RandomState(42)

    # Initialize membrane voltages
    voltages = [np.zeros((batch, w.shape[0])) for w in self.weights]
    spike_counts = np.zeros((batch, self.weights[-1].shape[0]))

    for t in range(self.T):
        # Rate-code input: spike with probability proportional to x
        input_spikes = (rng.random(x.shape) < x).astype(np.float64)

        layer_input = input_spikes
        for i, (w, b, theta) in enumerate(zip(self.weights, self.biases, self.thresholds)):
            current = layer_input @ w.T
            if b is not None:
                current += b / self.T
            voltages[i] += current
            spikes = (voltages[i] >= theta).astype(np.float64)
            voltages[i] -= spikes * theta
            layer_input = spikes

            if i == self.n_layers - 1:
                spike_counts += spikes

    if squeeze:
        spike_counts = spike_counts[0]
    return spike_counts

classify(x)

Run SNN and return predicted class indices.

Source code in src/sc_neurocore/conversion/ann_to_snn.py
112
113
114
115
def classify(self, x: np.ndarray) -> np.ndarray:
    """Run SNN and return predicted class indices."""
    counts = self.run(x)
    return np.argmax(counts, axis=-1)

convert(model, calibration_data=None, T=16, percentile=99.9)

Convert a trained PyTorch ANN to a rate-coded SNN.

Parameters

model : nn.Module Trained PyTorch model (Sequential with Linear + ReLU). calibration_data : Tensor, optional Sample input batch for threshold calibration. If None, uses default threshold of 1.0 per layer. T : int Number of simulation timesteps (higher = more accurate, slower). percentile : float Activation percentile for threshold normalization.

Returns

ConvertedSNN Converted spiking network ready to run.

Source code in src/sc_neurocore/conversion/ann_to_snn.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def convert(
    model,
    calibration_data=None,
    T: int = 16,
    percentile: float = 99.9,
) -> ConvertedSNN:
    """Convert a trained PyTorch ANN to a rate-coded SNN.

    Parameters
    ----------
    model : nn.Module
        Trained PyTorch model (Sequential with Linear + ReLU).
    calibration_data : Tensor, optional
        Sample input batch for threshold calibration. If None, uses
        default threshold of 1.0 per layer.
    T : int
        Number of simulation timesteps (higher = more accurate, slower).
    percentile : float
        Activation percentile for threshold normalization.

    Returns
    -------
    ConvertedSNN
        Converted spiking network ready to run.
    """
    if not HAS_TORCH:
        raise ImportError("PyTorch required for ANN-to-SNN conversion")

    layers = _extract_layers(model)
    if not layers:
        raise ValueError("No Linear/Conv2d layers found in model")

    weights = [w for w, _ in layers]
    biases = [b for _, b in layers]

    if calibration_data is not None:
        max_acts = _compute_max_activations(model, calibration_data, percentile)
        # Pad if fewer ReLUs than Linear layers
        while len(max_acts) < len(weights):
            max_acts.append(1.0)
        thresholds = max_acts
    else:
        thresholds = [1.0] * len(weights)

    # Normalize weights: scale so that max activation maps to threshold
    normalized_weights = []
    prev_scale = 1.0
    for i, (w, theta) in enumerate(zip(weights, thresholds)):
        scale = theta / prev_scale if i > 0 else theta
        normalized_weights.append(w / scale)
        prev_scale = theta

    return ConvertedSNN(
        weights=normalized_weights,
        biases=biases,
        thresholds=[1.0] * len(weights),
        T=T,
    )

QCFS Activation

sc_neurocore.conversion.qcfs

QCFS (Quantization-Clip-Floor-Shift) activation function.

Replaces ReLU in the ANN during conversion-aware training or post-hoc conversion. QCFS approximates the rate-coded SNN firing rate as a quantized step function, minimizing conversion error.

Reference: Bu et al. 2022 — "Optimal ANN-SNN Conversion for High-accuracy and Ultra-low-latency Spiking Neural Networks"

QCFSActivation

Bases: Module

QCFS activation: quantized clip-floor-shift ReLU replacement.

For T timesteps and threshold theta

QCFS(x) = clip(floor(x * T / theta + 0.5), 0, T) * theta / T

This quantizes activations to T+1 levels in [0, theta], matching the achievable spike rates of an IF neuron over T timesteps.

Parameters

T : int Number of simulation timesteps. theta : float Firing threshold (default 1.0). learn_theta : bool Make threshold trainable (default False).

Source code in src/sc_neurocore/conversion/qcfs.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class QCFSActivation(nn.Module):
    """QCFS activation: quantized clip-floor-shift ReLU replacement.

    For T timesteps and threshold theta:
        QCFS(x) = clip(floor(x * T / theta + 0.5), 0, T) * theta / T

    This quantizes activations to T+1 levels in [0, theta], matching
    the achievable spike rates of an IF neuron over T timesteps.

    Parameters
    ----------
    T : int
        Number of simulation timesteps.
    theta : float
        Firing threshold (default 1.0).
    learn_theta : bool
        Make threshold trainable (default False).
    """

    def __init__(self, T: int = 8, theta: float = 1.0, learn_theta: bool = False):
        super().__init__()
        self.T = T
        if learn_theta:
            self.theta = nn.Parameter(torch.tensor(theta))
        else:
            self.register_buffer("theta", torch.tensor(theta))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scaled = x * self.T / self.theta + 0.5
        # STE: floor in forward, pass gradient straight through
        quantized = scaled.floor() - (scaled.floor() - scaled).detach()
        clipped = quantized.clamp(0, self.T)
        return clipped * self.theta / self.T

    def extra_repr(self) -> str:
        return f"T={self.T}, theta={self.theta.item():.2f}"