Skip to content

Tutorial 76: Knowledge Distillation

Transfer knowledge from a large/slow teacher SNN to a small/fast student SNN. Temporal-aware distillation matches per-timestep output distributions. Self-distillation uses extended timesteps as an implicit teacher.

Why Distillation for SNNs

SNN accuracy scales with timesteps T: more timesteps = more spikes = better integration = higher accuracy. But hardware deployment needs small T for latency. Distillation transfers the accuracy of T=32 into a T=4 model.

Temporal Distillation Loss

import numpy as np
from sc_neurocore.distillation import TemporalDistillationLoss

loss_fn = TemporalDistillationLoss(
    temperature=3.0,
    alpha=0.5,
    entropy_weight=0.1,
)

teacher_logits = np.random.randn(32, 10)  # T=32, 10 classes
student_logits = np.random.randn(4, 10)   # T=4, 10 classes
targets = np.zeros(10); targets[3] = 1.0

result = loss_fn.compute(student_logits, teacher_logits, targets)
print(f"Total: {result['total_loss']:.3f}, "
      f"Distill: {result['distill_loss']:.3f}, "
      f"Task: {result['task_loss']:.3f}")

Self-Distillation

No separate teacher model needed. Run the same model at more timesteps to generate soft targets:

from sc_neurocore.distillation import SelfDistiller

distiller = SelfDistiller(T_teacher=32, T_student=8, temperature=3.0)

def run_model(inputs, T):
    return np.random.randn(10)  # your SNN forward pass

soft_targets = distiller.generate_targets(run_model, inputs=np.zeros(784))

API Reference

sc_neurocore.distillation.distill

Knowledge distillation for SNNs with temporal spike alignment.

No SNN library ships distillation utilities. Every paper has its own script.

Reference: CVPR 2025 — temporal separation + entropy regularization

TemporalDistillationLoss

Temporal-aware distillation loss for SNN→SNN or ANN→SNN transfer.

Matches per-timestep output distributions from teacher to student, with entropy regularization to prevent learning erroneous knowledge.

Parameters

temperature : float Softmax temperature for logit matching. alpha : float Weight of distillation loss vs task loss (0=task only, 1=distill only). entropy_weight : float Entropy regularization strength.

Source code in src/sc_neurocore/distillation/distill.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class TemporalDistillationLoss:
    """Temporal-aware distillation loss for SNN→SNN or ANN→SNN transfer.

    Matches per-timestep output distributions from teacher to student,
    with entropy regularization to prevent learning erroneous knowledge.

    Parameters
    ----------
    temperature : float
        Softmax temperature for logit matching.
    alpha : float
        Weight of distillation loss vs task loss (0=task only, 1=distill only).
    entropy_weight : float
        Entropy regularization strength.
    """

    def __init__(self, temperature: float = 3.0, alpha: float = 0.5, entropy_weight: float = 0.1):
        self.temperature = temperature
        self.alpha = alpha
        self.entropy_weight = entropy_weight

    def compute(
        self,
        student_logits: np.ndarray,
        teacher_logits: np.ndarray,
        targets: np.ndarray | None = None,
    ) -> dict:
        """Compute distillation loss.

        Parameters
        ----------
        student_logits : ndarray of shape (T, N_classes) or (N_classes,)
            Student output per timestep.
        teacher_logits : ndarray of shape (T, N_classes) or (N_classes,)
        targets : ndarray of shape (N_classes,), optional
            Ground truth for task loss.

        Returns
        -------
        dict with 'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'
        """
        # Soften logits
        s_soft = self._softmax(student_logits / self.temperature)
        t_soft = self._softmax(teacher_logits / self.temperature)

        # KL divergence: sum(t * log(t/s))
        kl = np.sum(t_soft * np.log(np.clip(t_soft / np.clip(s_soft, 1e-10, None), 1e-10, None)))
        distill_loss = float(kl * self.temperature**2)

        # Entropy regularization
        entropy = -float(np.sum(s_soft * np.log(np.clip(s_soft, 1e-10, None))))
        entropy_loss = -self.entropy_weight * entropy

        # Task loss (cross-entropy with targets)
        task_loss = 0.0
        if targets is not None:
            s_logits = student_logits if student_logits.ndim == 1 else student_logits.mean(axis=0)
            s_prob = self._softmax(s_logits)
            task_loss = -float(np.sum(targets * np.log(np.clip(s_prob, 1e-10, None))))

        total = self.alpha * distill_loss + (1 - self.alpha) * task_loss + entropy_loss

        return {
            "total_loss": total,
            "distill_loss": distill_loss,
            "task_loss": task_loss,
            "entropy_loss": entropy_loss,
        }

    @staticmethod
    def _softmax(x: np.ndarray) -> np.ndarray:
        if x.ndim > 1:
            x = x.mean(axis=0)
        e = np.exp(x - x.max())
        return e / e.sum()

compute(student_logits, teacher_logits, targets=None)

Compute distillation loss.

Parameters

student_logits : ndarray of shape (T, N_classes) or (N_classes,) Student output per timestep. teacher_logits : ndarray of shape (T, N_classes) or (N_classes,) targets : ndarray of shape (N_classes,), optional Ground truth for task loss.

Returns

dict with 'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'

Source code in src/sc_neurocore/distillation/distill.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def compute(
    self,
    student_logits: np.ndarray,
    teacher_logits: np.ndarray,
    targets: np.ndarray | None = None,
) -> dict:
    """Compute distillation loss.

    Parameters
    ----------
    student_logits : ndarray of shape (T, N_classes) or (N_classes,)
        Student output per timestep.
    teacher_logits : ndarray of shape (T, N_classes) or (N_classes,)
    targets : ndarray of shape (N_classes,), optional
        Ground truth for task loss.

    Returns
    -------
    dict with 'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'
    """
    # Soften logits
    s_soft = self._softmax(student_logits / self.temperature)
    t_soft = self._softmax(teacher_logits / self.temperature)

    # KL divergence: sum(t * log(t/s))
    kl = np.sum(t_soft * np.log(np.clip(t_soft / np.clip(s_soft, 1e-10, None), 1e-10, None)))
    distill_loss = float(kl * self.temperature**2)

    # Entropy regularization
    entropy = -float(np.sum(s_soft * np.log(np.clip(s_soft, 1e-10, None))))
    entropy_loss = -self.entropy_weight * entropy

    # Task loss (cross-entropy with targets)
    task_loss = 0.0
    if targets is not None:
        s_logits = student_logits if student_logits.ndim == 1 else student_logits.mean(axis=0)
        s_prob = self._softmax(s_logits)
        task_loss = -float(np.sum(targets * np.log(np.clip(s_prob, 1e-10, None))))

    total = self.alpha * distill_loss + (1 - self.alpha) * task_loss + entropy_loss

    return {
        "total_loss": total,
        "distill_loss": distill_loss,
        "task_loss": task_loss,
        "entropy_loss": entropy_loss,
    }

SelfDistiller dataclass

Self-distillation: use extended-T model as implicit teacher.

Run the same model at T_teacher timesteps (more accurate) to generate soft targets for training at T_student timesteps (faster).

Parameters

T_teacher : int Timesteps for teacher forward pass. T_student : int Timesteps for student forward pass. temperature : float

Source code in src/sc_neurocore/distillation/distill.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@dataclass
class SelfDistiller:
    """Self-distillation: use extended-T model as implicit teacher.

    Run the same model at T_teacher timesteps (more accurate) to
    generate soft targets for training at T_student timesteps (faster).

    Parameters
    ----------
    T_teacher : int
        Timesteps for teacher forward pass.
    T_student : int
        Timesteps for student forward pass.
    temperature : float
    """

    T_teacher: int = 32
    T_student: int = 8
    temperature: float = 3.0

    def generate_targets(self, run_fn, inputs: np.ndarray) -> np.ndarray:
        """Run model at T_teacher steps to generate soft targets.

        Parameters
        ----------
        run_fn : callable(inputs, T) -> logits
        inputs : ndarray

        Returns
        -------
        ndarray — soft targets from teacher
        """
        teacher_logits = run_fn(inputs, self.T_teacher)
        return self._softmax(teacher_logits / self.temperature)

    @staticmethod
    def _softmax(x: np.ndarray) -> np.ndarray:
        e = np.exp(x - x.max())
        return e / e.sum()

generate_targets(run_fn, inputs)

Run model at T_teacher steps to generate soft targets.

Parameters

run_fn : callable(inputs, T) -> logits inputs : ndarray

Returns

ndarray — soft targets from teacher

Source code in src/sc_neurocore/distillation/distill.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def generate_targets(self, run_fn, inputs: np.ndarray) -> np.ndarray:
    """Run model at T_teacher steps to generate soft targets.

    Parameters
    ----------
    run_fn : callable(inputs, T) -> logits
    inputs : ndarray

    Returns
    -------
    ndarray — soft targets from teacher
    """
    teacher_logits = run_fn(inputs, self.T_teacher)
    return self._softmax(teacher_logits / self.temperature)