1use rayon::prelude::*;
2use std::f64::consts::PI;
11
12pub fn population_vector_decode(
16 trains: &[&[i32]],
17 preferred_directions: &[f64],
18 window: usize,
19) -> Vec<f64> {
20 if trains.is_empty() || window == 0 {
21 return vec![];
22 }
23 let min_len = trains.iter().map(|t| t.len()).min().unwrap_or(0);
24 let n_bins = min_len / window;
25 if n_bins == 0 {
26 return vec![];
27 }
28 let dirs_cos: Vec<f64> = preferred_directions.iter().map(|&d| d.cos()).collect();
30 let dirs_sin: Vec<f64> = preferred_directions.iter().map(|&d| d.sin()).collect();
31
32 let decoded: Vec<f64> = (0..n_bins)
33 .into_par_iter()
34 .map(|b| {
35 let mut sx = 0.0_f64;
36 let mut sy = 0.0_f64;
37 let start = b * window;
38 let end = (b + 1) * window;
39 for (i, t) in trains.iter().enumerate() {
40 let count: i64 = t[start..end].iter().map(|&v| v as i64).sum();
41 let c = dirs_cos.get(i).copied().unwrap_or(1.0);
42 let s = dirs_sin.get(i).copied().unwrap_or(0.0);
43 sx += count as f64 * c;
44 sy += count as f64 * s;
45 }
46 sy.atan2(sx)
47 })
48 .collect();
49 decoded
50}
51
52pub fn bayesian_decode(
56 spike_counts: &[f64],
57 tuning_rates: &[f64],
58 n_stimuli: usize,
59 n_neurons: usize,
60 prior: &[f64],
61) -> usize {
62 if n_stimuli == 0 || n_neurons == 0 {
63 return 0;
64 }
65 let use_uniform = prior.is_empty();
66 let log_prior_uniform = -(n_stimuli as f64).ln();
67
68 let (best_s, _best_lp) = (0..n_stimuli)
69 .into_par_iter()
70 .map(|s| {
71 let mut lp = if use_uniform {
72 log_prior_uniform
73 } else {
74 (prior.get(s).copied().unwrap_or(1e-30) + 1e-30).ln()
75 };
76 let row_rates = &tuning_rates[s * n_neurons..(s + 1) * n_neurons];
77 let mut j = 0;
78 while j + 3 < n_neurons {
79 let lam0 = row_rates[j].max(1e-10);
80 let lam1 = row_rates[j + 1].max(1e-10);
81 let lam2 = row_rates[j + 2].max(1e-10);
82 let lam3 = row_rates[j + 3].max(1e-10);
83
84 lp += spike_counts[j] * lam0.ln() - lam0;
85 lp += spike_counts[j + 1] * lam1.ln() - lam1;
86 lp += spike_counts[j + 2] * lam2.ln() - lam2;
87 lp += spike_counts[j + 3] * lam3.ln() - lam3;
88 j += 4;
89 }
90 while j < n_neurons {
91 let lam = row_rates[j].max(1e-10);
92 lp += spike_counts[j] * lam.ln() - lam;
93 j += 1;
94 }
95 (s, lp)
96 })
97 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
98 .unwrap_or((0, f64::NEG_INFINITY));
99 best_s
100}
101
102pub fn maximum_likelihood_decode(
104 spike_counts: &[f64],
105 tuning_rates: &[f64],
106 n_stimuli: usize,
107 n_neurons: usize,
108) -> usize {
109 bayesian_decode(spike_counts, tuning_rates, n_stimuli, n_neurons, &[])
110}
111
112pub fn linear_discriminant_decode(
116 train_data: &[f64],
117 n_samples: usize,
118 n_features: usize,
119 labels: &[i64],
120 test_point: &[f64],
121) -> i64 {
122 if n_samples == 0 || n_features == 0 {
123 return 0;
124 }
125
126 let mut classes: Vec<i64> = labels[..n_samples].to_vec();
128 classes.sort();
129 classes.dedup();
130 if classes.len() < 2 {
131 return classes.first().copied().unwrap_or(0);
132 }
133
134 let (class_means, class_indices): (Vec<Vec<f64>>, Vec<Vec<usize>>) = classes
136 .par_iter()
137 .map(|&c| {
138 let indices: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == c).collect();
139 let mut mean = vec![0.0_f64; n_features];
140 for &idx in &indices {
141 let row = &train_data[idx * n_features..(idx + 1) * n_features];
142 for f in 0..n_features {
143 mean[f] += row[f];
144 }
145 }
146 let n_c = indices.len() as f64;
147 for v in &mut mean {
148 *v /= n_c;
149 }
150 (mean, indices)
151 })
152 .unzip();
153
154 let nf = n_features;
156 let mut s_w = vec![0.0_f64; nf * nf];
157 for (ci, indices) in class_indices.iter().enumerate() {
158 let mean = &class_means[ci];
159 for &idx in indices {
160 for i in 0..nf {
161 let di = train_data[idx * nf + i] - mean[i];
162 for j in 0..nf {
163 let dj = train_data[idx * nf + j] - mean[j];
164 s_w[i * nf + j] += di * dj;
165 }
166 }
167 }
168 }
169 for i in 0..nf {
171 s_w[i * nf + i] += 1e-8;
172 }
173
174 let mut overall_mean = vec![0.0_f64; nf];
176 for i in 0..n_samples {
177 for f in 0..nf {
178 overall_mean[f] += train_data[i * nf + f];
179 }
180 }
181 for v in &mut overall_mean {
182 *v /= n_samples as f64;
183 }
184
185 let mut best_class = classes[0];
187 let mut best_score = f64::NEG_INFINITY;
188
189 for (ci, &c) in classes.iter().enumerate() {
190 let diff: Vec<f64> = (0..nf)
191 .map(|f| class_means[ci][f] - overall_mean[f])
192 .collect();
193 let w = solve_linear(&s_w, &diff, nf);
194 let score: f64 = (0..nf).map(|f| w[f] * test_point[f]).sum();
195 if score > best_score {
196 best_score = score;
197 best_class = c;
198 }
199 }
200 best_class
201}
202
203pub fn naive_bayes_decode(
207 train_data: &[f64],
208 n_samples: usize,
209 n_features: usize,
210 labels: &[i64],
211 test_point: &[f64],
212) -> i64 {
213 if n_samples == 0 || n_features == 0 {
214 return 0;
215 }
216
217 let mut classes: Vec<i64> = labels[..n_samples].to_vec();
218 classes.sort();
219 classes.dedup();
220
221 let mut best_class = classes.first().copied().unwrap_or(0);
222 let mut best_log_p = f64::NEG_INFINITY;
223
224 for &c in &classes {
225 let indices: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == c).collect();
226 let n_c = indices.len() as f64;
227 let log_prior = (n_c / n_samples as f64).ln();
228
229 let mut log_likelihood = 0.0_f64;
231 for f in 0..n_features {
232 let vals: Vec<f64> = indices
233 .iter()
234 .map(|&i| train_data[i * n_features + f])
235 .collect();
236 let mu: f64 = vals.iter().sum::<f64>() / n_c;
237 let var: f64 = vals.iter().map(|&v| (v - mu).powi(2)).sum::<f64>() / n_c + 1e-10;
238 let x = test_point[f];
239 log_likelihood += -0.5 * ((2.0 * PI * var).ln() + (x - mu).powi(2) / var);
240 }
241
242 let log_p = log_prior + log_likelihood;
243 if log_p > best_log_p {
244 best_log_p = log_p;
245 best_class = c;
246 }
247 }
248 best_class
249}
250
251fn solve_linear(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
253 let mut aug = vec![0.0_f64; n * (n + 1)];
254 for i in 0..n {
255 for j in 0..n {
256 aug[i * (n + 1) + j] = a[i * n + j];
257 }
258 aug[i * (n + 1) + n] = b[i];
259 }
260 let stride = n + 1;
261 for col in 0..n {
262 let mut max_row = col;
263 let mut max_val = aug[col * stride + col].abs();
264 for row in (col + 1)..n {
265 let v = aug[row * stride + col].abs();
266 if v > max_val {
267 max_val = v;
268 max_row = row;
269 }
270 }
271 if max_row != col {
272 for j in 0..stride {
273 aug.swap(col * stride + j, max_row * stride + j);
274 }
275 }
276 let pivot = aug[col * stride + col];
277 if pivot.abs() < 1e-30 {
278 continue;
279 }
280 for row in (col + 1)..n {
281 let factor = aug[row * stride + col] / pivot;
282 for j in col..stride {
283 aug[row * stride + j] -= factor * aug[col * stride + j];
284 }
285 }
286 }
287 let mut x = vec![0.0_f64; n];
288 for i in (0..n).rev() {
289 let mut sum = aug[i * stride + n];
290 for j in (i + 1)..n {
291 sum -= aug[i * stride + j] * x[j];
292 }
293 let diag = aug[i * stride + i];
294 x[i] = if diag.abs() > 1e-30 { sum / diag } else { 0.0 };
295 }
296 x
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
306 fn test_pv_single_neuron_right() {
307 let train = vec![1i32; 100];
309 let trains: Vec<&[i32]> = vec![&train];
310 let dirs = vec![0.0_f64]; let decoded = population_vector_decode(&trains, &dirs, 50);
312 assert_eq!(decoded.len(), 2);
313 assert!((decoded[0] - 0.0).abs() < 1e-10, "should decode to 0 rad");
314 }
315
316 #[test]
317 fn test_pv_two_neurons_45deg() {
318 let train = vec![1i32; 100];
320 let trains: Vec<&[i32]> = vec![&train, &train];
321 let dirs = vec![0.0, PI / 2.0];
322 let decoded = population_vector_decode(&trains, &dirs, 100);
323 assert_eq!(decoded.len(), 1);
324 assert!(
325 (decoded[0] - PI / 4.0).abs() < 1e-10,
326 "equal firing at 0 and π/2 → π/4, got {}",
327 decoded[0]
328 );
329 }
330
331 #[test]
332 fn test_pv_empty() {
333 let decoded = population_vector_decode(&[], &[], 50);
334 assert!(decoded.is_empty());
335 }
336
337 #[test]
338 fn test_pv_no_bins() {
339 let train = vec![1i32; 10];
340 let trains: Vec<&[i32]> = vec![&train];
341 let decoded = population_vector_decode(&trains, &[0.0], 100);
342 assert!(decoded.is_empty(), "train shorter than window → empty");
343 }
344
345 #[test]
348 fn test_bayesian_obvious() {
349 let tuning = vec![10.0, 0.1, 0.1, 10.0]; let counts = vec![8.0, 0.0]; let s = bayesian_decode(&counts, &tuning, 2, 2, &[]);
353 assert_eq!(s, 0, "high neuron 0 firing → stimulus 0");
354 }
355
356 #[test]
357 fn test_bayesian_with_prior() {
358 let tuning = vec![5.0, 5.0, 5.0, 5.0]; let counts = vec![5.0, 5.0];
360 let prior = vec![0.1, 0.9]; let s = bayesian_decode(&counts, &tuning, 2, 2, &prior);
362 assert_eq!(s, 1, "equal evidence + strong prior → stimulus 1");
363 }
364
365 #[test]
366 fn test_bayesian_empty() {
367 assert_eq!(bayesian_decode(&[], &[], 0, 0, &[]), 0);
368 }
369
370 #[test]
373 fn test_ml_matches_bayesian_uniform() {
374 let tuning = vec![10.0, 0.1, 0.1, 10.0];
375 let counts = vec![0.0, 8.0]; let s_ml = maximum_likelihood_decode(&counts, &tuning, 2, 2);
377 let s_bay = bayesian_decode(&counts, &tuning, 2, 2, &[]);
378 assert_eq!(s_ml, s_bay);
379 assert_eq!(s_ml, 1);
380 }
381
382 #[test]
385 fn test_lda_separable() {
386 #[rustfmt::skip]
390 let data = vec![
391 0.0, 0.0,
392 0.1, 0.1,
393 -0.1, 0.1,
394 10.0, 10.0,
395 10.1, 9.9,
396 9.9, 10.1,
397 ];
398 let labels = vec![0_i64, 0, 0, 1, 1, 1];
399 let test_1 = vec![10.0, 10.0];
401 assert_eq!(linear_discriminant_decode(&data, 6, 2, &labels, &test_1), 1);
402 let r0 = linear_discriminant_decode(&data, 6, 2, &labels, &[-5.0, -5.0]);
404 let r1 = linear_discriminant_decode(&data, 6, 2, &labels, &[15.0, 15.0]);
405 assert_ne!(r0, r1, "distant points should decode to different classes");
406 }
407
408 #[test]
409 fn test_lda_single_class() {
410 let data = vec![1.0, 2.0, 3.0, 4.0];
411 let labels = vec![5_i64, 5];
412 let test = vec![2.0, 3.0];
413 assert_eq!(linear_discriminant_decode(&data, 2, 2, &labels, &test), 5);
414 }
415
416 #[test]
417 fn test_lda_empty() {
418 assert_eq!(linear_discriminant_decode(&[], 0, 0, &[], &[]), 0);
419 }
420
421 #[test]
424 fn test_nb_separable() {
425 #[rustfmt::skip]
426 let data = vec![
427 0.0, 0.0,
428 0.1, 0.1,
429 -0.1, -0.1,
430 10.0, 10.0,
431 10.1, 10.1,
432 9.9, 9.9,
433 ];
434 let labels = vec![0_i64, 0, 0, 1, 1, 1];
435 let test_0 = vec![0.2, 0.2];
436 let test_1 = vec![9.8, 9.8];
437 assert_eq!(naive_bayes_decode(&data, 6, 2, &labels, &test_0), 0);
438 assert_eq!(naive_bayes_decode(&data, 6, 2, &labels, &test_1), 1);
439 }
440
441 #[test]
442 fn test_nb_single_class() {
443 let data = vec![1.0, 2.0];
444 let labels = vec![7_i64];
445 assert_eq!(naive_bayes_decode(&data, 1, 2, &labels, &[1.0, 2.0]), 7);
446 }
447
448 #[test]
449 fn test_nb_agrees_with_lda_simple() {
450 #[rustfmt::skip]
452 let data = vec![
453 -5.0, -5.0,
454 -4.9, -5.1,
455 5.0, 5.0,
456 5.1, 4.9,
457 ];
458 let labels = vec![0_i64, 0, 1, 1];
459 let test = vec![4.0, 4.0];
460 let lda = linear_discriminant_decode(&data, 4, 2, &labels, &test);
461 let nb = naive_bayes_decode(&data, 4, 2, &labels, &test);
462 assert_eq!(lda, nb, "well-separated → both predict same class");
463 }
464
465 #[test]
468 fn test_solve_2x2() {
469 let a = vec![2.0, 1.0, 1.0, 3.0];
470 let b = vec![5.0, 10.0];
471 let x = solve_linear(&a, &b, 2);
472 assert!((x[0] - 1.0).abs() < 1e-10);
473 assert!((x[1] - 3.0).abs() < 1e-10);
474 }
475}