1use crate::bitstream;
15
16pub fn prediction_error_packed(predicted: &[u64], actual: &[u64], length: usize) -> f64 {
19 if length == 0 {
20 return 0.0;
21 }
22 let n = predicted.len().min(actual.len());
23 let mut xor_result = vec![0u64; n];
24 for i in 0..n {
25 xor_result[i] = predicted[i] ^ actual[i];
26 }
27 let hamming = bitstream::popcount_words_portable(&xor_result);
28 hamming as f64 / length as f64
29}
30
31pub fn batch_prediction_error(
34 predicted: &[Vec<u64>], actual: &[Vec<u64>], n_neurons: usize,
37 n_inputs: usize,
38 length: usize,
39) -> Vec<f64> {
40 let mut surprises = vec![0.0f64; n_neurons];
41 for j in 0..n_neurons {
42 let mut total_error = 0.0;
43 for i in 0..n_inputs {
44 let pred_idx = j * n_inputs + i;
45 if pred_idx < predicted.len() && i < actual.len() {
46 total_error += prediction_error_packed(&predicted[pred_idx], &actual[i], length);
47 }
48 }
49 surprises[j] = total_error / n_inputs.max(1) as f64;
50 }
51 surprises
52}
53
54pub fn update_prediction_weights(
56 weights: &mut [f64], actual_probs: &[f64], n_neurons: usize,
59 n_inputs: usize,
60 lr: f64,
61) {
62 for j in 0..n_neurons {
63 for i in 0..n_inputs {
64 let idx = j * n_inputs + i;
65 if idx < weights.len() && i < actual_probs.len() {
66 weights[idx] += lr * (actual_probs[i] - weights[idx]);
67 weights[idx] = weights[idx].clamp(0.0, 1.0);
68 }
69 }
70 }
71}
72
73use crate::encoder::Lfsr16;
76
77pub fn predict_and_xor_ema(
80 spikes: &[i8], n_channels: usize,
82 alpha: f64,
83 threshold: f64,
84) -> (Vec<i8>, usize) {
85 let t_steps = spikes.len() / n_channels;
86 let mut rates = vec![0.0f64; n_channels];
87 let mut errors = vec![0i8; spikes.len()];
88 let mut correct: usize = 0;
89 let one_minus_alpha = 1.0 - alpha;
90
91 for t in 0..t_steps {
92 let row_start = t * n_channels;
93 for ch in 0..n_channels {
94 let actual = spikes[row_start + ch];
95 let predicted = if rates[ch] > threshold { 1i8 } else { 0i8 };
96 let err = actual ^ predicted;
97 errors[row_start + ch] = err;
98 if err == 0 {
99 correct += 1;
100 }
101 rates[ch] = one_minus_alpha * rates[ch] + alpha * (actual as f64);
102 }
103 }
104 (errors, correct)
105}
106
107pub fn xor_and_recover_ema(
109 errors: &[i8],
110 n_channels: usize,
111 alpha: f64,
112 threshold: f64,
113) -> Vec<i8> {
114 let t_steps = errors.len() / n_channels;
115 let mut rates = vec![0.0f64; n_channels];
116 let mut spikes = vec![0i8; errors.len()];
117 let one_minus_alpha = 1.0 - alpha;
118
119 for t in 0..t_steps {
120 let row_start = t * n_channels;
121 for ch in 0..n_channels {
122 let predicted = if rates[ch] > threshold { 1i8 } else { 0i8 };
123 let actual = errors[row_start + ch] ^ predicted;
124 spikes[row_start + ch] = actual;
125 rates[ch] = one_minus_alpha * rates[ch] + alpha * (actual as f64);
126 }
127 }
128 spikes
129}
130
131pub fn predict_and_xor_lfsr(
134 spikes: &[i8],
135 n_channels: usize,
136 alpha_q8: i32,
137 seed: u16,
138) -> (Vec<i8>, usize) {
139 let t_steps = spikes.len() / n_channels;
140 let mut rates_q8 = vec![0i32; n_channels];
141 let mut errors = vec![0i8; spikes.len()];
142 let mut correct: usize = 0;
143
144 let mut lfsrs: Vec<Lfsr16> = (0..n_channels)
146 .map(|ch| {
147 let s = ((seed as u32).wrapping_add((ch as u32).wrapping_mul(7919))) & 0xFFFF;
148 Lfsr16::new(if s == 0 { 1 } else { s as u16 })
149 })
150 .collect();
151
152 for t in 0..t_steps {
153 let row_start = t * n_channels;
154 for ch in 0..n_channels {
155 let actual = spikes[row_start + ch];
156 let predicted = if (lfsrs[ch].reg as i32) < rates_q8[ch] {
157 1i8
158 } else {
159 0i8
160 };
161 lfsrs[ch].step();
162
163 let err = actual ^ predicted;
164 errors[row_start + ch] = err;
165 if err == 0 {
166 correct += 1;
167 }
168
169 let target: i32 = if actual != 0 { 255 } else { 0 };
170 rates_q8[ch] += (alpha_q8 * (target - rates_q8[ch])) >> 8;
171 rates_q8[ch] = rates_q8[ch].clamp(0, 255);
172 }
173 }
174 (errors, correct)
175}
176
177pub fn xor_and_recover_lfsr(errors: &[i8], n_channels: usize, alpha_q8: i32, seed: u16) -> Vec<i8> {
179 let t_steps = errors.len() / n_channels;
180 let mut rates_q8 = vec![0i32; n_channels];
181 let mut spikes = vec![0i8; errors.len()];
182
183 let mut lfsrs: Vec<Lfsr16> = (0..n_channels)
184 .map(|ch| {
185 let s = ((seed as u32).wrapping_add((ch as u32).wrapping_mul(7919))) & 0xFFFF;
186 Lfsr16::new(if s == 0 { 1 } else { s as u16 })
187 })
188 .collect();
189
190 for t in 0..t_steps {
191 let row_start = t * n_channels;
192 for ch in 0..n_channels {
193 let predicted = if (lfsrs[ch].reg as i32) < rates_q8[ch] {
194 1i8
195 } else {
196 0i8
197 };
198 lfsrs[ch].step();
199
200 let actual = errors[row_start + ch] ^ predicted;
201 spikes[row_start + ch] = actual;
202
203 let target: i32 = if actual != 0 { 255 } else { 0 };
204 rates_q8[ch] += (alpha_q8 * (target - rates_q8[ch])) >> 8;
205 rates_q8[ch] = rates_q8[ch].clamp(0, 255);
206 }
207 }
208 spikes
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_identical_streams_zero_error() {
217 let a = vec![0xFF_FF_FF_FF_FF_FF_FF_FFu64; 16]; let error = prediction_error_packed(&a, &a, 1024);
219 assert!((error - 0.0).abs() < 1e-10);
220 }
221
222 #[test]
223 fn test_opposite_streams_max_error() {
224 let a = vec![0xFF_FF_FF_FF_FF_FF_FF_FFu64; 16];
225 let b = vec![0u64; 16];
226 let error = prediction_error_packed(&a, &b, 1024);
227 assert!((error - 1.0).abs() < 1e-10);
228 }
229
230 #[test]
231 fn test_batch_error_shape() {
232 let pred = vec![vec![0u64; 4]; 6]; let actual = vec![vec![0xFF_FF_FF_FF_FF_FF_FF_FFu64; 4]; 3];
234 let surprises = batch_prediction_error(&pred, &actual, 2, 3, 256);
235 assert_eq!(surprises.len(), 2);
236 assert!(surprises[0] > 0.0);
237 }
238
239 #[test]
240 fn test_weight_update() {
241 let mut weights = vec![0.5, 0.5, 0.5, 0.5]; let actual = vec![0.8, 0.2];
243 update_prediction_weights(&mut weights, &actual, 2, 2, 0.5);
244 assert!(weights[0] > 0.5); assert!(weights[1] < 0.5); }
247
248 #[test]
249 fn test_ema_roundtrip() {
250 let spikes = vec![0i8; 100]; let (errors, correct) = predict_and_xor_ema(&spikes, 10, 0.005, 0.5);
253 assert_eq!(errors.len(), 100);
254 assert_eq!(correct, 100); let recovered = xor_and_recover_ema(&errors, 10, 0.005, 0.5);
256 assert_eq!(recovered, spikes);
257 }
258
259 #[test]
260 fn test_ema_roundtrip_with_spikes() {
261 let mut spikes = vec![0i8; 200]; spikes[5] = 1; spikes[15] = 1; let (errors, _) = predict_and_xor_ema(&spikes, 10, 0.01, 0.5);
265 let recovered = xor_and_recover_ema(&errors, 10, 0.01, 0.5);
266 assert_eq!(recovered, spikes);
267 }
268
269 #[test]
270 fn test_lfsr_roundtrip() {
271 let spikes = vec![0i8; 100];
272 let (errors, correct) = predict_and_xor_lfsr(&spikes, 10, 1, 0xACE1);
273 assert_eq!(correct, 100);
274 let recovered = xor_and_recover_lfsr(&errors, 10, 1, 0xACE1);
275 assert_eq!(recovered, spikes);
276 }
277
278 #[test]
279 fn test_lfsr_roundtrip_with_spikes() {
280 let mut spikes = vec![0i8; 200];
281 spikes[5] = 1;
282 spikes[15] = 1;
283 spikes[100] = 1;
284 let (errors, _) = predict_and_xor_lfsr(&spikes, 10, 2, 0x1234);
285 let recovered = xor_and_recover_lfsr(&errors, 10, 2, 0x1234);
286 assert_eq!(recovered, spikes);
287 }
288
289 #[test]
290 fn test_lfsr_deterministic() {
291 let spikes = vec![0i8; 50]; let (e1, _) = predict_and_xor_lfsr(&spikes, 10, 1, 0xBEEF);
293 let (e2, _) = predict_and_xor_lfsr(&spikes, 10, 1, 0xBEEF);
294 assert_eq!(e1, e2);
295 }
296}