Skip to content

DelayLinear API

Trainable per-synapse delays for temporal coding in SNNs.

DelayLinear is a dense layer where each synapse has a trainable weight AND a trainable delay. During forward pass, the input spike history is queried at fractional delay positions via linear interpolation, making delays differentiable. This implements the DCLS principle (Hammouamri et al. 2023) applied to fully-connected SNN layers.

Why trainable delays?

Rate coding discards temporal structure. Temporal coding (spike timing) carries more information per spike — a single precisely-timed spike encodes as much as hundreds of rate-coded spikes. Trainable delays let the network learn optimal spike timing relationships: which input spikes should arrive simultaneously (coincidence detection) and which should be staggered (sequence recognition).

Research shows delays can replace entire layers — same accuracy with fewer parameters (Hammouamri et al. 2023). SC-NeuroCore's DelayLinear makes this practical: train in PyTorch, export integer delays to FPGA via to_nir_delay_array().

Architecture

flowchart TB
    subgraph Buffer["Circular Spike History Buffer"]
        direction LR
        H0["t-0"] --- H1["t-1"] --- H2["t-2"] --- H3["t-3"] --- H4["..."] --- HN["t-max"]
    end
    subgraph Synapse["For each synapse (i → j)"]
        D["delay[j,i] = 2.3<br/>(continuous)"]
        D --> I["interp(t-2, t-3)<br/>0.7 × h[t-2] + 0.3 × h[t-3]"]
        I --> W["× weight[j,i]"]
        W --> O["output[j] += ..."]
    end
    Buffer --> Synapse

    style Buffer fill:#e1f5fe
    style Synapse fill:#fff3e0
Interpolation detail — differentiable delay readout:

    spike history:  ──┬──○──●──○──●──○──○──●──○──
                      t  t-1 t-2 t-3 t-4 ...

    delay d = 2.3:        │←─ 2.3 ─→│
                          ▼
    interp(t-2.3) = 0.7 × history[t-2] + 0.3 × history[t-3]
                    └── floor fraction ──┘   └── ceil fraction ──┘

Gradient flows through the interpolation weights (0.7 and 0.3), telling the optimizer whether to increase or decrease each delay.

Parameters

Parameter Type Default Description
in_features int Number of input neurons
out_features int Number of output neurons
max_delay int 16 Maximum delay in timesteps
bias bool False Include bias term
learn_delay bool True Make delays trainable
init_delay float 1.0 Initial delay for all synapses

Methods

Method Returns Description
step(x) Tensor Process one timestep. x: (batch, in) or (in,)
reset() None Clear spike history. Call between sequences
delays_int LongTensor Quantized integer delays for hardware
to_nir_delay_array() ndarray Flat float64 array for Projection

Example: sequence classifier with delays

import torch
import torch.nn as nn
from sc_neurocore.training import LIFCell, DelayLinear, atan_surrogate

class DelayedSNN(nn.Module):
    def __init__(self, n_in, n_hidden, n_out, max_delay=8):
        super().__init__()
        self.delay1 = DelayLinear(n_in, n_hidden, max_delay=max_delay)
        self.lif1 = LIFCell(beta=0.9)
        self.fc2 = nn.Linear(n_hidden, n_out)
        self.lif2 = LIFCell(beta=0.9)

    def forward(self, x):
        """x: (T, batch, n_in)"""
        T, batch, _ = x.shape
        v1 = torch.zeros(batch, self.delay1.out_features, device=x.device)
        v2 = torch.zeros(batch, self.fc2.out_features, device=x.device)
        spike_sum = torch.zeros(batch, self.fc2.out_features, device=x.device)

        self.delay1.reset()
        for t in range(T):
            h = self.delay1.step(x[t])
            spike, v1 = self.lif1(h, v1)
            h = self.fc2(spike)
            spike, v2 = self.lif2(h, v2)
            spike_sum += spike

        return spike_sum

model = DelayedSNN(n_in=16, n_hidden=64, n_out=5, max_delay=8)
x = torch.randn(30, 4, 16)  # T=30, batch=4
out = model(x)  # (4, 5)

Hardware export

# After training
int_delays = model.delay1.delays_int  # (n_hidden, n_in) integer tensor
nir_delays = model.delay1.to_nir_delay_array()  # flat float64

# Use with network engine
from sc_neurocore.network.projection import Projection
proj = Projection(src_pop, tgt_pop, weight=0.1, delay=nir_delays)

References

  • Hammouamri, Xiloyannis, Bhatt, Bhattacharyya & Bhatt, "Learning Delays in Spiking Neural Networks using Dilated Convolutions with Learnable Spacings", ICLR 2023
  • Göltz, Kriener, Baumbach, Billaudelle, Breitwieser, Cramer, Dold, Kungl, Senn, Schemmel, Meier & Petrovici, "DelGrad: Exact Gradients in Spiking Networks for Learning Transmission Delays and Weights", arXiv 2024
  • Sun, Zeng, Fang & Li, "Learnable Axonal Delay in Spiking Neural Networks for Adaptive Temporal Representation", AAAI 2025

See Tutorial 39: Learnable Delays and Training API Reference.

sc_neurocore.training.delay_linear

Learnable-delay linear layer for surrogate gradient SNN training.

Implements the DCLS (Dilated Convolutions with Learnable Spacings) principle applied to fully-connected SNN layers. Each synapse has a trainable weight AND a trainable delay. During forward pass, the input spike history is queried at fractional delay positions via linear interpolation, making delays differentiable.

References

Hammouamri et al. 2023 — "Learning Delays in SNNs using DCLS" Göltz et al. 2025 — "DelGrad: exact gradients for delays on BrainScaleS-2"

DelayLinear

Bases: Module

Linear layer with trainable per-synapse delays.

Parameters

in_features : int Number of input neurons. out_features : int Number of output neurons. max_delay : int Maximum delay in timesteps. Delay buffer stores this many steps. bias : bool Include bias term (default False for SNN). learn_delay : bool Make delays trainable (default True). init_delay : float Initial delay value for all synapses (default 1.0).

Forward pass

Call step(input_spikes) at each timestep. The module maintains an internal spike history buffer. For each synapse (i, j):

delayed_input[j] = sum_i W[j,i] * interp(history, t - D[j,i])

where interp linearly interpolates between integer delay bins, making D differentiable.

Export

delays_int returns quantized integer delays for hardware deployment. to_nir_delay_array() returns delays in the CSR format expected by Projection(delay=array).

Source code in src/sc_neurocore/training/delay_linear.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class DelayLinear(nn.Module):
    """Linear layer with trainable per-synapse delays.

    Parameters
    ----------
    in_features : int
        Number of input neurons.
    out_features : int
        Number of output neurons.
    max_delay : int
        Maximum delay in timesteps. Delay buffer stores this many steps.
    bias : bool
        Include bias term (default False for SNN).
    learn_delay : bool
        Make delays trainable (default True).
    init_delay : float
        Initial delay value for all synapses (default 1.0).

    Forward pass
    -------------
    Call ``step(input_spikes)`` at each timestep. The module maintains
    an internal spike history buffer. For each synapse (i, j):

        delayed_input[j] = sum_i W[j,i] * interp(history, t - D[j,i])

    where interp linearly interpolates between integer delay bins,
    making D differentiable.

    Export
    ------
    ``delays_int`` returns quantized integer delays for hardware deployment.
    ``to_nir_delay_array()`` returns delays in the CSR format expected by
    ``Projection(delay=array)``.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        max_delay: int = 16,
        bias: bool = False,
        learn_delay: bool = True,
        init_delay: float = 1.0,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.max_delay = max_delay

        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        # Delays stored as continuous values in [0, max_delay)
        delay_init = torch.full((out_features, in_features), init_delay)
        if learn_delay:
            self.delay = nn.Parameter(delay_init)
        else:
            self.register_buffer("delay", delay_init)

        # Spike history buffer: (max_delay + 1, in_features)
        self.register_buffer("_history", torch.zeros(max_delay + 1, in_features))
        self._t = 0

    def reset(self):
        """Clear spike history. Call between sequences."""
        self._history.zero_()
        self._t = 0

    def step(self, x: torch.Tensor) -> torch.Tensor:
        """Process one timestep.

        Parameters
        ----------
        x : Tensor of shape (batch, in_features) or (in_features,)
            Input spikes (binary or continuous).

        Returns
        -------
        Tensor of shape (batch, out_features) or (out_features,)
            Weighted, delayed input current.
        """
        squeeze = x.dim() == 1
        if squeeze:
            x = x.unsqueeze(0)

        batch_size = x.shape[0]
        buf_len = self.max_delay + 1

        # Store current input in history (use first batch element for buffer)
        write_idx = self._t % buf_len
        self._history[write_idx] = x[0].detach()

        # Clamp delays to valid range
        d = self.delay.clamp(0, self.max_delay - 1e-6)

        # Integer floor and ceil indices
        d_floor = d.long()
        d_ceil = (d_floor + 1).clamp(max=self.max_delay)
        frac = d - d_floor.float()

        # Read from history at delayed positions
        # idx_floor[j, i] = (current_t - d_floor[j, i]) % buf_len
        idx_floor = (self._t - d_floor) % buf_len
        idx_ceil = (self._t - d_ceil) % buf_len

        # Gather delayed spikes via linear interpolation
        # history shape: (buf_len, in_features)
        # We need history[idx[j,i], i] for each (j, i)
        hist_floor = self._history[idx_floor, torch.arange(self.in_features).unsqueeze(0)]
        hist_ceil = self._history[idx_ceil, torch.arange(self.in_features).unsqueeze(0)]
        delayed_x = (1 - frac) * hist_floor + frac * hist_ceil

        # Weighted sum: out[j] = sum_i W[j,i] * delayed_x[j,i]
        output = (self.weight * delayed_x).sum(dim=1)
        if self.bias is not None:
            output = output + self.bias

        self._t += 1

        # Broadcast to batch
        output = output.unsqueeze(0).expand(batch_size, -1)
        if squeeze:
            output = output.squeeze(0)
        return output

    @property
    def delays_int(self) -> torch.Tensor:
        """Quantized integer delays for hardware export."""
        with torch.no_grad():
            return self.delay.clamp(0, self.max_delay).round().long()

    def to_nir_delay_array(self):
        """Export delays as flat array matching CSR data order.

        Returns delays in the same order as weights flattened row-major,
        suitable for ``Projection(delay=array)``.
        """
        import numpy as np

        return self.delays_int.detach().cpu().numpy().flatten().astype(np.float64)

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"max_delay={self.max_delay}, learn_delay={isinstance(self.delay, nn.Parameter)}"
        )

delays_int property

Quantized integer delays for hardware export.

reset()

Clear spike history. Call between sequences.

Source code in src/sc_neurocore/training/delay_linear.py
94
95
96
97
def reset(self):
    """Clear spike history. Call between sequences."""
    self._history.zero_()
    self._t = 0

step(x)

Process one timestep.

Parameters

x : Tensor of shape (batch, in_features) or (in_features,) Input spikes (binary or continuous).

Returns

Tensor of shape (batch, out_features) or (out_features,) Weighted, delayed input current.

Source code in src/sc_neurocore/training/delay_linear.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def step(self, x: torch.Tensor) -> torch.Tensor:
    """Process one timestep.

    Parameters
    ----------
    x : Tensor of shape (batch, in_features) or (in_features,)
        Input spikes (binary or continuous).

    Returns
    -------
    Tensor of shape (batch, out_features) or (out_features,)
        Weighted, delayed input current.
    """
    squeeze = x.dim() == 1
    if squeeze:
        x = x.unsqueeze(0)

    batch_size = x.shape[0]
    buf_len = self.max_delay + 1

    # Store current input in history (use first batch element for buffer)
    write_idx = self._t % buf_len
    self._history[write_idx] = x[0].detach()

    # Clamp delays to valid range
    d = self.delay.clamp(0, self.max_delay - 1e-6)

    # Integer floor and ceil indices
    d_floor = d.long()
    d_ceil = (d_floor + 1).clamp(max=self.max_delay)
    frac = d - d_floor.float()

    # Read from history at delayed positions
    # idx_floor[j, i] = (current_t - d_floor[j, i]) % buf_len
    idx_floor = (self._t - d_floor) % buf_len
    idx_ceil = (self._t - d_ceil) % buf_len

    # Gather delayed spikes via linear interpolation
    # history shape: (buf_len, in_features)
    # We need history[idx[j,i], i] for each (j, i)
    hist_floor = self._history[idx_floor, torch.arange(self.in_features).unsqueeze(0)]
    hist_ceil = self._history[idx_ceil, torch.arange(self.in_features).unsqueeze(0)]
    delayed_x = (1 - frac) * hist_floor + frac * hist_ceil

    # Weighted sum: out[j] = sum_i W[j,i] * delayed_x[j,i]
    output = (self.weight * delayed_x).sum(dim=1)
    if self.bias is not None:
        output = output + self.bias

    self._t += 1

    # Broadcast to batch
    output = output.unsqueeze(0).expand(batch_size, -1)
    if squeeze:
        output = output.squeeze(0)
    return output

to_nir_delay_array()

Export delays as flat array matching CSR data order.

Returns delays in the same order as weights flattened row-major, suitable for Projection(delay=array).

Source code in src/sc_neurocore/training/delay_linear.py
162
163
164
165
166
167
168
169
170
def to_nir_delay_array(self):
    """Export delays as flat array matching CSR data order.

    Returns delays in the same order as weights flattened row-major,
    suitable for ``Projection(delay=array)``.
    """
    import numpy as np

    return self.delays_int.detach().cpu().numpy().flatten().astype(np.float64)