Skip to content

Training API Reference

GPU-accelerated SNN training with surrogate gradients and SC bitstream export.

Install: pip install sc-neurocore[training] (adds PyTorch ≥ 2.0)

from sc_neurocore.training import LIFCell, SpikingNet, train_epoch, evaluate

All modules are torch.nn.Module subclasses. Train with standard PyTorch optimizers and loss functions, then export weights to stochastic computing bitstreams via to_sc_weights().

End-to-end pipeline

flowchart LR
    A[Raw Data<br/>images, audio, events] --> B[Spike Encoder<br/>rate / latency / delta]
    B --> C[SpikingNet<br/>Linear → LIFCell × N]
    C --> D{Loss + Backward<br/>surrogate gradient}
    D --> |optimizer.step| C
    D --> E[to_sc_weights<br/>normalize to 0,1]
    E --> F[SC Bitstream<br/>SCDenseLayer]
    E --> G[Verilog RTL<br/>equation compiler]
    G --> H[FPGA Bitstream<br/>Yosys + nextpnr]

    style A fill:#e1f5fe
    style C fill:#fff3e0
    style E fill:#e8f5e9
    style H fill:#fce4ec

Module hierarchy

classDiagram
    class nn_Module["torch.nn.Module"]

    class LIFCell {
        beta: float
        threshold: float
        surrogate_fn: Callable
        forward(current, v) → spike, v
    }
    class IFCell {
        forward(current, v) → spike, v
    }
    class SynapticCell {
        alpha: float
        forward(current, i_syn, v) → spike, i_syn, v
    }
    class ALIFCell {
        rho: float
        beta_adapt: float
        forward(current, v, a) → spike, v, a
    }
    class ExpIFCell {
        delta_t: float
        v_rh: float
        forward(current, v) → spike, v
    }
    class AdExCell {
        a: float
        b: float
        forward(current, v, w) → spike, v, w
    }
    class SpikingNet {
        linears: ModuleList
        lifs: ModuleList
        forward(x) → spike_counts, mem_acc
        to_sc_weights() → List
    }
    class ConvSpikingNet {
        conv1, conv2: Conv2d
        fc1, fc2: Linear
        forward(x) → spike_counts, mem_acc
    }
    class DelayLinear {
        delay: Parameter
        step(x) → current
        delays_int → LongTensor
    }

    nn_Module <|-- LIFCell
    nn_Module <|-- IFCell
    nn_Module <|-- SynapticCell
    nn_Module <|-- ALIFCell
    nn_Module <|-- ExpIFCell
    nn_Module <|-- AdExCell
    nn_Module <|-- SpikingNet
    nn_Module <|-- ConvSpikingNet
    nn_Module <|-- DelayLinear

Surrogate Gradient Functions

Surrogate gradients solve the non-differentiability of spike generation. Forward pass: Heaviside step (x > 0) → {0, 1}. Backward pass: smooth approximation of the Dirac delta. All functions expect pre-shifted input x = v - threshold.

flowchart LR
    subgraph Forward["Forward Pass"]
        direction TB
        F1["v - threshold"] --> F2{"x > 0 ?"}
        F2 -->|Yes| F3["spike = 1"]
        F2 -->|No| F4["spike = 0"]
    end
    subgraph Backward["Backward Pass (surrogate)"]
        direction TB
        B1["grad_output"] --> B2["× surrogate'(x)"]
        B2 --> B3["grad_input"]
    end
    Forward -.->|"gradient path<br/>(smooth approx)"| Backward

    style Forward fill:#e8f5e9
    style Backward fill:#fff3e0

Surrogate gradient shapes (backward pass — gradient magnitude vs distance from threshold):

    gradient
    ▲
1.0 │  ╱╲   atan (wide, stable)
    │ ╱  ╲
0.5 │╱    ╲╱╲  fast_sigmoid (sharp, fast)
    │       ╲
0.0 │────────╲─────────────► x = v - threshold
   -3  -2  -1   0   1   2   3

    Wider gradient window = more neurons receive learning signal
    Narrower = sharper threshold, faster convergence but less stable
Function Backward formula Default param Citation
atan_surrogate(x, alpha=2.0) α / (2(1 + (παx/2)²)) α=2.0 Fang et al. 2021
fast_sigmoid(x, slope=25.0) slope / (1 + slope·|x|)² slope=25.0 Zenke & Vogels 2021
superspike(x, beta=10.0) 1 / (1 + β·|x|)² β=10.0 Zenke & Ganguli 2018
sigmoid_surrogate(x, slope=5.0) slope · σ(sx)(1 - σ(sx)) slope=5.0 Standard
straight_through(x) 1 (identity) Bengio et al. 2013
triangular(x, width=1.0) max(0, 1 - |x|/w) / w width=1.0 Esser et al. 2016

Choosing a surrogate: atan_surrogate is the safest default — wide gradient window, stable convergence on most tasks. fast_sigmoid trains faster on deep networks (>3 spiking layers). superspike gives the sharpest gradients near threshold — useful for temporal coding but requires lower learning rates. straight_through passes gradients unchanged — works for simple architectures but is theoretically unprincipled.

from sc_neurocore.training import atan_surrogate, fast_sigmoid, superspike
from sc_neurocore.training import sigmoid_surrogate, straight_through, triangular

# All share the same signature
x = torch.tensor([-0.5, 0.0, 0.5], requires_grad=True)
spike = atan_surrogate(x)  # tensor([0., 0., 1.])
spike.sum().backward()      # x.grad is smooth, nonzero everywhere

Neuron Cells

All cells are torch.nn.Module instances. Forward pass takes input current and hidden state(s), returns (spike, *new_states). Spikes are {0, 1} tensors.

LIFCell

Leaky Integrate-and-Fire. The workhorse spiking neuron.

v[t] = beta * v[t-1] + I[t]
spike[t] = H(v[t] - threshold)
v[t] -= spike[t] * threshold
from sc_neurocore.training import LIFCell

cell = LIFCell(
    beta=0.9,              # membrane leak (higher = longer memory)
    threshold=1.0,         # spike threshold
    surrogate_fn=atan_surrogate,
    learn_beta=False,      # True → beta becomes a trainable parameter
    learn_threshold=False, # True → threshold becomes trainable
)

# Single-step forward
current = torch.randn(batch, n_neurons)
v = torch.zeros(batch, n_neurons)
spike, v_next = cell(current, v)

When learn_beta=True, beta is stored as log(p/(1-p)) (logit) and projected through sigmoid to stay in (0, 1). When learn_threshold=True, threshold is stored as log(threshold) and projected through exp to stay positive.

IFCell

Integrate-and-Fire without leak (beta = 1). Accumulates input until threshold. Simplest spiking model — useful for energy estimation and spike counting tasks.

from sc_neurocore.training import IFCell
cell = IFCell(threshold=1.0)
spike, v_next = cell(current, v)  # v_next = v + current (no decay)

SynapticCell

Dual-exponential synaptic current + membrane. Two state variables provide more realistic temporal filtering of synaptic input.

i_syn[t] = alpha * i_syn[t-1] + I[t]
v[t] = beta * v[t-1] + i_syn[t]
from sc_neurocore.training import SynapticCell
cell = SynapticCell(alpha=0.9, beta=0.8, threshold=1.0)
spike, i_syn_next, v_next = cell(current, i_syn, v)

ALIFCell

Adaptive LIF (Bellec et al., 2020). Threshold increases after each spike, implementing spike-frequency adaptation — the network learns when to suppress firing.

v[t] = beta * v[t-1] + I[t]
theta[t] = theta_0 + beta_adapt * a[t]
a[t] = rho * a[t-1] + spike[t-1]
from sc_neurocore.training import ALIFCell
cell = ALIFCell(beta=0.9, threshold=1.0, rho=0.99, beta_adapt=1.8)
spike, v_next, a_next = cell(current, v, a)

The adaptation variable a tracks recent spiking history. rho controls how quickly adaptation decays (0.99 = slow adaptation, 0.9 = fast). beta_adapt scales the threshold shift.

ExpIFCell

Exponential IF (Fourcaud-Trocmé et al., 2003). An exponential term creates a sharp voltage upstroke near threshold, modelling the sodium channel activation in cortical neurons.

v[t] = beta * v[t-1] + delta_T * exp((v[t-1] - v_rh) / delta_T) + I[t]
from sc_neurocore.training import ExpIFCell
cell = ExpIFCell(beta=0.9, threshold=1.0, delta_t=0.5, v_rh=0.8)
spike, v_next = cell(current, v)

delta_t controls the sharpness of the upstroke. v_rh is the rheobase (voltage where exponential term activates). The exp term is clamped at 5.0 to prevent numerical overflow.

AdExCell

Adaptive Exponential IF (Brette & Gerstner, 2005). Combines the exponential upstroke with an adaptation current w that modulates firing patterns. Can reproduce tonic, adapting, bursting, and irregular spiking.

v[t] = beta * v[t-1] + delta_T * exp((v - v_rh) / delta_T) - w[t-1] + I[t]
w[t] = rho * w[t-1] + a * (v[t-1] - v_rest) + b * spike[t]
from sc_neurocore.training import AdExCell
cell = AdExCell(beta=0.9, threshold=1.0, delta_t=0.5, v_rh=0.8,
                a=0.01, b=0.1, rho=0.99, v_rest=0.0)
spike, v_next, w_next = cell(current, v, w)

a couples membrane voltage to adaptation. b controls the spike-triggered adaptation increment. Together they determine the neuron's firing pattern class.

LapicqueCell

Lapicque IF with membrane resistance (Lapicque, 1907). The original integrate-and-fire model with explicit RC circuit parameters.

v[t] = (1 - dt/tau) * (v[t-1] - v_rest) + v_rest + (R * dt / tau) * I[t]
from sc_neurocore.training import LapicqueCell
cell = LapicqueCell(tau=20.0, r=1.0, dt=1.0, threshold=1.0, v_rest=0.0)
spike, v_next = cell(current, v)

tau is the membrane time constant (ms). r is the membrane resistance (MΩ). dt is the simulation timestep.

AlphaCell

Alpha synapse neuron (Rall, 1967). Separate excitatory and inhibitory synaptic currents with independent time constants. Models the biological separation of glutamatergic and GABAergic synapses.

i_exc[t] = alpha_exc * i_exc[t-1] + I_exc[t]
i_inh[t] = alpha_inh * i_inh[t-1] + I_inh[t]
v[t] = beta * v[t-1] + i_exc[t] - i_inh[t]
from sc_neurocore.training import AlphaCell
cell = AlphaCell(alpha_exc=0.9, alpha_inh=0.85, beta=0.9)
spike, i_exc_next, i_inh_next, v_next = cell(exc_current, inh_current, i_exc, i_inh, v)

SecondOrderLIFCell

Second-order LIF with inertial acceleration term (Dayan & Abbott, 2001). The acceleration a acts as a low-pass filter that smooths input current before reaching the membrane, producing smoother voltage trajectories.

a[t] = alpha * a[t-1] + I[t]
v[t] = beta * v[t-1] + a[t]
from sc_neurocore.training import SecondOrderLIFCell
cell = SecondOrderLIFCell(alpha=0.95, beta=0.9)
spike, a_next, v_next = cell(current, a, v)

RecurrentLIFCell

LIF with trainable recurrent weights. An orthogonal-initialized nn.Linear feeds previous spikes back as additional input.

from sc_neurocore.training import RecurrentLIFCell
cell = RecurrentLIFCell(n_neurons=128, beta=0.9)
spike, v_next = cell(current, v, spike_prev)

Recurrence adds temporal context without increasing timesteps. Useful for sequence classification (speech, gestures).


Network Architectures

SpikingNet

Multi-layer feedforward SNN: [Linear → LIFCell] × (n_layers + 1). Readout accumulates output spike counts and membrane potential over T timesteps.

flowchart LR
    subgraph Input
        X["x<br/>(T, batch, 784)"]
    end
    subgraph Hidden["Hidden Layers × n_layers"]
        L1[Linear<br/>784→128] --> LIF1[LIFCell<br/>β=0.9]
        LIF1 -->|spikes| L2[Linear<br/>128→128]
        L2 --> LIF2[LIFCell<br/>β=0.9]
    end
    subgraph Output
        L3[Linear<br/>128→10] --> LIF3[LIFCell<br/>β=0.9]
        LIF3 -->|accumulate T steps| SC["spike_counts<br/>(batch, 10)"]
    end
    X --> L1
    LIF2 -->|spikes| L3

    style Input fill:#e1f5fe
    style Hidden fill:#fff3e0
    style Output fill:#e8f5e9
from sc_neurocore.training import SpikingNet

net = SpikingNet(
    n_input=784,     # flattened 28×28 MNIST
    n_hidden=128,    # hidden layer width
    n_output=10,     # classes
    n_layers=2,      # number of hidden layers
    beta=0.9,
    surrogate_fn=atan_surrogate,
    learn_beta=False,
    learn_threshold=False,
)

# Forward: x is (T, batch, n_input) → (spike_counts, membrane_acc)
x = torch.randn(25, 64, 784)  # T=25, batch=64
spike_counts, mem_acc = net(x)
predicted = spike_counts.argmax(dim=1)  # (64,)

to_sc_weights(include_bias=True)

Export trained weights to [0, 1] range for stochastic computing bitstream deployment. Each layer's weight matrix is min-max normalized independently.

sc_layers = net.to_sc_weights()
for i, layer in enumerate(sc_layers):
    w = layer["weight"]  # Tensor, values in [0, 1]
    print(f"Layer {i}: {tuple(w.shape)}, range [{w.min():.3f}, {w.max():.3f}]")
    if "bias" in layer:
        print(f"  bias: {tuple(layer['bias'].shape)}")

These weights map directly to bitstream probabilities in SCDenseLayer and the equation compiler's Verilog RTL.

ConvSpikingNet

Convolutional SNN for image classification:

flowchart LR
    I["Input<br/>28×28×1"] --> C1["Conv2d<br/>1→32, 5×5"]
    C1 --> S1["LIF<br/>24×24×32"]
    S1 --> P1["AvgPool<br/>12×12×32"]
    P1 --> C2["Conv2d<br/>32→64, 5×5"]
    C2 --> S2["LIF<br/>8×8×64"]
    S2 --> P2["AvgPool<br/>4×4×64"]
    P2 --> FL["Flatten<br/>1024"]
    FL --> F1["Linear<br/>1024→128"]
    F1 --> S3["LIF"]
    S3 --> F2["Linear<br/>128→10"]
    F2 --> S4["LIF"]
    S4 --> O["spike_counts<br/>(batch, 10)"]

    style I fill:#e1f5fe
    style O fill:#e8f5e9
from sc_neurocore.training import ConvSpikingNet

net = ConvSpikingNet(
    n_output=10,
    beta=0.9,
    learn_beta=True,
    learn_threshold=True,
)

# Forward: x is (T, batch, 1, 28, 28)
x = torch.randn(25, 32, 1, 28, 28)
spike_counts, mem_acc = net(x)

Designed for 28×28 grayscale images (MNIST, Fashion-MNIST). For other input sizes, modify self.fc1 input dimension.


Spike Encoding

Convert continuous values to binary spike trains for SNN input.

rate_encode

Poisson rate coding. Each timestep, spike with probability proportional to input value. Works for any static data.

from sc_neurocore.training import rate_encode

x = torch.rand(64, 784)         # batch of images in [0, 1]
spikes = rate_encode(x, n_timesteps=25)  # (25, 64, 784)

latency_encode

Time-to-first-spike. Stronger inputs spike earlier. Each neuron fires exactly once. Information is in spike timing, not rate.

from sc_neurocore.training import latency_encode

spikes = latency_encode(x, n_timesteps=25, tau=5.0)
# Strong input (0.9) → spike at t≈0
# Weak input (0.1)  → spike at t≈4

delta_encode

Spike on temporal change. For event-based or streaming data where change matters more than absolute value.

from sc_neurocore.training import delta_encode

# x: (T, *batch) — temporal sequence
spikes = delta_encode(x, threshold=0.1)  # spike where |dx| > 0.1

Loss Functions

Classification losses

from sc_neurocore.training import spike_count_loss, membrane_loss, spike_rate_loss

# Cross-entropy on spike counts (default, recommended)
loss = spike_count_loss(spike_counts, targets)

# Cross-entropy on accumulated membrane potential
loss = membrane_loss(mem_acc, targets)

# MSE on spike rates vs target pattern
loss = spike_rate_loss(spike_counts, targets, n_timesteps=25, target_rate=0.8)

spike_count_loss is recommended for most tasks. membrane_loss can work better when spike counts are very sparse (few spikes per class). spike_rate_loss explicitly shapes firing rates — useful when you need precise rate control for SC bitstream deployment.

Regularizers

from sc_neurocore.training import spike_l1_loss, spike_l2_loss

# L1 on mean firing rate — encourages sparse firing
reg = spike_l1_loss(spike_counts, n_timesteps=25)

# L2 on mean firing rate — penalizes high-firing outliers
reg = spike_l2_loss(spike_counts, n_timesteps=25)

# Combined
total_loss = spike_count_loss(spike_counts, targets) + 0.01 * spike_l1_loss(spike_counts, 25)

Training and Evaluation

auto_device

Select the best available device: CUDA → MPS → CPU.

from sc_neurocore.training import auto_device
device = auto_device()  # torch.device('cuda'), ('mps'), or ('cpu')

train_epoch

One full pass through the training set. Handles timestep expansion, loss computation, gradient clipping.

from sc_neurocore.training import train_epoch

avg_loss, accuracy = train_epoch(
    model,
    train_loader,
    optimizer,
    n_timesteps=25,
    loss_fn=spike_count_loss,  # any (spike_counts, targets) → scalar
    device="cuda",
    max_grad_norm=1.0,         # None to disable clipping
    flatten_input=True,        # False for ConvSpikingNet
)

flatten_input=True reshapes (batch, C, H, W) to (batch, C*H*W) for feedforward SNNs. Set to False for convolutional models.

evaluate

Same as train_epoch but with torch.no_grad() and model.eval().

from sc_neurocore.training import evaluate

val_loss, val_acc = evaluate(model, test_loader, n_timesteps=25, device="cuda")

Utilities

SpikeMonitor

Record spike activity per layer during forward pass. Attach to any SpikingNet or ConvSpikingNet.

from sc_neurocore.training import SpikeMonitor

monitor = SpikeMonitor(model)
spike_counts, mem = model(x)

# Retrieve spikes for a specific layer
for name in monitor.layer_names:
    spikes = monitor.get(name)  # (T, batch, n_neurons) or None
    print(f"{name}: {spikes.shape if spikes is not None else 'no spikes'}")

monitor.reset()   # clear recorded data (keep hooks)
monitor.remove()  # remove hooks entirely

model_info

Architecture summary for SNN models.

from sc_neurocore.training import model_info

info = model_info(model)
# {
#   "total_params": 134922,
#   "trainable_params": 134922,
#   "spiking_cells": 3,
#   "cell_types": ["LIFCell"],
#   "learnable_dynamics": ["lifs.0._beta_logit", ...]
# }

population_decode

Weighted average readout instead of argmax. Computes the centroid of spike-count distribution over preferred values.

from sc_neurocore.training import population_decode

# Decode angle from population of 36 neurons (10° spacing)
preferred = torch.arange(0, 360, 10, dtype=torch.float32)
angle = population_decode(spike_counts, preferred)  # (batch,)

reset_states

Clear all SpikeMonitor instances in a list.

from sc_neurocore.training import reset_states
reset_states([monitor1, monitor2])

DelayLinear

Trainable per-synapse delays for temporal coding. Each synapse has a weight AND a continuous-valued delay, both optimized by backpropagation.

from sc_neurocore.training import DelayLinear

layer = DelayLinear(
    in_features=64,
    out_features=32,
    max_delay=16,      # delays in [0, 16) timesteps
    learn_delay=True,  # make delays trainable
    init_delay=1.0,    # initial delay for all synapses
    bias=False,
)

Usage: call step(x) at each timestep. The module maintains an internal spike history buffer.

layer.reset()  # clear history between sequences
for t in range(T):
    x = input_spikes[t]          # (batch, in_features)
    current = layer.step(x)      # (batch, out_features)

Hardware export:

int_delays = layer.delays_int          # quantized (out, in) integers
nir_array = layer.to_nir_delay_array() # flat float64 for Projection(delay=...)

See Tutorial 39: Learnable Delays.

sc_neurocore.training.delay_linear

Learnable-delay linear layer for surrogate gradient SNN training.

Implements the DCLS (Dilated Convolutions with Learnable Spacings) principle applied to fully-connected SNN layers. Each synapse has a trainable weight AND a trainable delay. During forward pass, the input spike history is queried at fractional delay positions via linear interpolation, making delays differentiable.

References

Hammouamri et al. 2023 — "Learning Delays in SNNs using DCLS" Göltz et al. 2025 — "DelGrad: exact gradients for delays on BrainScaleS-2"

DelayLinear

Bases: Module

Linear layer with trainable per-synapse delays.

Parameters

in_features : int Number of input neurons. out_features : int Number of output neurons. max_delay : int Maximum delay in timesteps. Delay buffer stores this many steps. bias : bool Include bias term (default False for SNN). learn_delay : bool Make delays trainable (default True). init_delay : float Initial delay value for all synapses (default 1.0).

Forward pass

Call step(input_spikes) at each timestep. The module maintains an internal spike history buffer. For each synapse (i, j):

delayed_input[j] = sum_i W[j,i] * interp(history, t - D[j,i])

where interp linearly interpolates between integer delay bins, making D differentiable.

Export

delays_int returns quantized integer delays for hardware deployment. to_nir_delay_array() returns delays in the CSR format expected by Projection(delay=array).

Source code in src/sc_neurocore/training/delay_linear.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class DelayLinear(nn.Module):
    """Linear layer with trainable per-synapse delays.

    Parameters
    ----------
    in_features : int
        Number of input neurons.
    out_features : int
        Number of output neurons.
    max_delay : int
        Maximum delay in timesteps. Delay buffer stores this many steps.
    bias : bool
        Include bias term (default False for SNN).
    learn_delay : bool
        Make delays trainable (default True).
    init_delay : float
        Initial delay value for all synapses (default 1.0).

    Forward pass
    -------------
    Call ``step(input_spikes)`` at each timestep. The module maintains
    an internal spike history buffer. For each synapse (i, j):

        delayed_input[j] = sum_i W[j,i] * interp(history, t - D[j,i])

    where interp linearly interpolates between integer delay bins,
    making D differentiable.

    Export
    ------
    ``delays_int`` returns quantized integer delays for hardware deployment.
    ``to_nir_delay_array()`` returns delays in the CSR format expected by
    ``Projection(delay=array)``.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        max_delay: int = 16,
        bias: bool = False,
        learn_delay: bool = True,
        init_delay: float = 1.0,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.max_delay = max_delay

        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        # Delays stored as continuous values in [0, max_delay)
        delay_init = torch.full((out_features, in_features), init_delay)
        if learn_delay:
            self.delay = nn.Parameter(delay_init)
        else:
            self.register_buffer("delay", delay_init)

        # Spike history buffer: (max_delay + 1, in_features)
        self.register_buffer("_history", torch.zeros(max_delay + 1, in_features))
        self._t = 0

    def reset(self):
        """Clear spike history. Call between sequences."""
        self._history.zero_()
        self._t = 0

    def step(self, x: torch.Tensor) -> torch.Tensor:
        """Process one timestep.

        Parameters
        ----------
        x : Tensor of shape (batch, in_features) or (in_features,)
            Input spikes (binary or continuous).

        Returns
        -------
        Tensor of shape (batch, out_features) or (out_features,)
            Weighted, delayed input current.
        """
        squeeze = x.dim() == 1
        if squeeze:
            x = x.unsqueeze(0)

        batch_size = x.shape[0]
        buf_len = self.max_delay + 1

        # Store current input in history (use first batch element for buffer)
        write_idx = self._t % buf_len
        self._history[write_idx] = x[0].detach()

        # Clamp delays to valid range
        d = self.delay.clamp(0, self.max_delay - 1e-6)

        # Integer floor and ceil indices
        d_floor = d.long()
        d_ceil = (d_floor + 1).clamp(max=self.max_delay)
        frac = d - d_floor.float()

        # Read from history at delayed positions
        # idx_floor[j, i] = (current_t - d_floor[j, i]) % buf_len
        idx_floor = (self._t - d_floor) % buf_len
        idx_ceil = (self._t - d_ceil) % buf_len

        # Gather delayed spikes via linear interpolation
        # history shape: (buf_len, in_features)
        # We need history[idx[j,i], i] for each (j, i)
        hist_floor = self._history[idx_floor, torch.arange(self.in_features).unsqueeze(0)]
        hist_ceil = self._history[idx_ceil, torch.arange(self.in_features).unsqueeze(0)]
        delayed_x = (1 - frac) * hist_floor + frac * hist_ceil

        # Weighted sum: out[j] = sum_i W[j,i] * delayed_x[j,i]
        output = (self.weight * delayed_x).sum(dim=1)
        if self.bias is not None:
            output = output + self.bias

        self._t += 1

        # Broadcast to batch
        output = output.unsqueeze(0).expand(batch_size, -1)
        if squeeze:
            output = output.squeeze(0)
        return output

    @property
    def delays_int(self) -> torch.Tensor:
        """Quantized integer delays for hardware export."""
        with torch.no_grad():
            return self.delay.clamp(0, self.max_delay).round().long()

    def to_nir_delay_array(self):
        """Export delays as flat array matching CSR data order.

        Returns delays in the same order as weights flattened row-major,
        suitable for ``Projection(delay=array)``.
        """
        import numpy as np

        return self.delays_int.detach().cpu().numpy().flatten().astype(np.float64)

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"max_delay={self.max_delay}, learn_delay={isinstance(self.delay, nn.Parameter)}"
        )

delays_int property

Quantized integer delays for hardware export.

reset()

Clear spike history. Call between sequences.

Source code in src/sc_neurocore/training/delay_linear.py
94
95
96
97
def reset(self):
    """Clear spike history. Call between sequences."""
    self._history.zero_()
    self._t = 0

step(x)

Process one timestep.

Parameters

x : Tensor of shape (batch, in_features) or (in_features,) Input spikes (binary or continuous).

Returns

Tensor of shape (batch, out_features) or (out_features,) Weighted, delayed input current.

Source code in src/sc_neurocore/training/delay_linear.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def step(self, x: torch.Tensor) -> torch.Tensor:
    """Process one timestep.

    Parameters
    ----------
    x : Tensor of shape (batch, in_features) or (in_features,)
        Input spikes (binary or continuous).

    Returns
    -------
    Tensor of shape (batch, out_features) or (out_features,)
        Weighted, delayed input current.
    """
    squeeze = x.dim() == 1
    if squeeze:
        x = x.unsqueeze(0)

    batch_size = x.shape[0]
    buf_len = self.max_delay + 1

    # Store current input in history (use first batch element for buffer)
    write_idx = self._t % buf_len
    self._history[write_idx] = x[0].detach()

    # Clamp delays to valid range
    d = self.delay.clamp(0, self.max_delay - 1e-6)

    # Integer floor and ceil indices
    d_floor = d.long()
    d_ceil = (d_floor + 1).clamp(max=self.max_delay)
    frac = d - d_floor.float()

    # Read from history at delayed positions
    # idx_floor[j, i] = (current_t - d_floor[j, i]) % buf_len
    idx_floor = (self._t - d_floor) % buf_len
    idx_ceil = (self._t - d_ceil) % buf_len

    # Gather delayed spikes via linear interpolation
    # history shape: (buf_len, in_features)
    # We need history[idx[j,i], i] for each (j, i)
    hist_floor = self._history[idx_floor, torch.arange(self.in_features).unsqueeze(0)]
    hist_ceil = self._history[idx_ceil, torch.arange(self.in_features).unsqueeze(0)]
    delayed_x = (1 - frac) * hist_floor + frac * hist_ceil

    # Weighted sum: out[j] = sum_i W[j,i] * delayed_x[j,i]
    output = (self.weight * delayed_x).sum(dim=1)
    if self.bias is not None:
        output = output + self.bias

    self._t += 1

    # Broadcast to batch
    output = output.unsqueeze(0).expand(batch_size, -1)
    if squeeze:
        output = output.squeeze(0)
    return output

to_nir_delay_array()

Export delays as flat array matching CSR data order.

Returns delays in the same order as weights flattened row-major, suitable for Projection(delay=array).

Source code in src/sc_neurocore/training/delay_linear.py
162
163
164
165
166
167
168
169
170
def to_nir_delay_array(self):
    """Export delays as flat array matching CSR data order.

    Returns delays in the same order as weights flattened row-major,
    suitable for ``Projection(delay=array)``.
    """
    import numpy as np

    return self.delays_int.detach().cpu().numpy().flatten().astype(np.float64)