Skip to content

Tutorial 53: Spike-Level Training Profiler

Diagnose SNN training problems automatically.

The Problem

SNN training fails silently. Dead neurons, gradient collapse, saturated layers — all produce the same symptom: loss doesn't decrease. No existing framework tells you why.

Quick Start

from sc_neurocore.profiling import SpikeProfiler
import numpy as np

profiler = SpikeProfiler()

# During training, record each layer's spikes and voltages
for epoch in range(100):
    for batch in dataloader:
        # ... forward pass ...
        profiler.record_step("hidden", spikes_hidden, voltages=v_hidden)
        profiler.record_step("output", spikes_output, voltages=v_output)

# Get diagnostic report
report = profiler.report()
print(report.summary())

What It Detects

Pathology Severity Trigger Fix
Dead neurons CRITICAL if >50%, WARNING if >10% Firing rate < 0.01 Lower threshold, add noise
Saturated neurons WARNING if >30% Firing rate > 0.95 Raise threshold, reduce input
Silent network CRITICAL Max rate < 0.001 Check input encoding, lower all thresholds
Voltage collapse WARNING Voltage std < 1e-6 Increase input current
Gradient explosion CRITICAL Max/mean norm > 100x Clip gradients, reduce LR
Gradient vanishing CRITICAL First/last layer ratio > 100x Skip connections, adaptive surrogate slope

Recording Gradients

# Record surrogate gradient magnitudes for gradient health check
profiler.record_step(
    "hidden",
    spikes_hidden,
    voltages=v_hidden,
    gradients=surrogate_grad_hidden,
)

Checking for Problems

report = profiler.report()

if report.has_critical:
    print("CRITICAL issues found!")
    for p in report.pathologies:
        print(f"  [{p.severity.value}] {p.category} @ {p.layer}")
        print(f"    {p.message}")
        print(f"    Fix: {p.suggestion}")

Per-Layer Statistics

for name, stats in report.layer_stats.items():
    print(f"{name}:")
    print(f"  Neurons: {stats.n_neurons}")
    print(f"  Firing rate: {stats.firing_rates.mean():.3f}")
    print(f"  Dead: {stats.dead_neuron_count} ({stats.dead_neuron_fraction:.0%})")
    print(f"  Voltage: {stats.voltage_mean:.3f} +/- {stats.voltage_std:.3f}")