Skip to content

World Model — Spike Prediction + Planning

Three components: (1) online-learnable spike predictor for codec integration, (2) stochastic state-transition model, (3) greedy action planner.

SpikePredictor — Online Autoregressive Codec

The core workhorse. Predicts multi-channel spike patterns from recent history using a linear autoregressive model trained online via LMS (Least Mean Squares). No backprop, no batches — updates one sample at a time.

Codec integration: Encoder and decoder both maintain identical SpikePredictor instances. Both see the same history. Prediction error (XOR of actual vs predicted) is what gets transmitted. At the decoder, XOR recovers the original. Deterministic: same history → same prediction → lossless roundtrip.

Parameter Default Meaning
n_channels (required) Number of spike channels
history_len 8 Context window (K past timesteps)
lr 0.01 LMS learning rate
threshold 0.5 Binary prediction threshold

Codec functions:

  • predict_and_xor_world_model(spikes, n_channels, ...) → (errors, correct_count) — Encoder
  • xor_and_recover_world_model(errors, n_channels, ...) → spikes — Decoder

PredictiveWorldModel — State Transitions

Linear transition model: state_next = clip(T @ [state; action], 0, 1) where T is a row-normalized transition matrix. Provides predict_next_state() and forecast() for multi-step rollouts.

SCPlanner — Greedy Action Selection

Uses PredictiveWorldModel for random-shooting planning: sample N candidate actions, predict outcomes, pick the one closest to the goal state.

  • propose_action(current, goal, n_candidates) — Best single action
  • plan_sequence(current, goal, horizon) — Greedy multi-step plan

Usage

from sc_neurocore.world_model import SpikePredictor
from sc_neurocore.world_model.spike_predictor import (
    predict_and_xor_world_model,
    xor_and_recover_world_model,
)
import numpy as np

# Lossless codec roundtrip
spikes = (np.random.rand(100, 32) < 0.3).astype(np.int8)
errors, correct = predict_and_xor_world_model(spikes, n_channels=32)
recovered = xor_and_recover_world_model(errors, n_channels=32)
assert np.array_equal(spikes, recovered)  # Always true
print(f"Prediction accuracy: {correct / (100 * 32):.1%}")

# Planning
from sc_neurocore.world_model import PredictiveWorldModel, SCPlanner
model = PredictiveWorldModel(state_dim=4, action_dim=2)
planner = SCPlanner(world_model=model)
plan = planner.plan_sequence(
    current_state=np.array([0.1, 0.2, 0.3, 0.4]),
    goal_state=np.array([0.9, 0.8, 0.7, 0.6]),
    horizon=5,
)

sc_neurocore.world_model

sc_neurocore.world_model -- Tier: research (experimental / research).

SCPlanner dataclass

A planner that uses a PredictiveWorldModel to select actions.

Source code in src/sc_neurocore/world_model/planner.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@dataclass
class SCPlanner:
    """
    A planner that uses a PredictiveWorldModel to select actions.
    """

    world_model: PredictiveWorldModel

    def propose_action(
        self,
        current_state: np.ndarray[Any, Any],
        goal_state: np.ndarray[Any, Any],
        n_candidates: int = 10,
    ) -> np.ndarray[Any, Any]:
        """
        Propose the best action among n_candidates based on predicted outcome.
        """
        best_action = None
        min_dist = float("inf")

        for _ in range(n_candidates):
            # Sample a random action
            candidate_action = np.random.uniform(0, 1, self.world_model.action_dim)

            # Predict next state
            predicted_state = self.world_model.predict_next_state(current_state, candidate_action)

            # Evaluate distance to goal
            dist = np.linalg.norm(predicted_state - goal_state)

            if dist < min_dist:
                min_dist = dist
                best_action = candidate_action

        return best_action

    def plan_sequence(
        self,
        current_state: np.ndarray[Any, Any],
        goal_state: np.ndarray[Any, Any],
        horizon: int = 5,
    ) -> List[np.ndarray[Any, Any]]:
        """
        Simple greedy planning for a sequence of actions.
        """
        plan = []
        curr_s = current_state
        for _ in range(horizon):
            action = self.propose_action(curr_s, goal_state)
            plan.append(action)
            curr_s = self.world_model.predict_next_state(curr_s, action)
        return plan

propose_action(current_state, goal_state, n_candidates=10)

Propose the best action among n_candidates based on predicted outcome.

Source code in src/sc_neurocore/world_model/planner.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def propose_action(
    self,
    current_state: np.ndarray[Any, Any],
    goal_state: np.ndarray[Any, Any],
    n_candidates: int = 10,
) -> np.ndarray[Any, Any]:
    """
    Propose the best action among n_candidates based on predicted outcome.
    """
    best_action = None
    min_dist = float("inf")

    for _ in range(n_candidates):
        # Sample a random action
        candidate_action = np.random.uniform(0, 1, self.world_model.action_dim)

        # Predict next state
        predicted_state = self.world_model.predict_next_state(current_state, candidate_action)

        # Evaluate distance to goal
        dist = np.linalg.norm(predicted_state - goal_state)

        if dist < min_dist:
            min_dist = dist
            best_action = candidate_action

    return best_action

plan_sequence(current_state, goal_state, horizon=5)

Simple greedy planning for a sequence of actions.

Source code in src/sc_neurocore/world_model/planner.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def plan_sequence(
    self,
    current_state: np.ndarray[Any, Any],
    goal_state: np.ndarray[Any, Any],
    horizon: int = 5,
) -> List[np.ndarray[Any, Any]]:
    """
    Simple greedy planning for a sequence of actions.
    """
    plan = []
    curr_s = current_state
    for _ in range(horizon):
        action = self.propose_action(curr_s, goal_state)
        plan.append(action)
        curr_s = self.world_model.predict_next_state(curr_s, action)
    return plan

PredictiveWorldModel dataclass

A stochastic predictive world model. Predicts state_next = f(state_curr, action).

Source code in src/sc_neurocore/world_model/predictive_model.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@dataclass
class PredictiveWorldModel:
    """
    A stochastic predictive world model.
    Predicts state_next = f(state_curr, action).
    """

    state_dim: int
    action_dim: int

    def __post_init__(self) -> None:
        # Internal transition weights (simplified)
        self.transition_matrix = np.random.uniform(
            0, 1, (self.state_dim, self.state_dim + self.action_dim)
        )
        # Normalize rows to represent probabilities
        row_sums = self.transition_matrix.sum(axis=1)
        self.transition_matrix /= row_sums[:, np.newaxis]

    def predict_next_state(
        self, current_state: np.ndarray[Any, Any], action: np.ndarray[Any, Any]
    ) -> np.ndarray[Any, Any]:
        """
        Predicts the next state given current state and action.
        Inputs:
            current_state: (state_dim,) array of probabilities.
            action: (action_dim,) array of probabilities.
        Returns:
            next_state: (state_dim,) predicted probabilities.
        """
        # Concatenate state and action
        combined_input = np.concatenate([current_state, action])

        # Linear transition in probability domain
        next_state = np.dot(self.transition_matrix, combined_input)

        # Clip to ensure valid probabilities
        return np.clip(next_state, 0, 1)

    def forecast(
        self, initial_state: np.ndarray[Any, Any], actions: list[np.ndarray[Any, Any]]
    ) -> list[np.ndarray[Any, Any]]:
        """
        Forecast multiple steps ahead given a sequence of actions.
        """
        trajectory = []
        curr_state = initial_state
        for act in actions:
            curr_state = self.predict_next_state(curr_state, act)
            trajectory.append(curr_state)
        return trajectory

predict_next_state(current_state, action)

Predicts the next state given current state and action. Inputs: current_state: (state_dim,) array of probabilities. action: (action_dim,) array of probabilities. Returns: next_state: (state_dim,) predicted probabilities.

Source code in src/sc_neurocore/world_model/predictive_model.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def predict_next_state(
    self, current_state: np.ndarray[Any, Any], action: np.ndarray[Any, Any]
) -> np.ndarray[Any, Any]:
    """
    Predicts the next state given current state and action.
    Inputs:
        current_state: (state_dim,) array of probabilities.
        action: (action_dim,) array of probabilities.
    Returns:
        next_state: (state_dim,) predicted probabilities.
    """
    # Concatenate state and action
    combined_input = np.concatenate([current_state, action])

    # Linear transition in probability domain
    next_state = np.dot(self.transition_matrix, combined_input)

    # Clip to ensure valid probabilities
    return np.clip(next_state, 0, 1)

forecast(initial_state, actions)

Forecast multiple steps ahead given a sequence of actions.

Source code in src/sc_neurocore/world_model/predictive_model.py
53
54
55
56
57
58
59
60
61
62
63
64
def forecast(
    self, initial_state: np.ndarray[Any, Any], actions: list[np.ndarray[Any, Any]]
) -> list[np.ndarray[Any, Any]]:
    """
    Forecast multiple steps ahead given a sequence of actions.
    """
    trajectory = []
    curr_state = initial_state
    for act in actions:
        curr_state = self.predict_next_state(curr_state, act)
        trajectory.append(curr_state)
    return trajectory

SpikePredictor dataclass

Online autoregressive spike pattern predictor.

Learns to predict spike[t] from spike[t-K:t] per channel. Weight matrix W of shape (N, N*K) maps flattened history to per-channel firing probabilities. Binary prediction via threshold.

LMS update after each timestep.

W += lr * outer(error, history)

where error = actual - predicted_prob.

Parameters

n_channels : int Number of spike channels. history_len : int Number of past timesteps to use as context (K). lr : float LMS learning rate. threshold : float Probability threshold for binary prediction. seed : int RNG seed for weight initialization.

Source code in src/sc_neurocore/world_model/spike_predictor.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@dataclass
class SpikePredictor:
    """Online autoregressive spike pattern predictor.

    Learns to predict spike[t] from spike[t-K:t] per channel.
    Weight matrix W of shape (N, N*K) maps flattened history to
    per-channel firing probabilities. Binary prediction via threshold.

    Training: LMS update after each timestep.
        W += lr * outer(error, history)
    where error = actual - predicted_prob.

    Parameters
    ----------
    n_channels : int
        Number of spike channels.
    history_len : int
        Number of past timesteps to use as context (K).
    lr : float
        LMS learning rate.
    threshold : float
        Probability threshold for binary prediction.
    seed : int
        RNG seed for weight initialization.
    """

    n_channels: int
    history_len: int = 8
    lr: float = 0.01
    threshold: float = 0.5
    seed: int = 42

    def __post_init__(self):
        rng = np.random.RandomState(self.seed)
        n_features = self.n_channels * self.history_len
        # Small random weights — predict from history
        self.W = rng.randn(self.n_channels, n_features) * 0.01
        self.bias = np.zeros(self.n_channels)
        # Circular buffer for history
        self._history = np.zeros((self.history_len, self.n_channels), dtype=np.float64)
        self._t = 0

    def _features(self) -> np.ndarray:
        """Flatten history buffer into feature vector."""
        # Ordered: oldest first
        indices = [(self._t + i) % self.history_len for i in range(self.history_len)]
        return self._history[indices].ravel()

    def predict_probs(self) -> np.ndarray:
        """Predict per-channel firing probabilities from history."""
        features = self._features()
        logits = self.W @ features + self.bias
        # Sigmoid activation
        probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -20, 20)))
        return probs

    def predict(self) -> np.ndarray:
        """Predict binary spike pattern."""
        return (self.predict_probs() > self.threshold).astype(np.int8)

    def update(self, actual: np.ndarray):
        """Update weights with observed spike pattern (LMS rule).

        Parameters
        ----------
        actual : ndarray of shape (n_channels,), binary
        """
        features = self._features()
        probs = self.predict_probs()
        error = actual.astype(np.float64) - probs

        # LMS weight update
        self.W += self.lr * np.outer(error, features)
        self.bias += self.lr * error

        # Push actual into history buffer
        self._history[self._t % self.history_len] = actual.astype(np.float64)
        self._t += 1

    def reset(self):
        """Reset to initial state (same seed → same weights)."""
        self.__post_init__()

predict_probs()

Predict per-channel firing probabilities from history.

Source code in src/sc_neurocore/world_model/spike_predictor.py
77
78
79
80
81
82
83
def predict_probs(self) -> np.ndarray:
    """Predict per-channel firing probabilities from history."""
    features = self._features()
    logits = self.W @ features + self.bias
    # Sigmoid activation
    probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -20, 20)))
    return probs

predict()

Predict binary spike pattern.

Source code in src/sc_neurocore/world_model/spike_predictor.py
85
86
87
def predict(self) -> np.ndarray:
    """Predict binary spike pattern."""
    return (self.predict_probs() > self.threshold).astype(np.int8)

update(actual)

Update weights with observed spike pattern (LMS rule).

Parameters

actual : ndarray of shape (n_channels,), binary

Source code in src/sc_neurocore/world_model/spike_predictor.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def update(self, actual: np.ndarray):
    """Update weights with observed spike pattern (LMS rule).

    Parameters
    ----------
    actual : ndarray of shape (n_channels,), binary
    """
    features = self._features()
    probs = self.predict_probs()
    error = actual.astype(np.float64) - probs

    # LMS weight update
    self.W += self.lr * np.outer(error, features)
    self.bias += self.lr * error

    # Push actual into history buffer
    self._history[self._t % self.history_len] = actual.astype(np.float64)
    self._t += 1

reset()

Reset to initial state (same seed → same weights).

Source code in src/sc_neurocore/world_model/spike_predictor.py
108
109
110
def reset(self):
    """Reset to initial state (same seed → same weights)."""
    self.__post_init__()