World Model — Spike Prediction + Planning¶
Three components: (1) online-learnable spike predictor for codec integration, (2) stochastic state-transition model, (3) greedy action planner.
SpikePredictor — Online Autoregressive Codec¶
The core workhorse. Predicts multi-channel spike patterns from recent history using a linear autoregressive model trained online via LMS (Least Mean Squares). No backprop, no batches — updates one sample at a time.
Codec integration: Encoder and decoder both maintain identical SpikePredictor instances. Both see the same history. Prediction error (XOR of actual vs predicted) is what gets transmitted. At the decoder, XOR recovers the original. Deterministic: same history → same prediction → lossless roundtrip.
| Parameter | Default | Meaning |
|---|---|---|
n_channels |
(required) | Number of spike channels |
history_len |
8 | Context window (K past timesteps) |
lr |
0.01 | LMS learning rate |
threshold |
0.5 | Binary prediction threshold |
Codec functions:
predict_and_xor_world_model(spikes, n_channels, ...)→ (errors, correct_count) — Encoderxor_and_recover_world_model(errors, n_channels, ...)→ spikes — Decoder
PredictiveWorldModel — State Transitions¶
Linear transition model: state_next = clip(T @ [state; action], 0, 1) where T is a row-normalized transition matrix. Provides predict_next_state() and forecast() for multi-step rollouts.
SCPlanner — Greedy Action Selection¶
Uses PredictiveWorldModel for random-shooting planning: sample N candidate actions, predict outcomes, pick the one closest to the goal state.
propose_action(current, goal, n_candidates)— Best single actionplan_sequence(current, goal, horizon)— Greedy multi-step plan
Usage¶
from sc_neurocore.world_model import SpikePredictor
from sc_neurocore.world_model.spike_predictor import (
predict_and_xor_world_model,
xor_and_recover_world_model,
)
import numpy as np
# Lossless codec roundtrip
spikes = (np.random.rand(100, 32) < 0.3).astype(np.int8)
errors, correct = predict_and_xor_world_model(spikes, n_channels=32)
recovered = xor_and_recover_world_model(errors, n_channels=32)
assert np.array_equal(spikes, recovered) # Always true
print(f"Prediction accuracy: {correct / (100 * 32):.1%}")
# Planning
from sc_neurocore.world_model import PredictiveWorldModel, SCPlanner
model = PredictiveWorldModel(state_dim=4, action_dim=2)
planner = SCPlanner(world_model=model)
plan = planner.plan_sequence(
current_state=np.array([0.1, 0.2, 0.3, 0.4]),
goal_state=np.array([0.9, 0.8, 0.7, 0.6]),
horizon=5,
)
sc_neurocore.world_model
¶
sc_neurocore.world_model -- Tier: research (experimental / research).
SCPlanner
dataclass
¶
A planner that uses a PredictiveWorldModel to select actions.
Source code in src/sc_neurocore/world_model/planner.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | |
propose_action(current_state, goal_state, n_candidates=10)
¶
Propose the best action among n_candidates based on predicted outcome.
Source code in src/sc_neurocore/world_model/planner.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | |
plan_sequence(current_state, goal_state, horizon=5)
¶
Simple greedy planning for a sequence of actions.
Source code in src/sc_neurocore/world_model/planner.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | |
PredictiveWorldModel
dataclass
¶
A stochastic predictive world model. Predicts state_next = f(state_curr, action).
Source code in src/sc_neurocore/world_model/predictive_model.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | |
predict_next_state(current_state, action)
¶
Predicts the next state given current state and action. Inputs: current_state: (state_dim,) array of probabilities. action: (action_dim,) array of probabilities. Returns: next_state: (state_dim,) predicted probabilities.
Source code in src/sc_neurocore/world_model/predictive_model.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | |
forecast(initial_state, actions)
¶
Forecast multiple steps ahead given a sequence of actions.
Source code in src/sc_neurocore/world_model/predictive_model.py
53 54 55 56 57 58 59 60 61 62 63 64 | |
SpikePredictor
dataclass
¶
Online autoregressive spike pattern predictor.
Learns to predict spike[t] from spike[t-K:t] per channel. Weight matrix W of shape (N, N*K) maps flattened history to per-channel firing probabilities. Binary prediction via threshold.
LMS update after each timestep.
W += lr * outer(error, history)
where error = actual - predicted_prob.
Parameters¶
n_channels : int Number of spike channels. history_len : int Number of past timesteps to use as context (K). lr : float LMS learning rate. threshold : float Probability threshold for binary prediction. seed : int RNG seed for weight initialization.
Source code in src/sc_neurocore/world_model/spike_predictor.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | |
predict_probs()
¶
Predict per-channel firing probabilities from history.
Source code in src/sc_neurocore/world_model/spike_predictor.py
77 78 79 80 81 82 83 | |
predict()
¶
Predict binary spike pattern.
Source code in src/sc_neurocore/world_model/spike_predictor.py
85 86 87 | |
update(actual)
¶
Update weights with observed spike pattern (LMS rule).
Parameters¶
actual : ndarray of shape (n_channels,), binary
Source code in src/sc_neurocore/world_model/spike_predictor.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
reset()
¶
Reset to initial state (same seed → same weights).
Source code in src/sc_neurocore/world_model/spike_predictor.py
108 109 110 | |