Skip to content

ANN-to-SNN Conversion

Convert trained PyTorch ANNs to rate-coded spiking neural networks.

Converter

sc_neurocore.conversion.ann_to_snn

Convert trained PyTorch ANNs to rate-coded spiking neural networks.

The conversion replaces ReLU activations with IF (integrate-and-fire) neurons and uses weight/threshold normalization to preserve accuracy. Rate coding: ANN activation a maps to spike rate a/theta over T steps.

Pipeline
  1. Extract weights and biases from PyTorch Sequential model
  2. Compute per-layer activation statistics (max, percentile)
  3. Normalize weights so that max activation = threshold
  4. Build an SNN with IF neurons that reproduces the ANN output as spike counts over T timesteps

Reference: Diehl et al. 2015 — "Fast-classifying, high-accuracy spiking deep networks through weight and threshold balancing"

ConvertedSNN dataclass

Rate-coded SNN converted from an ANN.

Attributes

weights : list of ndarray Per-layer weight matrices. biases : list of ndarray or None Per-layer biases (None if absent). thresholds : list of float Per-layer firing thresholds after normalization. T : int Number of simulation timesteps. n_layers : int Number of layers.

Source code in src/sc_neurocore/conversion/ann_to_snn.py
Python
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@dataclass
class ConvertedSNN:
    """Rate-coded SNN converted from an ANN.

    Attributes
    ----------
    weights : list of ndarray
        Per-layer weight matrices.
    biases : list of ndarray or None
        Per-layer biases (None if absent).
    thresholds : list of float
        Per-layer firing thresholds after normalization.
    T : int
        Number of simulation timesteps.
    n_layers : int
        Number of layers.
    """

    weights: list[np.ndarray]
    biases: list[np.ndarray | None]
    thresholds: list[float]
    T: int
    n_layers: int = field(init=False)

    def __post_init__(self) -> None:
        self.n_layers = len(self.weights)

    def run(self, x: np.ndarray) -> np.ndarray:
        """Run the converted SNN for T timesteps on input x.

        Parameters
        ----------
        x : ndarray of shape (n_input,) or (batch, n_input)
            Input values in [0, 1]. Converted to Poisson spike trains.

        Returns
        -------
        ndarray of shape (n_output,) or (batch, n_output)
            Output spike counts over T timesteps (unnormalized).
        """
        squeeze = x.ndim == 1
        if squeeze:
            x = x[np.newaxis]

        batch = x.shape[0]
        rng = np.random.RandomState(42)

        # Initialize membrane voltages
        voltages = [np.zeros((batch, w.shape[0])) for w in self.weights]
        spike_counts = np.zeros((batch, self.weights[-1].shape[0]))

        for t in range(self.T):
            # Rate-code input: spike with probability proportional to x
            input_spikes = (rng.random(x.shape) < x).astype(np.float64)

            layer_input = input_spikes
            for i, (w, b, theta) in enumerate(zip(self.weights, self.biases, self.thresholds)):
                current = layer_input @ w.T
                if b is not None:
                    current += b / self.T
                voltages[i] += current
                spikes = (voltages[i] >= theta).astype(np.float64)
                voltages[i] -= spikes * theta
                layer_input = spikes

                if i == self.n_layers - 1:
                    spike_counts += spikes

        if squeeze:
            spike_counts = spike_counts[0]
        return spike_counts

    def classify(self, x: np.ndarray) -> np.ndarray:
        """Run SNN and return predicted class indices."""
        counts = self.run(x)
        return np.argmax(counts, axis=-1)

run(x)

Run the converted SNN for T timesteps on input x.

Parameters

x : ndarray of shape (n_input,) or (batch, n_input) Input values in [0, 1]. Converted to Poisson spike trains.

Returns

ndarray of shape (n_output,) or (batch, n_output) Output spike counts over T timesteps (unnormalized).

Source code in src/sc_neurocore/conversion/ann_to_snn.py
Python
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def run(self, x: np.ndarray) -> np.ndarray:
    """Run the converted SNN for T timesteps on input x.

    Parameters
    ----------
    x : ndarray of shape (n_input,) or (batch, n_input)
        Input values in [0, 1]. Converted to Poisson spike trains.

    Returns
    -------
    ndarray of shape (n_output,) or (batch, n_output)
        Output spike counts over T timesteps (unnormalized).
    """
    squeeze = x.ndim == 1
    if squeeze:
        x = x[np.newaxis]

    batch = x.shape[0]
    rng = np.random.RandomState(42)

    # Initialize membrane voltages
    voltages = [np.zeros((batch, w.shape[0])) for w in self.weights]
    spike_counts = np.zeros((batch, self.weights[-1].shape[0]))

    for t in range(self.T):
        # Rate-code input: spike with probability proportional to x
        input_spikes = (rng.random(x.shape) < x).astype(np.float64)

        layer_input = input_spikes
        for i, (w, b, theta) in enumerate(zip(self.weights, self.biases, self.thresholds)):
            current = layer_input @ w.T
            if b is not None:
                current += b / self.T
            voltages[i] += current
            spikes = (voltages[i] >= theta).astype(np.float64)
            voltages[i] -= spikes * theta
            layer_input = spikes

            if i == self.n_layers - 1:
                spike_counts += spikes

    if squeeze:
        spike_counts = spike_counts[0]
    return spike_counts

classify(x)

Run SNN and return predicted class indices.

Source code in src/sc_neurocore/conversion/ann_to_snn.py
Python
114
115
116
117
def classify(self, x: np.ndarray) -> np.ndarray:
    """Run SNN and return predicted class indices."""
    counts = self.run(x)
    return np.argmax(counts, axis=-1)

convert(model, calibration_data=None, T=16, percentile=99.9)

Convert a trained PyTorch ANN to a rate-coded SNN.

Parameters

model : nn.Module Trained PyTorch model (Sequential with Linear + ReLU). calibration_data : Tensor, optional Sample input batch for threshold calibration. If None, uses default threshold of 1.0 per layer. T : int Number of simulation timesteps (higher = more accurate, slower). percentile : float Activation percentile for threshold normalization.

Returns

ConvertedSNN Converted spiking network ready to run.

Source code in src/sc_neurocore/conversion/ann_to_snn.py
Python
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def convert(
    model: object,
    calibration_data: object = None,
    T: int = 16,
    percentile: float = 99.9,
) -> ConvertedSNN:
    """Convert a trained PyTorch ANN to a rate-coded SNN.

    Parameters
    ----------
    model : nn.Module
        Trained PyTorch model (Sequential with Linear + ReLU).
    calibration_data : Tensor, optional
        Sample input batch for threshold calibration. If None, uses
        default threshold of 1.0 per layer.
    T : int
        Number of simulation timesteps (higher = more accurate, slower).
    percentile : float
        Activation percentile for threshold normalization.

    Returns
    -------
    ConvertedSNN
        Converted spiking network ready to run.
    """
    if not HAS_TORCH:
        raise ImportError("PyTorch required for ANN-to-SNN conversion")

    layers = _extract_layers(model)
    if not layers:
        raise ValueError("No Linear/Conv2d layers found in model")

    weights = [w for w, _ in layers]
    biases = [b for _, b in layers]

    if calibration_data is not None:
        max_acts = _compute_max_activations(model, calibration_data, percentile)  # type: ignore[arg-type]
        # Pad if fewer ReLUs than Linear layers
        while len(max_acts) < len(weights):
            max_acts.append(1.0)
        thresholds = max_acts
    else:
        thresholds = [1.0] * len(weights)

    # Normalize weights: scale so that max activation maps to threshold
    normalized_weights = []
    prev_scale = 1.0
    for i, (w, theta) in enumerate(zip(weights, thresholds)):
        scale = theta / prev_scale if i > 0 else theta
        normalized_weights.append(w / scale)
        prev_scale = theta

    return ConvertedSNN(
        weights=normalized_weights,
        biases=biases,
        thresholds=[1.0] * len(weights),
        T=T,
    )

QCFS Activation

sc_neurocore.conversion.qcfs

QCFS (Quantization-Clip-Floor-Shift) activation function.

Replaces ReLU in the ANN during conversion-aware training or post-hoc conversion. QCFS approximates the rate-coded SNN firing rate as a quantized step function, minimizing conversion error.

Reference: Bu et al. 2022 — "Optimal ANN-SNN Conversion for High-accuracy and Ultra-low-latency Spiking Neural Networks"

QCFSActivation

Bases: Module

QCFS activation: quantized clip-floor-shift ReLU replacement.

For T timesteps and threshold theta

QCFS(x) = clip(floor(x * T / theta + 0.5), 0, T) * theta / T

This quantizes activations to T+1 levels in [0, theta], matching the achievable spike rates of an IF neuron over T timesteps.

Parameters

T : int Number of simulation timesteps. theta : float Firing threshold (default 1.0). learn_theta : bool Make threshold trainable (default False).

Source code in src/sc_neurocore/conversion/qcfs.py
Python
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class QCFSActivation(nn.Module):
    """QCFS activation: quantized clip-floor-shift ReLU replacement.

    For T timesteps and threshold theta:
        QCFS(x) = clip(floor(x * T / theta + 0.5), 0, T) * theta / T

    This quantizes activations to T+1 levels in [0, theta], matching
    the achievable spike rates of an IF neuron over T timesteps.

    Parameters
    ----------
    T : int
        Number of simulation timesteps.
    theta : float
        Firing threshold (default 1.0).
    learn_theta : bool
        Make threshold trainable (default False).
    """

    def __init__(self, T: int = 8, theta: float = 1.0, learn_theta: bool = False):
        super().__init__()
        self.T = T
        if learn_theta:
            self.theta = nn.Parameter(torch.tensor(theta))
        else:
            self.register_buffer("theta", torch.tensor(theta))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scaled = x * self.T / self.theta + 0.5
        # STE: floor in forward, pass gradient straight through
        quantized = scaled.floor() - (scaled.floor() - scaled).detach()
        clipped = quantized.clamp(0, self.T)
        return clipped * self.theta / self.T

    def extra_repr(self) -> str:
        return f"T={self.T}, theta={self.theta.item():.2f}"