Skip to content

Tutorial 81: SNN Transfer Learning

Pretrain, save, load, freeze, fine-tune. The standard ML workflow adapted for spiking neural networks.

Save and Load Checkpoints

from sc_neurocore.transfer import (
    save_checkpoint, load_checkpoint, SNNCheckpoint, TransferConfig,
)

# Save trained model
ckpt = SNNCheckpoint(
    weights=model_weights,
    layer_names=["h1", "out"],
    layer_sizes=[(784, 256), (256, 10)],
)
save_checkpoint(ckpt, "mnist_snn")

# Load checkpoint
ckpt = load_checkpoint("mnist_snn")

Freeze and Fine-Tune

from sc_neurocore.transfer.fine_tune import apply_transfer_config

# Freeze all layers except the last (readout)
config = TransferConfig(freeze_until=0, lr_head=0.001)
ckpt, per_layer_lr = apply_transfer_config(ckpt, config)

# per_layer_lr: [0.0, 0.001] — first layer frozen, second trains

Transfer Learning Workflow

  1. Pretrain on large dataset (e.g., MNIST 60K examples)
  2. Save checkpoint with save_checkpoint()
  3. Load on new task with load_checkpoint()
  4. Freeze feature extraction layers
  5. Fine-tune readout layer on new task (few examples suffice)

SNN transfer learning preserves learned temporal dynamics — spike timing patterns transfer across tasks, not just weight magnitudes.

API Reference

sc_neurocore.transfer

Save, load, freeze, fine-tune SNN models. The foundation of modern ML.

SNNCheckpoint dataclass

Complete SNN model checkpoint.

Source code in src/sc_neurocore/transfer/checkpoint.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@dataclass
class SNNCheckpoint:
    """Complete SNN model checkpoint."""

    weights: list[np.ndarray]
    layer_names: list[str]
    layer_sizes: list[tuple[int, int]]
    neuron_types: list[str] = field(default_factory=list)
    metadata: dict = field(default_factory=dict)
    frozen_layers: list[str] = field(default_factory=list)

    @property
    def n_layers(self) -> int:
        return len(self.weights)

    @property
    def total_params(self) -> int:
        return sum(w.size for w in self.weights)

TransferConfig dataclass

Configuration for transfer learning.

Parameters

freeze_until : str or int Freeze all layers up to (and including) this layer name or index. lr_backbone : float Learning rate for frozen layers (usually 0 or very small). lr_head : float Learning rate for unfrozen layers.

Source code in src/sc_neurocore/transfer/fine_tune.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@dataclass
class TransferConfig:
    """Configuration for transfer learning.

    Parameters
    ----------
    freeze_until : str or int
        Freeze all layers up to (and including) this layer name or index.
    lr_backbone : float
        Learning rate for frozen layers (usually 0 or very small).
    lr_head : float
        Learning rate for unfrozen layers.
    """

    freeze_until: str | int = -1
    lr_backbone: float = 0.0
    lr_head: float = 0.01

save_checkpoint(checkpoint, path)

Save SNN checkpoint to .npz + .json.

Parameters

checkpoint : SNNCheckpoint path : str or Path Base path (without extension). Creates path.npz and path.json.

Source code in src/sc_neurocore/transfer/checkpoint.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def save_checkpoint(checkpoint: SNNCheckpoint, path: str | Path):
    """Save SNN checkpoint to .npz + .json.

    Parameters
    ----------
    checkpoint : SNNCheckpoint
    path : str or Path
        Base path (without extension). Creates path.npz and path.json.
    """
    path = Path(path)

    # Save weights
    weight_dict = {f"layer_{i}": w for i, w in enumerate(checkpoint.weights)}
    np.savez_compressed(str(path) + ".npz", **weight_dict)

    # Save metadata
    meta = {
        "layer_names": checkpoint.layer_names,
        "layer_sizes": checkpoint.layer_sizes,
        "neuron_types": checkpoint.neuron_types,
        "frozen_layers": checkpoint.frozen_layers,
        "n_layers": checkpoint.n_layers,
        "total_params": checkpoint.total_params,
        "metadata": checkpoint.metadata,
    }
    with open(str(path) + ".json", "w") as f:
        json.dump(meta, f, indent=2)

load_checkpoint(path)

Load SNN checkpoint from .npz + .json.

Parameters

path : str or Path Base path (without extension).

Returns

SNNCheckpoint

Source code in src/sc_neurocore/transfer/checkpoint.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def load_checkpoint(path: str | Path) -> SNNCheckpoint:
    """Load SNN checkpoint from .npz + .json.

    Parameters
    ----------
    path : str or Path
        Base path (without extension).

    Returns
    -------
    SNNCheckpoint
    """
    path = Path(path)

    # Load weights
    data = np.load(str(path) + ".npz")
    weights = [data[f"layer_{i}"] for i in range(len(data.files))]

    # Load metadata
    with open(str(path) + ".json") as f:
        meta = json.load(f)

    return SNNCheckpoint(
        weights=weights,
        layer_names=meta["layer_names"],
        layer_sizes=[tuple(s) for s in meta["layer_sizes"]],
        neuron_types=meta.get("neuron_types", []),
        metadata=meta.get("metadata", {}),
        frozen_layers=meta.get("frozen_layers", []),
    )

freeze_layers(checkpoint, layer_names=None, until_index=None)

Freeze layers (mark as non-trainable).

Parameters

checkpoint : SNNCheckpoint layer_names : list of str, optional Specific layers to freeze. until_index : int, optional Freeze all layers with index <= until_index.

Returns

SNNCheckpoint with frozen_layers updated

Source code in src/sc_neurocore/transfer/fine_tune.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def freeze_layers(
    checkpoint: SNNCheckpoint,
    layer_names: list[str] | None = None,
    until_index: int | None = None,
) -> SNNCheckpoint:
    """Freeze layers (mark as non-trainable).

    Parameters
    ----------
    checkpoint : SNNCheckpoint
    layer_names : list of str, optional
        Specific layers to freeze.
    until_index : int, optional
        Freeze all layers with index <= until_index.

    Returns
    -------
    SNNCheckpoint with frozen_layers updated
    """
    frozen = set(checkpoint.frozen_layers)

    if layer_names is not None:
        frozen.update(layer_names)

    if until_index is not None:
        for i, name in enumerate(checkpoint.layer_names):
            if i <= until_index:
                frozen.add(name)

    checkpoint.frozen_layers = sorted(frozen)
    return checkpoint

unfreeze_layers(checkpoint, layer_names=None, all_layers=False)

Unfreeze layers (mark as trainable).

Parameters

checkpoint : SNNCheckpoint layer_names : list of str, optional Specific layers to unfreeze. all_layers : bool If True, unfreeze everything.

Source code in src/sc_neurocore/transfer/fine_tune.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def unfreeze_layers(
    checkpoint: SNNCheckpoint,
    layer_names: list[str] | None = None,
    all_layers: bool = False,
) -> SNNCheckpoint:
    """Unfreeze layers (mark as trainable).

    Parameters
    ----------
    checkpoint : SNNCheckpoint
    layer_names : list of str, optional
        Specific layers to unfreeze.
    all_layers : bool
        If True, unfreeze everything.
    """
    if all_layers:
        checkpoint.frozen_layers = []
        return checkpoint

    if layer_names is not None:
        checkpoint.frozen_layers = [n for n in checkpoint.frozen_layers if n not in layer_names]

    return checkpoint