Skip to content

Few-Shot Meta-Learning — Hebbian Associative Memory

Learn from 1-5 examples using spike-timing plasticity instead of gradient descent. Two approaches: Hebbian weight storage and prototypical network classification.

HebbianFewShot — Associative Memory

Support patterns stored via one-shot Hebbian update: memory[label] += lr * pattern. Queries classified by cosine similarity to stored memories. The few_shot_episode() method handles the full N-way K-shot protocol: reset → store support set → classify query set.

Parameter Default Meaning
n_features (required) Input feature dimension
n_classes (required) Number of classes
lr_hebbian 0.1 Hebbian learning rate for storage

Accepts spike-rate vectors (n_features,) or raw spike trains (T, n_features) — automatically averaged over time.

SpikePrototypeNet — Prototypical Network

Computes class prototypes as mean spike-rate vectors from the support set. Classifies queries by nearest prototype using cosine or Euclidean distance. Stateless — no internal weights to maintain.

Parameter Default Meaning
n_features (required) Feature dimension
metric "cosine" Distance metric: "cosine" or "euclidean"

Usage

Python
from sc_neurocore.few_shot import HebbianFewShot, SpikePrototypeNet
import numpy as np

# 5-way 1-shot with Hebbian memory
learner = HebbianFewShot(n_features=64, n_classes=5)
support_x = [np.random.rand(64) for _ in range(5)]
support_y = [0, 1, 2, 3, 4]
query_x = [np.random.rand(64) for _ in range(10)]
predictions = learner.few_shot_episode(support_x, support_y, query_x)

# Prototypical network (no training needed)
proto = SpikePrototypeNet(n_features=64, metric="cosine")
predictions = proto.classify(support_x, support_y, query_x)

Reference: HAAM (BICS 2024).

See Tutorial 84: Few-Shot Meta-Learning.

sc_neurocore.few_shot.haam

Hebbian-Augmented Associative Memory for few-shot SNN learning.

Learn from 1-5 examples using spike-timing plasticity, not gradients. Store support examples as spike patterns, retrieve via cosine similarity.

Reference: HAAM (BICS 2024)

HebbianFewShot

Hebbian few-shot learner using associative memory.

Support examples stored via one-shot Hebbian weight update. Query classified by comparing spike-rate representation to stored prototypes.

Parameters

n_features : int Input feature dimension. n_classes : int Number of classes to support. lr_hebbian : float Hebbian learning rate for support storage.

Source code in src/sc_neurocore/few_shot/haam.py
Python
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class HebbianFewShot:
    """Hebbian few-shot learner using associative memory.

    Support examples stored via one-shot Hebbian weight update.
    Query classified by comparing spike-rate representation to
    stored prototypes.

    Parameters
    ----------
    n_features : int
        Input feature dimension.
    n_classes : int
        Number of classes to support.
    lr_hebbian : float
        Hebbian learning rate for support storage.
    """

    def __init__(self, n_features: int, n_classes: int, lr_hebbian: float = 0.1):
        self.n_features = n_features
        self.n_classes = n_classes
        self.lr_hebbian = lr_hebbian
        # Associative memory: one weight vector per class
        self.memory = np.zeros((n_classes, n_features))
        self._counts = np.zeros(n_classes, dtype=int)

    def store(self, spike_pattern: np.ndarray, label: int) -> None:
        """Store one support example via Hebbian update.

        Parameters
        ----------
        spike_pattern : ndarray of shape (n_features,) or (T, n_features)
            Spike pattern or spike rate vector.
        label : int
            Class label.
        """
        if spike_pattern.ndim > 1:
            pattern = spike_pattern.mean(axis=0)
        else:
            pattern = spike_pattern.astype(np.float64)

        # Hebbian update: strengthen connections for this class
        self.memory[label] += self.lr_hebbian * pattern
        self._counts[label] += 1

    def query(self, spike_pattern: np.ndarray) -> int:
        """Classify a query pattern by cosine similarity to stored memories.

        Parameters
        ----------
        spike_pattern : ndarray of shape (n_features,) or (T, n_features)

        Returns
        -------
        int — predicted class
        """
        if spike_pattern.ndim > 1:
            pattern = spike_pattern.mean(axis=0)
        else:
            pattern = spike_pattern.astype(np.float64)

        similarities = np.zeros(self.n_classes)
        for c in range(self.n_classes):
            if self._counts[c] == 0:
                continue
            mem_norm = np.linalg.norm(self.memory[c])
            pat_norm = np.linalg.norm(pattern)
            if mem_norm > 1e-10 and pat_norm > 1e-10:
                similarities[c] = np.dot(self.memory[c], pattern) / (mem_norm * pat_norm)

        return int(np.argmax(similarities))

    def few_shot_episode(
        self,
        support_x: list[np.ndarray],
        support_y: list[int],
        query_x: list[np.ndarray],
    ) -> list[int]:
        """Run a complete few-shot episode.

        Parameters
        ----------
        support_x : list of ndarray
            Support set spike patterns.
        support_y : list of int
            Support set labels.
        query_x : list of ndarray
            Query set spike patterns.

        Returns
        -------
        list of int — predicted labels for query set
        """
        self.reset()
        for pattern, label in zip(support_x, support_y):
            self.store(pattern, label)
        return [self.query(q) for q in query_x]

    def reset(self) -> None:
        self.memory[:] = 0
        self._counts[:] = 0

store(spike_pattern, label)

Store one support example via Hebbian update.

Parameters

spike_pattern : ndarray of shape (n_features,) or (T, n_features) Spike pattern or spike rate vector. label : int Class label.

Source code in src/sc_neurocore/few_shot/haam.py
Python
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def store(self, spike_pattern: np.ndarray, label: int) -> None:
    """Store one support example via Hebbian update.

    Parameters
    ----------
    spike_pattern : ndarray of shape (n_features,) or (T, n_features)
        Spike pattern or spike rate vector.
    label : int
        Class label.
    """
    if spike_pattern.ndim > 1:
        pattern = spike_pattern.mean(axis=0)
    else:
        pattern = spike_pattern.astype(np.float64)

    # Hebbian update: strengthen connections for this class
    self.memory[label] += self.lr_hebbian * pattern
    self._counts[label] += 1

query(spike_pattern)

Classify a query pattern by cosine similarity to stored memories.

Parameters

spike_pattern : ndarray of shape (n_features,) or (T, n_features)

Returns

int — predicted class

Source code in src/sc_neurocore/few_shot/haam.py
Python
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def query(self, spike_pattern: np.ndarray) -> int:
    """Classify a query pattern by cosine similarity to stored memories.

    Parameters
    ----------
    spike_pattern : ndarray of shape (n_features,) or (T, n_features)

    Returns
    -------
    int — predicted class
    """
    if spike_pattern.ndim > 1:
        pattern = spike_pattern.mean(axis=0)
    else:
        pattern = spike_pattern.astype(np.float64)

    similarities = np.zeros(self.n_classes)
    for c in range(self.n_classes):
        if self._counts[c] == 0:
            continue
        mem_norm = np.linalg.norm(self.memory[c])
        pat_norm = np.linalg.norm(pattern)
        if mem_norm > 1e-10 and pat_norm > 1e-10:
            similarities[c] = np.dot(self.memory[c], pattern) / (mem_norm * pat_norm)

    return int(np.argmax(similarities))

few_shot_episode(support_x, support_y, query_x)

Run a complete few-shot episode.

Parameters

support_x : list of ndarray Support set spike patterns. support_y : list of int Support set labels. query_x : list of ndarray Query set spike patterns.

Returns

list of int — predicted labels for query set

Source code in src/sc_neurocore/few_shot/haam.py
Python
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def few_shot_episode(
    self,
    support_x: list[np.ndarray],
    support_y: list[int],
    query_x: list[np.ndarray],
) -> list[int]:
    """Run a complete few-shot episode.

    Parameters
    ----------
    support_x : list of ndarray
        Support set spike patterns.
    support_y : list of int
        Support set labels.
    query_x : list of ndarray
        Query set spike patterns.

    Returns
    -------
    list of int — predicted labels for query set
    """
    self.reset()
    for pattern, label in zip(support_x, support_y):
        self.store(pattern, label)
    return [self.query(q) for q in query_x]

SpikePrototypeNet dataclass

Prototypical network in spike domain.

Compute class prototypes as mean spike-rate vectors from support set. Classify queries by nearest prototype (Euclidean or cosine).

Parameters

n_features : int metric : str 'cosine' or 'euclidean'

Source code in src/sc_neurocore/few_shot/haam.py
Python
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@dataclass
class SpikePrototypeNet:
    """Prototypical network in spike domain.

    Compute class prototypes as mean spike-rate vectors from support set.
    Classify queries by nearest prototype (Euclidean or cosine).

    Parameters
    ----------
    n_features : int
    metric : str
        'cosine' or 'euclidean'
    """

    n_features: int
    metric: str = "cosine"

    def classify(
        self,
        support_x: list[np.ndarray],
        support_y: list[int],
        query_x: list[np.ndarray],
    ) -> list[int]:
        """Classify query set using support set prototypes.

        Parameters
        ----------
        support_x : list of ndarray, shape (n_features,) or (T, n_features)
        support_y : list of int
        query_x : list of ndarray

        Returns
        -------
        list of int
        """
        # Compute prototypes
        classes = sorted(set(support_y))
        prototypes = {}
        for c in classes:
            patterns = [
                s.mean(axis=0) if s.ndim > 1 else s.astype(np.float64)
                for s, y in zip(support_x, support_y)
                if y == c
            ]
            prototypes[c] = np.mean(patterns, axis=0)

        # Classify queries
        predictions = []
        for q in query_x:
            qv = q.mean(axis=0) if q.ndim > 1 else q.astype(np.float64)
            best_c = classes[0]
            best_score = -float("inf")
            for c, proto in prototypes.items():
                if self.metric == "cosine":
                    n1, n2 = np.linalg.norm(qv), np.linalg.norm(proto)
                    score = np.dot(qv, proto) / max(n1 * n2, 1e-10)
                else:
                    score = -np.linalg.norm(qv - proto)
                if score > best_score:
                    best_score = score
                    best_c = c
            predictions.append(best_c)

        return predictions

classify(support_x, support_y, query_x)

Classify query set using support set prototypes.

Parameters

support_x : list of ndarray, shape (n_features,) or (T, n_features) support_y : list of int query_x : list of ndarray

Returns

list of int

Source code in src/sc_neurocore/few_shot/haam.py
Python
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def classify(
    self,
    support_x: list[np.ndarray],
    support_y: list[int],
    query_x: list[np.ndarray],
) -> list[int]:
    """Classify query set using support set prototypes.

    Parameters
    ----------
    support_x : list of ndarray, shape (n_features,) or (T, n_features)
    support_y : list of int
    query_x : list of ndarray

    Returns
    -------
    list of int
    """
    # Compute prototypes
    classes = sorted(set(support_y))
    prototypes = {}
    for c in classes:
        patterns = [
            s.mean(axis=0) if s.ndim > 1 else s.astype(np.float64)
            for s, y in zip(support_x, support_y)
            if y == c
        ]
        prototypes[c] = np.mean(patterns, axis=0)

    # Classify queries
    predictions = []
    for q in query_x:
        qv = q.mean(axis=0) if q.ndim > 1 else q.astype(np.float64)
        best_c = classes[0]
        best_score = -float("inf")
        for c, proto in prototypes.items():
            if self.metric == "cosine":
                n1, n2 = np.linalg.norm(qv), np.linalg.norm(proto)
                score = np.dot(qv, proto) / max(n1 * n2, 1e-10)
            else:
                score = -np.linalg.norm(qv - proto)
            if score > best_score:
                best_score = score
                best_c = c
        predictions.append(best_c)

    return predictions