1use nalgebra::{Cholesky as NaCholesky, DMatrix, DVector};
31use ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
32
33pub struct KalmanResult {
35 pub means: Array2<f64>, pub covariances: Array3<f64>, pub pred_means: Array2<f64>, pub pred_covariances: Array3<f64>, pub log_likelihood: f64,
40}
41
42pub fn kalman_filter(
52 observations: ArrayView2<f64>, controls: ArrayView2<f64>, a: ArrayView2<f64>, b: ArrayView2<f64>, c: ArrayView2<f64>, d: ArrayView2<f64>, q: ArrayView2<f64>, r: ArrayView2<f64>, mu_0: ArrayView1<f64>, sigma_0: ArrayView2<f64>, ) -> KalmanResult {
63 let t_len = observations.nrows();
64 let p_dim = observations.ncols();
65 let d_dim = a.nrows();
66 let m_dim = b.ncols();
67
68 let has_control = m_dim > 0;
69
70 let mut means = Array2::<f64>::zeros((t_len, d_dim));
71 let mut covs = Array3::<f64>::zeros((t_len, d_dim, d_dim));
72 let mut pred_means = Array2::<f64>::zeros((t_len, d_dim));
73 let mut pred_covs = Array3::<f64>::zeros((t_len, d_dim, d_dim));
74
75 let mut x_pred: Array1<f64> = mu_0.to_owned();
76 let mut p_pred: Array2<f64> = sigma_0.to_owned();
77
78 let mut log_lik = 0.0_f64;
79 let two_pi_log = (2.0 * std::f64::consts::PI).ln();
80 let i_d = Array2::<f64>::eye(d_dim);
81
82 for t in 0..t_len {
83 pred_means.slice_mut(s![t, ..]).assign(&x_pred);
85 pred_covs.slice_mut(s![t, .., ..]).assign(&p_pred);
86
87 let y_t = observations.slice(s![t, ..]);
88
89 let mut y_hat = c.dot(&x_pred);
91 if has_control {
92 let u_t = controls.slice(s![t, ..]);
93 y_hat = y_hat + d.dot(&u_t);
94 }
95 let innov = &y_t - &y_hat;
96
97 let s_mat = c.dot(&p_pred).dot(&c.t()) + r;
99
100 let s_na = DMatrix::<f64>::from_fn(p_dim, p_dim, |i, j| s_mat[(i, j)]);
105 let (logdet_s, s_inv_innov, k_gain) = match NaCholesky::new(s_na) {
106 Some(chol) => {
107 let l = chol.l();
109 let logdet = 2.0 * (0..p_dim).map(|i| l[(i, i)].ln()).sum::<f64>();
110 let innov_na = DVector::<f64>::from_fn(p_dim, |i, _| innov[i]);
112 let z = chol.solve(&innov_na);
113 let s_inv_innov = Array1::<f64>::from_iter((0..p_dim).map(|i| z[i]));
114 let cp = c.dot(&p_pred); let cp_na = DMatrix::<f64>::from_fn(p_dim, d_dim, |i, j| cp[(i, j)]);
119 let x = chol.solve(&cp_na); let k = Array2::<f64>::from_shape_fn((d_dim, p_dim), |(i, j)| x[(j, i)]);
121 (logdet, s_inv_innov, k)
122 }
123 None => {
124 (
128 f64::NAN,
129 Array1::<f64>::zeros(p_dim),
130 Array2::<f64>::zeros((d_dim, p_dim)),
131 )
132 }
133 };
134
135 let quad_form = innov.dot(&s_inv_innov);
136 log_lik += -0.5 * (p_dim as f64 * two_pi_log + logdet_s + quad_form);
137
138 let x_filt = &x_pred + &k_gain.dot(&innov);
140
141 let i_minus_kc = &i_d - &k_gain.dot(&c);
144 let p_filt = i_minus_kc.dot(&p_pred).dot(&i_minus_kc.t()) + k_gain.dot(&r).dot(&k_gain.t());
145
146 means.slice_mut(s![t, ..]).assign(&x_filt);
147 covs.slice_mut(s![t, .., ..]).assign(&p_filt);
148
149 let mut x_next = a.dot(&x_filt);
151 if has_control {
152 let u_t = controls.slice(s![t, ..]);
153 x_next = x_next + b.dot(&u_t);
154 }
155 let p_next = a.dot(&p_filt).dot(&a.t()) + q;
156 x_pred = x_next;
157 p_pred = p_next;
158 }
159
160 KalmanResult {
161 means,
162 covariances: covs,
163 pred_means,
164 pred_covariances: pred_covs,
165 log_likelihood: log_lik,
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use ndarray::array;
173
174 #[test]
175 fn kalman_scalar_random_walk_matches_analytic() {
176 let a = array![[1.0]];
184 let b = array![[]];
185 let c = array![[1.0]];
186 let d = array![[]];
187 let q = array![[0.1]];
188 let r_mat = array![[1.0]];
189 let mu_0 = array![0.0];
190 let sigma_0 = array![[1.0]];
191
192 let obs = array![[1.0_f64]];
193 let controls = Array2::<f64>::zeros((1, 0));
194
195 let result = kalman_filter(
196 obs.view(),
197 controls.view(),
198 a.view(),
199 b.view(),
200 c.view(),
201 d.view(),
202 q.view(),
203 r_mat.view(),
204 mu_0.view(),
205 sigma_0.view(),
206 );
207
208 assert!((result.means[(0, 0)] - 0.5).abs() < 1e-12);
209 assert!((result.covariances[(0, 0, 0)] - 0.5).abs() < 1e-12);
210
211 let expected_ll = -0.5 * ((2.0 * std::f64::consts::PI).ln() + 2.0_f64.ln() + 0.5);
216 assert!((result.log_likelihood - expected_ll).abs() < 1e-12);
217 }
218
219 #[test]
220 fn kalman_log_likelihood_finite() {
221 let a = array![[0.9, 0.1], [0.0, 0.95]];
223 let b = array![[], []];
224 let c = array![[1.0, 0.0]];
225 let d = array![[]];
226 let q = array![[0.01, 0.0], [0.0, 0.01]];
227 let r_mat = array![[0.1]];
228 let mu_0 = array![0.0, 0.0];
229 let sigma_0 = array![[1.0, 0.0], [0.0, 1.0]];
230
231 let obs = array![
232 [0.1],
233 [0.2],
234 [0.15],
235 [0.18],
236 [0.22],
237 [0.25],
238 [0.21],
239 [0.24],
240 [0.27],
241 [0.26],
242 ];
243 let controls = Array2::<f64>::zeros((10, 0));
244
245 let result = kalman_filter(
246 obs.view(),
247 controls.view(),
248 a.view(),
249 b.view(),
250 c.view(),
251 d.view(),
252 q.view(),
253 r_mat.view(),
254 mu_0.view(),
255 sigma_0.view(),
256 );
257
258 assert!(result.log_likelihood.is_finite());
259 }
260
261 #[test]
262 fn kalman_two_dim_obs_symmetric_psd_and_finite() {
263 let a = array![[0.95, 0.0], [0.1, 0.9]];
266 let b = array![[], []];
267 let c = array![[1.0, 0.2], [0.0, 1.0]];
268 let d = array![[], []];
269 let q = array![[0.02, 0.0], [0.0, 0.02]];
270 let r_mat = array![[0.15, 0.05], [0.05, 0.2]];
271 let mu_0 = array![0.0, 0.0];
272 let sigma_0 = array![[1.0, 0.0], [0.0, 1.0]];
273
274 let obs = array![
275 [0.10, 0.05],
276 [0.20, 0.12],
277 [0.18, 0.09],
278 [0.25, 0.15],
279 [0.30, 0.20],
280 ];
281 let controls = Array2::<f64>::zeros((5, 0));
282
283 let result = kalman_filter(
284 obs.view(),
285 controls.view(),
286 a.view(),
287 b.view(),
288 c.view(),
289 d.view(),
290 q.view(),
291 r_mat.view(),
292 mu_0.view(),
293 sigma_0.view(),
294 );
295
296 assert!(result.log_likelihood.is_finite());
297 for t in 0..obs.nrows() {
298 assert!(result.means[(t, 0)].is_finite());
299 assert!(result.means[(t, 1)].is_finite());
300 let p00 = result.covariances[(t, 0, 0)];
302 let p11 = result.covariances[(t, 1, 1)];
303 let p01 = result.covariances[(t, 0, 1)];
304 let p10 = result.covariances[(t, 1, 0)];
305 assert!(
306 (p01 - p10).abs() < 1e-12,
307 "covariance not symmetric at t={t}"
308 );
309 assert!(p00 >= 0.0 && p11 >= 0.0, "negative variance at t={t}");
310 assert!(
311 p00 * p11 - p01 * p10 >= -1e-12,
312 "covariance not PSD at t={t}"
313 );
314 }
315 }
316}