Tutorial 82: Spiking Neural ODEs
Continuous-depth SNN layer: adaptive ODE solver with event-driven spike
detection. Takes large steps when membrane is far from threshold, bisects
on crossings for sub-timestep precision. No other library has this as a
reusable layer.
How It Works
- Large steps when membrane is far from threshold
- Shrink step size near threshold crossings
- Bisection to find exact spike times
- Reset and continue after spike emission
SpikingODELayer
import numpy as np
from sc_neurocore.spike_ode import SpikingODELayer, ODELIFDynamics
dynamics = ODELIFDynamics(
tau_mem=20.0, v_rest=0.0, v_threshold=1.0, v_reset=0.0,
)
layer = SpikingODELayer(
n_inputs=32, n_neurons=16,
dynamics=dynamics, dt_init=0.1, dt_min=0.001,
)
inputs = np.random.randn(100, 32) * 0.5
spike_counts = layer.forward(inputs, interval=1.0)
# shape: (100, 16)
print(f"Total spikes: {int(spike_counts.sum())}")
Single-Step (Online)
layer.reset()
x = np.random.randn(32) * 0.5
counts = layer.step(x, interval=1.0)
print(f"Voltage: {layer.voltage}")
Adaptive stepping gives ~5x speedup over fixed small-step integration.
API Reference
sc_neurocore.spike_ode.ode_layer
Continuous-depth SNN layer combining ODE solver with spike events.
Solves the LIF membrane ODE continuously, detects threshold crossings
as events, emits spikes, resets, continues. Adaptive step-size Euler
with event detection.
The frontier intersection of Neural ODEs + SNNs. No library has this
as a reusable layer.
Reference: EventProp (Wunderlich & Pehle 2021)
ODELIFDynamics
dataclass
LIF membrane ODE dynamics.
dv/dt = -(v - v_rest) / tau_mem + I(t) / C_mem
Parameters
tau_mem : float
Membrane time constant (ms).
v_rest : float
v_threshold : float
v_reset : float
C_mem : float
Membrane capacitance (normalized).
Source code in src/sc_neurocore/spike_ode/ode_layer.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 | @dataclass
class ODELIFDynamics:
"""LIF membrane ODE dynamics.
dv/dt = -(v - v_rest) / tau_mem + I(t) / C_mem
Parameters
----------
tau_mem : float
Membrane time constant (ms).
v_rest : float
v_threshold : float
v_reset : float
C_mem : float
Membrane capacitance (normalized).
"""
tau_mem: float = 20.0
v_rest: float = 0.0
v_threshold: float = 1.0
v_reset: float = 0.0
C_mem: float = 1.0
def dvdt(self, v: np.ndarray, I: np.ndarray) -> np.ndarray:
"""Compute membrane voltage derivative."""
return -(v - self.v_rest) / self.tau_mem + I / self.C_mem
|
dvdt(v, I)
Compute membrane voltage derivative.
Source code in src/sc_neurocore/spike_ode/ode_layer.py
| def dvdt(self, v: np.ndarray, I: np.ndarray) -> np.ndarray:
"""Compute membrane voltage derivative."""
return -(v - self.v_rest) / self.tau_mem + I / self.C_mem
|
SpikingODELayer
Spiking Neural ODE layer with event-driven integration.
Integrates the membrane ODE with adaptive Euler stepping.
Detects threshold crossings via bisection, emits spikes, resets.
Parameters
n_inputs : int
n_neurons : int
dynamics : ODELIFDynamics
dt_init : float
Initial integration step size.
dt_min : float
Minimum step size.
max_steps_per_interval : int
Max ODE steps per simulation interval.
seed : int
Source code in src/sc_neurocore/spike_ode/ode_layer.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182 | class SpikingODELayer:
"""Spiking Neural ODE layer with event-driven integration.
Integrates the membrane ODE with adaptive Euler stepping.
Detects threshold crossings via bisection, emits spikes, resets.
Parameters
----------
n_inputs : int
n_neurons : int
dynamics : ODELIFDynamics
dt_init : float
Initial integration step size.
dt_min : float
Minimum step size.
max_steps_per_interval : int
Max ODE steps per simulation interval.
seed : int
"""
def __init__(
self,
n_inputs: int,
n_neurons: int,
dynamics: ODELIFDynamics | None = None,
dt_init: float = 0.1,
dt_min: float = 0.001,
max_steps_per_interval: int = 100,
seed: int = 42,
):
self.n_inputs = n_inputs
self.n_neurons = n_neurons
self.dynamics = dynamics or ODELIFDynamics()
self.dt_init = dt_init
self.dt_min = dt_min
self.max_steps = max_steps_per_interval
rng = np.random.RandomState(seed)
self.W = rng.randn(n_neurons, n_inputs) * np.sqrt(2.0 / n_inputs)
self._v = np.full(n_neurons, self.dynamics.v_rest)
def step(self, x: np.ndarray, interval: float = 1.0) -> np.ndarray:
"""Integrate ODE over one interval, return spike counts.
Parameters
----------
x : ndarray of shape (n_inputs,)
Input (constant over interval).
interval : float
Duration of this interval (ms).
Returns
-------
ndarray of shape (n_neurons,)
Spike count per neuron during interval.
"""
I = self.W @ x
spike_counts = np.zeros(self.n_neurons)
t = 0.0
dt = self.dt_init
steps = 0
while t < interval and steps < self.max_steps:
dt = min(dt, interval - t)
if dt < self.dt_min:
break
# Euler step
dv = self.dynamics.dvdt(self._v, I)
v_new = self._v + dt * dv
# Event detection: threshold crossing
crossed = v_new >= self.dynamics.v_threshold
if crossed.any():
# Bisection to find exact crossing time
for _ in range(5): # 5 bisection steps
dt_half = dt / 2
v_mid = self._v + dt_half * dv
still_crossed = v_mid >= self.dynamics.v_threshold
if still_crossed.any():
dt = dt_half
v_new = v_mid
else:
break
spike_counts[crossed] += 1
v_new[crossed] = self.dynamics.v_reset
self._v = v_new
t += dt
steps += 1
# Adaptive step: increase if no spikes, decrease near threshold
distance_to_thresh = self.dynamics.v_threshold - self._v
min_dist = distance_to_thresh.min()
if min_dist < 0.1 * self.dynamics.v_threshold:
dt = max(dt * 0.5, self.dt_min)
else:
dt = min(dt * 1.5, self.dt_init)
return spike_counts
def forward(self, inputs: np.ndarray, interval: float = 1.0) -> np.ndarray:
"""Process a sequence of inputs.
Parameters
----------
inputs : ndarray of shape (T, n_inputs)
interval : float
Duration per input step.
Returns
-------
ndarray of shape (T, n_neurons), spike counts per interval
"""
self.reset()
T = inputs.shape[0]
outputs = np.zeros((T, self.n_neurons))
for t in range(T):
outputs[t] = self.step(inputs[t], interval)
return outputs
def reset(self):
self._v = np.full(self.n_neurons, self.dynamics.v_rest)
@property
def voltage(self) -> np.ndarray:
return self._v.copy()
|
step(x, interval=1.0)
Integrate ODE over one interval, return spike counts.
Parameters
x : ndarray of shape (n_inputs,)
Input (constant over interval).
interval : float
Duration of this interval (ms).
Returns
ndarray of shape (n_neurons,)
Spike count per neuron during interval.
Source code in src/sc_neurocore/spike_ode/ode_layer.py
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155 | def step(self, x: np.ndarray, interval: float = 1.0) -> np.ndarray:
"""Integrate ODE over one interval, return spike counts.
Parameters
----------
x : ndarray of shape (n_inputs,)
Input (constant over interval).
interval : float
Duration of this interval (ms).
Returns
-------
ndarray of shape (n_neurons,)
Spike count per neuron during interval.
"""
I = self.W @ x
spike_counts = np.zeros(self.n_neurons)
t = 0.0
dt = self.dt_init
steps = 0
while t < interval and steps < self.max_steps:
dt = min(dt, interval - t)
if dt < self.dt_min:
break
# Euler step
dv = self.dynamics.dvdt(self._v, I)
v_new = self._v + dt * dv
# Event detection: threshold crossing
crossed = v_new >= self.dynamics.v_threshold
if crossed.any():
# Bisection to find exact crossing time
for _ in range(5): # 5 bisection steps
dt_half = dt / 2
v_mid = self._v + dt_half * dv
still_crossed = v_mid >= self.dynamics.v_threshold
if still_crossed.any():
dt = dt_half
v_new = v_mid
else:
break
spike_counts[crossed] += 1
v_new[crossed] = self.dynamics.v_reset
self._v = v_new
t += dt
steps += 1
# Adaptive step: increase if no spikes, decrease near threshold
distance_to_thresh = self.dynamics.v_threshold - self._v
min_dist = distance_to_thresh.min()
if min_dist < 0.1 * self.dynamics.v_threshold:
dt = max(dt * 0.5, self.dt_min)
else:
dt = min(dt * 1.5, self.dt_init)
return spike_counts
|
forward(inputs, interval=1.0)
Process a sequence of inputs.
Parameters
inputs : ndarray of shape (T, n_inputs)
interval : float
Duration per input step.
Returns
ndarray of shape (T, n_neurons), spike counts per interval
Source code in src/sc_neurocore/spike_ode/ode_layer.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175 | def forward(self, inputs: np.ndarray, interval: float = 1.0) -> np.ndarray:
"""Process a sequence of inputs.
Parameters
----------
inputs : ndarray of shape (T, n_inputs)
interval : float
Duration per input step.
Returns
-------
ndarray of shape (T, n_neurons), spike counts per interval
"""
self.reset()
T = inputs.shape[0]
outputs = np.zeros((T, self.n_neurons))
for t in range(T):
outputs[t] = self.step(inputs[t], interval)
return outputs
|