Skip to main content

sc_neurocore_engine/
lgssm.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Rust LGSSM Kalman filter (parity with src/sc_neurocore/world_model/predictive_model.py)
8
9//! Rust implementation of the Kalman filter forward pass for a
10//! linear Gaussian state-space model.
11//!
12//! Match for `KalmanFilter.filter()` in
13//! `src/sc_neurocore/world_model/predictive_model.py` so that the
14//! Python and Rust paths return identical (within float64
15//! round-off) means, covariances, and log-likelihood.
16//!
17//! References match the Python module:
18//!   Kalman 1960; Bishop 2006 §13.3.1.
19//!
20//! Algorithm per timestep t:
21//!   x_pred = A x_filt + B u
22//!   P_pred = A P_filt A^T + Q
23//!   e_t = y_t - C x_pred - D u
24//!   S = C P_pred C^T + R
25//!   K = P_pred C^T S^{-1}
26//!   x_filt = x_pred + K e_t
27//!   P_filt = (I - K C) P_pred (I - K C)^T + K R K^T   (Joseph form)
28//!   log-lik += -0.5 * (p log 2pi + log |S| + e_t^T S^{-1} e_t)
29
30use ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
31
32/// Result of the Kalman filter forward pass.
33pub struct KalmanResult {
34    pub means: Array2<f64>,            // (T, d) filtered means
35    pub covariances: Array3<f64>,      // (T, d, d) filtered covariances
36    pub pred_means: Array2<f64>,       // (T, d) predicted means
37    pub pred_covariances: Array3<f64>, // (T, d, d) predicted covariances
38    pub log_likelihood: f64,
39}
40
41/// Run the forward Kalman filter on a sequence of observations.
42///
43/// All input matrices are dense `Array2<f64>`. Controls may be
44/// empty (shape (T, 0)) when the model has no control input.
45///
46/// # Panics
47///
48/// Panics if dimensions are inconsistent. The caller (PyO3
49/// wrapper) validates shapes before invoking.
50pub fn kalman_filter(
51    observations: ArrayView2<f64>, // (T, p)
52    controls: ArrayView2<f64>,     // (T, m) or (T, 0)
53    a: ArrayView2<f64>,            // (d, d)
54    b: ArrayView2<f64>,            // (d, m) or (d, 0)
55    c: ArrayView2<f64>,            // (p, d)
56    d: ArrayView2<f64>,            // (p, m) or (p, 0)
57    q: ArrayView2<f64>,            // (d, d)
58    r: ArrayView2<f64>,            // (p, p)
59    mu_0: ArrayView1<f64>,         // (d,)
60    sigma_0: ArrayView2<f64>,      // (d, d)
61) -> KalmanResult {
62    let t_len = observations.nrows();
63    let p_dim = observations.ncols();
64    let d_dim = a.nrows();
65    let m_dim = b.ncols();
66
67    let has_control = m_dim > 0;
68
69    let mut means = Array2::<f64>::zeros((t_len, d_dim));
70    let mut covs = Array3::<f64>::zeros((t_len, d_dim, d_dim));
71    let mut pred_means = Array2::<f64>::zeros((t_len, d_dim));
72    let mut pred_covs = Array3::<f64>::zeros((t_len, d_dim, d_dim));
73
74    let mut x_pred: Array1<f64> = mu_0.to_owned();
75    let mut p_pred: Array2<f64> = sigma_0.to_owned();
76
77    let mut log_lik = 0.0_f64;
78    let two_pi_log = (2.0 * std::f64::consts::PI).ln();
79    let i_d = Array2::<f64>::eye(d_dim);
80
81    for t in 0..t_len {
82        // Record predicted state for this step
83        pred_means.slice_mut(s![t, ..]).assign(&x_pred);
84        pred_covs.slice_mut(s![t, .., ..]).assign(&p_pred);
85
86        let y_t = observations.slice(s![t, ..]);
87
88        // Innovation: e = y - C x_pred - D u
89        let mut y_hat = c.dot(&x_pred);
90        if has_control {
91            let u_t = controls.slice(s![t, ..]);
92            y_hat = y_hat + d.dot(&u_t);
93        }
94        let innov = &y_t - &y_hat;
95
96        // Innovation covariance: S = C P_pred C^T + R
97        let s_mat = c.dot(&p_pred).dot(&c.t()) + r;
98
99        // Solve S z = innov using Gaussian elimination (small p)
100        let s_inv = invert_psd_matrix(&s_mat);
101        let s_inv_innov = s_inv.dot(&innov);
102
103        // Log-determinant of S via LU/Cholesky-style recursion
104        let logdet_s = log_det_psd(&s_mat);
105        let quad_form = innov.dot(&s_inv_innov);
106        log_lik += -0.5 * (p_dim as f64 * two_pi_log + logdet_s + quad_form);
107
108        // Kalman gain: K = P_pred C^T S^{-1}
109        let k_gain = p_pred.dot(&c.t()).dot(&s_inv);
110
111        // Filtered state: x_filt = x_pred + K e
112        let x_filt = &x_pred + &k_gain.dot(&innov);
113
114        // Joseph form for filtered covariance:
115        //   P_filt = (I - K C) P_pred (I - K C)^T + K R K^T
116        let i_minus_kc = &i_d - &k_gain.dot(&c);
117        let p_filt = i_minus_kc.dot(&p_pred).dot(&i_minus_kc.t()) + k_gain.dot(&r).dot(&k_gain.t());
118
119        means.slice_mut(s![t, ..]).assign(&x_filt);
120        covs.slice_mut(s![t, .., ..]).assign(&p_filt);
121
122        // Predict next state
123        let mut x_next = a.dot(&x_filt);
124        if has_control {
125            let u_t = controls.slice(s![t, ..]);
126            x_next = x_next + b.dot(&u_t);
127        }
128        let p_next = a.dot(&p_filt).dot(&a.t()) + q;
129        x_pred = x_next;
130        p_pred = p_next;
131    }
132
133    KalmanResult {
134        means,
135        covariances: covs,
136        pred_means,
137        pred_covariances: pred_covs,
138        log_likelihood: log_lik,
139    }
140}
141
142/// Cholesky-decomposition log-determinant of a symmetric PSD matrix.
143///
144/// Returns 2 * sum(log diag(L)) where M = L L^T. Returns NaN if
145/// the matrix is not positive definite.
146fn log_det_psd(m: &Array2<f64>) -> f64 {
147    let l = cholesky(m);
148    let mut acc = 0.0_f64;
149    for i in 0..l.nrows() {
150        let d = l[(i, i)];
151        if d <= 0.0 {
152            return f64::NAN;
153        }
154        acc += d.ln();
155    }
156    2.0 * acc
157}
158
159/// Cholesky-based inversion of a symmetric PSD matrix.
160///
161/// Computes `M^{-1}` via L L^T = M then inverting via forward +
162/// backward substitution on the identity.
163fn invert_psd_matrix(m: &Array2<f64>) -> Array2<f64> {
164    let n = m.nrows();
165    let l = cholesky(m);
166    let mut inv = Array2::<f64>::eye(n);
167    // Solve L Y = I
168    for k in 0..n {
169        for i in 0..n {
170            let mut sum = inv[(i, k)];
171            for j in 0..i {
172                sum -= l[(i, j)] * inv[(j, k)];
173            }
174            inv[(i, k)] = sum / l[(i, i)];
175        }
176    }
177    // Solve L^T X = Y
178    let mut out = Array2::<f64>::zeros((n, n));
179    for k in 0..n {
180        for i in (0..n).rev() {
181            let mut sum = inv[(i, k)];
182            for j in (i + 1)..n {
183                sum -= l[(j, i)] * out[(j, k)];
184            }
185            out[(i, k)] = sum / l[(i, i)];
186        }
187    }
188    out
189}
190
191/// Cholesky decomposition: returns L lower-triangular such that
192/// L L^T = M for symmetric PSD M. No checks on PSD-ness; if M is
193/// not PSD, the diagonal of L will contain a NaN.
194fn cholesky(m: &Array2<f64>) -> Array2<f64> {
195    let n = m.nrows();
196    let mut l = Array2::<f64>::zeros((n, n));
197    for i in 0..n {
198        for j in 0..=i {
199            let mut sum = m[(i, j)];
200            for k in 0..j {
201                sum -= l[(i, k)] * l[(j, k)];
202            }
203            if i == j {
204                l[(i, j)] = sum.sqrt();
205            } else {
206                l[(i, j)] = sum / l[(j, j)];
207            }
208        }
209    }
210    l
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use ndarray::array;
217
218    #[test]
219    fn cholesky_identity() {
220        let i = Array2::<f64>::eye(3);
221        let l = cholesky(&i);
222        for r in 0..3 {
223            for c in 0..3 {
224                let expected = if r == c { 1.0 } else { 0.0 };
225                assert!((l[(r, c)] - expected).abs() < 1e-12);
226            }
227        }
228    }
229
230    #[test]
231    fn invert_psd_matrix_identity() {
232        let i = Array2::<f64>::eye(3);
233        let inv = invert_psd_matrix(&i);
234        for r in 0..3 {
235            for c in 0..3 {
236                let expected = if r == c { 1.0 } else { 0.0 };
237                assert!((inv[(r, c)] - expected).abs() < 1e-12);
238            }
239        }
240    }
241
242    #[test]
243    fn log_det_identity_is_zero() {
244        let i = Array2::<f64>::eye(4);
245        assert!(log_det_psd(&i).abs() < 1e-12);
246    }
247
248    #[test]
249    fn kalman_scalar_random_walk_matches_analytic() {
250        // 1-D random walk: A=1, C=1, Q=0.1, R=1.
251        // First-step prediction: mu_0 = 0, Sigma_0 = 1.
252        // After observing y_0 = 1:
253        //   S = 1 + 1 = 2
254        //   K = 1 / 2 = 0.5
255        //   x_filt = 0 + 0.5 * (1 - 0) = 0.5
256        //   P_filt = (1 - 0.5) * 1 * (1 - 0.5) + 0.5 * 1 * 0.5 = 0.25 + 0.25 = 0.5
257        let a = array![[1.0]];
258        let b = array![[]];
259        let c = array![[1.0]];
260        let d = array![[]];
261        let q = array![[0.1]];
262        let r_mat = array![[1.0]];
263        let mu_0 = array![0.0];
264        let sigma_0 = array![[1.0]];
265
266        let obs = array![[1.0_f64]];
267        let controls = Array2::<f64>::zeros((1, 0));
268
269        let result = kalman_filter(
270            obs.view(),
271            controls.view(),
272            a.view(),
273            b.view(),
274            c.view(),
275            d.view(),
276            q.view(),
277            r_mat.view(),
278            mu_0.view(),
279            sigma_0.view(),
280        );
281
282        assert!((result.means[(0, 0)] - 0.5).abs() < 1e-12);
283        assert!((result.covariances[(0, 0, 0)] - 0.5).abs() < 1e-12);
284    }
285
286    #[test]
287    fn kalman_log_likelihood_finite() {
288        // 2-D state, 1-D obs, T=10 random sequence — log-lik must be finite.
289        let a = array![[0.9, 0.1], [0.0, 0.95]];
290        let b = array![[], []];
291        let c = array![[1.0, 0.0]];
292        let d = array![[]];
293        let q = array![[0.01, 0.0], [0.0, 0.01]];
294        let r_mat = array![[0.1]];
295        let mu_0 = array![0.0, 0.0];
296        let sigma_0 = array![[1.0, 0.0], [0.0, 1.0]];
297
298        let obs = array![
299            [0.1],
300            [0.2],
301            [0.15],
302            [0.18],
303            [0.22],
304            [0.25],
305            [0.21],
306            [0.24],
307            [0.27],
308            [0.26],
309        ];
310        let controls = Array2::<f64>::zeros((10, 0));
311
312        let result = kalman_filter(
313            obs.view(),
314            controls.view(),
315            a.view(),
316            b.view(),
317            c.view(),
318            d.view(),
319            q.view(),
320            r_mat.view(),
321            mu_0.view(),
322            sigma_0.view(),
323        );
324
325        assert!(result.log_likelihood.is_finite());
326    }
327}