Skip to content

World Model — Spike Prediction + Planning

Three components: (1) online-learnable spike predictor for codec integration, (2) stochastic state-transition model, (3) greedy action planner.

SpikePredictor — Online Autoregressive Codec

The core workhorse. Predicts multi-channel spike patterns from recent history using a linear autoregressive model trained online via LMS (Least Mean Squares). No backprop, no batches — updates one sample at a time.

Codec integration: Encoder and decoder both maintain identical SpikePredictor instances. Both see the same history. Prediction error (XOR of actual vs predicted) is what gets transmitted. At the decoder, XOR recovers the original. Deterministic: same history → same prediction → lossless roundtrip.

Parameter Default Meaning
n_channels (required) Number of spike channels
history_len 8 Context window (K past timesteps)
lr 0.01 LMS learning rate
threshold 0.5 Binary prediction threshold

Codec functions:

  • predict_and_xor_world_model(spikes, n_channels, ...) → (errors, correct_count) — Encoder
  • xor_and_recover_world_model(errors, n_channels, ...) → spikes — Decoder

PredictiveWorldModel — Linear Gaussian State-Space

Probabilistic predictive model implemented as a Linear Gaussian State-Space Model (LGSSM) with Kalman filter (forward), RTS smoother (backward), and EM parameter learner. References: Kalman 1960, Rauch-Tung-Striebel 1965, Shumway & Stoffer 1982, Bishop 2006 §13.3.

Provides predict_next_state() (deterministic mean), predict_next_state_with_cov() (mean + covariance), forecast() / forecast_with_cov() for multi-step rollouts.

The previous "linear transition matrix + clip-to-[0,1]" implementation was a placeholder masquerading as a world model and was replaced 2026-04-17 per feedback_sophisticated_from_start.md. See Predictive Model detailed page for the full LGSSM + Kalman + RTS + EM derivations, performance numbers, and the multi-language backend status.

SCPlanner — Greedy Action Selection

Uses PredictiveWorldModel for random-shooting planning: sample N candidate actions, predict outcomes, pick the one closest to the goal state.

  • propose_action(current, goal, n_candidates) — Best single action
  • plan_sequence(current, goal, horizon) — Greedy multi-step plan

Usage

Python
from sc_neurocore.world_model import SpikePredictor
from sc_neurocore.world_model.spike_predictor import (
    predict_and_xor_world_model,
    xor_and_recover_world_model,
)
import numpy as np

# Lossless codec roundtrip
spikes = (np.random.rand(100, 32) < 0.3).astype(np.int8)
errors, correct = predict_and_xor_world_model(spikes, n_channels=32)
recovered = xor_and_recover_world_model(errors, n_channels=32)
assert np.array_equal(spikes, recovered)  # Always true
print(f"Prediction accuracy: {correct / (100 * 32):.1%}")

# Planning
from sc_neurocore.world_model import PredictiveWorldModel, SCPlanner
model = PredictiveWorldModel(state_dim=4, action_dim=2)
planner = SCPlanner(world_model=model)
plan = planner.plan_sequence(
    current_state=np.array([0.1, 0.2, 0.3, 0.4]),
    goal_state=np.array([0.9, 0.8, 0.7, 0.6]),
    horizon=5,
)

sc_neurocore.world_model

sc_neurocore.world_model -- Tier: research (experimental / research).

SpikePredictor dataclass

Online autoregressive spike pattern predictor.

Learns to predict spike[t] from spike[t-K:t] per channel. Weight matrix W of shape (N, N*K) maps flattened history to per-channel firing probabilities. Binary prediction via threshold.

LMS update after each timestep.

W += lr * outer(error, history)

where error = actual - predicted_prob.

Parameters

n_channels : int Number of spike channels. history_len : int Number of past timesteps to use as context (K). lr : float LMS learning rate. threshold : float Probability threshold for binary prediction. seed : int RNG seed for weight initialization.

Source code in src/sc_neurocore/world_model/spike_predictor.py
Python
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@dataclass
class SpikePredictor:
    """Online autoregressive spike pattern predictor.

    Learns to predict spike[t] from spike[t-K:t] per channel.
    Weight matrix W of shape (N, N*K) maps flattened history to
    per-channel firing probabilities. Binary prediction via threshold.

    Training: LMS update after each timestep.
        W += lr * outer(error, history)
    where error = actual - predicted_prob.

    Parameters
    ----------
    n_channels : int
        Number of spike channels.
    history_len : int
        Number of past timesteps to use as context (K).
    lr : float
        LMS learning rate.
    threshold : float
        Probability threshold for binary prediction.
    seed : int
        RNG seed for weight initialization.
    """

    n_channels: int
    history_len: int = 8
    lr: float = 0.01
    threshold: float = 0.5
    seed: int = 42

    def __post_init__(self) -> None:
        rng = np.random.RandomState(self.seed)
        n_features = self.n_channels * self.history_len
        # Small random weights — predict from history
        self.W = rng.randn(self.n_channels, n_features) * 0.01
        self.bias = np.zeros(self.n_channels)
        # Circular buffer for history
        self._history = np.zeros((self.history_len, self.n_channels), dtype=np.float64)
        self._t = 0

    def _features(self) -> np.ndarray:
        """Flatten history buffer into feature vector."""
        # Ordered: oldest first
        indices = [(self._t + i) % self.history_len for i in range(self.history_len)]
        return self._history[indices].ravel()

    def predict_probs(self) -> np.ndarray:
        """Predict per-channel firing probabilities from history."""
        features = self._features()
        logits = self.W @ features + self.bias
        # Sigmoid activation
        probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -20, 20)))
        return probs

    def predict(self) -> np.ndarray:
        """Predict binary spike pattern."""
        return (self.predict_probs() > self.threshold).astype(np.int8)

    def update(self, actual: np.ndarray) -> None:
        """Update weights with observed spike pattern (LMS rule).

        Parameters
        ----------
        actual : ndarray of shape (n_channels,), binary
        """
        features = self._features()
        probs = self.predict_probs()
        error = actual.astype(np.float64) - probs

        # LMS weight update
        self.W += self.lr * np.outer(error, features)
        self.bias += self.lr * error

        # Push actual into history buffer
        self._history[self._t % self.history_len] = actual.astype(np.float64)
        self._t += 1

    def reset(self) -> None:
        """Reset to initial state (same seed → same weights)."""
        self.__post_init__()

predict_probs()

Predict per-channel firing probabilities from history.

Source code in src/sc_neurocore/world_model/spike_predictor.py
Python
78
79
80
81
82
83
84
def predict_probs(self) -> np.ndarray:
    """Predict per-channel firing probabilities from history."""
    features = self._features()
    logits = self.W @ features + self.bias
    # Sigmoid activation
    probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -20, 20)))
    return probs

predict()

Predict binary spike pattern.

Source code in src/sc_neurocore/world_model/spike_predictor.py
Python
86
87
88
def predict(self) -> np.ndarray:
    """Predict binary spike pattern."""
    return (self.predict_probs() > self.threshold).astype(np.int8)

update(actual)

Update weights with observed spike pattern (LMS rule).

Parameters

actual : ndarray of shape (n_channels,), binary

Source code in src/sc_neurocore/world_model/spike_predictor.py
Python
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def update(self, actual: np.ndarray) -> None:
    """Update weights with observed spike pattern (LMS rule).

    Parameters
    ----------
    actual : ndarray of shape (n_channels,), binary
    """
    features = self._features()
    probs = self.predict_probs()
    error = actual.astype(np.float64) - probs

    # LMS weight update
    self.W += self.lr * np.outer(error, features)
    self.bias += self.lr * error

    # Push actual into history buffer
    self._history[self._t % self.history_len] = actual.astype(np.float64)
    self._t += 1

reset()

Reset to initial state (same seed → same weights).

Source code in src/sc_neurocore/world_model/spike_predictor.py
Python
109
110
111
def reset(self) -> None:
    """Reset to initial state (same seed → same weights)."""
    self.__post_init__()

SCPlanner dataclass

A planner that uses a PredictiveWorldModel to select actions.

Source code in src/sc_neurocore/world_model/planner.py
Python
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@dataclass
class SCPlanner:
    """
    A planner that uses a PredictiveWorldModel to select actions.
    """

    world_model: PredictiveWorldModel

    def propose_action(
        self,
        current_state: np.ndarray[Any, Any],
        goal_state: np.ndarray[Any, Any],
        n_candidates: int = 10,
    ) -> np.ndarray[Any, Any]:
        """
        Propose the best action among n_candidates based on predicted outcome.
        """
        best_action = None
        min_dist = float("inf")

        for _ in range(n_candidates):
            # Sample a random action
            candidate_action = np.random.uniform(0, 1, self.world_model.action_dim)

            # Predict next state
            predicted_state = self.world_model.predict_next_state(current_state, candidate_action)

            # Evaluate distance to goal
            dist = np.linalg.norm(predicted_state - goal_state)

            if dist < min_dist:
                min_dist = dist  # type: ignore[assignment]
                best_action = candidate_action

        return best_action  # type: ignore[return-value]

    def plan_sequence(
        self,
        current_state: np.ndarray[Any, Any],
        goal_state: np.ndarray[Any, Any],
        horizon: int = 5,
    ) -> List[np.ndarray[Any, Any]]:
        """
        Simple greedy planning for a sequence of actions.
        """
        plan = []
        curr_s = current_state
        for _ in range(horizon):
            action = self.propose_action(curr_s, goal_state)
            plan.append(action)
            curr_s = self.world_model.predict_next_state(curr_s, action)
        return plan

propose_action(current_state, goal_state, n_candidates=10)

Propose the best action among n_candidates based on predicted outcome.

Source code in src/sc_neurocore/world_model/planner.py
Python
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def propose_action(
    self,
    current_state: np.ndarray[Any, Any],
    goal_state: np.ndarray[Any, Any],
    n_candidates: int = 10,
) -> np.ndarray[Any, Any]:
    """
    Propose the best action among n_candidates based on predicted outcome.
    """
    best_action = None
    min_dist = float("inf")

    for _ in range(n_candidates):
        # Sample a random action
        candidate_action = np.random.uniform(0, 1, self.world_model.action_dim)

        # Predict next state
        predicted_state = self.world_model.predict_next_state(current_state, candidate_action)

        # Evaluate distance to goal
        dist = np.linalg.norm(predicted_state - goal_state)

        if dist < min_dist:
            min_dist = dist  # type: ignore[assignment]
            best_action = candidate_action

    return best_action  # type: ignore[return-value]

plan_sequence(current_state, goal_state, horizon=5)

Simple greedy planning for a sequence of actions.

Source code in src/sc_neurocore/world_model/planner.py
Python
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def plan_sequence(
    self,
    current_state: np.ndarray[Any, Any],
    goal_state: np.ndarray[Any, Any],
    horizon: int = 5,
) -> List[np.ndarray[Any, Any]]:
    """
    Simple greedy planning for a sequence of actions.
    """
    plan = []
    curr_s = current_state
    for _ in range(horizon):
        action = self.propose_action(curr_s, goal_state)
        plan.append(action)
        curr_s = self.world_model.predict_next_state(curr_s, action)
    return plan

PredictiveWorldModel dataclass

Probabilistic predictive world model based on a Linear Gaussian SSM.

The legacy 65-LOC predict_next_state / forecast API is preserved as a thin wrapper on the proper LinearGaussianSSM + KalmanFilter infrastructure above. The previous deterministic linear matmul + clip placeholder was replaced 2026-04-17 per feedback_sophisticated_from_start.md.

Source code in src/sc_neurocore/world_model/predictive_model.py
Python
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
@dataclass
class PredictiveWorldModel:
    """Probabilistic predictive world model based on a Linear Gaussian SSM.

    The legacy 65-LOC `predict_next_state` / `forecast` API is
    preserved as a thin wrapper on the proper `LinearGaussianSSM`
    + `KalmanFilter` infrastructure above. The previous
    deterministic linear matmul + clip placeholder was replaced
    2026-04-17 per `feedback_sophisticated_from_start.md`.
    """

    state_dim: int
    action_dim: int
    seed: int = 42

    def __post_init__(self) -> None:
        self.model: LinearGaussianSSM = LinearGaussianSSM.random(
            state_dim=self.state_dim,
            obs_dim=self.state_dim,  # observe the state directly
            control_dim=self.action_dim,
            seed=self.seed,
        )
        # Filtered posterior moments — updated by `predict_next_state`.
        self._mu: np.ndarray = self.model.mu_0.copy()
        self._Sigma: np.ndarray = self.model.Sigma_0.copy()

    def reset(self) -> None:
        self._mu = self.model.mu_0.copy()
        self._Sigma = self.model.Sigma_0.copy()

    def predict_next_state(
        self,
        current_state: np.ndarray,
        action: np.ndarray,
    ) -> np.ndarray:
        """Predict E[x_{t+1} | x_t, u_t] under the SSM dynamics.

        Returns the deterministic mean prediction; for a full
        probabilistic forecast use `predict_next_state_with_cov`.
        """
        u: npt.NDArray[np.float64] = action.astype(np.float64)
        if u.shape == ():
            u = u[np.newaxis]
        return self.model.A @ current_state + (
            self.model.B @ u if self.model.control_dim > 0 else 0.0
        )

    def predict_next_state_with_cov(
        self,
        current_state: np.ndarray,
        current_cov: np.ndarray,
        action: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Predict mean + covariance of x_{t+1} given (x_t, Σ_t, u_t)."""
        mu_next = self.predict_next_state(current_state, action)
        Sigma_next = self.model.A @ current_cov @ self.model.A.T + self.model.Q
        return mu_next, Sigma_next

    def forecast(
        self,
        initial_state: np.ndarray,
        actions: list[np.ndarray],
    ) -> list[np.ndarray]:
        """Multi-step deterministic forecast (mean trajectory)."""
        traj: list[np.ndarray] = []
        x: npt.NDArray[np.float64] = initial_state.astype(np.float64)
        for a in actions:
            x = self.predict_next_state(x, np.asarray(a, dtype=np.float64))
            traj.append(x.copy())
        return traj

    def forecast_with_cov(
        self,
        initial_state: np.ndarray,
        initial_cov: np.ndarray,
        actions: list[np.ndarray],
    ) -> list[Tuple[np.ndarray, np.ndarray]]:
        """Multi-step probabilistic forecast (mean + cov trajectory)."""
        traj: list[Tuple[np.ndarray, np.ndarray]] = []
        x: npt.NDArray[np.float64] = initial_state.astype(np.float64)
        P: npt.NDArray[np.float64] = initial_cov.astype(np.float64)
        for a in actions:
            x, P = self.predict_next_state_with_cov(
                x,
                P,
                np.asarray(a, dtype=np.float64),
            )
            traj.append((x.copy(), P.copy()))
        return traj

predict_next_state(current_state, action)

Predict E[x_{t+1} | x_t, u_t] under the SSM dynamics.

Returns the deterministic mean prediction; for a full probabilistic forecast use predict_next_state_with_cov.

Source code in src/sc_neurocore/world_model/predictive_model.py
Python
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
def predict_next_state(
    self,
    current_state: np.ndarray,
    action: np.ndarray,
) -> np.ndarray:
    """Predict E[x_{t+1} | x_t, u_t] under the SSM dynamics.

    Returns the deterministic mean prediction; for a full
    probabilistic forecast use `predict_next_state_with_cov`.
    """
    u: npt.NDArray[np.float64] = action.astype(np.float64)
    if u.shape == ():
        u = u[np.newaxis]
    return self.model.A @ current_state + (
        self.model.B @ u if self.model.control_dim > 0 else 0.0
    )

predict_next_state_with_cov(current_state, current_cov, action)

Predict mean + covariance of x_{t+1} given (x_t, Σ_t, u_t).

Source code in src/sc_neurocore/world_model/predictive_model.py
Python
959
960
961
962
963
964
965
966
967
968
def predict_next_state_with_cov(
    self,
    current_state: np.ndarray,
    current_cov: np.ndarray,
    action: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
    """Predict mean + covariance of x_{t+1} given (x_t, Σ_t, u_t)."""
    mu_next = self.predict_next_state(current_state, action)
    Sigma_next = self.model.A @ current_cov @ self.model.A.T + self.model.Q
    return mu_next, Sigma_next

forecast(initial_state, actions)

Multi-step deterministic forecast (mean trajectory).

Source code in src/sc_neurocore/world_model/predictive_model.py
Python
970
971
972
973
974
975
976
977
978
979
980
981
def forecast(
    self,
    initial_state: np.ndarray,
    actions: list[np.ndarray],
) -> list[np.ndarray]:
    """Multi-step deterministic forecast (mean trajectory)."""
    traj: list[np.ndarray] = []
    x: npt.NDArray[np.float64] = initial_state.astype(np.float64)
    for a in actions:
        x = self.predict_next_state(x, np.asarray(a, dtype=np.float64))
        traj.append(x.copy())
    return traj

forecast_with_cov(initial_state, initial_cov, actions)

Multi-step probabilistic forecast (mean + cov trajectory).

Source code in src/sc_neurocore/world_model/predictive_model.py
Python
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
def forecast_with_cov(
    self,
    initial_state: np.ndarray,
    initial_cov: np.ndarray,
    actions: list[np.ndarray],
) -> list[Tuple[np.ndarray, np.ndarray]]:
    """Multi-step probabilistic forecast (mean + cov trajectory)."""
    traj: list[Tuple[np.ndarray, np.ndarray]] = []
    x: npt.NDArray[np.float64] = initial_state.astype(np.float64)
    P: npt.NDArray[np.float64] = initial_cov.astype(np.float64)
    for a in actions:
        x, P = self.predict_next_state_with_cov(
            x,
            P,
            np.asarray(a, dtype=np.float64),
        )
        traj.append((x.copy(), P.copy()))
    return traj