Skip to content

Tutorial 80: Contrastive Self-Supervised Learning

Self-supervised learning for SNNs without labeled data. InfoNCE contrastive loss for spike representations, and CSDP — a biologically plausible local learning rule (Forward-Forward generalized to spiking circuits).

SpikeContrastiveLoss (InfoNCE)

import numpy as np
from sc_neurocore.contrastive import SpikeContrastiveLoss

loss_fn = SpikeContrastiveLoss(temperature=0.5)

rng = np.random.RandomState(42)
view_a = np.abs(rng.randn(16, 128))  # augmentation A
view_b = view_a + rng.randn(16, 128) * 0.1  # augmentation B

loss = loss_fn.compute(view_a, view_b)
print(f"Contrastive loss: {loss:.3f}")

CSDP: Biologically Plausible Learning

Positive phase (real data) → Hebbian. Negative phase (corrupted) → anti-Hebbian.

from sc_neurocore.contrastive import CSDPRule

csdp = CSDPRule(lr=0.01, decay=0.001)
W = np.random.randn(64, 128) * 0.01

pos_pre = (np.random.rand(128) > 0.5).astype(float)
pos_post = (np.random.rand(64) > 0.5).astype(float)
neg_pre = np.random.rand(128)
neg_post = (np.random.rand(64) > 0.5).astype(float)

W = csdp.contrastive_step(W, pos_pre, pos_post, neg_pre, neg_post)

# Goodness score: positive data should score higher
print(f"Pos goodness: {csdp.goodness(pos_post):.2f}")

Reference: Ororbia 2024, Science Advances

API Reference

sc_neurocore.contrastive.ssl

Contrastive self-supervised learning for SNNs.

SpikeContrastiveLoss: InfoNCE-style loss for spike representations. CSDPRule: Contrastive Signal-Dependent Plasticity — biologically plausible local learning rule (Science Advances 2024).

No SNN library ships self-supervised learning utilities.

SpikeContrastiveLoss

InfoNCE contrastive loss adapted for spike representations.

Computes similarity between spike-rate vectors from two augmented views of the same input. Positive pairs = same input, different augmentation. Negative pairs = different inputs.

Parameters

temperature : float Contrastive temperature scaling.

Source code in src/sc_neurocore/contrastive/ssl.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class SpikeContrastiveLoss:
    """InfoNCE contrastive loss adapted for spike representations.

    Computes similarity between spike-rate vectors from two augmented
    views of the same input. Positive pairs = same input, different
    augmentation. Negative pairs = different inputs.

    Parameters
    ----------
    temperature : float
        Contrastive temperature scaling.
    """

    def __init__(self, temperature: float = 0.5):
        self.temperature = temperature

    def compute(
        self,
        view_a: np.ndarray,
        view_b: np.ndarray,
    ) -> float:
        """Compute contrastive loss for a batch of spike-rate pairs.

        Parameters
        ----------
        view_a : ndarray of shape (batch, n_features)
            Spike rates from augmentation A.
        view_b : ndarray of shape (batch, n_features)
            Spike rates from augmentation B.

        Returns
        -------
        float — InfoNCE loss
        """
        batch = view_a.shape[0]
        if batch < 2:
            return 0.0

        # Normalize
        a_norm = view_a / np.clip(np.linalg.norm(view_a, axis=1, keepdims=True), 1e-8, None)
        b_norm = view_b / np.clip(np.linalg.norm(view_b, axis=1, keepdims=True), 1e-8, None)

        # Similarity matrix
        sim = a_norm @ b_norm.T / self.temperature

        # InfoNCE: positive = diagonal, negatives = off-diagonal
        # log softmax along rows
        exp_sim = np.exp(sim - sim.max(axis=1, keepdims=True))
        log_prob = np.log(
            np.clip(
                np.diag(exp_sim) / exp_sim.sum(axis=1),
                1e-10,
                None,
            )
        )
        return -float(log_prob.mean())

compute(view_a, view_b)

Compute contrastive loss for a batch of spike-rate pairs.

Parameters

view_a : ndarray of shape (batch, n_features) Spike rates from augmentation A. view_b : ndarray of shape (batch, n_features) Spike rates from augmentation B.

Returns

float — InfoNCE loss

Source code in src/sc_neurocore/contrastive/ssl.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def compute(
    self,
    view_a: np.ndarray,
    view_b: np.ndarray,
) -> float:
    """Compute contrastive loss for a batch of spike-rate pairs.

    Parameters
    ----------
    view_a : ndarray of shape (batch, n_features)
        Spike rates from augmentation A.
    view_b : ndarray of shape (batch, n_features)
        Spike rates from augmentation B.

    Returns
    -------
    float — InfoNCE loss
    """
    batch = view_a.shape[0]
    if batch < 2:
        return 0.0

    # Normalize
    a_norm = view_a / np.clip(np.linalg.norm(view_a, axis=1, keepdims=True), 1e-8, None)
    b_norm = view_b / np.clip(np.linalg.norm(view_b, axis=1, keepdims=True), 1e-8, None)

    # Similarity matrix
    sim = a_norm @ b_norm.T / self.temperature

    # InfoNCE: positive = diagonal, negatives = off-diagonal
    # log softmax along rows
    exp_sim = np.exp(sim - sim.max(axis=1, keepdims=True))
    log_prob = np.log(
        np.clip(
            np.diag(exp_sim) / exp_sim.sum(axis=1),
            1e-10,
            None,
        )
    )
    return -float(log_prob.mean())

CSDPRule dataclass

Contrastive Signal-Dependent Plasticity.

Local learning rule: weight update depends on (pre, post, contrastive_signal). Positive phase: present real data → Hebbian update. Negative phase: present corrupted data → anti-Hebbian update.

Generalizes Forward-Forward to spiking circuits.

Reference: Ororbia 2024, Science Advances

Parameters

lr : float Learning rate. decay : float Weight decay for regularization.

Source code in src/sc_neurocore/contrastive/ssl.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@dataclass
class CSDPRule:
    """Contrastive Signal-Dependent Plasticity.

    Local learning rule: weight update depends on (pre, post, contrastive_signal).
    Positive phase: present real data → Hebbian update.
    Negative phase: present corrupted data → anti-Hebbian update.

    Generalizes Forward-Forward to spiking circuits.

    Reference: Ororbia 2024, Science Advances

    Parameters
    ----------
    lr : float
        Learning rate.
    decay : float
        Weight decay for regularization.
    """

    lr: float = 0.01
    decay: float = 0.001

    def positive_update(
        self,
        weights: np.ndarray,
        pre_spikes: np.ndarray,
        post_spikes: np.ndarray,
    ) -> np.ndarray:
        """Hebbian update from positive (real) data.

        dW = lr * (post @ pre^T) - decay * W
        """
        dW = self.lr * np.outer(post_spikes, pre_spikes) - self.decay * weights
        return weights + dW

    def negative_update(
        self,
        weights: np.ndarray,
        pre_spikes: np.ndarray,
        post_spikes: np.ndarray,
    ) -> np.ndarray:
        """Anti-Hebbian update from negative (corrupted) data.

        dW = -lr * (post @ pre^T)
        """
        dW = -self.lr * np.outer(post_spikes, pre_spikes)
        return weights + dW

    def contrastive_step(
        self,
        weights: np.ndarray,
        pos_pre: np.ndarray,
        pos_post: np.ndarray,
        neg_pre: np.ndarray,
        neg_post: np.ndarray,
    ) -> np.ndarray:
        """Full contrastive update: positive + negative phase."""
        w = self.positive_update(weights, pos_pre, pos_post)
        w = self.negative_update(w, neg_pre, neg_post)
        return w

    def goodness(self, activations: np.ndarray) -> float:
        """Compute 'goodness' score (sum of squared activations).

        Positive data should have high goodness, negative data low.
        """
        return float(np.sum(activations**2))

positive_update(weights, pre_spikes, post_spikes)

Hebbian update from positive (real) data.

dW = lr * (post @ pre^T) - decay * W

Source code in src/sc_neurocore/contrastive/ssl.py
105
106
107
108
109
110
111
112
113
114
115
116
def positive_update(
    self,
    weights: np.ndarray,
    pre_spikes: np.ndarray,
    post_spikes: np.ndarray,
) -> np.ndarray:
    """Hebbian update from positive (real) data.

    dW = lr * (post @ pre^T) - decay * W
    """
    dW = self.lr * np.outer(post_spikes, pre_spikes) - self.decay * weights
    return weights + dW

negative_update(weights, pre_spikes, post_spikes)

Anti-Hebbian update from negative (corrupted) data.

dW = -lr * (post @ pre^T)

Source code in src/sc_neurocore/contrastive/ssl.py
118
119
120
121
122
123
124
125
126
127
128
129
def negative_update(
    self,
    weights: np.ndarray,
    pre_spikes: np.ndarray,
    post_spikes: np.ndarray,
) -> np.ndarray:
    """Anti-Hebbian update from negative (corrupted) data.

    dW = -lr * (post @ pre^T)
    """
    dW = -self.lr * np.outer(post_spikes, pre_spikes)
    return weights + dW

contrastive_step(weights, pos_pre, pos_post, neg_pre, neg_post)

Full contrastive update: positive + negative phase.

Source code in src/sc_neurocore/contrastive/ssl.py
131
132
133
134
135
136
137
138
139
140
141
142
def contrastive_step(
    self,
    weights: np.ndarray,
    pos_pre: np.ndarray,
    pos_post: np.ndarray,
    neg_pre: np.ndarray,
    neg_post: np.ndarray,
) -> np.ndarray:
    """Full contrastive update: positive + negative phase."""
    w = self.positive_update(weights, pos_pre, pos_post)
    w = self.negative_update(w, neg_pre, neg_post)
    return w

goodness(activations)

Compute 'goodness' score (sum of squared activations).

Positive data should have high goodness, negative data low.

Source code in src/sc_neurocore/contrastive/ssl.py
144
145
146
147
148
149
def goodness(self, activations: np.ndarray) -> float:
    """Compute 'goodness' score (sum of squared activations).

    Positive data should have high goodness, negative data low.
    """
    return float(np.sum(activations**2))