Skip to content

Spike GNN — Graph Neural Networks with Spike Messages

Graph neural networks where messages are spike trains instead of float vectors. Nodes are spiking neuron populations. Aggregation via normalized neighborhood summation, followed by LIF integration.

Architecture

Each SpikeGraphConv layer performs:

  1. Message passing: h_agg = (A @ X) / deg — aggregate neighbor features
  2. Linear projection: h_proj = h_agg @ W^T — learned weight transform
  3. LIF integration: Over T timesteps, rate-coded input drives LIF neurons. Output = spike counts per node.

SpikeGNNLayer stacks multiple SpikeGraphConv layers with inter-layer spike count normalization.

Components

  • SpikeGraphConv — Single spike-based graph convolution layer.
Parameter Default Meaning
in_features (required) Input dimension per node
out_features (required) Output dimension per node
threshold 1.0 LIF spike threshold
tau_mem 10.0 Membrane time constant
  • SpikeGNNLayer — Multi-layer spike GNN for graph classification.
Parameter Default Meaning
layer_dims (required) [in, hidden, ..., out] dimensions
threshold 1.0 LIF threshold for all layers
T 8 Simulation timesteps per layer

Methods: forward(node_features, adjacency), graph_classify(node_features, adjacency).

Usage

from sc_neurocore.spike_gnn.spike_gnn import SpikeGraphConv, SpikeGNNLayer
import numpy as np

# Single layer
conv = SpikeGraphConv(in_features=16, out_features=8)
adj = np.array([[0, 1, 1], [1, 0, 0], [1, 0, 0]])
features = np.random.rand(3, 16)
output = conv.forward(features, adj, T=8)  # (3, 8) spike counts

# Multi-layer graph classifier
gnn = SpikeGNNLayer(layer_dims=[16, 8, 4], threshold=1.0, T=8)
label = gnn.graph_classify(features, adj)
print(f"Predicted class: {label}")

Reference: SGNNBench (2025) — 9 SGNN architectures benchmarked.

See Tutorial 73: Spike GNN.

sc_neurocore.spike_gnn

Spike-based GNN: message passing with spike trains instead of float vectors.

SpikeGNNLayer dataclass

Multi-layer spike GNN for graph classification/regression.

Parameters

layer_dims : list of int [in_features, hidden1, ..., out_features] threshold : float T : int Simulation timesteps per layer.

Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@dataclass
class SpikeGNNLayer:
    """Multi-layer spike GNN for graph classification/regression.

    Parameters
    ----------
    layer_dims : list of int
        [in_features, hidden1, ..., out_features]
    threshold : float
    T : int
        Simulation timesteps per layer.
    """

    layer_dims: list[int]
    threshold: float = 1.0
    T: int = 8

    def __post_init__(self):
        self.convs = []
        for i in range(len(self.layer_dims) - 1):
            self.convs.append(
                SpikeGraphConv(
                    self.layer_dims[i],
                    self.layer_dims[i + 1],
                    threshold=self.threshold,
                    seed=42 + i,
                )
            )

    def forward(self, node_features: np.ndarray, adjacency: np.ndarray) -> np.ndarray:
        """Forward pass through all layers.

        Parameters
        ----------
        node_features : ndarray of shape (N_nodes, in_features)
        adjacency : ndarray of shape (N_nodes, N_nodes)

        Returns
        -------
        ndarray of shape (N_nodes, out_features)
        """
        h = node_features
        for conv in self.convs:
            h = conv.forward(h, adjacency, T=self.T)
            # Normalize spike counts to [0, 1] for next layer
            max_val = h.max()
            if max_val > 0:  # pragma: no cover
                h = h / max_val
        return h

    def graph_classify(self, node_features: np.ndarray, adjacency: np.ndarray) -> int:
        """Classify a graph by global readout (sum pooling + argmax)."""
        node_out = self.forward(node_features, adjacency)
        graph_vec = node_out.sum(axis=0)
        return int(np.argmax(graph_vec))

    @property
    def n_layers(self) -> int:
        return len(self.convs)

forward(node_features, adjacency)

Forward pass through all layers.

Parameters

node_features : ndarray of shape (N_nodes, in_features) adjacency : ndarray of shape (N_nodes, N_nodes)

Returns

ndarray of shape (N_nodes, out_features)

Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def forward(self, node_features: np.ndarray, adjacency: np.ndarray) -> np.ndarray:
    """Forward pass through all layers.

    Parameters
    ----------
    node_features : ndarray of shape (N_nodes, in_features)
    adjacency : ndarray of shape (N_nodes, N_nodes)

    Returns
    -------
    ndarray of shape (N_nodes, out_features)
    """
    h = node_features
    for conv in self.convs:
        h = conv.forward(h, adjacency, T=self.T)
        # Normalize spike counts to [0, 1] for next layer
        max_val = h.max()
        if max_val > 0:  # pragma: no cover
            h = h / max_val
    return h

graph_classify(node_features, adjacency)

Classify a graph by global readout (sum pooling + argmax).

Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
158
159
160
161
162
def graph_classify(self, node_features: np.ndarray, adjacency: np.ndarray) -> int:
    """Classify a graph by global readout (sum pooling + argmax)."""
    node_out = self.forward(node_features, adjacency)
    graph_vec = node_out.sum(axis=0)
    return int(np.argmax(graph_vec))

SpikeGraphConv

Spike-based graph convolution layer.

Message passing: each node aggregates spike trains from neighbors, applies a learned weight transform via LIF integration.

Parameters

in_features : int Input feature dimension per node. out_features : int Output feature dimension per node. threshold : float tau_mem : float

Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class SpikeGraphConv:
    """Spike-based graph convolution layer.

    Message passing: each node aggregates spike trains from neighbors,
    applies a learned weight transform via LIF integration.

    Parameters
    ----------
    in_features : int
        Input feature dimension per node.
    out_features : int
        Output feature dimension per node.
    threshold : float
    tau_mem : float
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        threshold: float = 1.0,
        tau_mem: float = 10.0,
        seed: int = 42,
    ):
        self.in_features = in_features
        self.out_features = out_features
        self.threshold = threshold
        self.tau_mem = tau_mem

        rng = np.random.RandomState(seed)
        self.W = rng.randn(out_features, in_features) * np.sqrt(2.0 / in_features)
        self._v: np.ndarray | None = None

    def forward(
        self,
        node_features: np.ndarray,
        adjacency: np.ndarray,
        T: int = 8,
    ) -> np.ndarray:
        """Spike-based graph convolution.

        Parameters
        ----------
        node_features : ndarray of shape (N_nodes, in_features)
            Node features in [0, 1] (spike rates or encoded features).
        adjacency : ndarray of shape (N_nodes, N_nodes)
            Binary adjacency matrix (1 = edge, 0 = no edge).
        T : int
            Number of simulation timesteps.

        Returns
        -------
        ndarray of shape (N_nodes, out_features)
            Output spike counts per node per feature.
        """
        N = node_features.shape[0]
        rng = np.random.RandomState(42)

        # Aggregate neighbor features (message passing)
        degree = adjacency.sum(axis=1, keepdims=True)
        degree = np.clip(degree, 1, None)
        aggregated = (adjacency @ node_features) / degree

        # Project through weight matrix
        projected = aggregated @ self.W.T

        # LIF integration over T timesteps
        self._v = np.zeros((N, self.out_features))
        spike_counts = np.zeros((N, self.out_features))
        alpha = np.exp(-1.0 / self.tau_mem)

        for t in range(T):
            # Rate-code input: spike with probability proportional to projected value
            input_spikes = (rng.random(projected.shape) < np.clip(projected, 0, 1)).astype(
                np.float64
            )
            self._v = alpha * self._v + (1 - alpha) * input_spikes
            spikes = (self._v >= self.threshold).astype(np.float64)
            self._v -= spikes * self.threshold
            spike_counts += spikes

        return spike_counts

forward(node_features, adjacency, T=8)

Spike-based graph convolution.

Parameters

node_features : ndarray of shape (N_nodes, in_features) Node features in [0, 1] (spike rates or encoded features). adjacency : ndarray of shape (N_nodes, N_nodes) Binary adjacency matrix (1 = edge, 0 = no edge). T : int Number of simulation timesteps.

Returns

ndarray of shape (N_nodes, out_features) Output spike counts per node per feature.

Source code in src/sc_neurocore/spike_gnn/spike_gnn.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def forward(
    self,
    node_features: np.ndarray,
    adjacency: np.ndarray,
    T: int = 8,
) -> np.ndarray:
    """Spike-based graph convolution.

    Parameters
    ----------
    node_features : ndarray of shape (N_nodes, in_features)
        Node features in [0, 1] (spike rates or encoded features).
    adjacency : ndarray of shape (N_nodes, N_nodes)
        Binary adjacency matrix (1 = edge, 0 = no edge).
    T : int
        Number of simulation timesteps.

    Returns
    -------
    ndarray of shape (N_nodes, out_features)
        Output spike counts per node per feature.
    """
    N = node_features.shape[0]
    rng = np.random.RandomState(42)

    # Aggregate neighbor features (message passing)
    degree = adjacency.sum(axis=1, keepdims=True)
    degree = np.clip(degree, 1, None)
    aggregated = (adjacency @ node_features) / degree

    # Project through weight matrix
    projected = aggregated @ self.W.T

    # LIF integration over T timesteps
    self._v = np.zeros((N, self.out_features))
    spike_counts = np.zeros((N, self.out_features))
    alpha = np.exp(-1.0 / self.tau_mem)

    for t in range(T):
        # Rate-code input: spike with probability proportional to projected value
        input_spikes = (rng.random(projected.shape) < np.clip(projected, 0, 1)).astype(
            np.float64
        )
        self._v = alpha * self._v + (1 - alpha) * input_spikes
        spikes = (self._v >= self.threshold).astype(np.float64)
        self._v -= spikes * self.threshold
        spike_counts += spikes

    return spike_counts