1use rayon::prelude::*;
2use nalgebra::{Cholesky, DMatrix};
11use std::f64::consts::PI;
12
13pub fn population_vector_decode(
17 trains: &[&[i32]],
18 preferred_directions: &[f64],
19 window: usize,
20) -> Vec<f64> {
21 if trains.is_empty() || window == 0 {
22 return vec![];
23 }
24 let min_len = trains.iter().map(|t| t.len()).min().unwrap_or(0);
25 let n_bins = min_len / window;
26 if n_bins == 0 {
27 return vec![];
28 }
29 let dirs_cos: Vec<f64> = preferred_directions.iter().map(|&d| d.cos()).collect();
31 let dirs_sin: Vec<f64> = preferred_directions.iter().map(|&d| d.sin()).collect();
32
33 let decoded: Vec<f64> = (0..n_bins)
34 .into_par_iter()
35 .map(|b| {
36 let mut sx = 0.0_f64;
37 let mut sy = 0.0_f64;
38 let start = b * window;
39 let end = (b + 1) * window;
40 for (i, t) in trains.iter().enumerate() {
41 let count: i64 = t[start..end].iter().map(|&v| v as i64).sum();
42 let c = dirs_cos.get(i).copied().unwrap_or(1.0);
43 let s = dirs_sin.get(i).copied().unwrap_or(0.0);
44 sx += count as f64 * c;
45 sy += count as f64 * s;
46 }
47 sy.atan2(sx)
48 })
49 .collect();
50 decoded
51}
52
53pub fn bayesian_decode(
57 spike_counts: &[f64],
58 tuning_rates: &[f64],
59 n_stimuli: usize,
60 n_neurons: usize,
61 prior: &[f64],
62) -> usize {
63 if n_stimuli == 0 || n_neurons == 0 {
64 return 0;
65 }
66 let use_uniform = prior.is_empty();
67 let log_prior_uniform = -(n_stimuli as f64).ln();
68
69 let (best_s, _best_lp) = (0..n_stimuli)
70 .into_par_iter()
71 .map(|s| {
72 let mut lp = if use_uniform {
73 log_prior_uniform
74 } else {
75 (prior.get(s).copied().unwrap_or(1e-30) + 1e-30).ln()
76 };
77 let row_rates = &tuning_rates[s * n_neurons..(s + 1) * n_neurons];
78 let mut j = 0;
79 while j + 3 < n_neurons {
80 let lam0 = row_rates[j].max(1e-10);
81 let lam1 = row_rates[j + 1].max(1e-10);
82 let lam2 = row_rates[j + 2].max(1e-10);
83 let lam3 = row_rates[j + 3].max(1e-10);
84
85 lp += spike_counts[j] * lam0.ln() - lam0;
86 lp += spike_counts[j + 1] * lam1.ln() - lam1;
87 lp += spike_counts[j + 2] * lam2.ln() - lam2;
88 lp += spike_counts[j + 3] * lam3.ln() - lam3;
89 j += 4;
90 }
91 while j < n_neurons {
92 let lam = row_rates[j].max(1e-10);
93 lp += spike_counts[j] * lam.ln() - lam;
94 j += 1;
95 }
96 (s, lp)
97 })
98 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
99 .unwrap_or((0, f64::NEG_INFINITY));
100 best_s
101}
102
103pub fn maximum_likelihood_decode(
105 spike_counts: &[f64],
106 tuning_rates: &[f64],
107 n_stimuli: usize,
108 n_neurons: usize,
109) -> usize {
110 bayesian_decode(spike_counts, tuning_rates, n_stimuli, n_neurons, &[])
111}
112
113pub fn linear_discriminant_decode(
117 train_data: &[f64],
118 n_samples: usize,
119 n_features: usize,
120 labels: &[i64],
121 test_point: &[f64],
122) -> i64 {
123 if n_samples == 0 || n_features == 0 {
124 return 0;
125 }
126
127 let mut classes: Vec<i64> = labels[..n_samples].to_vec();
129 classes.sort();
130 classes.dedup();
131 if classes.len() < 2 {
132 return classes.first().copied().unwrap_or(0);
133 }
134
135 let (class_means, class_indices): (Vec<Vec<f64>>, Vec<Vec<usize>>) = classes
137 .par_iter()
138 .map(|&c| {
139 let indices: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == c).collect();
140 let mut mean = vec![0.0_f64; n_features];
141 for &idx in &indices {
142 let row = &train_data[idx * n_features..(idx + 1) * n_features];
143 for f in 0..n_features {
144 mean[f] += row[f];
145 }
146 }
147 let n_c = indices.len() as f64;
148 for v in &mut mean {
149 *v /= n_c;
150 }
151 (mean, indices)
152 })
153 .unzip();
154
155 let nf = n_features;
157 let mut s_w = vec![0.0_f64; nf * nf];
158 for (ci, indices) in class_indices.iter().enumerate() {
159 let mean = &class_means[ci];
160 for &idx in indices {
161 for i in 0..nf {
162 let di = train_data[idx * nf + i] - mean[i];
163 for j in 0..nf {
164 let dj = train_data[idx * nf + j] - mean[j];
165 s_w[i * nf + j] += di * dj;
166 }
167 }
168 }
169 }
170 for i in 0..nf {
172 s_w[i * nf + i] += 1e-8;
173 }
174
175 let mut overall_mean = vec![0.0_f64; nf];
177 for i in 0..n_samples {
178 for f in 0..nf {
179 overall_mean[f] += train_data[i * nf + f];
180 }
181 }
182 for v in &mut overall_mean {
183 *v /= n_samples as f64;
184 }
185
186 let n_classes = classes.len();
191 let mut diffs = vec![0.0_f64; nf * n_classes];
192 for (ci, mean) in class_means.iter().enumerate() {
193 for f in 0..nf {
194 diffs[f * n_classes + ci] = mean[f] - overall_mean[f];
195 }
196 }
197 let weights = solve_spd(&s_w, &diffs, nf, n_classes);
198
199 let mut best_class = classes[0];
200 let mut best_score = f64::NEG_INFINITY;
201 for (ci, &c) in classes.iter().enumerate() {
202 let score: f64 = (0..nf)
203 .map(|f| weights[f * n_classes + ci] * test_point[f])
204 .sum();
205 if score > best_score {
206 best_score = score;
207 best_class = c;
208 }
209 }
210 best_class
211}
212
213pub fn naive_bayes_decode(
217 train_data: &[f64],
218 n_samples: usize,
219 n_features: usize,
220 labels: &[i64],
221 test_point: &[f64],
222) -> i64 {
223 if n_samples == 0 || n_features == 0 {
224 return 0;
225 }
226
227 let mut classes: Vec<i64> = labels[..n_samples].to_vec();
228 classes.sort();
229 classes.dedup();
230
231 let mut best_class = classes.first().copied().unwrap_or(0);
232 let mut best_log_p = f64::NEG_INFINITY;
233
234 for &c in &classes {
235 let indices: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == c).collect();
236 let n_c = indices.len() as f64;
237 let log_prior = (n_c / n_samples as f64).ln();
238
239 let mut log_likelihood = 0.0_f64;
241 for f in 0..n_features {
242 let vals: Vec<f64> = indices
243 .iter()
244 .map(|&i| train_data[i * n_features + f])
245 .collect();
246 let mu: f64 = vals.iter().sum::<f64>() / n_c;
247 let var: f64 = vals.iter().map(|&v| (v - mu).powi(2)).sum::<f64>() / n_c + 1e-10;
248 let x = test_point[f];
249 log_likelihood += -0.5 * ((2.0 * PI * var).ln() + (x - mu).powi(2) / var);
250 }
251
252 let log_p = log_prior + log_likelihood;
253 if log_p > best_log_p {
254 best_log_p = log_p;
255 best_class = c;
256 }
257 }
258 best_class
259}
260
261fn solve_spd(a: &[f64], b: &[f64], n: usize, m: usize) -> Vec<f64> {
269 let a_mat = DMatrix::<f64>::from_row_slice(n, n, a);
270 let b_mat = DMatrix::<f64>::from_row_slice(n, m, b);
271 match Cholesky::new(a_mat) {
272 Some(chol) => {
273 let x = chol.solve(&b_mat);
274 let mut out = vec![0.0_f64; n * m];
275 for i in 0..n {
276 for j in 0..m {
277 out[i * m + j] = x[(i, j)];
278 }
279 }
280 out
281 }
282 None => vec![0.0_f64; n * m],
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
293 fn test_pv_single_neuron_right() {
294 let train = vec![1i32; 100];
296 let trains: Vec<&[i32]> = vec![&train];
297 let dirs = vec![0.0_f64]; let decoded = population_vector_decode(&trains, &dirs, 50);
299 assert_eq!(decoded.len(), 2);
300 assert!((decoded[0] - 0.0).abs() < 1e-10, "should decode to 0 rad");
301 }
302
303 #[test]
304 fn test_pv_two_neurons_45deg() {
305 let train = vec![1i32; 100];
307 let trains: Vec<&[i32]> = vec![&train, &train];
308 let dirs = vec![0.0, PI / 2.0];
309 let decoded = population_vector_decode(&trains, &dirs, 100);
310 assert_eq!(decoded.len(), 1);
311 assert!(
312 (decoded[0] - PI / 4.0).abs() < 1e-10,
313 "equal firing at 0 and π/2 → π/4, got {}",
314 decoded[0]
315 );
316 }
317
318 #[test]
319 fn test_pv_empty() {
320 let decoded = population_vector_decode(&[], &[], 50);
321 assert!(decoded.is_empty());
322 }
323
324 #[test]
325 fn test_pv_no_bins() {
326 let train = vec![1i32; 10];
327 let trains: Vec<&[i32]> = vec![&train];
328 let decoded = population_vector_decode(&trains, &[0.0], 100);
329 assert!(decoded.is_empty(), "train shorter than window → empty");
330 }
331
332 #[test]
335 fn test_bayesian_obvious() {
336 let tuning = vec![10.0, 0.1, 0.1, 10.0]; let counts = vec![8.0, 0.0]; let s = bayesian_decode(&counts, &tuning, 2, 2, &[]);
340 assert_eq!(s, 0, "high neuron 0 firing → stimulus 0");
341 }
342
343 #[test]
344 fn test_bayesian_with_prior() {
345 let tuning = vec![5.0, 5.0, 5.0, 5.0]; let counts = vec![5.0, 5.0];
347 let prior = vec![0.1, 0.9]; let s = bayesian_decode(&counts, &tuning, 2, 2, &prior);
349 assert_eq!(s, 1, "equal evidence + strong prior → stimulus 1");
350 }
351
352 #[test]
353 fn test_bayesian_empty() {
354 assert_eq!(bayesian_decode(&[], &[], 0, 0, &[]), 0);
355 }
356
357 #[test]
360 fn test_ml_matches_bayesian_uniform() {
361 let tuning = vec![10.0, 0.1, 0.1, 10.0];
362 let counts = vec![0.0, 8.0]; let s_ml = maximum_likelihood_decode(&counts, &tuning, 2, 2);
364 let s_bay = bayesian_decode(&counts, &tuning, 2, 2, &[]);
365 assert_eq!(s_ml, s_bay);
366 assert_eq!(s_ml, 1);
367 }
368
369 #[test]
372 fn test_lda_separable() {
373 #[rustfmt::skip]
377 let data = vec![
378 0.0, 0.0,
379 0.1, 0.1,
380 -0.1, 0.1,
381 10.0, 10.0,
382 10.1, 9.9,
383 9.9, 10.1,
384 ];
385 let labels = vec![0_i64, 0, 0, 1, 1, 1];
386 let test_1 = vec![10.0, 10.0];
388 assert_eq!(linear_discriminant_decode(&data, 6, 2, &labels, &test_1), 1);
389 let r0 = linear_discriminant_decode(&data, 6, 2, &labels, &[-5.0, -5.0]);
391 let r1 = linear_discriminant_decode(&data, 6, 2, &labels, &[15.0, 15.0]);
392 assert_ne!(r0, r1, "distant points should decode to different classes");
393 }
394
395 #[test]
396 fn test_lda_single_class() {
397 let data = vec![1.0, 2.0, 3.0, 4.0];
398 let labels = vec![5_i64, 5];
399 let test = vec![2.0, 3.0];
400 assert_eq!(linear_discriminant_decode(&data, 2, 2, &labels, &test), 5);
401 }
402
403 #[test]
404 fn test_lda_empty() {
405 assert_eq!(linear_discriminant_decode(&[], 0, 0, &[], &[]), 0);
406 }
407
408 #[test]
411 fn test_nb_separable() {
412 #[rustfmt::skip]
413 let data = vec![
414 0.0, 0.0,
415 0.1, 0.1,
416 -0.1, -0.1,
417 10.0, 10.0,
418 10.1, 10.1,
419 9.9, 9.9,
420 ];
421 let labels = vec![0_i64, 0, 0, 1, 1, 1];
422 let test_0 = vec![0.2, 0.2];
423 let test_1 = vec![9.8, 9.8];
424 assert_eq!(naive_bayes_decode(&data, 6, 2, &labels, &test_0), 0);
425 assert_eq!(naive_bayes_decode(&data, 6, 2, &labels, &test_1), 1);
426 }
427
428 #[test]
429 fn test_nb_single_class() {
430 let data = vec![1.0, 2.0];
431 let labels = vec![7_i64];
432 assert_eq!(naive_bayes_decode(&data, 1, 2, &labels, &[1.0, 2.0]), 7);
433 }
434
435 #[test]
436 fn test_nb_agrees_with_lda_simple() {
437 #[rustfmt::skip]
439 let data = vec![
440 -5.0, -5.0,
441 -4.9, -5.1,
442 5.0, 5.0,
443 5.1, 4.9,
444 ];
445 let labels = vec![0_i64, 0, 1, 1];
446 let test = vec![4.0, 4.0];
447 let lda = linear_discriminant_decode(&data, 4, 2, &labels, &test);
448 let nb = naive_bayes_decode(&data, 4, 2, &labels, &test);
449 assert_eq!(lda, nb, "well-separated → both predict same class");
450 }
451
452 #[test]
455 fn test_solve_spd_2x2() {
456 let a = vec![2.0, 1.0, 1.0, 3.0];
458 let b = vec![5.0, 10.0];
459 let x = solve_spd(&a, &b, 2, 1);
460 assert!((x[0] - 1.0).abs() < 1e-10);
461 assert!((x[1] - 3.0).abs() < 1e-10);
462 }
463
464 #[test]
465 fn test_solve_spd_multi_rhs() {
466 let a = vec![2.0, 0.0, 0.0, 4.0];
468 let b = vec![2.0, 4.0, 4.0, 8.0];
469 let x = solve_spd(&a, &b, 2, 2);
470 assert!((x[0] - 1.0).abs() < 1e-10); assert!((x[1] - 2.0).abs() < 1e-10); assert!((x[2] - 1.0).abs() < 1e-10); assert!((x[3] - 2.0).abs() < 1e-10); }
475
476 #[test]
477 fn test_solve_spd_non_pd_falls_back_to_zero() {
478 let a = vec![0.0, 1.0, 1.0, 0.0];
480 let b = vec![1.0, 1.0];
481 let x = solve_spd(&a, &b, 2, 1);
482 assert_eq!(x, vec![0.0, 0.0]);
483 }
484}