1use rand::{RngExt, SeedableRng};
14use rand_chacha::ChaCha8Rng;
15use rand_xoshiro::Xoshiro256PlusPlus;
16use rayon::prelude::*;
17
18use crate::{bitstream, simd};
19
20const RAYON_ENCODE_THRESHOLD: usize = 128;
22const RAYON_NEURON_THRESHOLD: usize = 8;
24
25#[derive(Clone, Debug)]
27pub struct DenseLayer {
28 pub n_inputs: usize,
30 pub n_neurons: usize,
32 pub length: usize,
34 pub words_per_input: usize,
36 pub weights: Vec<Vec<f64>>,
38 packed_weights_flat: Vec<u64>,
41 weight_seed: u64,
42}
43
44impl DenseLayer {
45 pub fn new(n_inputs: usize, n_neurons: usize, length: usize, seed: u64) -> Self {
47 assert!(length > 0, "bitstream length must be > 0");
48 assert!(n_inputs > 0, "n_inputs must be > 0");
49 assert!(n_neurons > 0, "n_neurons must be > 0");
50 let mut rng = ChaCha8Rng::seed_from_u64(seed);
51 let mut weights = vec![vec![0.0; n_inputs]; n_neurons];
52
53 for row in &mut weights {
54 for p in row {
55 *p = rng.random::<f64>();
56 }
57 }
58
59 let mut layer = Self {
60 n_inputs,
61 n_neurons,
62 length,
63 words_per_input: length.div_ceil(64),
64 weights,
65 packed_weights_flat: vec![],
66 weight_seed: seed.wrapping_add(1),
67 };
68 layer.refresh_packed_weights();
69 layer
70 }
71
72 #[inline]
74 fn weight_slice(&self, neuron: usize, input: usize) -> &[u64] {
75 let start = (neuron * self.n_inputs + input) * self.words_per_input;
76 &self.packed_weights_flat[start..start + self.words_per_input]
77 }
78
79 pub fn get_weights(&self) -> Vec<Vec<f64>> {
81 self.weights.clone()
82 }
83
84 pub fn set_weights(&mut self, weights: Vec<Vec<f64>>) -> Result<(), String> {
86 if weights.len() != self.n_neurons {
87 return Err(format!(
88 "Expected {} rows, got {}.",
89 self.n_neurons,
90 weights.len()
91 ));
92 }
93 for (row_idx, row) in weights.iter().enumerate() {
94 if row.len() != self.n_inputs {
95 return Err(format!(
96 "Row {} has length {}, expected {}.",
97 row_idx,
98 row.len(),
99 self.n_inputs
100 ));
101 }
102 }
103 self.weights = weights;
104 self.refresh_packed_weights();
105 Ok(())
106 }
107
108 pub fn refresh_packed_weights(&mut self) {
110 let mut rng = ChaCha8Rng::seed_from_u64(self.weight_seed);
111 let mut packed_weights_flat =
112 vec![0_u64; self.n_neurons * self.n_inputs * self.words_per_input];
113
114 for (neuron_idx, neuron_weights) in self.weights.iter().enumerate().take(self.n_neurons) {
115 for (input_idx, weight_prob) in neuron_weights.iter().enumerate().take(self.n_inputs) {
116 let packed = bitstream::bernoulli_packed_simd(*weight_prob, self.length, &mut rng);
117 let start = (neuron_idx * self.n_inputs + input_idx) * self.words_per_input;
118 packed_weights_flat[start..start + self.words_per_input].copy_from_slice(&packed);
119 }
120 }
121
122 self.packed_weights_flat = packed_weights_flat;
123 }
124
125 pub fn forward(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
129 if input_values.len() != self.n_inputs {
130 return Err(format!(
131 "Expected input of length {}, got {}.",
132 self.n_inputs,
133 input_values.len()
134 ));
135 }
136
137 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
138 let mut packed_inputs = vec![Vec::<u64>::new(); self.n_inputs];
139 for (idx, p) in input_values.iter().copied().enumerate() {
140 packed_inputs[idx] = bitstream::bernoulli_packed(p, self.length, &mut rng);
141 }
142
143 let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
144 (0..self.n_neurons)
145 .into_par_iter()
146 .map(|neuron_idx| {
147 let total: u64 = packed_inputs
148 .iter()
149 .enumerate()
150 .map(|(input_idx, input_words)| {
151 simd::fused_and_popcount_dispatch(
152 self.weight_slice(neuron_idx, input_idx),
153 input_words,
154 )
155 })
156 .sum();
157 total as f64 / self.length as f64
158 })
159 .collect()
160 } else {
161 (0..self.n_neurons)
162 .map(|neuron_idx| {
163 let total: u64 = packed_inputs
164 .iter()
165 .enumerate()
166 .map(|(input_idx, input_words)| {
167 simd::fused_and_popcount_dispatch(
168 self.weight_slice(neuron_idx, input_idx),
169 input_words,
170 )
171 })
172 .sum();
173 total as f64 / self.length as f64
174 })
175 .collect()
176 };
177
178 Ok(out)
179 }
180
181 pub fn forward_fast(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
186 if input_values.len() != self.n_inputs {
187 return Err(format!(
188 "Expected input of length {}, got {}.",
189 self.n_inputs,
190 input_values.len()
191 ));
192 }
193
194 let packed_inputs: Vec<Vec<u64>> = if self.n_inputs >= RAYON_ENCODE_THRESHOLD {
195 input_values
196 .par_iter()
197 .enumerate()
198 .map(|(idx, &p)| {
199 let input_seed = seed.wrapping_add(idx as u64);
200 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
201 bitstream::bernoulli_packed_simd(p, self.length, &mut rng)
202 })
203 .collect()
204 } else {
205 input_values
206 .iter()
207 .enumerate()
208 .map(|(idx, &p)| {
209 let input_seed = seed.wrapping_add(idx as u64);
210 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
211 bitstream::bernoulli_packed_simd(p, self.length, &mut rng)
212 })
213 .collect()
214 };
215
216 let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
217 (0..self.n_neurons)
218 .into_par_iter()
219 .map(|neuron_idx| {
220 let total: u64 = packed_inputs
221 .iter()
222 .enumerate()
223 .map(|(input_idx, input_words)| {
224 simd::fused_and_popcount_dispatch(
225 self.weight_slice(neuron_idx, input_idx),
226 input_words,
227 )
228 })
229 .sum();
230 total as f64 / self.length as f64
231 })
232 .collect()
233 } else {
234 (0..self.n_neurons)
235 .map(|neuron_idx| {
236 let total: u64 = packed_inputs
237 .iter()
238 .enumerate()
239 .map(|(input_idx, input_words)| {
240 simd::fused_and_popcount_dispatch(
241 self.weight_slice(neuron_idx, input_idx),
242 input_words,
243 )
244 })
245 .sum();
246 total as f64 / self.length as f64
247 })
248 .collect()
249 };
250
251 Ok(out)
252 }
253
254 pub fn forward_fused(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
259 if input_values.len() != self.n_inputs {
260 return Err(format!(
261 "Expected input of length {}, got {}.",
262 self.n_inputs,
263 input_values.len()
264 ));
265 }
266
267 let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
268 (0..self.n_neurons)
269 .into_par_iter()
270 .map(|neuron_idx| {
271 let total: u64 = input_values
272 .iter()
273 .enumerate()
274 .map(|(input_idx, &p)| {
275 let input_seed = seed.wrapping_add(input_idx as u64);
276 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
277 bitstream::encode_and_popcount(
278 self.weight_slice(neuron_idx, input_idx),
279 p,
280 self.length,
281 &mut rng,
282 )
283 })
284 .sum();
285 total as f64 / self.length as f64
286 })
287 .collect()
288 } else {
289 (0..self.n_neurons)
290 .map(|neuron_idx| {
291 let total: u64 = input_values
292 .iter()
293 .enumerate()
294 .map(|(input_idx, &p)| {
295 let input_seed = seed.wrapping_add(input_idx as u64);
296 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
297 bitstream::encode_and_popcount(
298 self.weight_slice(neuron_idx, input_idx),
299 p,
300 self.length,
301 &mut rng,
302 )
303 })
304 .sum();
305 total as f64 / self.length as f64
306 })
307 .collect()
308 };
309
310 Ok(out)
311 }
312
313 pub fn forward_batch_into(
318 &self,
319 inputs_flat: &[f64],
320 n_samples: usize,
321 seed: u64,
322 output: &mut [f64],
323 ) -> Result<(), String> {
324 let expected_inputs = n_samples.checked_mul(self.n_inputs).ok_or_else(|| {
325 "Input size overflow when validating n_samples * n_inputs.".to_string()
326 })?;
327 if inputs_flat.len() != expected_inputs {
328 return Err(format!(
329 "Expected {} values ({}×{}), got {}.",
330 expected_inputs,
331 n_samples,
332 self.n_inputs,
333 inputs_flat.len()
334 ));
335 }
336
337 let expected_outputs = n_samples.checked_mul(self.n_neurons).ok_or_else(|| {
338 "Output size overflow when validating n_samples * n_neurons.".to_string()
339 })?;
340 if output.len() != expected_outputs {
341 return Err(format!(
342 "Expected output length {} ({}×{}), got {}.",
343 expected_outputs,
344 n_samples,
345 self.n_neurons,
346 output.len()
347 ));
348 }
349
350 output
351 .par_chunks_mut(self.n_neurons)
352 .enumerate()
353 .for_each(|(sample_idx, out_row)| {
354 let start = sample_idx * self.n_inputs;
355 let end = start + self.n_inputs;
356 let input_row = &inputs_flat[start..end];
357 let sample_seed = seed.wrapping_add((sample_idx as u64).wrapping_mul(1_000_000));
358
359 for (neuron_idx, out_val) in out_row.iter_mut().enumerate() {
360 let total: u64 = input_row
361 .iter()
362 .enumerate()
363 .map(|(input_idx, &p)| {
364 let input_seed = sample_seed.wrapping_add(input_idx as u64);
365 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
366 bitstream::encode_and_popcount(
367 self.weight_slice(neuron_idx, input_idx),
368 p,
369 self.length,
370 &mut rng,
371 )
372 })
373 .sum();
374 *out_val = total as f64 / self.length as f64;
375 }
376 });
377
378 Ok(())
379 }
380
381 pub fn forward_batch(
386 &self,
387 inputs_flat: &[f64],
388 n_samples: usize,
389 seed: u64,
390 ) -> Result<Vec<f64>, String> {
391 let output_len = n_samples.checked_mul(self.n_neurons).ok_or_else(|| {
392 "Output size overflow when allocating n_samples * n_neurons.".to_string()
393 })?;
394 let mut output = vec![0.0_f64; output_len];
395 self.forward_batch_into(inputs_flat, n_samples, seed, &mut output)?;
396 Ok(output)
397 }
398
399 pub fn forward_prepacked(&self, packed_inputs: &[Vec<u64>]) -> Result<Vec<f64>, String> {
405 if packed_inputs.len() != self.n_inputs {
406 return Err(format!(
407 "Expected {} packed inputs, got {}.",
408 self.n_inputs,
409 packed_inputs.len()
410 ));
411 }
412 let expected_words = self.length.div_ceil(64);
413 for (idx, pi) in packed_inputs.iter().enumerate() {
414 if pi.len() != expected_words {
415 return Err(format!(
416 "Packed input {} has {} words, expected {}.",
417 idx,
418 pi.len(),
419 expected_words
420 ));
421 }
422 }
423
424 let out = (0..self.n_neurons)
425 .into_par_iter()
426 .map(|neuron_idx| {
427 let total: u64 = packed_inputs
428 .iter()
429 .enumerate()
430 .map(|(input_idx, input_words)| {
431 simd::fused_and_popcount_dispatch(
432 self.weight_slice(neuron_idx, input_idx),
433 input_words,
434 )
435 })
436 .sum();
437 total as f64 / self.length as f64
438 })
439 .collect();
440
441 Ok(out)
442 }
443
444 pub fn forward_prepacked_2d(
449 &self,
450 packed_flat: &[u64],
451 n_inputs: usize,
452 words: usize,
453 ) -> Result<Vec<f64>, String> {
454 if n_inputs != self.n_inputs {
455 return Err(format!(
456 "Expected {} packed inputs, got {}.",
457 self.n_inputs, n_inputs
458 ));
459 }
460 let expected_words = self.length.div_ceil(64);
461 if words != expected_words {
462 return Err(format!(
463 "Expected {} words per input, got {}.",
464 expected_words, words
465 ));
466 }
467 if packed_flat.len() != n_inputs * words {
468 return Err(format!(
469 "Flat buffer length {} != n_inputs({}) * words({}).",
470 packed_flat.len(),
471 n_inputs,
472 words
473 ));
474 }
475
476 let out = (0..self.n_neurons)
477 .into_par_iter()
478 .map(|neuron_idx| {
479 let total: u64 = (0..self.n_inputs)
480 .map(|input_idx| {
481 let row_start = input_idx * words;
482 let input_words = &packed_flat[row_start..row_start + words];
483 simd::fused_and_popcount_dispatch(
484 self.weight_slice(neuron_idx, input_idx),
485 input_words,
486 )
487 })
488 .sum();
489 total as f64 / self.length as f64
490 })
491 .collect();
492
493 Ok(out)
494 }
495
496 pub fn forward_numpy_inner(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
500 self.forward_fused(input_values, seed)
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use rand::SeedableRng;
507 use rand_chacha::ChaCha8Rng;
508
509 use super::DenseLayer;
510 use crate::bitstream;
511
512 #[test]
513 fn flat_weight_roundtrip() {
514 let layer = DenseLayer::new(3, 2, 130, 42);
515 let words = 130_usize.div_ceil(64);
516 assert_eq!(layer.words_per_input, words);
517 assert_eq!(layer.packed_weights_flat.len(), 3 * 2 * words);
518
519 let mut rng = ChaCha8Rng::seed_from_u64(43);
520 for neuron in 0..2 {
521 for input in 0..3 {
522 let expected =
523 bitstream::bernoulli_packed_simd(layer.weights[neuron][input], 130, &mut rng);
524 assert_eq!(layer.weight_slice(neuron, input), expected.as_slice());
525 }
526 }
527 }
528
529 #[test]
530 fn forward_fused_matches_forward_fast() {
531 let layer = DenseLayer::new(16, 8, 1024, 42);
532 let inputs: Vec<f64> = (0..16).map(|i| (i as f64) / 16.0).collect();
533 let seed = 999_u64;
534
535 let fast = layer
536 .forward_fast(&inputs, seed)
537 .expect("forward_fast should succeed");
538 let fused = layer
539 .forward_fused(&inputs, seed)
540 .expect("forward_fused should succeed");
541 assert_eq!(
542 fast, fused,
543 "forward_fused must be bit-identical to forward_fast"
544 );
545 }
546
547 #[test]
548 fn forward_batch_matches_sequential_fused() {
549 let layer = DenseLayer::new(4, 3, 256, 123);
550 let n_samples = 5;
551 let inputs_flat: Vec<f64> = (0..(n_samples * 4))
552 .map(|i| ((i * 17 + 11) % 100) as f64 / 100.0)
553 .collect();
554 let seed = 77_u64;
555
556 let batch = layer
557 .forward_batch(&inputs_flat, n_samples, seed)
558 .expect("forward_batch should succeed");
559
560 for sample_idx in 0..n_samples {
561 let row = &inputs_flat[sample_idx * 4..(sample_idx + 1) * 4];
562 let sample_seed = seed.wrapping_add((sample_idx as u64).wrapping_mul(1_000_000));
563 let expected = layer
564 .forward_fused(row, sample_seed)
565 .expect("forward_fused should succeed");
566 let got = &batch[sample_idx * 3..(sample_idx + 1) * 3];
567 assert_eq!(got, expected.as_slice(), "sample_idx={sample_idx}");
568 }
569 }
570}