Skip to content

DelayLinear API

Trainable per-synapse delays for temporal coding in SNNs.

DelayLinear is a dense layer where each synapse has a trainable weight AND a trainable delay. During forward pass, the input spike history is queried at fractional delay positions via linear interpolation, making delays differentiable. This implements the DCLS principle (Hammouamri et al. 2023) applied to fully-connected SNN layers.

Why trainable delays?

Rate coding discards temporal structure. Temporal coding (spike timing) carries more information per spike — a single precisely-timed spike encodes as much as hundreds of rate-coded spikes. Trainable delays let the network learn optimal spike timing relationships: which input spikes should arrive simultaneously (coincidence detection) and which should be staggered (sequence recognition).

Research shows delays can replace entire layers — same accuracy with fewer parameters (Hammouamri et al. 2023). SC-NeuroCore's DelayLinear makes this practical: train in PyTorch, export integer delays to FPGA via to_nir_delay_array().

Architecture

flowchart TB
    subgraph Buffer["Circular Spike History Buffer"]
        direction LR
        H0["t-0"] --- H1["t-1"] --- H2["t-2"] --- H3["t-3"] --- H4["..."] --- HN["t-max"]
    end
    subgraph Synapse["For each synapse (i → j)"]
        D["delay[j,i] = 2.3<br/>(continuous)"]
        D --> I["interp(t-2, t-3)<br/>0.7 × h[t-2] + 0.3 × h[t-3]"]
        I --> W["× weight[j,i]"]
        W --> O["output[j] += ..."]
    end
    Buffer --> Synapse

    style Buffer fill:#e1f5fe
    style Synapse fill:#fff3e0
Text Only
Interpolation detail — differentiable delay readout:

    spike history:  ──┬──○──●──○──●──○──○──●──○──
                      t  t-1 t-2 t-3 t-4 ...

    delay d = 2.3:        │←─ 2.3 ─→│
                          ▼
    interp(t-2.3) = 0.7 × history[t-2] + 0.3 × history[t-3]
                    └── floor fraction ──┘   └── ceil fraction ──┘

Gradient flows through the interpolation weights (0.7 and 0.3), telling the optimizer whether to increase or decrease each delay.

Parameters

Parameter Type Default Description
in_features int Number of input neurons
out_features int Number of output neurons
max_delay int 16 Maximum delay in timesteps
bias bool False Include bias term
learn_delay bool True Make delays trainable
init_delay float 1.0 Initial delay for all synapses

Methods

Method Returns Description
step(x) Tensor Process one timestep. x: (batch, in) or (in,)
reset() None Clear spike history. Call between sequences
delays_int LongTensor Quantized integer delays for hardware
to_nir_delay_array() ndarray Flat float64 array for Projection

Example: sequence classifier with delays

Python
import torch
import torch.nn as nn
from sc_neurocore.training import LIFCell, DelayLinear, atan_surrogate

class DelayedSNN(nn.Module):
    def __init__(self, n_in, n_hidden, n_out, max_delay=8):
        super().__init__()
        self.delay1 = DelayLinear(n_in, n_hidden, max_delay=max_delay)
        self.lif1 = LIFCell(beta=0.9)
        self.fc2 = nn.Linear(n_hidden, n_out)
        self.lif2 = LIFCell(beta=0.9)

    def forward(self, x):
        """x: (T, batch, n_in)"""
        T, batch, _ = x.shape
        v1 = torch.zeros(batch, self.delay1.out_features, device=x.device)
        v2 = torch.zeros(batch, self.fc2.out_features, device=x.device)
        spike_sum = torch.zeros(batch, self.fc2.out_features, device=x.device)

        self.delay1.reset()
        for t in range(T):
            h = self.delay1.step(x[t])
            spike, v1 = self.lif1(h, v1)
            h = self.fc2(spike)
            spike, v2 = self.lif2(h, v2)
            spike_sum += spike

        return spike_sum

model = DelayedSNN(n_in=16, n_hidden=64, n_out=5, max_delay=8)
x = torch.randn(30, 4, 16)  # T=30, batch=4
out = model(x)  # (4, 5)

Hardware export

Python
# After training
int_delays = model.delay1.delays_int  # (n_hidden, n_in) integer tensor
nir_delays = model.delay1.to_nir_delay_array()  # flat float64

# Use with network engine
from sc_neurocore.network.projection import Projection
proj = Projection(src_pop, tgt_pop, weight=0.1, delay=nir_delays)

Deployable checkpoint selection

For FPGA or ASIC deployment, the checkpoint-selection metric must match the hardware constraint. A delay model trained with fractional positions and a wide interpolation kernel can score well in native PyTorch validation while still being a poor hardware candidate after delays are rounded to integer timesteps. The deployable validation path therefore uses the same conditions as exported RTL:

  1. save the current training delay and kernel state;
  2. round trainable delay positions to integer timesteps;
  3. set the DCLS max kernel bandwidth to SIG=0;
  4. run validation and record the score as fpga_val_acc;
  5. restore the training state before continuing optimisation.

data/masquelier_shd/train_dcls_max.py follows this protocol for the SHD dcls_max experiments. The script writes both native validation accuracy and fpga_val_acc to training_log.csv; best.pth is selected by fpga_val_acc, while native validation remains a diagnostic ceiling.

The upstream SHD evaluator rounds model positions in place during evaluation, so the script saves and restores the full model state around native validation and test scoring unless persistent rounding is explicitly enabled. This keeps deployable scoring separate from the optimiser trajectory.

Relevant environment variables:

Variable Default Purpose
SHD_SIGMA_INIT 15.0 Start of the DCLS max bandwidth schedule.
SHD_SIGMA_FINAL 0.0 End of the bandwidth schedule.
SHD_ROUND_EACH_EPOCH 0 Set to 1 to persistently round delays after every epoch.
SHD_SEED config seed Override deterministic training seed for sweeps.
SHD_OUTPUT_SUBDIR dcls_max Output directory below exp/SHD/SNN_axonal_feedforward_delays/.

References

  • Hammouamri, Xiloyannis, Bhatt, Bhattacharyya & Bhatt, "Learning Delays in Spiking Neural Networks using Dilated Convolutions with Learnable Spacings", ICLR 2023
  • Göltz, Kriener, Baumbach, Billaudelle, Breitwieser, Cramer, Dold, Kungl, Senn, Schemmel, Meier & Petrovici, "DelGrad: Exact Gradients in Spiking Networks for Learning Transmission Delays and Weights", arXiv 2024
  • Sun, Zeng, Fang & Li, "Learnable Axonal Delay in Spiking Neural Networks for Adaptive Temporal Representation", AAAI 2025

See Tutorial 39: Learnable Delays and Training API Reference.

sc_neurocore.training.delay_linear

Learnable-delay linear layer for surrogate gradient SNN training.

Implements the DCLS (Dilated Convolutions with Learnable Spacings) principle applied to fully-connected SNN layers. Each synapse has a trainable weight AND a trainable delay. During forward pass, the input spike history is queried at fractional delay positions via linear interpolation, making delays differentiable.

References

Hammouamri et al. 2023 — "Learning Delays in SNNs using DCLS" Göltz et al. 2025 — "DelGrad: exact gradients for delays on BrainScaleS-2"

DelayLinear

Bases: Module

Linear layer with trainable per-synapse delays.

Parameters

in_features : int Number of input neurons. out_features : int Number of output neurons. max_delay : int Maximum delay in timesteps. Delay buffer stores this many steps. bias : bool Include bias term (default False for SNN). learn_delay : bool Make delays trainable (default True). init_delay : float Initial delay value for all synapses (default 1.0).

Forward pass

Call step(input_spikes) at each timestep. The module maintains an internal spike history buffer. For each synapse (i, j):

Text Only
delayed_input[j] = sum_i W[j,i] * interp(history, t - D[j,i])

where interp linearly interpolates between integer delay bins, making D differentiable.

Export

delays_int returns quantized integer delays for hardware deployment. to_nir_delay_array() returns delays in the CSR format expected by Projection(delay=array).

Source code in src/sc_neurocore/training/delay_linear.py
Python
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class DelayLinear(nn.Module):
    """Linear layer with trainable per-synapse delays.

    Parameters
    ----------
    in_features : int
        Number of input neurons.
    out_features : int
        Number of output neurons.
    max_delay : int
        Maximum delay in timesteps. Delay buffer stores this many steps.
    bias : bool
        Include bias term (default False for SNN).
    learn_delay : bool
        Make delays trainable (default True).
    init_delay : float
        Initial delay value for all synapses (default 1.0).

    Forward pass
    -------------
    Call ``step(input_spikes)`` at each timestep. The module maintains
    an internal spike history buffer. For each synapse (i, j):

        delayed_input[j] = sum_i W[j,i] * interp(history, t - D[j,i])

    where interp linearly interpolates between integer delay bins,
    making D differentiable.

    Export
    ------
    ``delays_int`` returns quantized integer delays for hardware deployment.
    ``to_nir_delay_array()`` returns delays in the CSR format expected by
    ``Projection(delay=array)``.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        max_delay: int = 16,
        bias: bool = False,
        learn_delay: bool = True,
        init_delay: float = 1.0,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.max_delay = max_delay

        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        # Delays stored as continuous values in [0, max_delay)
        delay_init = torch.full((out_features, in_features), init_delay)
        if learn_delay:
            self.delay = nn.Parameter(delay_init)
        else:
            self.register_buffer("delay", delay_init)

        # Spike history buffer: (max_delay + 1, in_features)
        self.register_buffer("_history", torch.zeros(max_delay + 1, in_features))
        self._t = 0

    def reset(self) -> None:
        """Clear spike history. Call between sequences."""
        self._history.zero_()  # type: ignore[operator]
        self._t = 0

    def step(self, x: torch.Tensor) -> torch.Tensor:
        """Process one timestep.

        Parameters
        ----------
        x : Tensor of shape (batch, in_features) or (in_features,)
            Input spikes (binary or continuous).

        Returns
        -------
        Tensor of shape (batch, out_features) or (out_features,)
            Weighted, delayed input current.
        """
        squeeze = x.dim() == 1
        if squeeze:
            x = x.unsqueeze(0)

        batch_size = x.shape[0]
        buf_len = self.max_delay + 1

        # Store current input in history (use first batch element for buffer)
        write_idx = self._t % buf_len
        self._history[write_idx] = x[0].detach()  # type: ignore[operator]

        # Clamp delays to valid range
        d = self.delay.clamp(0, self.max_delay - 1e-6)

        # Integer floor and ceil indices
        d_floor = d.long()
        d_ceil = (d_floor + 1).clamp(max=self.max_delay)
        frac = d - d_floor.float()

        # Read from history at delayed positions
        # idx_floor[j, i] = (current_t - d_floor[j, i]) % buf_len
        idx_floor = (self._t - d_floor) % buf_len
        idx_ceil = (self._t - d_ceil) % buf_len

        # Gather delayed spikes via linear interpolation
        # history shape: (buf_len, in_features)
        # We need history[idx[j,i], i] for each (j, i)
        hist_floor = self._history[idx_floor, torch.arange(self.in_features).unsqueeze(0)]  # type: ignore[index]
        hist_ceil = self._history[idx_ceil, torch.arange(self.in_features).unsqueeze(0)]  # type: ignore[index]
        delayed_x = (1 - frac) * hist_floor + frac * hist_ceil

        # Weighted sum: out[j] = sum_i W[j,i] * delayed_x[j,i]
        output = (self.weight * delayed_x).sum(dim=1)
        if self.bias is not None:
            output = output + self.bias

        self._t += 1

        # Broadcast to batch
        output = output.unsqueeze(0).expand(batch_size, -1)
        if squeeze:
            output = output.squeeze(0)
        return output

    @property
    def delays_int(self) -> torch.Tensor:
        """Quantized integer delays for hardware export."""
        with torch.no_grad():
            return self.delay.clamp(0, self.max_delay).round().long()

    def to_nir_delay_array(self) -> Any:
        """Export delays as flat array matching CSR data order.

        Returns delays in the same order as weights flattened row-major,
        suitable for ``Projection(delay=array)``.
        """
        import numpy as np

        return self.delays_int.detach().cpu().numpy().flatten().astype(np.float64)

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"max_delay={self.max_delay}, learn_delay={isinstance(self.delay, nn.Parameter)}"
        )

delays_int property

Quantized integer delays for hardware export.

reset()

Clear spike history. Call between sequences.

Source code in src/sc_neurocore/training/delay_linear.py
Python
96
97
98
99
def reset(self) -> None:
    """Clear spike history. Call between sequences."""
    self._history.zero_()  # type: ignore[operator]
    self._t = 0

step(x)

Process one timestep.

Parameters

x : Tensor of shape (batch, in_features) or (in_features,) Input spikes (binary or continuous).

Returns

Tensor of shape (batch, out_features) or (out_features,) Weighted, delayed input current.

Source code in src/sc_neurocore/training/delay_linear.py
Python
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def step(self, x: torch.Tensor) -> torch.Tensor:
    """Process one timestep.

    Parameters
    ----------
    x : Tensor of shape (batch, in_features) or (in_features,)
        Input spikes (binary or continuous).

    Returns
    -------
    Tensor of shape (batch, out_features) or (out_features,)
        Weighted, delayed input current.
    """
    squeeze = x.dim() == 1
    if squeeze:
        x = x.unsqueeze(0)

    batch_size = x.shape[0]
    buf_len = self.max_delay + 1

    # Store current input in history (use first batch element for buffer)
    write_idx = self._t % buf_len
    self._history[write_idx] = x[0].detach()  # type: ignore[operator]

    # Clamp delays to valid range
    d = self.delay.clamp(0, self.max_delay - 1e-6)

    # Integer floor and ceil indices
    d_floor = d.long()
    d_ceil = (d_floor + 1).clamp(max=self.max_delay)
    frac = d - d_floor.float()

    # Read from history at delayed positions
    # idx_floor[j, i] = (current_t - d_floor[j, i]) % buf_len
    idx_floor = (self._t - d_floor) % buf_len
    idx_ceil = (self._t - d_ceil) % buf_len

    # Gather delayed spikes via linear interpolation
    # history shape: (buf_len, in_features)
    # We need history[idx[j,i], i] for each (j, i)
    hist_floor = self._history[idx_floor, torch.arange(self.in_features).unsqueeze(0)]  # type: ignore[index]
    hist_ceil = self._history[idx_ceil, torch.arange(self.in_features).unsqueeze(0)]  # type: ignore[index]
    delayed_x = (1 - frac) * hist_floor + frac * hist_ceil

    # Weighted sum: out[j] = sum_i W[j,i] * delayed_x[j,i]
    output = (self.weight * delayed_x).sum(dim=1)
    if self.bias is not None:
        output = output + self.bias

    self._t += 1

    # Broadcast to batch
    output = output.unsqueeze(0).expand(batch_size, -1)
    if squeeze:
        output = output.squeeze(0)
    return output

to_nir_delay_array()

Export delays as flat array matching CSR data order.

Returns delays in the same order as weights flattened row-major, suitable for Projection(delay=array).

Source code in src/sc_neurocore/training/delay_linear.py
Python
164
165
166
167
168
169
170
171
172
def to_nir_delay_array(self) -> Any:
    """Export delays as flat array matching CSR data order.

    Returns delays in the same order as weights flattened row-major,
    suitable for ``Projection(delay=array)``.
    """
    import numpy as np

    return self.delays_int.detach().cpu().numpy().flatten().astype(np.float64)