Skip to content

Spiking Transformers & State-Space Models

Spike-driven attention, state-space models, and CPG positional encoding. Zero-multiplication attention via binary spike AND operations.

Spike-Driven Attention (SSA)

sc_neurocore.transformers.spikformer.SpikeDrivenAttention dataclass

Spike-Driven Self-Attention (SSA).

Replaces Q*K^T softmax with spike-based masking: Attention = SpikeFn(Q_linear(S)) * SpikeFn(K_linear(S))^T * V_linear(S)

All operations reduce to AND gates on binary spikes — zero multiplications, pure SC-compatible logic.

Parameters

embed_dim : int Embedding dimension. num_heads : int Number of attention heads. T : int Number of simulation timesteps. threshold : float Spike threshold for Q/K projections.

Source code in src/sc_neurocore/transformers/spikformer.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@dataclass
class SpikeDrivenAttention:
    """Spike-Driven Self-Attention (SSA).

    Replaces Q*K^T softmax with spike-based masking:
      Attention = SpikeFn(Q_linear(S)) * SpikeFn(K_linear(S))^T * V_linear(S)

    All operations reduce to AND gates on binary spikes —
    zero multiplications, pure SC-compatible logic.

    Parameters
    ----------
    embed_dim : int
        Embedding dimension.
    num_heads : int
        Number of attention heads.
    T : int
        Number of simulation timesteps.
    threshold : float
        Spike threshold for Q/K projections.
    """

    embed_dim: int
    num_heads: int = 1
    T: int = 8
    threshold: float = 1.0

    def __post_init__(self):
        self.head_dim = self.embed_dim // self.num_heads
        rng = np.random.RandomState(42)
        # Linear projections (Q, K, V)
        scale = np.sqrt(2.0 / self.embed_dim)
        self.W_q = rng.randn(self.embed_dim, self.embed_dim) * scale
        self.W_k = rng.randn(self.embed_dim, self.embed_dim) * scale
        self.W_v = rng.randn(self.embed_dim, self.embed_dim) * scale
        self.W_out = rng.randn(self.embed_dim, self.embed_dim) * scale
        # Membrane state for Q/K spike generation
        self._v_q = None
        self._v_k = None

    def _spike_fn(self, membrane: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """Integrate-and-fire: returns (spikes, new_membrane)."""
        spikes = (membrane >= self.threshold).astype(np.float64)
        membrane = membrane - spikes * self.threshold
        return spikes, membrane

    def forward(self, x: np.ndarray) -> np.ndarray:
        """Forward pass: spike-driven attention over T timesteps.

        Parameters
        ----------
        x : ndarray of shape (seq_len, embed_dim) or (embed_dim,)
            Input spike rates in [0, 1].

        Returns
        -------
        ndarray, same shape as x
            Attention output.
        """
        squeeze = x.ndim == 1
        if squeeze:
            x = x[np.newaxis]

        seq_len = x.shape[0]
        # Linear projections
        Q_proj = x @ self.W_q
        K_proj = x @ self.W_k
        V_proj = x @ self.W_v

        # Accumulate over T timesteps with spike-driven attention
        output_acc = np.zeros_like(x)
        self._v_q = np.zeros_like(Q_proj)
        self._v_k = np.zeros_like(K_proj)

        for t in range(self.T):
            # Rate-code input: spike with probability proportional to projection
            self._v_q += np.clip(Q_proj, 0, None) / self.T
            self._v_k += np.clip(K_proj, 0, None) / self.T

            Q_spikes, self._v_q = self._spike_fn(self._v_q)
            K_spikes, self._v_k = self._spike_fn(self._v_k)

            # SSA: spike AND instead of softmax
            # attn_weights[i,j] = Q_spikes[i] AND K_spikes[j] (dot product of binary)
            attn = Q_spikes @ K_spikes.T  # (seq, seq) — counts of matching spikes
            scale = max(np.sqrt(self.head_dim), 1.0)
            attn = attn / scale

            # Weighted sum of V
            output_acc += attn @ V_proj

        output = (output_acc / self.T) @ self.W_out

        if squeeze:
            output = output[0]
        return output

    @property
    def num_multiply_ops(self) -> int:
        """Zero multiplications in the attention core (AND gates only)."""
        return 0

num_multiply_ops property

Zero multiplications in the attention core (AND gates only).

forward(x)

Forward pass: spike-driven attention over T timesteps.

Parameters

x : ndarray of shape (seq_len, embed_dim) or (embed_dim,) Input spike rates in [0, 1].

Returns

ndarray, same shape as x Attention output.

Source code in src/sc_neurocore/transformers/spikformer.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def forward(self, x: np.ndarray) -> np.ndarray:
    """Forward pass: spike-driven attention over T timesteps.

    Parameters
    ----------
    x : ndarray of shape (seq_len, embed_dim) or (embed_dim,)
        Input spike rates in [0, 1].

    Returns
    -------
    ndarray, same shape as x
        Attention output.
    """
    squeeze = x.ndim == 1
    if squeeze:
        x = x[np.newaxis]

    seq_len = x.shape[0]
    # Linear projections
    Q_proj = x @ self.W_q
    K_proj = x @ self.W_k
    V_proj = x @ self.W_v

    # Accumulate over T timesteps with spike-driven attention
    output_acc = np.zeros_like(x)
    self._v_q = np.zeros_like(Q_proj)
    self._v_k = np.zeros_like(K_proj)

    for t in range(self.T):
        # Rate-code input: spike with probability proportional to projection
        self._v_q += np.clip(Q_proj, 0, None) / self.T
        self._v_k += np.clip(K_proj, 0, None) / self.T

        Q_spikes, self._v_q = self._spike_fn(self._v_q)
        K_spikes, self._v_k = self._spike_fn(self._v_k)

        # SSA: spike AND instead of softmax
        # attn_weights[i,j] = Q_spikes[i] AND K_spikes[j] (dot product of binary)
        attn = Q_spikes @ K_spikes.T  # (seq, seq) — counts of matching spikes
        scale = max(np.sqrt(self.head_dim), 1.0)
        attn = attn / scale

        # Weighted sum of V
        output_acc += attn @ V_proj

    output = (output_acc / self.T) @ self.W_out

    if squeeze:
        output = output[0]
    return output

Spiking State-Space Model

sc_neurocore.transformers.spikformer.SpikyStateSpace dataclass

Spiking State-Space Model (S4-SNN hybrid).

Combines linear state-space dynamics with spiking nonlinearity

h_t = A * h_{t-1} + B * spike_input_t y_t = C * h_t spike_t = IF(y_t > threshold)

Runs in O(1) memory per timestep (no BPTT unrolling needed). Reference: SpikySpace (2025).

Parameters

d_model : int Input/output dimension. d_state : int Hidden state dimension. threshold : float Spiking threshold. dt : float Discretization timestep.

Source code in src/sc_neurocore/transformers/spikformer.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@dataclass
class SpikyStateSpace:
    """Spiking State-Space Model (S4-SNN hybrid).

    Combines linear state-space dynamics with spiking nonlinearity:
      h_t = A * h_{t-1} + B * spike_input_t
      y_t = C * h_t
      spike_t = IF(y_t > threshold)

    Runs in O(1) memory per timestep (no BPTT unrolling needed).
    Reference: SpikySpace (2025).

    Parameters
    ----------
    d_model : int
        Input/output dimension.
    d_state : int
        Hidden state dimension.
    threshold : float
        Spiking threshold.
    dt : float
        Discretization timestep.
    """

    d_model: int
    d_state: int = 64
    threshold: float = 1.0
    dt: float = 0.01

    def __post_init__(self):
        rng = np.random.RandomState(42)
        # State-space matrices (discretized)
        # A: state transition (diagonal for efficiency)
        self.A = np.exp(-self.dt * np.abs(rng.randn(self.d_state)))
        self.B = rng.randn(self.d_state, self.d_model) * np.sqrt(2.0 / self.d_model)
        self.C = rng.randn(self.d_model, self.d_state) * np.sqrt(2.0 / self.d_state)
        self._h = np.zeros(self.d_state)
        self._v = np.zeros(self.d_model)

    def reset(self):
        """Reset hidden state and membrane potential."""
        self._h = np.zeros(self.d_state)
        self._v = np.zeros(self.d_model)

    def step(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """Process one timestep.

        Parameters
        ----------
        x : ndarray of shape (d_model,)
            Input (binary spikes or continuous).

        Returns
        -------
        (spikes, output) tuple
            spikes: binary spike output (d_model,)
            output: continuous pre-spike output (d_model,)
        """
        self._h = self.A * self._h + self.B @ x
        y = self.C @ self._h

        self._v += y
        spikes = (self._v >= self.threshold).astype(np.float64)
        self._v -= spikes * self.threshold

        return spikes, y

    def forward(self, x_seq: np.ndarray) -> np.ndarray:
        """Process a full sequence.

        Parameters
        ----------
        x_seq : ndarray of shape (T, d_model)

        Returns
        -------
        ndarray of shape (T, d_model)
            Spike output per timestep.
        """
        self.reset()
        T = x_seq.shape[0]
        out = np.zeros_like(x_seq)
        for t in range(T):
            spikes, _ = self.step(x_seq[t])
            out[t] = spikes
        return out

reset()

Reset hidden state and membrane potential.

Source code in src/sc_neurocore/transformers/spikformer.py
171
172
173
174
def reset(self):
    """Reset hidden state and membrane potential."""
    self._h = np.zeros(self.d_state)
    self._v = np.zeros(self.d_model)

step(x)

Process one timestep.

Parameters

x : ndarray of shape (d_model,) Input (binary spikes or continuous).

Returns

(spikes, output) tuple spikes: binary spike output (d_model,) output: continuous pre-spike output (d_model,)

Source code in src/sc_neurocore/transformers/spikformer.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def step(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Process one timestep.

    Parameters
    ----------
    x : ndarray of shape (d_model,)
        Input (binary spikes or continuous).

    Returns
    -------
    (spikes, output) tuple
        spikes: binary spike output (d_model,)
        output: continuous pre-spike output (d_model,)
    """
    self._h = self.A * self._h + self.B @ x
    y = self.C @ self._h

    self._v += y
    spikes = (self._v >= self.threshold).astype(np.float64)
    self._v -= spikes * self.threshold

    return spikes, y

forward(x_seq)

Process a full sequence.

Parameters

x_seq : ndarray of shape (T, d_model)

Returns

ndarray of shape (T, d_model) Spike output per timestep.

Source code in src/sc_neurocore/transformers/spikformer.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def forward(self, x_seq: np.ndarray) -> np.ndarray:
    """Process a full sequence.

    Parameters
    ----------
    x_seq : ndarray of shape (T, d_model)

    Returns
    -------
    ndarray of shape (T, d_model)
        Spike output per timestep.
    """
    self.reset()
    T = x_seq.shape[0]
    out = np.zeros_like(x_seq)
    for t in range(T):
        spikes, _ = self.step(x_seq[t])
        out[t] = spikes
    return out

CPG Positional Encoding

sc_neurocore.transformers.spikformer.CPGPositionalEncoding dataclass

Central Pattern Generator positional encoding.

Replaces sinusoidal positional encoding with biologically-inspired CPG oscillators. Each dimension has a different frequency and phase, generating spike-compatible temporal position signals.

Parameters

d_model : int Encoding dimension. max_len : int Maximum sequence length.

Source code in src/sc_neurocore/transformers/spikformer.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@dataclass
class CPGPositionalEncoding:
    """Central Pattern Generator positional encoding.

    Replaces sinusoidal positional encoding with biologically-inspired
    CPG oscillators. Each dimension has a different frequency and phase,
    generating spike-compatible temporal position signals.

    Parameters
    ----------
    d_model : int
        Encoding dimension.
    max_len : int
        Maximum sequence length.
    """

    d_model: int
    max_len: int = 1024

    def __post_init__(self):
        rng = np.random.RandomState(42)
        self.frequencies = np.exp(rng.randn(self.d_model) * 0.5)
        self.phases = rng.uniform(0, 2 * np.pi, self.d_model)

    def encode(self, seq_len: int) -> np.ndarray:
        """Generate positional encoding.

        Returns
        -------
        ndarray of shape (seq_len, d_model)
            Values in [0, 1] suitable for spike rate encoding.
        """
        t = np.arange(seq_len)[:, np.newaxis]
        angles = t * self.frequencies[np.newaxis, :] * 0.01 + self.phases[np.newaxis, :]
        return (np.sin(angles) + 1.0) / 2.0  # Map to [0, 1]

    def encode_spikes(self, seq_len: int, rng: np.random.RandomState | None = None) -> np.ndarray:
        """Generate spike-encoded positional encoding.

        Returns
        -------
        ndarray of shape (seq_len, d_model), binary
        """
        if rng is None:
            rng = np.random.RandomState(0)
        rates = self.encode(seq_len)
        return (rng.random(rates.shape) < rates).astype(np.int8)

encode(seq_len)

Generate positional encoding.

Returns

ndarray of shape (seq_len, d_model) Values in [0, 1] suitable for spike rate encoding.

Source code in src/sc_neurocore/transformers/spikformer.py
244
245
246
247
248
249
250
251
252
253
254
def encode(self, seq_len: int) -> np.ndarray:
    """Generate positional encoding.

    Returns
    -------
    ndarray of shape (seq_len, d_model)
        Values in [0, 1] suitable for spike rate encoding.
    """
    t = np.arange(seq_len)[:, np.newaxis]
    angles = t * self.frequencies[np.newaxis, :] * 0.01 + self.phases[np.newaxis, :]
    return (np.sin(angles) + 1.0) / 2.0  # Map to [0, 1]

encode_spikes(seq_len, rng=None)

Generate spike-encoded positional encoding.

Returns

ndarray of shape (seq_len, d_model), binary

Source code in src/sc_neurocore/transformers/spikformer.py
256
257
258
259
260
261
262
263
264
265
266
def encode_spikes(self, seq_len: int, rng: np.random.RandomState | None = None) -> np.ndarray:
    """Generate spike-encoded positional encoding.

    Returns
    -------
    ndarray of shape (seq_len, d_model), binary
    """
    if rng is None:
        rng = np.random.RandomState(0)
    rates = self.encode(seq_len)
    return (rng.random(rates.shape) < rates).astype(np.int8)