Skip to main content

sc_neurocore_engine/
lgssm.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Rust LGSSM Kalman filter (parity with src/sc_neurocore/world_model/predictive_model.py)
8
9//! Rust implementation of the Kalman filter forward pass for a
10//! linear Gaussian state-space model.
11//!
12//! Match for `KalmanFilter.filter()` in
13//! `src/sc_neurocore/world_model/predictive_model.py` so that the
14//! Python and Rust paths return identical (within float64
15//! round-off) means, covariances, and log-likelihood.
16//!
17//! References match the Python module:
18//!   Kalman 1960; Bishop 2006 §13.3.1.
19//!
20//! Algorithm per timestep t:
21//!   x_pred = A x_filt + B u
22//!   P_pred = A P_filt A^T + Q
23//!   e_t = y_t - C x_pred - D u
24//!   S = C P_pred C^T + R
25//!   K = P_pred C^T S^{-1}
26//!   x_filt = x_pred + K e_t
27//!   P_filt = (I - K C) P_pred (I - K C)^T + K R K^T   (Joseph form)
28//!   log-lik += -0.5 * (p log 2pi + log |S| + e_t^T S^{-1} e_t)
29
30use nalgebra::{Cholesky as NaCholesky, DMatrix, DVector};
31use ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
32
33/// Result of the Kalman filter forward pass.
34pub struct KalmanResult {
35    pub means: Array2<f64>,            // (T, d) filtered means
36    pub covariances: Array3<f64>,      // (T, d, d) filtered covariances
37    pub pred_means: Array2<f64>,       // (T, d) predicted means
38    pub pred_covariances: Array3<f64>, // (T, d, d) predicted covariances
39    pub log_likelihood: f64,
40}
41
42/// Run the forward Kalman filter on a sequence of observations.
43///
44/// All input matrices are dense `Array2<f64>`. Controls may be
45/// empty (shape (T, 0)) when the model has no control input.
46///
47/// # Panics
48///
49/// Panics if dimensions are inconsistent. The caller (PyO3
50/// wrapper) validates shapes before invoking.
51pub fn kalman_filter(
52    observations: ArrayView2<f64>, // (T, p)
53    controls: ArrayView2<f64>,     // (T, m) or (T, 0)
54    a: ArrayView2<f64>,            // (d, d)
55    b: ArrayView2<f64>,            // (d, m) or (d, 0)
56    c: ArrayView2<f64>,            // (p, d)
57    d: ArrayView2<f64>,            // (p, m) or (p, 0)
58    q: ArrayView2<f64>,            // (d, d)
59    r: ArrayView2<f64>,            // (p, p)
60    mu_0: ArrayView1<f64>,         // (d,)
61    sigma_0: ArrayView2<f64>,      // (d, d)
62) -> KalmanResult {
63    let t_len = observations.nrows();
64    let p_dim = observations.ncols();
65    let d_dim = a.nrows();
66    let m_dim = b.ncols();
67
68    let has_control = m_dim > 0;
69
70    let mut means = Array2::<f64>::zeros((t_len, d_dim));
71    let mut covs = Array3::<f64>::zeros((t_len, d_dim, d_dim));
72    let mut pred_means = Array2::<f64>::zeros((t_len, d_dim));
73    let mut pred_covs = Array3::<f64>::zeros((t_len, d_dim, d_dim));
74
75    let mut x_pred: Array1<f64> = mu_0.to_owned();
76    let mut p_pred: Array2<f64> = sigma_0.to_owned();
77
78    let mut log_lik = 0.0_f64;
79    let two_pi_log = (2.0 * std::f64::consts::PI).ln();
80    let i_d = Array2::<f64>::eye(d_dim);
81
82    for t in 0..t_len {
83        // Record predicted state for this step
84        pred_means.slice_mut(s![t, ..]).assign(&x_pred);
85        pred_covs.slice_mut(s![t, .., ..]).assign(&p_pred);
86
87        let y_t = observations.slice(s![t, ..]);
88
89        // Innovation: e = y - C x_pred - D u
90        let mut y_hat = c.dot(&x_pred);
91        if has_control {
92            let u_t = controls.slice(s![t, ..]);
93            y_hat = y_hat + d.dot(&u_t);
94        }
95        let innov = &y_t - &y_hat;
96
97        // Innovation covariance: S = C P_pred C^T + R
98        let s_mat = c.dot(&p_pred).dot(&c.t()) + r;
99
100        // S is symmetric positive-definite; factor it once with a Cholesky
101        // decomposition (nalgebra, LAPACK-grade) and reuse the single factor for the
102        // log-determinant, the innovation quadratic form, and the Kalman gain —
103        // never forming S^{-1} explicitly.
104        let s_na = DMatrix::<f64>::from_fn(p_dim, p_dim, |i, j| s_mat[(i, j)]);
105        let (logdet_s, s_inv_innov, k_gain) = match NaCholesky::new(s_na) {
106            Some(chol) => {
107                // log|S| = 2 Σ ln L_ii — the stable sum-of-logs form.
108                let l = chol.l();
109                let logdet = 2.0 * (0..p_dim).map(|i| l[(i, i)].ln()).sum::<f64>();
110                // S^{-1} innov for the quadratic form, via the triangular solves.
111                let innov_na = DVector::<f64>::from_fn(p_dim, |i, _| innov[i]);
112                let z = chol.solve(&innov_na);
113                let s_inv_innov = Array1::<f64>::from_iter((0..p_dim).map(|i| z[i]));
114                // Kalman gain K = P_pred C^T S^{-1}. With S and P_pred symmetric,
115                // K^T = S^{-1} (C P_pred), so solve S X = C P_pred and transpose —
116                // no explicit inverse.
117                let cp = c.dot(&p_pred); // (p × d)
118                let cp_na = DMatrix::<f64>::from_fn(p_dim, d_dim, |i, j| cp[(i, j)]);
119                let x = chol.solve(&cp_na); // S^{-1} (C P_pred), (p × d)
120                let k = Array2::<f64>::from_shape_fn((d_dim, p_dim), |(i, j)| x[(j, i)]);
121                (logdet, s_inv_innov, k)
122            }
123            None => {
124                // Defensive: a non-positive-definite innovation covariance cannot
125                // occur while R is positive-definite. Mirror the prior NaN
126                // propagation rather than panicking.
127                (
128                    f64::NAN,
129                    Array1::<f64>::zeros(p_dim),
130                    Array2::<f64>::zeros((d_dim, p_dim)),
131                )
132            }
133        };
134
135        let quad_form = innov.dot(&s_inv_innov);
136        log_lik += -0.5 * (p_dim as f64 * two_pi_log + logdet_s + quad_form);
137
138        // Filtered state: x_filt = x_pred + K e
139        let x_filt = &x_pred + &k_gain.dot(&innov);
140
141        // Joseph form for filtered covariance:
142        //   P_filt = (I - K C) P_pred (I - K C)^T + K R K^T
143        let i_minus_kc = &i_d - &k_gain.dot(&c);
144        let p_filt = i_minus_kc.dot(&p_pred).dot(&i_minus_kc.t()) + k_gain.dot(&r).dot(&k_gain.t());
145
146        means.slice_mut(s![t, ..]).assign(&x_filt);
147        covs.slice_mut(s![t, .., ..]).assign(&p_filt);
148
149        // Predict next state
150        let mut x_next = a.dot(&x_filt);
151        if has_control {
152            let u_t = controls.slice(s![t, ..]);
153            x_next = x_next + b.dot(&u_t);
154        }
155        let p_next = a.dot(&p_filt).dot(&a.t()) + q;
156        x_pred = x_next;
157        p_pred = p_next;
158    }
159
160    KalmanResult {
161        means,
162        covariances: covs,
163        pred_means,
164        pred_covariances: pred_covs,
165        log_likelihood: log_lik,
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use ndarray::array;
173
174    #[test]
175    fn kalman_scalar_random_walk_matches_analytic() {
176        // 1-D random walk: A=1, C=1, Q=0.1, R=1.
177        // First-step prediction: mu_0 = 0, Sigma_0 = 1.
178        // After observing y_0 = 1:
179        //   S = 1 + 1 = 2
180        //   K = 1 / 2 = 0.5
181        //   x_filt = 0 + 0.5 * (1 - 0) = 0.5
182        //   P_filt = (1 - 0.5) * 1 * (1 - 0.5) + 0.5 * 1 * 0.5 = 0.25 + 0.25 = 0.5
183        let a = array![[1.0]];
184        let b = array![[]];
185        let c = array![[1.0]];
186        let d = array![[]];
187        let q = array![[0.1]];
188        let r_mat = array![[1.0]];
189        let mu_0 = array![0.0];
190        let sigma_0 = array![[1.0]];
191
192        let obs = array![[1.0_f64]];
193        let controls = Array2::<f64>::zeros((1, 0));
194
195        let result = kalman_filter(
196            obs.view(),
197            controls.view(),
198            a.view(),
199            b.view(),
200            c.view(),
201            d.view(),
202            q.view(),
203            r_mat.view(),
204            mu_0.view(),
205            sigma_0.view(),
206        );
207
208        assert!((result.means[(0, 0)] - 0.5).abs() < 1e-12);
209        assert!((result.covariances[(0, 0, 0)] - 0.5).abs() < 1e-12);
210
211        // Exact single-step Gaussian log-likelihood exercises the Cholesky
212        // log-determinant and quadratic-form path: S = 2, innov = 1, p = 1.
213        //   log N(y | y_hat, S) = -0.5 (log 2π + log|S| + innovᵀ S⁻¹ innov)
214        //                       = -0.5 (log 2π + log 2 + 0.5)
215        let expected_ll = -0.5 * ((2.0 * std::f64::consts::PI).ln() + 2.0_f64.ln() + 0.5);
216        assert!((result.log_likelihood - expected_ll).abs() < 1e-12);
217    }
218
219    #[test]
220    fn kalman_log_likelihood_finite() {
221        // 2-D state, 1-D obs, T=10 random sequence — log-lik must be finite.
222        let a = array![[0.9, 0.1], [0.0, 0.95]];
223        let b = array![[], []];
224        let c = array![[1.0, 0.0]];
225        let d = array![[]];
226        let q = array![[0.01, 0.0], [0.0, 0.01]];
227        let r_mat = array![[0.1]];
228        let mu_0 = array![0.0, 0.0];
229        let sigma_0 = array![[1.0, 0.0], [0.0, 1.0]];
230
231        let obs = array![
232            [0.1],
233            [0.2],
234            [0.15],
235            [0.18],
236            [0.22],
237            [0.25],
238            [0.21],
239            [0.24],
240            [0.27],
241            [0.26],
242        ];
243        let controls = Array2::<f64>::zeros((10, 0));
244
245        let result = kalman_filter(
246            obs.view(),
247            controls.view(),
248            a.view(),
249            b.view(),
250            c.view(),
251            d.view(),
252            q.view(),
253            r_mat.view(),
254            mu_0.view(),
255            sigma_0.view(),
256        );
257
258        assert!(result.log_likelihood.is_finite());
259    }
260
261    #[test]
262    fn kalman_two_dim_obs_symmetric_psd_and_finite() {
263        // 2-D state, 2-D observation with a non-diagonal C and a non-diagonal R so
264        // the gain transpose and the 2×2 Cholesky solve are exercised (p = d = 2).
265        let a = array![[0.95, 0.0], [0.1, 0.9]];
266        let b = array![[], []];
267        let c = array![[1.0, 0.2], [0.0, 1.0]];
268        let d = array![[], []];
269        let q = array![[0.02, 0.0], [0.0, 0.02]];
270        let r_mat = array![[0.15, 0.05], [0.05, 0.2]];
271        let mu_0 = array![0.0, 0.0];
272        let sigma_0 = array![[1.0, 0.0], [0.0, 1.0]];
273
274        let obs = array![
275            [0.10, 0.05],
276            [0.20, 0.12],
277            [0.18, 0.09],
278            [0.25, 0.15],
279            [0.30, 0.20],
280        ];
281        let controls = Array2::<f64>::zeros((5, 0));
282
283        let result = kalman_filter(
284            obs.view(),
285            controls.view(),
286            a.view(),
287            b.view(),
288            c.view(),
289            d.view(),
290            q.view(),
291            r_mat.view(),
292            mu_0.view(),
293            sigma_0.view(),
294        );
295
296        assert!(result.log_likelihood.is_finite());
297        for t in 0..obs.nrows() {
298            assert!(result.means[(t, 0)].is_finite());
299            assert!(result.means[(t, 1)].is_finite());
300            // Filtered covariance must stay symmetric (Joseph form) and PSD.
301            let p00 = result.covariances[(t, 0, 0)];
302            let p11 = result.covariances[(t, 1, 1)];
303            let p01 = result.covariances[(t, 0, 1)];
304            let p10 = result.covariances[(t, 1, 0)];
305            assert!(
306                (p01 - p10).abs() < 1e-12,
307                "covariance not symmetric at t={t}"
308            );
309            assert!(p00 >= 0.0 && p11 >= 0.0, "negative variance at t={t}");
310            assert!(
311                p00 * p11 - p01 * p10 >= -1e-12,
312                "covariance not PSD at t={t}"
313            );
314        }
315    }
316}