Skip to content

Tutorial 82: Spiking Neural ODEs

Continuous-depth SNN layer: adaptive ODE solver with event-driven spike detection. Takes large steps when membrane is far from threshold, bisects on crossings for sub-timestep precision. No other library has this as a reusable layer.

How It Works

  1. Large steps when membrane is far from threshold
  2. Shrink step size near threshold crossings
  3. Bisection to find exact spike times
  4. Reset and continue after spike emission

SpikingODELayer

import numpy as np
from sc_neurocore.spike_ode import SpikingODELayer, ODELIFDynamics

dynamics = ODELIFDynamics(
    tau_mem=20.0, v_rest=0.0, v_threshold=1.0, v_reset=0.0,
)

layer = SpikingODELayer(
    n_inputs=32, n_neurons=16,
    dynamics=dynamics, dt_init=0.1, dt_min=0.001,
)

inputs = np.random.randn(100, 32) * 0.5
spike_counts = layer.forward(inputs, interval=1.0)
# shape: (100, 16)
print(f"Total spikes: {int(spike_counts.sum())}")

Single-Step (Online)

layer.reset()
x = np.random.randn(32) * 0.5
counts = layer.step(x, interval=1.0)
print(f"Voltage: {layer.voltage}")

Adaptive stepping gives ~5x speedup over fixed small-step integration.

API Reference

sc_neurocore.spike_ode.ode_layer

Continuous-depth SNN layer combining ODE solver with spike events.

Solves the LIF membrane ODE continuously, detects threshold crossings as events, emits spikes, resets, continues. Adaptive step-size Euler with event detection.

The frontier intersection of Neural ODEs + SNNs. No library has this as a reusable layer.

Reference: EventProp (Wunderlich & Pehle 2021)

ODELIFDynamics dataclass

LIF membrane ODE dynamics.

dv/dt = -(v - v_rest) / tau_mem + I(t) / C_mem

Parameters

tau_mem : float Membrane time constant (ms). v_rest : float v_threshold : float v_reset : float C_mem : float Membrane capacitance (normalized).

Source code in src/sc_neurocore/spike_ode/ode_layer.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass
class ODELIFDynamics:
    """LIF membrane ODE dynamics.

    dv/dt = -(v - v_rest) / tau_mem + I(t) / C_mem

    Parameters
    ----------
    tau_mem : float
        Membrane time constant (ms).
    v_rest : float
    v_threshold : float
    v_reset : float
    C_mem : float
        Membrane capacitance (normalized).
    """

    tau_mem: float = 20.0
    v_rest: float = 0.0
    v_threshold: float = 1.0
    v_reset: float = 0.0
    C_mem: float = 1.0

    def dvdt(self, v: np.ndarray, I: np.ndarray) -> np.ndarray:
        """Compute membrane voltage derivative."""
        return -(v - self.v_rest) / self.tau_mem + I / self.C_mem

dvdt(v, I)

Compute membrane voltage derivative.

Source code in src/sc_neurocore/spike_ode/ode_layer.py
50
51
52
def dvdt(self, v: np.ndarray, I: np.ndarray) -> np.ndarray:
    """Compute membrane voltage derivative."""
    return -(v - self.v_rest) / self.tau_mem + I / self.C_mem

SpikingODELayer

Spiking Neural ODE layer with event-driven integration.

Integrates the membrane ODE with adaptive Euler stepping. Detects threshold crossings via bisection, emits spikes, resets.

Parameters

n_inputs : int n_neurons : int dynamics : ODELIFDynamics dt_init : float Initial integration step size. dt_min : float Minimum step size. max_steps_per_interval : int Max ODE steps per simulation interval. seed : int

Source code in src/sc_neurocore/spike_ode/ode_layer.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class SpikingODELayer:
    """Spiking Neural ODE layer with event-driven integration.

    Integrates the membrane ODE with adaptive Euler stepping.
    Detects threshold crossings via bisection, emits spikes, resets.

    Parameters
    ----------
    n_inputs : int
    n_neurons : int
    dynamics : ODELIFDynamics
    dt_init : float
        Initial integration step size.
    dt_min : float
        Minimum step size.
    max_steps_per_interval : int
        Max ODE steps per simulation interval.
    seed : int
    """

    def __init__(
        self,
        n_inputs: int,
        n_neurons: int,
        dynamics: ODELIFDynamics | None = None,
        dt_init: float = 0.1,
        dt_min: float = 0.001,
        max_steps_per_interval: int = 100,
        seed: int = 42,
    ):
        self.n_inputs = n_inputs
        self.n_neurons = n_neurons
        self.dynamics = dynamics or ODELIFDynamics()
        self.dt_init = dt_init
        self.dt_min = dt_min
        self.max_steps = max_steps_per_interval

        rng = np.random.RandomState(seed)
        self.W = rng.randn(n_neurons, n_inputs) * np.sqrt(2.0 / n_inputs)
        self._v = np.full(n_neurons, self.dynamics.v_rest)

    def step(self, x: np.ndarray, interval: float = 1.0) -> np.ndarray:
        """Integrate ODE over one interval, return spike counts.

        Parameters
        ----------
        x : ndarray of shape (n_inputs,)
            Input (constant over interval).
        interval : float
            Duration of this interval (ms).

        Returns
        -------
        ndarray of shape (n_neurons,)
            Spike count per neuron during interval.
        """
        I = self.W @ x
        spike_counts = np.zeros(self.n_neurons)
        t = 0.0
        dt = self.dt_init
        steps = 0

        while t < interval and steps < self.max_steps:
            dt = min(dt, interval - t)
            if dt < self.dt_min:
                break

            # Euler step
            dv = self.dynamics.dvdt(self._v, I)
            v_new = self._v + dt * dv

            # Event detection: threshold crossing
            crossed = v_new >= self.dynamics.v_threshold
            if crossed.any():
                # Bisection to find exact crossing time
                for _ in range(5):  # 5 bisection steps
                    dt_half = dt / 2
                    v_mid = self._v + dt_half * dv
                    still_crossed = v_mid >= self.dynamics.v_threshold
                    if still_crossed.any():
                        dt = dt_half
                        v_new = v_mid
                    else:
                        break

                spike_counts[crossed] += 1
                v_new[crossed] = self.dynamics.v_reset

            self._v = v_new
            t += dt
            steps += 1

            # Adaptive step: increase if no spikes, decrease near threshold
            distance_to_thresh = self.dynamics.v_threshold - self._v
            min_dist = distance_to_thresh.min()
            if min_dist < 0.1 * self.dynamics.v_threshold:
                dt = max(dt * 0.5, self.dt_min)
            else:
                dt = min(dt * 1.5, self.dt_init)

        return spike_counts

    def forward(self, inputs: np.ndarray, interval: float = 1.0) -> np.ndarray:
        """Process a sequence of inputs.

        Parameters
        ----------
        inputs : ndarray of shape (T, n_inputs)
        interval : float
            Duration per input step.

        Returns
        -------
        ndarray of shape (T, n_neurons), spike counts per interval
        """
        self.reset()
        T = inputs.shape[0]
        outputs = np.zeros((T, self.n_neurons))
        for t in range(T):
            outputs[t] = self.step(inputs[t], interval)
        return outputs

    def reset(self):
        self._v = np.full(self.n_neurons, self.dynamics.v_rest)

    @property
    def voltage(self) -> np.ndarray:
        return self._v.copy()

step(x, interval=1.0)

Integrate ODE over one interval, return spike counts.

Parameters

x : ndarray of shape (n_inputs,) Input (constant over interval). interval : float Duration of this interval (ms).

Returns

ndarray of shape (n_neurons,) Spike count per neuron during interval.

Source code in src/sc_neurocore/spike_ode/ode_layer.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def step(self, x: np.ndarray, interval: float = 1.0) -> np.ndarray:
    """Integrate ODE over one interval, return spike counts.

    Parameters
    ----------
    x : ndarray of shape (n_inputs,)
        Input (constant over interval).
    interval : float
        Duration of this interval (ms).

    Returns
    -------
    ndarray of shape (n_neurons,)
        Spike count per neuron during interval.
    """
    I = self.W @ x
    spike_counts = np.zeros(self.n_neurons)
    t = 0.0
    dt = self.dt_init
    steps = 0

    while t < interval and steps < self.max_steps:
        dt = min(dt, interval - t)
        if dt < self.dt_min:
            break

        # Euler step
        dv = self.dynamics.dvdt(self._v, I)
        v_new = self._v + dt * dv

        # Event detection: threshold crossing
        crossed = v_new >= self.dynamics.v_threshold
        if crossed.any():
            # Bisection to find exact crossing time
            for _ in range(5):  # 5 bisection steps
                dt_half = dt / 2
                v_mid = self._v + dt_half * dv
                still_crossed = v_mid >= self.dynamics.v_threshold
                if still_crossed.any():
                    dt = dt_half
                    v_new = v_mid
                else:
                    break

            spike_counts[crossed] += 1
            v_new[crossed] = self.dynamics.v_reset

        self._v = v_new
        t += dt
        steps += 1

        # Adaptive step: increase if no spikes, decrease near threshold
        distance_to_thresh = self.dynamics.v_threshold - self._v
        min_dist = distance_to_thresh.min()
        if min_dist < 0.1 * self.dynamics.v_threshold:
            dt = max(dt * 0.5, self.dt_min)
        else:
            dt = min(dt * 1.5, self.dt_init)

    return spike_counts

forward(inputs, interval=1.0)

Process a sequence of inputs.

Parameters

inputs : ndarray of shape (T, n_inputs) interval : float Duration per input step.

Returns

ndarray of shape (T, n_neurons), spike counts per interval

Source code in src/sc_neurocore/spike_ode/ode_layer.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def forward(self, inputs: np.ndarray, interval: float = 1.0) -> np.ndarray:
    """Process a sequence of inputs.

    Parameters
    ----------
    inputs : ndarray of shape (T, n_inputs)
    interval : float
        Duration per input step.

    Returns
    -------
    ndarray of shape (T, n_neurons), spike counts per interval
    """
    self.reset()
    T = inputs.shape[0]
    outputs = np.zeros((T, self.n_neurons))
    for t in range(T):
        outputs[t] = self.step(inputs[t], interval)
    return outputs