1use crate::bitstream;
14
15pub fn prediction_error_packed(predicted: &[u64], actual: &[u64], length: usize) -> f64 {
18 if length == 0 {
19 return 0.0;
20 }
21 let n = predicted.len().min(actual.len());
22 let mut xor_result = vec![0u64; n];
23 for i in 0..n {
24 xor_result[i] = predicted[i] ^ actual[i];
25 }
26 let hamming = bitstream::popcount_words_portable(&xor_result);
27 hamming as f64 / length as f64
28}
29
30pub fn batch_prediction_error(
33 predicted: &[Vec<u64>], actual: &[Vec<u64>], n_neurons: usize,
36 n_inputs: usize,
37 length: usize,
38) -> Vec<f64> {
39 let mut surprises = vec![0.0f64; n_neurons];
40 for j in 0..n_neurons {
41 let mut total_error = 0.0;
42 for i in 0..n_inputs {
43 let pred_idx = j * n_inputs + i;
44 if pred_idx < predicted.len() && i < actual.len() {
45 total_error += prediction_error_packed(&predicted[pred_idx], &actual[i], length);
46 }
47 }
48 surprises[j] = total_error / n_inputs.max(1) as f64;
49 }
50 surprises
51}
52
53pub fn update_prediction_weights(
55 weights: &mut [f64], actual_probs: &[f64], n_neurons: usize,
58 n_inputs: usize,
59 lr: f64,
60) {
61 for j in 0..n_neurons {
62 for i in 0..n_inputs {
63 let idx = j * n_inputs + i;
64 if idx < weights.len() && i < actual_probs.len() {
65 weights[idx] += lr * (actual_probs[i] - weights[idx]);
66 weights[idx] = weights[idx].clamp(0.0, 1.0);
67 }
68 }
69 }
70}
71
72use crate::encoder::Lfsr16;
75
76pub fn predict_and_xor_ema(
79 spikes: &[i8], n_channels: usize,
81 alpha: f64,
82 threshold: f64,
83) -> (Vec<i8>, usize) {
84 let t_steps = spikes.len() / n_channels;
85 let mut rates = vec![0.0f64; n_channels];
86 let mut errors = vec![0i8; spikes.len()];
87 let mut correct: usize = 0;
88 let one_minus_alpha = 1.0 - alpha;
89
90 for t in 0..t_steps {
91 let row_start = t * n_channels;
92 for ch in 0..n_channels {
93 let actual = spikes[row_start + ch];
94 let predicted = if rates[ch] > threshold { 1i8 } else { 0i8 };
95 let err = actual ^ predicted;
96 errors[row_start + ch] = err;
97 if err == 0 {
98 correct += 1;
99 }
100 rates[ch] = one_minus_alpha * rates[ch] + alpha * (actual as f64);
101 }
102 }
103 (errors, correct)
104}
105
106pub fn xor_and_recover_ema(
108 errors: &[i8],
109 n_channels: usize,
110 alpha: f64,
111 threshold: f64,
112) -> Vec<i8> {
113 let t_steps = errors.len() / n_channels;
114 let mut rates = vec![0.0f64; n_channels];
115 let mut spikes = vec![0i8; errors.len()];
116 let one_minus_alpha = 1.0 - alpha;
117
118 for t in 0..t_steps {
119 let row_start = t * n_channels;
120 for ch in 0..n_channels {
121 let predicted = if rates[ch] > threshold { 1i8 } else { 0i8 };
122 let actual = errors[row_start + ch] ^ predicted;
123 spikes[row_start + ch] = actual;
124 rates[ch] = one_minus_alpha * rates[ch] + alpha * (actual as f64);
125 }
126 }
127 spikes
128}
129
130pub fn predict_and_xor_lfsr(
133 spikes: &[i8],
134 n_channels: usize,
135 alpha_q8: i32,
136 seed: u16,
137) -> (Vec<i8>, usize) {
138 let t_steps = spikes.len() / n_channels;
139 let mut rates_q8 = vec![0i32; n_channels];
140 let mut errors = vec![0i8; spikes.len()];
141 let mut correct: usize = 0;
142
143 let mut lfsrs: Vec<Lfsr16> = (0..n_channels)
145 .map(|ch| {
146 let s = ((seed as u32).wrapping_add((ch as u32).wrapping_mul(7919))) & 0xFFFF;
147 Lfsr16::new(if s == 0 { 1 } else { s as u16 })
148 })
149 .collect();
150
151 for t in 0..t_steps {
152 let row_start = t * n_channels;
153 for ch in 0..n_channels {
154 let actual = spikes[row_start + ch];
155 let predicted = if (lfsrs[ch].reg as i32) < rates_q8[ch] {
156 1i8
157 } else {
158 0i8
159 };
160 lfsrs[ch].step();
161
162 let err = actual ^ predicted;
163 errors[row_start + ch] = err;
164 if err == 0 {
165 correct += 1;
166 }
167
168 let target: i32 = if actual != 0 { 255 } else { 0 };
169 rates_q8[ch] += (alpha_q8 * (target - rates_q8[ch])) >> 8;
170 rates_q8[ch] = rates_q8[ch].clamp(0, 255);
171 }
172 }
173 (errors, correct)
174}
175
176pub fn xor_and_recover_lfsr(errors: &[i8], n_channels: usize, alpha_q8: i32, seed: u16) -> Vec<i8> {
178 let t_steps = errors.len() / n_channels;
179 let mut rates_q8 = vec![0i32; n_channels];
180 let mut spikes = vec![0i8; errors.len()];
181
182 let mut lfsrs: Vec<Lfsr16> = (0..n_channels)
183 .map(|ch| {
184 let s = ((seed as u32).wrapping_add((ch as u32).wrapping_mul(7919))) & 0xFFFF;
185 Lfsr16::new(if s == 0 { 1 } else { s as u16 })
186 })
187 .collect();
188
189 for t in 0..t_steps {
190 let row_start = t * n_channels;
191 for ch in 0..n_channels {
192 let predicted = if (lfsrs[ch].reg as i32) < rates_q8[ch] {
193 1i8
194 } else {
195 0i8
196 };
197 lfsrs[ch].step();
198
199 let actual = errors[row_start + ch] ^ predicted;
200 spikes[row_start + ch] = actual;
201
202 let target: i32 = if actual != 0 { 255 } else { 0 };
203 rates_q8[ch] += (alpha_q8 * (target - rates_q8[ch])) >> 8;
204 rates_q8[ch] = rates_q8[ch].clamp(0, 255);
205 }
206 }
207 spikes
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn test_identical_streams_zero_error() {
216 let a = vec![0xFF_FF_FF_FF_FF_FF_FF_FFu64; 16]; let error = prediction_error_packed(&a, &a, 1024);
218 assert!((error - 0.0).abs() < 1e-10);
219 }
220
221 #[test]
222 fn test_opposite_streams_max_error() {
223 let a = vec![0xFF_FF_FF_FF_FF_FF_FF_FFu64; 16];
224 let b = vec![0u64; 16];
225 let error = prediction_error_packed(&a, &b, 1024);
226 assert!((error - 1.0).abs() < 1e-10);
227 }
228
229 #[test]
230 fn test_batch_error_shape() {
231 let pred = vec![vec![0u64; 4]; 6]; let actual = vec![vec![0xFF_FF_FF_FF_FF_FF_FF_FFu64; 4]; 3];
233 let surprises = batch_prediction_error(&pred, &actual, 2, 3, 256);
234 assert_eq!(surprises.len(), 2);
235 assert!(surprises[0] > 0.0);
236 }
237
238 #[test]
239 fn test_weight_update() {
240 let mut weights = vec![0.5, 0.5, 0.5, 0.5]; let actual = vec![0.8, 0.2];
242 update_prediction_weights(&mut weights, &actual, 2, 2, 0.5);
243 assert!(weights[0] > 0.5); assert!(weights[1] < 0.5); }
246
247 #[test]
248 fn test_ema_roundtrip() {
249 let spikes = vec![0i8; 100]; let (errors, correct) = predict_and_xor_ema(&spikes, 10, 0.005, 0.5);
252 assert_eq!(errors.len(), 100);
253 assert_eq!(correct, 100); let recovered = xor_and_recover_ema(&errors, 10, 0.005, 0.5);
255 assert_eq!(recovered, spikes);
256 }
257
258 #[test]
259 fn test_ema_roundtrip_with_spikes() {
260 let mut spikes = vec![0i8; 200]; spikes[5] = 1; spikes[15] = 1; let (errors, _) = predict_and_xor_ema(&spikes, 10, 0.01, 0.5);
264 let recovered = xor_and_recover_ema(&errors, 10, 0.01, 0.5);
265 assert_eq!(recovered, spikes);
266 }
267
268 #[test]
269 fn test_lfsr_roundtrip() {
270 let spikes = vec![0i8; 100];
271 let (errors, correct) = predict_and_xor_lfsr(&spikes, 10, 1, 0xACE1);
272 assert_eq!(correct, 100);
273 let recovered = xor_and_recover_lfsr(&errors, 10, 1, 0xACE1);
274 assert_eq!(recovered, spikes);
275 }
276
277 #[test]
278 fn test_lfsr_roundtrip_with_spikes() {
279 let mut spikes = vec![0i8; 200];
280 spikes[5] = 1;
281 spikes[15] = 1;
282 spikes[100] = 1;
283 let (errors, _) = predict_and_xor_lfsr(&spikes, 10, 2, 0x1234);
284 let recovered = xor_and_recover_lfsr(&errors, 10, 2, 0x1234);
285 assert_eq!(recovered, spikes);
286 }
287
288 #[test]
289 fn test_lfsr_deterministic() {
290 let spikes = vec![0i8; 50]; let (e1, _) = predict_and_xor_lfsr(&spikes, 10, 1, 0xBEEF);
292 let (e2, _) = predict_and_xor_lfsr(&spikes, 10, 1, 0xBEEF);
293 assert_eq!(e1, e2);
294 }
295}