Skip to content

Knowledge Distillation — Temporal Spike Transfer

SNN-to-SNN and ANN-to-SNN knowledge transfer with temporal spike alignment. No SNN library ships distillation as a reusable module.

TemporalDistillationLoss

Matches per-timestep output distributions from teacher to student, with entropy regularization to prevent learning erroneous knowledge.

L_total = α * L_distill + (1-α) * L_task + L_entropy

Where:

  • L_distill = T² * KL(teacher_soft || student_soft) — temperature-softened KL divergence
  • L_task = CrossEntropy(student, targets) — standard task loss (optional)
  • L_entropy = -β * H(student_soft) — entropy regularization (prevents collapse)
Parameter Default Meaning
temperature 3.0 Softmax temperature for logit matching
alpha 0.5 Balance: 0=task only, 1=distill only
entropy_weight 0.1 Entropy regularization strength

Returns dict: {'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'}.

SelfDistiller — Implicit Teacher

Uses the same model at extended timesteps as implicit teacher. Run model at T_teacher steps (more accurate, slower) to generate soft targets for training at T_student steps (faster, less accurate).

Parameter Default Meaning
T_teacher 32 Timesteps for teacher pass
T_student 8 Timesteps for student pass
temperature 3.0 Softmax temperature

Usage

Python
from sc_neurocore.distillation import TemporalDistillationLoss, SelfDistiller
import numpy as np

# Teacher-student distillation
loss_fn = TemporalDistillationLoss(temperature=3.0, alpha=0.7)
result = loss_fn.compute(
    student_logits=student_output,
    teacher_logits=teacher_output,
    targets=one_hot_labels,  # optional
)
print(f"Total: {result['total_loss']:.4f}, Distill: {result['distill_loss']:.4f}")

# Self-distillation
sd = SelfDistiller(T_teacher=32, T_student=8)
soft_targets = sd.generate_targets(run_fn=model.forward, inputs=x)

Reference: CVPR 2025 — temporal separation + entropy regularization.

See Tutorial 76: Knowledge Distillation.

sc_neurocore.distillation.distill

Knowledge distillation for SNNs with temporal spike alignment.

No SNN library ships distillation utilities. Every paper has its own script.

Reference: CVPR 2025 — temporal separation + entropy regularization

TemporalDistillationLoss

Temporal-aware distillation loss for SNN→SNN or ANN→SNN transfer.

Matches per-timestep output distributions from teacher to student, with entropy regularization to prevent learning erroneous knowledge.

Parameters

temperature : float Softmax temperature for logit matching. alpha : float Weight of distillation loss vs task loss (0=task only, 1=distill only). entropy_weight : float Entropy regularization strength.

Source code in src/sc_neurocore/distillation/distill.py
Python
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class TemporalDistillationLoss:
    """Temporal-aware distillation loss for SNN→SNN or ANN→SNN transfer.

    Matches per-timestep output distributions from teacher to student,
    with entropy regularization to prevent learning erroneous knowledge.

    Parameters
    ----------
    temperature : float
        Softmax temperature for logit matching.
    alpha : float
        Weight of distillation loss vs task loss (0=task only, 1=distill only).
    entropy_weight : float
        Entropy regularization strength.
    """

    def __init__(self, temperature: float = 3.0, alpha: float = 0.5, entropy_weight: float = 0.1):
        self.temperature = temperature
        self.alpha = alpha
        self.entropy_weight = entropy_weight

    def compute(
        self,
        student_logits: np.ndarray[Any, Any],
        teacher_logits: np.ndarray[Any, Any],
        targets: np.ndarray[Any, Any] | None = None,
    ) -> dict[str, float]:
        """Compute distillation loss.

        Parameters
        ----------
        student_logits : ndarray of shape (T, N_classes) or (N_classes,)
            Student output per timestep.
        teacher_logits : ndarray of shape (T, N_classes) or (N_classes,)
        targets : ndarray of shape (N_classes,), optional
            Ground truth for task loss.

        Returns
        -------
        dict with 'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'
        """
        # Soften logits
        s_soft = self._softmax(student_logits / self.temperature)
        t_soft = self._softmax(teacher_logits / self.temperature)

        # KL divergence: sum(t * log(t/s))
        kl = np.sum(t_soft * np.log(np.clip(t_soft / np.clip(s_soft, 1e-10, None), 1e-10, None)))
        distill_loss = float(kl * self.temperature**2)

        # Entropy regularization
        entropy = -float(np.sum(s_soft * np.log(np.clip(s_soft, 1e-10, None))))
        entropy_loss = -self.entropy_weight * entropy

        # Task loss (cross-entropy with targets)
        task_loss = 0.0
        if targets is not None:
            s_logits = student_logits if student_logits.ndim == 1 else student_logits.mean(axis=0)
            s_prob = self._softmax(s_logits)
            task_loss = -float(np.sum(targets * np.log(np.clip(s_prob, 1e-10, None))))

        total = self.alpha * distill_loss + (1 - self.alpha) * task_loss + entropy_loss

        return {
            "total_loss": total,
            "distill_loss": distill_loss,
            "task_loss": task_loss,
            "entropy_loss": entropy_loss,
        }

    @staticmethod
    def _softmax(x: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
        if x.ndim > 1:
            x = x.mean(axis=0)
        e = np.exp(x - x.max())
        probs: np.ndarray[Any, Any] = e / e.sum()
        return probs

compute(student_logits, teacher_logits, targets=None)

Compute distillation loss.

Parameters

student_logits : ndarray of shape (T, N_classes) or (N_classes,) Student output per timestep. teacher_logits : ndarray of shape (T, N_classes) or (N_classes,) targets : ndarray of shape (N_classes,), optional Ground truth for task loss.

Returns

dict with 'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'

Source code in src/sc_neurocore/distillation/distill.py
Python
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def compute(
    self,
    student_logits: np.ndarray[Any, Any],
    teacher_logits: np.ndarray[Any, Any],
    targets: np.ndarray[Any, Any] | None = None,
) -> dict[str, float]:
    """Compute distillation loss.

    Parameters
    ----------
    student_logits : ndarray of shape (T, N_classes) or (N_classes,)
        Student output per timestep.
    teacher_logits : ndarray of shape (T, N_classes) or (N_classes,)
    targets : ndarray of shape (N_classes,), optional
        Ground truth for task loss.

    Returns
    -------
    dict with 'total_loss', 'distill_loss', 'task_loss', 'entropy_loss'
    """
    # Soften logits
    s_soft = self._softmax(student_logits / self.temperature)
    t_soft = self._softmax(teacher_logits / self.temperature)

    # KL divergence: sum(t * log(t/s))
    kl = np.sum(t_soft * np.log(np.clip(t_soft / np.clip(s_soft, 1e-10, None), 1e-10, None)))
    distill_loss = float(kl * self.temperature**2)

    # Entropy regularization
    entropy = -float(np.sum(s_soft * np.log(np.clip(s_soft, 1e-10, None))))
    entropy_loss = -self.entropy_weight * entropy

    # Task loss (cross-entropy with targets)
    task_loss = 0.0
    if targets is not None:
        s_logits = student_logits if student_logits.ndim == 1 else student_logits.mean(axis=0)
        s_prob = self._softmax(s_logits)
        task_loss = -float(np.sum(targets * np.log(np.clip(s_prob, 1e-10, None))))

    total = self.alpha * distill_loss + (1 - self.alpha) * task_loss + entropy_loss

    return {
        "total_loss": total,
        "distill_loss": distill_loss,
        "task_loss": task_loss,
        "entropy_loss": entropy_loss,
    }

SelfDistiller dataclass

Self-distillation: use extended-T model as implicit teacher.

Run the same model at T_teacher timesteps (more accurate) to generate soft targets for training at T_student timesteps (faster).

Parameters

T_teacher : int Timesteps for teacher forward pass. T_student : int Timesteps for student forward pass. temperature : float

Source code in src/sc_neurocore/distillation/distill.py
Python
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass
class SelfDistiller:
    """Self-distillation: use extended-T model as implicit teacher.

    Run the same model at T_teacher timesteps (more accurate) to
    generate soft targets for training at T_student timesteps (faster).

    Parameters
    ----------
    T_teacher : int
        Timesteps for teacher forward pass.
    T_student : int
        Timesteps for student forward pass.
    temperature : float
    """

    T_teacher: int = 32
    T_student: int = 8
    temperature: float = 3.0

    def generate_targets(
        self,
        run_fn: Callable[[np.ndarray[Any, Any], int], np.ndarray[Any, Any]],
        inputs: np.ndarray[Any, Any],
    ) -> np.ndarray[Any, Any]:
        """Run model at T_teacher steps to generate soft targets.

        Parameters
        ----------
        run_fn : callable(inputs, T) -> logits
        inputs : ndarray

        Returns
        -------
        ndarray — soft targets from teacher
        """
        teacher_logits = run_fn(inputs, self.T_teacher)
        return self._softmax(teacher_logits / self.temperature)

    @staticmethod
    def _softmax(x: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
        e = np.exp(x - x.max())
        probs: np.ndarray[Any, Any] = e / e.sum()
        return probs

generate_targets(run_fn, inputs)

Run model at T_teacher steps to generate soft targets.

Parameters

run_fn : callable(inputs, T) -> logits inputs : ndarray

Returns

ndarray — soft targets from teacher

Source code in src/sc_neurocore/distillation/distill.py
Python
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def generate_targets(
    self,
    run_fn: Callable[[np.ndarray[Any, Any], int], np.ndarray[Any, Any]],
    inputs: np.ndarray[Any, Any],
) -> np.ndarray[Any, Any]:
    """Run model at T_teacher steps to generate soft targets.

    Parameters
    ----------
    run_fn : callable(inputs, T) -> logits
    inputs : ndarray

    Returns
    -------
    ndarray — soft targets from teacher
    """
    teacher_logits = run_fn(inputs, self.T_teacher)
    return self._softmax(teacher_logits / self.temperature)