Skip to content

SNN Explainability

Multi-method spike-level explainability: attribution, saliency, causal importance.

Spike Attributor

sc_neurocore.explain.spike_explain.SpikeAttributor

Backward spike attribution via eligibility-trace chain.

Traces the contribution of each input spike to the output through intermediate layers using eligibility trace products. Approximation of temporal backpropagation attribution.

Parameters

decay : float Temporal decay factor for backward attribution (0-1).

Source code in src/sc_neurocore/explain/spike_explain.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class SpikeAttributor:
    """Backward spike attribution via eligibility-trace chain.

    Traces the contribution of each input spike to the output
    through intermediate layers using eligibility trace products.
    Approximation of temporal backpropagation attribution.

    Parameters
    ----------
    decay : float
        Temporal decay factor for backward attribution (0-1).
    """

    def __init__(self, decay: float = 0.9):
        self.decay = decay

    def attribute(
        self,
        spikes: np.ndarray,
        weights: list[np.ndarray],
        output_neuron: int = 0,
    ) -> ExplanationResult:
        """Compute per-input-spike attribution scores.

        Parameters
        ----------
        spikes : ndarray of shape (T, N_input)
            Input spike trains.
        weights : list of ndarray
            Weight matrices [W1, W2, ...] where W_i is (n_out, n_in).
        output_neuron : int
            Which output neuron to explain.

        Returns
        -------
        ExplanationResult with importance_map of shape (T, N_input)
        """
        T, N_in = spikes.shape
        importance = np.zeros((T, N_in))

        # Backward through weight chain: output_neuron → input
        # Attribution = product of weight paths * temporal decay
        attribution_weights = np.ones(N_in)
        for w in reversed(weights):
            if output_neuron < w.shape[0]:
                row = np.abs(w[output_neuron])
                if row.shape[0] == attribution_weights.shape[0]:
                    attribution_weights = attribution_weights * row
                else:
                    attribution_weights = np.abs(w[output_neuron])
                output_neuron = 0  # reset for next layer

        # Temporal attribution: weight input spikes by attribution + decay
        for t in range(T):
            time_weight = self.decay ** (T - 1 - t)
            importance[t] = spikes[t].astype(np.float64) * attribution_weights * time_weight

        # Normalize
        max_val = importance.max()
        if max_val > 0:
            importance /= max_val

        return ExplanationResult(
            method="spike_attribution",
            importance_map=importance,
        )

attribute(spikes, weights, output_neuron=0)

Compute per-input-spike attribution scores.

Parameters

spikes : ndarray of shape (T, N_input) Input spike trains. weights : list of ndarray Weight matrices [W1, W2, ...] where W_i is (n_out, n_in). output_neuron : int Which output neuron to explain.

Returns

ExplanationResult with importance_map of shape (T, N_input)

Source code in src/sc_neurocore/explain/spike_explain.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def attribute(
    self,
    spikes: np.ndarray,
    weights: list[np.ndarray],
    output_neuron: int = 0,
) -> ExplanationResult:
    """Compute per-input-spike attribution scores.

    Parameters
    ----------
    spikes : ndarray of shape (T, N_input)
        Input spike trains.
    weights : list of ndarray
        Weight matrices [W1, W2, ...] where W_i is (n_out, n_in).
    output_neuron : int
        Which output neuron to explain.

    Returns
    -------
    ExplanationResult with importance_map of shape (T, N_input)
    """
    T, N_in = spikes.shape
    importance = np.zeros((T, N_in))

    # Backward through weight chain: output_neuron → input
    # Attribution = product of weight paths * temporal decay
    attribution_weights = np.ones(N_in)
    for w in reversed(weights):
        if output_neuron < w.shape[0]:
            row = np.abs(w[output_neuron])
            if row.shape[0] == attribution_weights.shape[0]:
                attribution_weights = attribution_weights * row
            else:
                attribution_weights = np.abs(w[output_neuron])
            output_neuron = 0  # reset for next layer

    # Temporal attribution: weight input spikes by attribution + decay
    for t in range(T):
        time_weight = self.decay ** (T - 1 - t)
        importance[t] = spikes[t].astype(np.float64) * attribution_weights * time_weight

    # Normalize
    max_val = importance.max()
    if max_val > 0:
        importance /= max_val

    return ExplanationResult(
        method="spike_attribution",
        importance_map=importance,
    )

Temporal Saliency

sc_neurocore.explain.spike_explain.TemporalSaliency

Perturbation-based temporal saliency.

For each input spike, measure the change in output when that spike is removed. Spikes whose removal causes large output change are salient (important).

Parameters

run_fn : callable Function that takes input spikes (T, N) and returns output spike counts or rates (N_output,).

Source code in src/sc_neurocore/explain/spike_explain.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class TemporalSaliency:
    """Perturbation-based temporal saliency.

    For each input spike, measure the change in output when that spike
    is removed. Spikes whose removal causes large output change are
    salient (important).

    Parameters
    ----------
    run_fn : callable
        Function that takes input spikes (T, N) and returns output
        spike counts or rates (N_output,).
    """

    def __init__(self, run_fn):
        self.run_fn = run_fn

    def explain(
        self,
        spikes: np.ndarray,
        output_neuron: int = 0,
    ) -> ExplanationResult:
        """Compute perturbation-based saliency for each input spike.

        Parameters
        ----------
        spikes : ndarray of shape (T, N)
        output_neuron : int

        Returns
        -------
        ExplanationResult
        """
        T, N = spikes.shape
        baseline_output = self.run_fn(spikes)
        if baseline_output.ndim > 0:
            baseline_val = float(baseline_output[output_neuron])
        else:
            baseline_val = float(baseline_output)

        importance = np.zeros((T, N))

        # Find spike locations to perturb
        spike_locs = np.argwhere(spikes > 0)

        for t, n in spike_locs:
            perturbed = spikes.copy()
            perturbed[t, n] = 0
            perturbed_output = self.run_fn(perturbed)
            if perturbed_output.ndim > 0:
                new_val = float(perturbed_output[output_neuron])
            else:
                new_val = float(perturbed_output)
            importance[t, n] = abs(baseline_val - new_val)

        max_val = importance.max()
        if max_val > 0:
            importance /= max_val

        return ExplanationResult(
            method="temporal_saliency",
            importance_map=importance,
        )

explain(spikes, output_neuron=0)

Compute perturbation-based saliency for each input spike.

Parameters

spikes : ndarray of shape (T, N) output_neuron : int

Returns

ExplanationResult

Source code in src/sc_neurocore/explain/spike_explain.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def explain(
    self,
    spikes: np.ndarray,
    output_neuron: int = 0,
) -> ExplanationResult:
    """Compute perturbation-based saliency for each input spike.

    Parameters
    ----------
    spikes : ndarray of shape (T, N)
    output_neuron : int

    Returns
    -------
    ExplanationResult
    """
    T, N = spikes.shape
    baseline_output = self.run_fn(spikes)
    if baseline_output.ndim > 0:
        baseline_val = float(baseline_output[output_neuron])
    else:
        baseline_val = float(baseline_output)

    importance = np.zeros((T, N))

    # Find spike locations to perturb
    spike_locs = np.argwhere(spikes > 0)

    for t, n in spike_locs:
        perturbed = spikes.copy()
        perturbed[t, n] = 0
        perturbed_output = self.run_fn(perturbed)
        if perturbed_output.ndim > 0:
            new_val = float(perturbed_output[output_neuron])
        else:
            new_val = float(perturbed_output)
        importance[t, n] = abs(baseline_val - new_val)

    max_val = importance.max()
    if max_val > 0:
        importance /= max_val

    return ExplanationResult(
        method="temporal_saliency",
        importance_map=importance,
    )

Causal Importance

sc_neurocore.explain.spike_explain.CausalImportance

Causal importance via forward intervention.

Silence each neuron (clamp to zero) across all timesteps and measure the impact on classification output. Builds a per-neuron causal importance score.

Parameters

run_fn : callable Function that takes input spikes (T, N) and returns output (N_output,).

Source code in src/sc_neurocore/explain/spike_explain.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class CausalImportance:
    """Causal importance via forward intervention.

    Silence each neuron (clamp to zero) across all timesteps and
    measure the impact on classification output. Builds a per-neuron
    causal importance score.

    Parameters
    ----------
    run_fn : callable
        Function that takes input spikes (T, N) and returns output (N_output,).
    """

    def __init__(self, run_fn):
        self.run_fn = run_fn

    def explain(
        self,
        spikes: np.ndarray,
        output_neuron: int = 0,
    ) -> ExplanationResult:
        """Compute causal importance by silencing each neuron.

        Parameters
        ----------
        spikes : ndarray of shape (T, N)
        output_neuron : int

        Returns
        -------
        ExplanationResult with importance_map of shape (1, N)
        """
        T, N = spikes.shape
        baseline_output = self.run_fn(spikes)
        if baseline_output.ndim > 0:
            baseline_val = float(baseline_output[output_neuron])
        else:
            baseline_val = float(baseline_output)

        neuron_importance = np.zeros(N)

        for n in range(N):
            silenced = spikes.copy()
            silenced[:, n] = 0
            silenced_output = self.run_fn(silenced)
            if silenced_output.ndim > 0:
                new_val = float(silenced_output[output_neuron])
            else:
                new_val = float(silenced_output)
            neuron_importance[n] = abs(baseline_val - new_val)

        max_val = neuron_importance.max()
        if max_val > 0:
            neuron_importance /= max_val

        importance_map = np.tile(neuron_importance, (1, 1))

        return ExplanationResult(
            method="causal_importance",
            importance_map=importance_map,
        )

explain(spikes, output_neuron=0)

Compute causal importance by silencing each neuron.

Parameters

spikes : ndarray of shape (T, N) output_neuron : int

Returns

ExplanationResult with importance_map of shape (1, N)

Source code in src/sc_neurocore/explain/spike_explain.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def explain(
    self,
    spikes: np.ndarray,
    output_neuron: int = 0,
) -> ExplanationResult:
    """Compute causal importance by silencing each neuron.

    Parameters
    ----------
    spikes : ndarray of shape (T, N)
    output_neuron : int

    Returns
    -------
    ExplanationResult with importance_map of shape (1, N)
    """
    T, N = spikes.shape
    baseline_output = self.run_fn(spikes)
    if baseline_output.ndim > 0:
        baseline_val = float(baseline_output[output_neuron])
    else:
        baseline_val = float(baseline_output)

    neuron_importance = np.zeros(N)

    for n in range(N):
        silenced = spikes.copy()
        silenced[:, n] = 0
        silenced_output = self.run_fn(silenced)
        if silenced_output.ndim > 0:
            new_val = float(silenced_output[output_neuron])
        else:
            new_val = float(silenced_output)
        neuron_importance[n] = abs(baseline_val - new_val)

    max_val = neuron_importance.max()
    if max_val > 0:
        neuron_importance /= max_val

    importance_map = np.tile(neuron_importance, (1, 1))

    return ExplanationResult(
        method="causal_importance",
        importance_map=importance_map,
    )

Result

sc_neurocore.explain.spike_explain.ExplanationResult dataclass

Result of an explanation method.

Source code in src/sc_neurocore/explain/spike_explain.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class ExplanationResult:
    """Result of an explanation method."""

    method: str
    importance_map: np.ndarray  # (T, N) importance scores
    top_spikes: list[tuple[int, int, float]] = field(default_factory=list)
    summary_text: str = ""

    def top_k(self, k: int = 10) -> list[tuple[int, int, float]]:
        """Return top-k most important (timestep, neuron_id, score) tuples."""
        flat = self.importance_map.ravel()
        indices = np.argsort(flat)[::-1][:k]
        T = self.importance_map.shape[0]
        results = []
        for idx in indices:
            t = idx // self.importance_map.shape[1]
            n = idx % self.importance_map.shape[1]
            results.append((int(t), int(n), float(flat[idx])))
        return results

    def summary(self) -> str:
        top = self.top_k(5)
        lines = [f"Explanation ({self.method}):"]
        for t, n, score in top:
            lines.append(f"  t={t}, neuron={n}: importance={score:.4f}")
        return "\n".join(lines)

top_k(k=10)

Return top-k most important (timestep, neuron_id, score) tuples.

Source code in src/sc_neurocore/explain/spike_explain.py
37
38
39
40
41
42
43
44
45
46
47
def top_k(self, k: int = 10) -> list[tuple[int, int, float]]:
    """Return top-k most important (timestep, neuron_id, score) tuples."""
    flat = self.importance_map.ravel()
    indices = np.argsort(flat)[::-1][:k]
    T = self.importance_map.shape[0]
    results = []
    for idx in indices:
        t = idx // self.importance_map.shape[1]
        n = idx % self.importance_map.shape[1]
        results.append((int(t), int(n), float(flat[idx])))
    return results