Skip to content

Spike-Level Training Profiler

Live training diagnostics: dead neurons, gradient pathology, saturated layers, energy bottlenecks. The first automated SNN training profiler in any framework.

Profiler

sc_neurocore.profiling.spike_profiler

Live profiler for SNN training: detects dead neurons, gradient pathology, saturated layers, temporal credit assignment failures, and energy bottlenecks.

No SNN framework provides automated training diagnostics. SNN debugging is manual and expertise-intensive. This profiler instruments the training loop and emits actionable fix suggestions.

Usage

profiler = SpikeProfiler() profiler.record_step(layer="hidden", spikes=spike_tensor, voltages=v_tensor) profiler.record_step(layer="hidden", spikes=spike_tensor2, voltages=v_tensor2) report = profiler.report() print(report.summary()) for p in report.pathologies: print(p.severity, p.message, p.suggestion)

SpikeProfiler

Instruments SNN training to detect pathologies and compute diagnostics.

Record spike tensors, voltage tensors, and optionally gradient tensors per layer per training step. Call report() to get a ProfileReport with detected pathologies and fix suggestions.

Parameters

dead_threshold : float Firing rate below which a neuron is considered dead (default 0.01). saturated_threshold : float Firing rate above which a neuron is considered saturated (default 0.95). gradient_explosion_ratio : float Max/mean gradient norm ratio above which gradient explosion is flagged.

Source code in src/sc_neurocore/profiling/spike_profiler.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
class SpikeProfiler:
    """Instruments SNN training to detect pathologies and compute diagnostics.

    Record spike tensors, voltage tensors, and optionally gradient tensors
    per layer per training step. Call report() to get a ProfileReport with
    detected pathologies and fix suggestions.

    Parameters
    ----------
    dead_threshold : float
        Firing rate below which a neuron is considered dead (default 0.01).
    saturated_threshold : float
        Firing rate above which a neuron is considered saturated (default 0.95).
    gradient_explosion_ratio : float
        Max/mean gradient norm ratio above which gradient explosion is flagged.
    """

    def __init__(
        self,
        dead_threshold: float = 0.01,
        saturated_threshold: float = 0.95,
        gradient_explosion_ratio: float = 100.0,
    ):
        self.dead_threshold = dead_threshold
        self.saturated_threshold = saturated_threshold
        self.gradient_explosion_ratio = gradient_explosion_ratio

        self._layers: dict[str, _LayerAccumulator] = {}

    def record_step(
        self,
        layer: str,
        spikes: np.ndarray,
        voltages: np.ndarray | None = None,
        gradients: np.ndarray | None = None,
    ):
        """Record one timestep of data for a layer.

        Parameters
        ----------
        layer : str
            Layer name.
        spikes : ndarray of shape (n_neurons,) or (batch, n_neurons)
            Binary spike tensor for this timestep.
        voltages : ndarray, optional
            Membrane voltages, same shape as spikes.
        gradients : ndarray, optional
            Gradient tensor (surrogate gradient magnitudes).
        """
        if layer not in self._layers:
            self._layers[layer] = _LayerAccumulator(layer)
        self._layers[layer].add(spikes, voltages, gradients)

    def reset(self):
        """Clear all accumulated data."""
        self._layers.clear()

    def report(self) -> ProfileReport:
        """Analyze accumulated data and return a ProfileReport."""
        report = ProfileReport()

        for name, acc in self._layers.items():
            stats = acc.compute_stats()
            report.layer_stats[name] = stats
            report.total_steps = max(report.total_steps, stats.n_steps)
            report.total_spikes += stats.total_spikes
            report.total_neurons += stats.n_neurons

        # Detect pathologies
        report.pathologies = self._detect_pathologies(report.layer_stats)
        return report

    def _detect_pathologies(self, layer_stats: dict[str, LayerStats]) -> list[Pathology]:
        pathologies = []

        for name, stats in layer_stats.items():
            # Dead neurons
            if stats.dead_neuron_fraction > 0.5:
                pathologies.append(
                    Pathology(
                        severity=Severity.CRITICAL,
                        category="dead_neurons",
                        layer=name,
                        message=f"{stats.dead_neuron_count}/{stats.n_neurons} neurons "
                        f"({stats.dead_neuron_fraction:.0%}) never fire",
                        suggestion="Lower firing threshold by ~20% or increase input current gain",
                        metric_value=stats.dead_neuron_fraction,
                    )
                )
            elif stats.dead_neuron_fraction > 0.1:
                pathologies.append(
                    Pathology(
                        severity=Severity.WARNING,
                        category="dead_neurons",
                        layer=name,
                        message=f"{stats.dead_neuron_count}/{stats.n_neurons} neurons "
                        f"({stats.dead_neuron_fraction:.0%}) never fire",
                        suggestion="Consider lowering threshold or adding noise",
                        metric_value=stats.dead_neuron_fraction,
                    )
                )

            # Saturated neurons
            if stats.saturated_neuron_fraction > 0.3:
                pathologies.append(
                    Pathology(
                        severity=Severity.WARNING,
                        category="saturated_neurons",
                        layer=name,
                        message=f"{stats.saturated_neuron_count}/{stats.n_neurons} neurons "
                        f"({stats.saturated_neuron_fraction:.0%}) fire almost every step",
                        suggestion="Raise threshold or reduce input gain to restore sparse coding",
                        metric_value=stats.saturated_neuron_fraction,
                    )
                )

            # Gradient explosion
            if stats.gradient_norm_mean > 0 and stats.gradient_norm_max > 0:
                ratio = stats.gradient_norm_max / max(stats.gradient_norm_mean, 1e-12)
                if ratio > self.gradient_explosion_ratio:
                    pathologies.append(
                        Pathology(
                            severity=Severity.CRITICAL,
                            category="gradient_explosion",
                            layer=name,
                            message=f"Gradient max/mean ratio = {ratio:.1f}x "
                            f"(threshold: {self.gradient_explosion_ratio}x)",
                            suggestion="Clip gradients, reduce learning rate, or add surrogate gradient damping",
                            metric_value=ratio,
                        )
                    )

            # Silent network (zero spikes across all neurons)
            if stats.firing_rates is not None and stats.firing_rates.max() < 0.001:
                pathologies.append(
                    Pathology(
                        severity=Severity.CRITICAL,
                        category="silent_network",
                        layer=name,
                        message="Layer produces almost no spikes (max rate < 0.001)",
                        suggestion="Check input encoding, lower all thresholds, or verify input data is non-zero",
                        metric_value=float(stats.firing_rates.max()),
                    )
                )

            # Voltage collapse (all voltages near rest)
            if stats.voltage_std < 1e-6 and stats.n_steps > 10:
                pathologies.append(
                    Pathology(
                        severity=Severity.WARNING,
                        category="voltage_collapse",
                        layer=name,
                        message=f"Voltage std = {stats.voltage_std:.2e} — neurons not integrating input",
                        suggestion="Increase input current or check connectivity",
                        metric_value=stats.voltage_std,
                    )
                )

        # Cross-layer: gradient vanishing
        if len(layer_stats) >= 2:
            grad_norms = [
                (name, s.gradient_norm_mean)
                for name, s in layer_stats.items()
                if s.gradient_norm_mean > 0
            ]
            if len(grad_norms) >= 2:
                first_norm = grad_norms[0][1]
                last_norm = grad_norms[-1][1]
                if first_norm > 0 and last_norm / max(first_norm, 1e-12) < 0.01:
                    pathologies.append(
                        Pathology(
                            severity=Severity.CRITICAL,
                            category="gradient_vanishing",
                            layer=f"{grad_norms[0][0]}{grad_norms[-1][0]}",
                            message=f"Gradient decays {first_norm / max(last_norm, 1e-12):.0f}x "
                            f"from first to last layer",
                            suggestion="Add skip connections, use adaptive surrogate gradient slope, "
                            "or reduce network depth",
                            metric_value=last_norm / max(first_norm, 1e-12),
                        )
                    )

        return pathologies

record_step(layer, spikes, voltages=None, gradients=None)

Record one timestep of data for a layer.

Parameters

layer : str Layer name. spikes : ndarray of shape (n_neurons,) or (batch, n_neurons) Binary spike tensor for this timestep. voltages : ndarray, optional Membrane voltages, same shape as spikes. gradients : ndarray, optional Gradient tensor (surrogate gradient magnitudes).

Source code in src/sc_neurocore/profiling/spike_profiler.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def record_step(
    self,
    layer: str,
    spikes: np.ndarray,
    voltages: np.ndarray | None = None,
    gradients: np.ndarray | None = None,
):
    """Record one timestep of data for a layer.

    Parameters
    ----------
    layer : str
        Layer name.
    spikes : ndarray of shape (n_neurons,) or (batch, n_neurons)
        Binary spike tensor for this timestep.
    voltages : ndarray, optional
        Membrane voltages, same shape as spikes.
    gradients : ndarray, optional
        Gradient tensor (surrogate gradient magnitudes).
    """
    if layer not in self._layers:
        self._layers[layer] = _LayerAccumulator(layer)
    self._layers[layer].add(spikes, voltages, gradients)

reset()

Clear all accumulated data.

Source code in src/sc_neurocore/profiling/spike_profiler.py
183
184
185
def reset(self):
    """Clear all accumulated data."""
    self._layers.clear()

report()

Analyze accumulated data and return a ProfileReport.

Source code in src/sc_neurocore/profiling/spike_profiler.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def report(self) -> ProfileReport:
    """Analyze accumulated data and return a ProfileReport."""
    report = ProfileReport()

    for name, acc in self._layers.items():
        stats = acc.compute_stats()
        report.layer_stats[name] = stats
        report.total_steps = max(report.total_steps, stats.n_steps)
        report.total_spikes += stats.total_spikes
        report.total_neurons += stats.n_neurons

    # Detect pathologies
    report.pathologies = self._detect_pathologies(report.layer_stats)
    return report

ProfileReport dataclass

Complete profiling report with per-layer stats and detected pathologies.

Source code in src/sc_neurocore/profiling/spike_profiler.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@dataclass
class ProfileReport:
    """Complete profiling report with per-layer stats and detected pathologies."""

    layer_stats: dict[str, LayerStats] = field(default_factory=dict)
    pathologies: list[Pathology] = field(default_factory=list)
    total_steps: int = 0
    total_spikes: int = 0
    total_neurons: int = 0

    def summary(self) -> str:
        lines = [
            f"SpikeProfiler Report: {self.total_steps} steps, "
            f"{self.total_neurons} neurons, {self.total_spikes} total spikes",
            "",
        ]
        for name, stats in self.layer_stats.items():
            fr = stats.firing_rates
            mean_fr = float(fr.mean()) if fr is not None else 0.0
            lines.append(
                f"  {name}: {stats.n_neurons}n, rate={mean_fr:.3f}, "
                f"dead={stats.dead_neuron_count}, sat={stats.saturated_neuron_count}, "
                f"V={stats.voltage_mean:.3f}+/-{stats.voltage_std:.3f}"
            )

        if self.pathologies:
            lines.append("")
            lines.append(f"Pathologies detected: {len(self.pathologies)}")
            for p in self.pathologies:
                lines.append(f"  [{p.severity.value}] {p.category} @ {p.layer}: {p.message}")
                lines.append(f"    Fix: {p.suggestion}")
        else:  # pragma: no cover
            lines.append("")
            lines.append("No pathologies detected.")

        return "\n".join(lines)

    @property
    def has_critical(self) -> bool:
        return any(p.severity == Severity.CRITICAL for p in self.pathologies)

LayerStats dataclass

Accumulated statistics for one layer across recorded steps.

Source code in src/sc_neurocore/profiling/spike_profiler.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@dataclass
class LayerStats:
    """Accumulated statistics for one layer across recorded steps."""

    name: str
    n_neurons: int = 0
    n_steps: int = 0

    # Spike statistics
    total_spikes: int = 0
    per_neuron_spikes: np.ndarray | None = None
    firing_rates: np.ndarray | None = None

    # Voltage statistics
    voltage_mean: float = 0.0
    voltage_std: float = 0.0
    voltage_min: float = 0.0
    voltage_max: float = 0.0

    # Gradient statistics (if recorded)
    gradient_norm_mean: float = 0.0
    gradient_norm_max: float = 0.0

    # ISI statistics
    mean_isi: float = 0.0
    cv_isi: float = 0.0

    # Derived
    dead_neuron_count: int = 0
    saturated_neuron_count: int = 0
    dead_neuron_fraction: float = 0.0
    saturated_neuron_fraction: float = 0.0

    # Energy estimate (synaptic operations)
    estimated_syn_ops: int = 0

Pathology dataclass

One detected training pathology.

Source code in src/sc_neurocore/profiling/spike_profiler.py
39
40
41
42
43
44
45
46
47
48
@dataclass
class Pathology:
    """One detected training pathology."""

    severity: Severity
    category: str
    layer: str
    message: str
    suggestion: str
    metric_value: float = 0.0

Severity

Bases: Enum

Source code in src/sc_neurocore/profiling/spike_profiler.py
33
34
35
36
class Severity(Enum):
    INFO = "info"
    WARNING = "warning"
    CRITICAL = "critical"