Skip to content

Few-Shot Meta-Learning — Hebbian Associative Memory

Learn from 1-5 examples using spike-timing plasticity instead of gradient descent. Two approaches: Hebbian weight storage and prototypical network classification.

HebbianFewShot — Associative Memory

Support patterns stored via one-shot Hebbian update: memory[label] += lr * pattern. Queries classified by cosine similarity to stored memories. The few_shot_episode() method handles the full N-way K-shot protocol: reset → store support set → classify query set.

Parameter Default Meaning
n_features (required) Input feature dimension
n_classes (required) Number of classes
lr_hebbian 0.1 Hebbian learning rate for storage

Accepts spike-rate vectors (n_features,) or raw spike trains (T, n_features) — automatically averaged over time.

SpikePrototypeNet — Prototypical Network

Computes class prototypes as mean spike-rate vectors from the support set. Classifies queries by nearest prototype using cosine or Euclidean distance. Stateless — no internal weights to maintain.

Parameter Default Meaning
n_features (required) Feature dimension
metric "cosine" Distance metric: "cosine" or "euclidean"

Usage

from sc_neurocore.few_shot import HebbianFewShot, SpikePrototypeNet
import numpy as np

# 5-way 1-shot with Hebbian memory
learner = HebbianFewShot(n_features=64, n_classes=5)
support_x = [np.random.rand(64) for _ in range(5)]
support_y = [0, 1, 2, 3, 4]
query_x = [np.random.rand(64) for _ in range(10)]
predictions = learner.few_shot_episode(support_x, support_y, query_x)

# Prototypical network (no training needed)
proto = SpikePrototypeNet(n_features=64, metric="cosine")
predictions = proto.classify(support_x, support_y, query_x)

Reference: HAAM (BICS 2024).

See Tutorial 84: Few-Shot Meta-Learning.

sc_neurocore.few_shot.haam

Hebbian-Augmented Associative Memory for few-shot SNN learning.

Learn from 1-5 examples using spike-timing plasticity, not gradients. Store support examples as spike patterns, retrieve via cosine similarity.

Reference: HAAM (BICS 2024)

HebbianFewShot

Hebbian few-shot learner using associative memory.

Support examples stored via one-shot Hebbian weight update. Query classified by comparing spike-rate representation to stored prototypes.

Parameters

n_features : int Input feature dimension. n_classes : int Number of classes to support. lr_hebbian : float Hebbian learning rate for support storage.

Source code in src/sc_neurocore/few_shot/haam.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class HebbianFewShot:
    """Hebbian few-shot learner using associative memory.

    Support examples stored via one-shot Hebbian weight update.
    Query classified by comparing spike-rate representation to
    stored prototypes.

    Parameters
    ----------
    n_features : int
        Input feature dimension.
    n_classes : int
        Number of classes to support.
    lr_hebbian : float
        Hebbian learning rate for support storage.
    """

    def __init__(self, n_features: int, n_classes: int, lr_hebbian: float = 0.1):
        self.n_features = n_features
        self.n_classes = n_classes
        self.lr_hebbian = lr_hebbian
        # Associative memory: one weight vector per class
        self.memory = np.zeros((n_classes, n_features))
        self._counts = np.zeros(n_classes, dtype=int)

    def store(self, spike_pattern: np.ndarray, label: int):
        """Store one support example via Hebbian update.

        Parameters
        ----------
        spike_pattern : ndarray of shape (n_features,) or (T, n_features)
            Spike pattern or spike rate vector.
        label : int
            Class label.
        """
        if spike_pattern.ndim > 1:
            pattern = spike_pattern.mean(axis=0)
        else:
            pattern = spike_pattern.astype(np.float64)

        # Hebbian update: strengthen connections for this class
        self.memory[label] += self.lr_hebbian * pattern
        self._counts[label] += 1

    def query(self, spike_pattern: np.ndarray) -> int:
        """Classify a query pattern by cosine similarity to stored memories.

        Parameters
        ----------
        spike_pattern : ndarray of shape (n_features,) or (T, n_features)

        Returns
        -------
        int — predicted class
        """
        if spike_pattern.ndim > 1:
            pattern = spike_pattern.mean(axis=0)
        else:
            pattern = spike_pattern.astype(np.float64)

        similarities = np.zeros(self.n_classes)
        for c in range(self.n_classes):
            if self._counts[c] == 0:
                continue
            mem_norm = np.linalg.norm(self.memory[c])
            pat_norm = np.linalg.norm(pattern)
            if mem_norm > 1e-10 and pat_norm > 1e-10:
                similarities[c] = np.dot(self.memory[c], pattern) / (mem_norm * pat_norm)

        return int(np.argmax(similarities))

    def few_shot_episode(
        self,
        support_x: list[np.ndarray],
        support_y: list[int],
        query_x: list[np.ndarray],
    ) -> list[int]:
        """Run a complete few-shot episode.

        Parameters
        ----------
        support_x : list of ndarray
            Support set spike patterns.
        support_y : list of int
            Support set labels.
        query_x : list of ndarray
            Query set spike patterns.

        Returns
        -------
        list of int — predicted labels for query set
        """
        self.reset()
        for pattern, label in zip(support_x, support_y):
            self.store(pattern, label)
        return [self.query(q) for q in query_x]

    def reset(self):
        self.memory[:] = 0
        self._counts[:] = 0

store(spike_pattern, label)

Store one support example via Hebbian update.

Parameters

spike_pattern : ndarray of shape (n_features,) or (T, n_features) Spike pattern or spike rate vector. label : int Class label.

Source code in src/sc_neurocore/few_shot/haam.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def store(self, spike_pattern: np.ndarray, label: int):
    """Store one support example via Hebbian update.

    Parameters
    ----------
    spike_pattern : ndarray of shape (n_features,) or (T, n_features)
        Spike pattern or spike rate vector.
    label : int
        Class label.
    """
    if spike_pattern.ndim > 1:
        pattern = spike_pattern.mean(axis=0)
    else:
        pattern = spike_pattern.astype(np.float64)

    # Hebbian update: strengthen connections for this class
    self.memory[label] += self.lr_hebbian * pattern
    self._counts[label] += 1

query(spike_pattern)

Classify a query pattern by cosine similarity to stored memories.

Parameters

spike_pattern : ndarray of shape (n_features,) or (T, n_features)

Returns

int — predicted class

Source code in src/sc_neurocore/few_shot/haam.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def query(self, spike_pattern: np.ndarray) -> int:
    """Classify a query pattern by cosine similarity to stored memories.

    Parameters
    ----------
    spike_pattern : ndarray of shape (n_features,) or (T, n_features)

    Returns
    -------
    int — predicted class
    """
    if spike_pattern.ndim > 1:
        pattern = spike_pattern.mean(axis=0)
    else:
        pattern = spike_pattern.astype(np.float64)

    similarities = np.zeros(self.n_classes)
    for c in range(self.n_classes):
        if self._counts[c] == 0:
            continue
        mem_norm = np.linalg.norm(self.memory[c])
        pat_norm = np.linalg.norm(pattern)
        if mem_norm > 1e-10 and pat_norm > 1e-10:
            similarities[c] = np.dot(self.memory[c], pattern) / (mem_norm * pat_norm)

    return int(np.argmax(similarities))

few_shot_episode(support_x, support_y, query_x)

Run a complete few-shot episode.

Parameters

support_x : list of ndarray Support set spike patterns. support_y : list of int Support set labels. query_x : list of ndarray Query set spike patterns.

Returns

list of int — predicted labels for query set

Source code in src/sc_neurocore/few_shot/haam.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def few_shot_episode(
    self,
    support_x: list[np.ndarray],
    support_y: list[int],
    query_x: list[np.ndarray],
) -> list[int]:
    """Run a complete few-shot episode.

    Parameters
    ----------
    support_x : list of ndarray
        Support set spike patterns.
    support_y : list of int
        Support set labels.
    query_x : list of ndarray
        Query set spike patterns.

    Returns
    -------
    list of int — predicted labels for query set
    """
    self.reset()
    for pattern, label in zip(support_x, support_y):
        self.store(pattern, label)
    return [self.query(q) for q in query_x]

SpikePrototypeNet dataclass

Prototypical network in spike domain.

Compute class prototypes as mean spike-rate vectors from support set. Classify queries by nearest prototype (Euclidean or cosine).

Parameters

n_features : int metric : str 'cosine' or 'euclidean'

Source code in src/sc_neurocore/few_shot/haam.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@dataclass
class SpikePrototypeNet:
    """Prototypical network in spike domain.

    Compute class prototypes as mean spike-rate vectors from support set.
    Classify queries by nearest prototype (Euclidean or cosine).

    Parameters
    ----------
    n_features : int
    metric : str
        'cosine' or 'euclidean'
    """

    n_features: int
    metric: str = "cosine"

    def classify(
        self,
        support_x: list[np.ndarray],
        support_y: list[int],
        query_x: list[np.ndarray],
    ) -> list[int]:
        """Classify query set using support set prototypes.

        Parameters
        ----------
        support_x : list of ndarray, shape (n_features,) or (T, n_features)
        support_y : list of int
        query_x : list of ndarray

        Returns
        -------
        list of int
        """
        # Compute prototypes
        classes = sorted(set(support_y))
        prototypes = {}
        for c in classes:
            patterns = [
                s.mean(axis=0) if s.ndim > 1 else s.astype(np.float64)
                for s, y in zip(support_x, support_y)
                if y == c
            ]
            prototypes[c] = np.mean(patterns, axis=0)

        # Classify queries
        predictions = []
        for q in query_x:
            qv = q.mean(axis=0) if q.ndim > 1 else q.astype(np.float64)
            best_c = classes[0]
            best_score = -float("inf")
            for c, proto in prototypes.items():
                if self.metric == "cosine":
                    n1, n2 = np.linalg.norm(qv), np.linalg.norm(proto)
                    score = np.dot(qv, proto) / max(n1 * n2, 1e-10)
                else:
                    score = -np.linalg.norm(qv - proto)
                if score > best_score:
                    best_score = score
                    best_c = c
            predictions.append(best_c)

        return predictions

classify(support_x, support_y, query_x)

Classify query set using support set prototypes.

Parameters

support_x : list of ndarray, shape (n_features,) or (T, n_features) support_y : list of int query_x : list of ndarray

Returns

list of int

Source code in src/sc_neurocore/few_shot/haam.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def classify(
    self,
    support_x: list[np.ndarray],
    support_y: list[int],
    query_x: list[np.ndarray],
) -> list[int]:
    """Classify query set using support set prototypes.

    Parameters
    ----------
    support_x : list of ndarray, shape (n_features,) or (T, n_features)
    support_y : list of int
    query_x : list of ndarray

    Returns
    -------
    list of int
    """
    # Compute prototypes
    classes = sorted(set(support_y))
    prototypes = {}
    for c in classes:
        patterns = [
            s.mean(axis=0) if s.ndim > 1 else s.astype(np.float64)
            for s, y in zip(support_x, support_y)
            if y == c
        ]
        prototypes[c] = np.mean(patterns, axis=0)

    # Classify queries
    predictions = []
    for q in query_x:
        qv = q.mean(axis=0) if q.ndim > 1 else q.astype(np.float64)
        best_c = classes[0]
        best_score = -float("inf")
        for c, proto in prototypes.items():
            if self.metric == "cosine":
                n1, n2 = np.linalg.norm(qv), np.linalg.norm(proto)
                score = np.dot(qv, proto) / max(n1 * n2, 1e-10)
            else:
                score = -np.linalg.norm(qv - proto)
            if score > best_score:
                best_score = score
                best_c = c
        predictions.append(best_c)

    return predictions