Learn from 1-5 examples using spike-timing plasticity, not gradients.
Two approaches: Hebbian Associative Memory (HAAM) and Spike Prototypical Networks.
HebbianFewShot (HAAM)
import numpy as np
from sc_neurocore.few_shot import HebbianFewShot
learner = HebbianFewShot(n_features=128, n_classes=5, lr_hebbian=0.1)
rng = np.random.RandomState(42)
# 5-way, 1-shot
for c in range(5):
pattern = rng.rand(128) * (c + 1) / 5
learner.store(pattern, label=c)
query = rng.rand(128) * 3 / 5
predicted = learner.query(query)
print(f"Predicted: {predicted}")
Few-Shot Episode
support_x = [rng.rand(128) for _ in range(10)]
support_y = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
query_x = [rng.rand(128) for _ in range(5)]
predictions = learner.few_shot_episode(support_x, support_y, query_x)
SpikePrototypeNet
Nearest-prototype classification in spike domain:
from sc_neurocore.few_shot import SpikePrototypeNet
proto_net = SpikePrototypeNet(n_features=128, metric="cosine")
predictions = proto_net.classify(support_x, support_y, query_x)
| Method |
Mechanism |
Hardware |
HebbianFewShot |
Hebbian weight update |
On-chip STDP |
SpikePrototypeNet |
Nearest prototype |
Hamming distance |
Both accept spike-rate vectors or raw spike trains (auto-averaged).
Reference: HAAM (BICS 2024)
API Reference
sc_neurocore.few_shot.haam
Hebbian-Augmented Associative Memory for few-shot SNN learning.
Learn from 1-5 examples using spike-timing plasticity, not gradients.
Store support examples as spike patterns, retrieve via cosine similarity.
Reference: HAAM (BICS 2024)
HebbianFewShot
Hebbian few-shot learner using associative memory.
Support examples stored via one-shot Hebbian weight update.
Query classified by comparing spike-rate representation to
stored prototypes.
Parameters
n_features : int
Input feature dimension.
n_classes : int
Number of classes to support.
lr_hebbian : float
Hebbian learning rate for support storage.
Source code in src/sc_neurocore/few_shot/haam.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122 | class HebbianFewShot:
"""Hebbian few-shot learner using associative memory.
Support examples stored via one-shot Hebbian weight update.
Query classified by comparing spike-rate representation to
stored prototypes.
Parameters
----------
n_features : int
Input feature dimension.
n_classes : int
Number of classes to support.
lr_hebbian : float
Hebbian learning rate for support storage.
"""
def __init__(self, n_features: int, n_classes: int, lr_hebbian: float = 0.1):
self.n_features = n_features
self.n_classes = n_classes
self.lr_hebbian = lr_hebbian
# Associative memory: one weight vector per class
self.memory = np.zeros((n_classes, n_features))
self._counts = np.zeros(n_classes, dtype=int)
def store(self, spike_pattern: np.ndarray, label: int):
"""Store one support example via Hebbian update.
Parameters
----------
spike_pattern : ndarray of shape (n_features,) or (T, n_features)
Spike pattern or spike rate vector.
label : int
Class label.
"""
if spike_pattern.ndim > 1:
pattern = spike_pattern.mean(axis=0)
else:
pattern = spike_pattern.astype(np.float64)
# Hebbian update: strengthen connections for this class
self.memory[label] += self.lr_hebbian * pattern
self._counts[label] += 1
def query(self, spike_pattern: np.ndarray) -> int:
"""Classify a query pattern by cosine similarity to stored memories.
Parameters
----------
spike_pattern : ndarray of shape (n_features,) or (T, n_features)
Returns
-------
int — predicted class
"""
if spike_pattern.ndim > 1:
pattern = spike_pattern.mean(axis=0)
else:
pattern = spike_pattern.astype(np.float64)
similarities = np.zeros(self.n_classes)
for c in range(self.n_classes):
if self._counts[c] == 0:
continue
mem_norm = np.linalg.norm(self.memory[c])
pat_norm = np.linalg.norm(pattern)
if mem_norm > 1e-10 and pat_norm > 1e-10:
similarities[c] = np.dot(self.memory[c], pattern) / (mem_norm * pat_norm)
return int(np.argmax(similarities))
def few_shot_episode(
self,
support_x: list[np.ndarray],
support_y: list[int],
query_x: list[np.ndarray],
) -> list[int]:
"""Run a complete few-shot episode.
Parameters
----------
support_x : list of ndarray
Support set spike patterns.
support_y : list of int
Support set labels.
query_x : list of ndarray
Query set spike patterns.
Returns
-------
list of int — predicted labels for query set
"""
self.reset()
for pattern, label in zip(support_x, support_y):
self.store(pattern, label)
return [self.query(q) for q in query_x]
def reset(self):
self.memory[:] = 0
self._counts[:] = 0
|
store(spike_pattern, label)
Store one support example via Hebbian update.
Parameters
spike_pattern : ndarray of shape (n_features,) or (T, n_features)
Spike pattern or spike rate vector.
label : int
Class label.
Source code in src/sc_neurocore/few_shot/haam.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65 | def store(self, spike_pattern: np.ndarray, label: int):
"""Store one support example via Hebbian update.
Parameters
----------
spike_pattern : ndarray of shape (n_features,) or (T, n_features)
Spike pattern or spike rate vector.
label : int
Class label.
"""
if spike_pattern.ndim > 1:
pattern = spike_pattern.mean(axis=0)
else:
pattern = spike_pattern.astype(np.float64)
# Hebbian update: strengthen connections for this class
self.memory[label] += self.lr_hebbian * pattern
self._counts[label] += 1
|
query(spike_pattern)
Classify a query pattern by cosine similarity to stored memories.
Parameters
spike_pattern : ndarray of shape (n_features,) or (T, n_features)
Returns
int — predicted class
Source code in src/sc_neurocore/few_shot/haam.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92 | def query(self, spike_pattern: np.ndarray) -> int:
"""Classify a query pattern by cosine similarity to stored memories.
Parameters
----------
spike_pattern : ndarray of shape (n_features,) or (T, n_features)
Returns
-------
int — predicted class
"""
if spike_pattern.ndim > 1:
pattern = spike_pattern.mean(axis=0)
else:
pattern = spike_pattern.astype(np.float64)
similarities = np.zeros(self.n_classes)
for c in range(self.n_classes):
if self._counts[c] == 0:
continue
mem_norm = np.linalg.norm(self.memory[c])
pat_norm = np.linalg.norm(pattern)
if mem_norm > 1e-10 and pat_norm > 1e-10:
similarities[c] = np.dot(self.memory[c], pattern) / (mem_norm * pat_norm)
return int(np.argmax(similarities))
|
few_shot_episode(support_x, support_y, query_x)
Run a complete few-shot episode.
Parameters
support_x : list of ndarray
Support set spike patterns.
support_y : list of int
Support set labels.
query_x : list of ndarray
Query set spike patterns.
Returns
list of int — predicted labels for query set
Source code in src/sc_neurocore/few_shot/haam.py
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118 | def few_shot_episode(
self,
support_x: list[np.ndarray],
support_y: list[int],
query_x: list[np.ndarray],
) -> list[int]:
"""Run a complete few-shot episode.
Parameters
----------
support_x : list of ndarray
Support set spike patterns.
support_y : list of int
Support set labels.
query_x : list of ndarray
Query set spike patterns.
Returns
-------
list of int — predicted labels for query set
"""
self.reset()
for pattern, label in zip(support_x, support_y):
self.store(pattern, label)
return [self.query(q) for q in query_x]
|
SpikePrototypeNet
dataclass
Prototypical network in spike domain.
Compute class prototypes as mean spike-rate vectors from support set.
Classify queries by nearest prototype (Euclidean or cosine).
Parameters
n_features : int
metric : str
'cosine' or 'euclidean'
Source code in src/sc_neurocore/few_shot/haam.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188 | @dataclass
class SpikePrototypeNet:
"""Prototypical network in spike domain.
Compute class prototypes as mean spike-rate vectors from support set.
Classify queries by nearest prototype (Euclidean or cosine).
Parameters
----------
n_features : int
metric : str
'cosine' or 'euclidean'
"""
n_features: int
metric: str = "cosine"
def classify(
self,
support_x: list[np.ndarray],
support_y: list[int],
query_x: list[np.ndarray],
) -> list[int]:
"""Classify query set using support set prototypes.
Parameters
----------
support_x : list of ndarray, shape (n_features,) or (T, n_features)
support_y : list of int
query_x : list of ndarray
Returns
-------
list of int
"""
# Compute prototypes
classes = sorted(set(support_y))
prototypes = {}
for c in classes:
patterns = [
s.mean(axis=0) if s.ndim > 1 else s.astype(np.float64)
for s, y in zip(support_x, support_y)
if y == c
]
prototypes[c] = np.mean(patterns, axis=0)
# Classify queries
predictions = []
for q in query_x:
qv = q.mean(axis=0) if q.ndim > 1 else q.astype(np.float64)
best_c = classes[0]
best_score = -float("inf")
for c, proto in prototypes.items():
if self.metric == "cosine":
n1, n2 = np.linalg.norm(qv), np.linalg.norm(proto)
score = np.dot(qv, proto) / max(n1 * n2, 1e-10)
else:
score = -np.linalg.norm(qv - proto)
if score > best_score:
best_score = score
best_c = c
predictions.append(best_c)
return predictions
|
classify(support_x, support_y, query_x)
Classify query set using support set prototypes.
Parameters
support_x : list of ndarray, shape (n_features,) or (T, n_features)
support_y : list of int
query_x : list of ndarray
Returns
list of int
Source code in src/sc_neurocore/few_shot/haam.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188 | def classify(
self,
support_x: list[np.ndarray],
support_y: list[int],
query_x: list[np.ndarray],
) -> list[int]:
"""Classify query set using support set prototypes.
Parameters
----------
support_x : list of ndarray, shape (n_features,) or (T, n_features)
support_y : list of int
query_x : list of ndarray
Returns
-------
list of int
"""
# Compute prototypes
classes = sorted(set(support_y))
prototypes = {}
for c in classes:
patterns = [
s.mean(axis=0) if s.ndim > 1 else s.astype(np.float64)
for s, y in zip(support_x, support_y)
if y == c
]
prototypes[c] = np.mean(patterns, axis=0)
# Classify queries
predictions = []
for q in query_x:
qv = q.mean(axis=0) if q.ndim > 1 else q.astype(np.float64)
best_c = classes[0]
best_score = -float("inf")
for c, proto in prototypes.items():
if self.metric == "cosine":
n1, n2 = np.linalg.norm(qv), np.linalg.norm(proto)
score = np.dot(qv, proto) / max(n1 * n2, 1e-10)
else:
score = -np.linalg.norm(qv - proto)
if score > best_score:
best_score = score
best_c = c
predictions.append(best_c)
return predictions
|