Tutorial 73: Spike-Native Graph Neural Networks
Graph processing with spike-based message passing. Unlike float-based GNNs,
messages are spike trains — enabling event-driven, power-proportional
computation on neuromorphic hardware.
SpikeGNNLayer
import numpy as np
from sc_neurocore.spike_gnn import SpikeGNNLayer
# 20-node graph with random connectivity
adj = (np.random.rand(20, 20) > 0.7).astype(float)
np.fill_diagonal(adj, 0)
# Node features: 16-dim per node
features = np.random.rand(20, 16)
# GNN layer: 16 -> 8 -> 3 (node classification)
gnn = SpikeGNNLayer([16, 8, 3], T=8)
node_out = gnn.forward(features, adj)
# shape: (20, 3) — per-node output
# Graph-level classification
predicted_class = gnn.graph_classify(features, adj)
How Message Passing Works
- Each node encodes its features as spike trains
- Spikes propagate along edges (adjacency matrix)
- Neighborhood aggregation via spike-domain summation
- Output is spike rate vector per node
Computation is O(spikes * edges), not O(nodes * features). Sparse graphs
with low firing rates get massive speedups vs dense float GNNs.
API Reference
sc_neurocore.spike_gnn
Spike-based GNN: message passing with spike trains instead of float vectors.
SpikeGNNLayer
dataclass
Multi-layer spike GNN for graph classification/regression.
Parameters
layer_dims : list of int
[in_features, hidden1, ..., out_features]
threshold : float
T : int
Simulation timesteps per layer.
Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166 | @dataclass
class SpikeGNNLayer:
"""Multi-layer spike GNN for graph classification/regression.
Parameters
----------
layer_dims : list of int
[in_features, hidden1, ..., out_features]
threshold : float
T : int
Simulation timesteps per layer.
"""
layer_dims: list[int]
threshold: float = 1.0
T: int = 8
def __post_init__(self):
self.convs = []
for i in range(len(self.layer_dims) - 1):
self.convs.append(
SpikeGraphConv(
self.layer_dims[i],
self.layer_dims[i + 1],
threshold=self.threshold,
seed=42 + i,
)
)
def forward(self, node_features: np.ndarray, adjacency: np.ndarray) -> np.ndarray:
"""Forward pass through all layers.
Parameters
----------
node_features : ndarray of shape (N_nodes, in_features)
adjacency : ndarray of shape (N_nodes, N_nodes)
Returns
-------
ndarray of shape (N_nodes, out_features)
"""
h = node_features
for conv in self.convs:
h = conv.forward(h, adjacency, T=self.T)
# Normalize spike counts to [0, 1] for next layer
max_val = h.max()
if max_val > 0: # pragma: no cover
h = h / max_val
return h
def graph_classify(self, node_features: np.ndarray, adjacency: np.ndarray) -> int:
"""Classify a graph by global readout (sum pooling + argmax)."""
node_out = self.forward(node_features, adjacency)
graph_vec = node_out.sum(axis=0)
return int(np.argmax(graph_vec))
@property
def n_layers(self) -> int:
return len(self.convs)
|
forward(node_features, adjacency)
Forward pass through all layers.
Parameters
node_features : ndarray of shape (N_nodes, in_features)
adjacency : ndarray of shape (N_nodes, N_nodes)
Returns
ndarray of shape (N_nodes, out_features)
Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156 | def forward(self, node_features: np.ndarray, adjacency: np.ndarray) -> np.ndarray:
"""Forward pass through all layers.
Parameters
----------
node_features : ndarray of shape (N_nodes, in_features)
adjacency : ndarray of shape (N_nodes, N_nodes)
Returns
-------
ndarray of shape (N_nodes, out_features)
"""
h = node_features
for conv in self.convs:
h = conv.forward(h, adjacency, T=self.T)
# Normalize spike counts to [0, 1] for next layer
max_val = h.max()
if max_val > 0: # pragma: no cover
h = h / max_val
return h
|
graph_classify(node_features, adjacency)
Classify a graph by global readout (sum pooling + argmax).
Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
| def graph_classify(self, node_features: np.ndarray, adjacency: np.ndarray) -> int:
"""Classify a graph by global readout (sum pooling + argmax)."""
node_out = self.forward(node_features, adjacency)
graph_vec = node_out.sum(axis=0)
return int(np.argmax(graph_vec))
|
SpikeGraphConv
Spike-based graph convolution layer.
Message passing: each node aggregates spike trains from neighbors,
applies a learned weight transform via LIF integration.
Parameters
in_features : int
Input feature dimension per node.
out_features : int
Output feature dimension per node.
threshold : float
tau_mem : float
Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 | class SpikeGraphConv:
"""Spike-based graph convolution layer.
Message passing: each node aggregates spike trains from neighbors,
applies a learned weight transform via LIF integration.
Parameters
----------
in_features : int
Input feature dimension per node.
out_features : int
Output feature dimension per node.
threshold : float
tau_mem : float
"""
def __init__(
self,
in_features: int,
out_features: int,
threshold: float = 1.0,
tau_mem: float = 10.0,
seed: int = 42,
):
self.in_features = in_features
self.out_features = out_features
self.threshold = threshold
self.tau_mem = tau_mem
rng = np.random.RandomState(seed)
self.W = rng.randn(out_features, in_features) * np.sqrt(2.0 / in_features)
self._v: np.ndarray | None = None
def forward(
self,
node_features: np.ndarray,
adjacency: np.ndarray,
T: int = 8,
) -> np.ndarray:
"""Spike-based graph convolution.
Parameters
----------
node_features : ndarray of shape (N_nodes, in_features)
Node features in [0, 1] (spike rates or encoded features).
adjacency : ndarray of shape (N_nodes, N_nodes)
Binary adjacency matrix (1 = edge, 0 = no edge).
T : int
Number of simulation timesteps.
Returns
-------
ndarray of shape (N_nodes, out_features)
Output spike counts per node per feature.
"""
N = node_features.shape[0]
rng = np.random.RandomState(42)
# Aggregate neighbor features (message passing)
degree = adjacency.sum(axis=1, keepdims=True)
degree = np.clip(degree, 1, None)
aggregated = (adjacency @ node_features) / degree
# Project through weight matrix
projected = aggregated @ self.W.T
# LIF integration over T timesteps
self._v = np.zeros((N, self.out_features))
spike_counts = np.zeros((N, self.out_features))
alpha = np.exp(-1.0 / self.tau_mem)
for t in range(T):
# Rate-code input: spike with probability proportional to projected value
input_spikes = (rng.random(projected.shape) < np.clip(projected, 0, 1)).astype(
np.float64
)
self._v = alpha * self._v + (1 - alpha) * input_spikes
spikes = (self._v >= self.threshold).astype(np.float64)
self._v -= spikes * self.threshold
spike_counts += spikes
return spike_counts
|
forward(node_features, adjacency, T=8)
Spike-based graph convolution.
Parameters
node_features : ndarray of shape (N_nodes, in_features)
Node features in [0, 1] (spike rates or encoded features).
adjacency : ndarray of shape (N_nodes, N_nodes)
Binary adjacency matrix (1 = edge, 0 = no edge).
T : int
Number of simulation timesteps.
Returns
ndarray of shape (N_nodes, out_features)
Output spike counts per node per feature.
Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 | def forward(
self,
node_features: np.ndarray,
adjacency: np.ndarray,
T: int = 8,
) -> np.ndarray:
"""Spike-based graph convolution.
Parameters
----------
node_features : ndarray of shape (N_nodes, in_features)
Node features in [0, 1] (spike rates or encoded features).
adjacency : ndarray of shape (N_nodes, N_nodes)
Binary adjacency matrix (1 = edge, 0 = no edge).
T : int
Number of simulation timesteps.
Returns
-------
ndarray of shape (N_nodes, out_features)
Output spike counts per node per feature.
"""
N = node_features.shape[0]
rng = np.random.RandomState(42)
# Aggregate neighbor features (message passing)
degree = adjacency.sum(axis=1, keepdims=True)
degree = np.clip(degree, 1, None)
aggregated = (adjacency @ node_features) / degree
# Project through weight matrix
projected = aggregated @ self.W.T
# LIF integration over T timesteps
self._v = np.zeros((N, self.out_features))
spike_counts = np.zeros((N, self.out_features))
alpha = np.exp(-1.0 / self.tau_mem)
for t in range(T):
# Rate-code input: spike with probability proportional to projected value
input_spikes = (rng.random(projected.shape) < np.clip(projected, 0, 1)).astype(
np.float64
)
self._v = alpha * self._v + (1 - alpha) * input_spikes
spikes = (self._v >= self.threshold).astype(np.float64)
self._v -= spikes * self.threshold
spike_counts += spikes
return spike_counts
|