1use rand::{RngExt, SeedableRng};
15use rand_chacha::ChaCha8Rng;
16use rand_xoshiro::Xoshiro256PlusPlus;
17use rayon::prelude::*;
18
19use crate::{bitstream, simd};
20
21const RAYON_ENCODE_THRESHOLD: usize = 128;
23const RAYON_NEURON_THRESHOLD: usize = 8;
25
26#[derive(Clone, Debug)]
28pub struct DenseLayer {
29 pub n_inputs: usize,
31 pub n_neurons: usize,
33 pub length: usize,
35 pub inv_length: f64,
36 pub words_per_input: usize,
38 pub weights: Vec<Vec<f64>>,
40 packed_weights_flat: Vec<u64>,
43 weight_seed: u64,
44}
45
46impl DenseLayer {
47 pub fn new(n_inputs: usize, n_neurons: usize, length: usize, seed: u64) -> Self {
49 assert!(length > 0, "bitstream length must be > 0");
50 assert!(n_inputs > 0, "n_inputs must be > 0");
51 assert!(n_neurons > 0, "n_neurons must be > 0");
52 let mut rng = ChaCha8Rng::seed_from_u64(seed);
53 let mut weights = vec![vec![0.0; n_inputs]; n_neurons];
54
55 for row in &mut weights {
56 for p in row {
57 *p = rng.random::<f64>();
58 }
59 }
60
61 let mut layer = Self {
62 n_inputs,
63 n_neurons,
64 length,
65 inv_length: 1.0 / length as f64,
66 words_per_input: length.div_ceil(64),
67 weights,
68 packed_weights_flat: vec![],
69 weight_seed: seed.wrapping_add(1),
70 };
71 layer.refresh_packed_weights();
72 layer
73 }
74
75 #[inline]
79 pub fn packed_weights_flat(&self) -> &[u64] {
80 &self.packed_weights_flat
81 }
82
83 #[inline]
85 fn weight_slice(&self, neuron: usize, input: usize) -> &[u64] {
86 let start = (neuron * self.n_inputs + input) * self.words_per_input;
87 &self.packed_weights_flat[start..start + self.words_per_input]
88 }
89
90 pub fn get_weights(&self) -> Vec<Vec<f64>> {
92 self.weights.clone()
93 }
94
95 pub fn set_weights(&mut self, weights: Vec<Vec<f64>>) -> Result<(), String> {
97 if weights.len() != self.n_neurons {
98 return Err(format!(
99 "Expected {} rows, got {}.",
100 self.n_neurons,
101 weights.len()
102 ));
103 }
104 for (row_idx, row) in weights.iter().enumerate() {
105 if row.len() != self.n_inputs {
106 return Err(format!(
107 "Row {} has length {}, expected {}.",
108 row_idx,
109 row.len(),
110 self.n_inputs
111 ));
112 }
113 }
114 self.weights = weights;
115 self.refresh_packed_weights();
116 Ok(())
117 }
118
119 pub fn refresh_packed_weights(&mut self) {
121 let n_inputs = self.n_inputs;
122 let words = self.words_per_input;
123 let length = self.length;
124 let weight_seed = self.weight_seed;
125 let weights = &self.weights;
126
127 let mut packed_weights_flat = vec![0_u64; self.n_neurons * n_inputs * words];
128
129 packed_weights_flat
130 .par_chunks_mut(n_inputs * words)
131 .enumerate()
132 .for_each(|(neuron_idx, neuron_chunk)| {
133 let mut rng =
134 ChaCha8Rng::seed_from_u64(weight_seed.wrapping_add(neuron_idx as u64));
135 for (input_idx, input_chunk) in neuron_chunk.chunks_mut(words).enumerate() {
136 let weight_prob = weights[neuron_idx][input_idx];
137 if weight_prob <= 0.0 {
138 input_chunk.fill(0);
139 } else if weight_prob >= 1.0 {
140 input_chunk.fill(u64::MAX);
141 if !length.is_multiple_of(64) {
142 input_chunk[words - 1] = (1_u64 << (length % 64)) - 1;
143 }
144 } else {
145 let packed =
146 bitstream::bernoulli_packed_simd(weight_prob, length, &mut rng);
147 input_chunk.copy_from_slice(&packed);
148 }
149 }
150 });
151
152 self.packed_weights_flat = packed_weights_flat;
153 }
154
155 pub fn forward(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
159 if input_values.len() != self.n_inputs {
160 return Err(format!(
161 "Expected input of length {}, got {}.",
162 self.n_inputs,
163 input_values.len()
164 ));
165 }
166
167 let words = self.words_per_input;
168 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
169 let mut packed_inputs_flat = vec![0_u64; self.n_inputs * words];
170
171 for (idx, p) in input_values.iter().copied().enumerate() {
172 let packed = bitstream::bernoulli_packed(p, self.length, &mut rng);
173 packed_inputs_flat[idx * words..(idx + 1) * words].copy_from_slice(&packed);
174 }
175
176 let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
177 let n_inputs = self.n_inputs;
178 self.packed_weights_flat
179 .par_chunks_exact(n_inputs * words)
180 .map(|neuron_weights| {
181 let total =
182 simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
183 total as f64 * self.inv_length
184 })
185 .collect()
186 } else {
187 let n_inputs = self.n_inputs;
188 self.packed_weights_flat
189 .chunks_exact(n_inputs * words)
190 .map(|neuron_weights| {
191 let total =
192 simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
193 total as f64 * self.inv_length
194 })
195 .collect()
196 };
197
198 Ok(out)
199 }
200
201 pub fn forward_fast(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
206 if input_values.len() != self.n_inputs {
207 return Err(format!(
208 "Expected input of length {}, got {}.",
209 self.n_inputs,
210 input_values.len()
211 ));
212 }
213
214 let words = self.words_per_input;
215 let mut packed_inputs_flat = vec![0_u64; self.n_inputs * words];
216
217 if self.n_inputs >= RAYON_ENCODE_THRESHOLD {
218 packed_inputs_flat
219 .par_chunks_mut(words)
220 .enumerate()
221 .for_each(|(idx, chunk)| {
222 let p = input_values[idx];
223 let input_seed = seed.wrapping_add(idx as u64);
224 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
225 let packed = bitstream::bernoulli_packed_simd(p, self.length, &mut rng);
226 chunk.copy_from_slice(&packed);
227 });
228 } else {
229 packed_inputs_flat
230 .chunks_mut(words)
231 .enumerate()
232 .for_each(|(idx, chunk)| {
233 let p = input_values[idx];
234 let input_seed = seed.wrapping_add(idx as u64);
235 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
236 let packed = bitstream::bernoulli_packed_simd(p, self.length, &mut rng);
237 chunk.copy_from_slice(&packed);
238 });
239 }
240
241 let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
242 let n_inputs = self.n_inputs;
243 self.packed_weights_flat
244 .par_chunks_exact(n_inputs * words)
245 .map(|neuron_weights| {
246 let total =
247 simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
248 total as f64 * self.inv_length
249 })
250 .collect()
251 } else {
252 let n_inputs = self.n_inputs;
253 self.packed_weights_flat
254 .chunks_exact(n_inputs * words)
255 .map(|neuron_weights| {
256 let total =
257 simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
258 total as f64 * self.inv_length
259 })
260 .collect()
261 };
262
263 Ok(out)
264 }
265
266 pub fn forward_fused(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
271 if input_values.len() != self.n_inputs {
272 return Err(format!(
273 "Expected input of length {}, got {}.",
274 self.n_inputs,
275 input_values.len()
276 ));
277 }
278
279 let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
280 (0..self.n_neurons)
281 .into_par_iter()
282 .map(|neuron_idx| {
283 let total: u64 = input_values
284 .iter()
285 .enumerate()
286 .map(|(input_idx, &p)| {
287 let input_seed = seed.wrapping_add(input_idx as u64);
288 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
289 bitstream::encode_and_popcount(
290 self.weight_slice(neuron_idx, input_idx),
291 p,
292 self.length,
293 &mut rng,
294 )
295 })
296 .sum();
297 total as f64 * self.inv_length
298 })
299 .collect()
300 } else {
301 (0..self.n_neurons)
302 .map(|neuron_idx| {
303 let total: u64 = input_values
304 .iter()
305 .enumerate()
306 .map(|(input_idx, &p)| {
307 let input_seed = seed.wrapping_add(input_idx as u64);
308 let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
309 bitstream::encode_and_popcount(
310 self.weight_slice(neuron_idx, input_idx),
311 p,
312 self.length,
313 &mut rng,
314 )
315 })
316 .sum();
317 total as f64 * self.inv_length
318 })
319 .collect()
320 };
321
322 Ok(out)
323 }
324
325 pub fn forward_batch_into(
330 &self,
331 inputs_flat: &[f64],
332 n_samples: usize,
333 seed: u64,
334 output: &mut [f64],
335 ) -> Result<(), String> {
336 let expected_inputs = n_samples.checked_mul(self.n_inputs).ok_or_else(|| {
337 "Input size overflow when validating n_samples * n_inputs.".to_string()
338 })?;
339 if inputs_flat.len() != expected_inputs {
340 return Err(format!(
341 "Expected {} values ({}×{}), got {}.",
342 expected_inputs,
343 n_samples,
344 self.n_inputs,
345 inputs_flat.len()
346 ));
347 }
348
349 let expected_outputs = n_samples.checked_mul(self.n_neurons).ok_or_else(|| {
350 "Output size overflow when validating n_samples * n_neurons.".to_string()
351 })?;
352 if output.len() != expected_outputs {
353 return Err(format!(
354 "Expected output length {} ({}×{}), got {}.",
355 expected_outputs,
356 n_samples,
357 self.n_neurons,
358 output.len()
359 ));
360 }
361
362 output
363 .par_chunks_mut(self.n_neurons)
364 .enumerate()
365 .for_each(|(sample_idx, out_row)| {
366 let start = sample_idx * self.n_inputs;
367 let end = start + self.n_inputs;
368 let input_row = &inputs_flat[start..end];
369 let sample_seed = seed.wrapping_add((sample_idx as u64).wrapping_mul(1_000_000));
370
371 if let Ok(res) = self.forward_fast(input_row, sample_seed) {
372 out_row.copy_from_slice(&res);
373 }
374 });
375
376 Ok(())
377 }
378
379 pub fn forward_batch(
384 &self,
385 inputs_flat: &[f64],
386 n_samples: usize,
387 seed: u64,
388 ) -> Result<Vec<f64>, String> {
389 let output_len = n_samples.checked_mul(self.n_neurons).ok_or_else(|| {
390 "Output size overflow when allocating n_samples * n_neurons.".to_string()
391 })?;
392 let mut output = vec![0.0_f64; output_len];
393 self.forward_batch_into(inputs_flat, n_samples, seed, &mut output)?;
394 Ok(output)
395 }
396
397 pub fn forward_prepacked(&self, packed_inputs: &[Vec<u64>]) -> Result<Vec<f64>, String> {
403 if packed_inputs.len() != self.n_inputs {
404 return Err(format!(
405 "Expected {} packed inputs, got {}.",
406 self.n_inputs,
407 packed_inputs.len()
408 ));
409 }
410 let expected_words = self.length.div_ceil(64);
411 for (idx, pi) in packed_inputs.iter().enumerate() {
412 if pi.len() != expected_words {
413 return Err(format!(
414 "Packed input {} has {} words, expected {}.",
415 idx,
416 pi.len(),
417 expected_words
418 ));
419 }
420 }
421
422 let out = (0..self.n_neurons)
423 .into_par_iter()
424 .map(|neuron_idx| {
425 let total: u64 = packed_inputs
426 .iter()
427 .enumerate()
428 .map(|(input_idx, input_words)| {
429 simd::fused_and_popcount_dispatch(
430 self.weight_slice(neuron_idx, input_idx),
431 input_words,
432 )
433 })
434 .sum();
435 total as f64 * self.inv_length
436 })
437 .collect();
438
439 Ok(out)
440 }
441
442 pub fn forward_prepacked_2d(
447 &self,
448 packed_flat: &[u64],
449 n_inputs: usize,
450 words: usize,
451 ) -> Result<Vec<f64>, String> {
452 if n_inputs != self.n_inputs {
453 return Err(format!(
454 "Expected {} packed inputs, got {}.",
455 self.n_inputs, n_inputs
456 ));
457 }
458 let expected_words = self.length.div_ceil(64);
459 if words != expected_words {
460 return Err(format!(
461 "Expected {} words per input, got {}.",
462 expected_words, words
463 ));
464 }
465 if packed_flat.len() != n_inputs * words {
466 return Err(format!(
467 "Flat buffer length {} != n_inputs({}) * words({}).",
468 packed_flat.len(),
469 n_inputs,
470 words
471 ));
472 }
473
474 let out = (0..self.n_neurons)
475 .into_par_iter()
476 .map(|neuron_idx| {
477 let total: u64 = (0..self.n_inputs)
478 .map(|input_idx| {
479 let row_start = input_idx * words;
480 let input_words = &packed_flat[row_start..row_start + words];
481 simd::fused_and_popcount_dispatch(
482 self.weight_slice(neuron_idx, input_idx),
483 input_words,
484 )
485 })
486 .sum();
487 total as f64 * self.inv_length
488 })
489 .collect();
490
491 Ok(out)
492 }
493
494 pub fn forward_numpy_inner(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
498 self.forward_fast(input_values, seed)
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use rand::SeedableRng;
505 use rand_chacha::ChaCha8Rng;
506
507 use super::DenseLayer;
508 use crate::bitstream;
509
510 #[test]
511 fn flat_weight_roundtrip() {
512 let layer = DenseLayer::new(3, 2, 130, 42);
513 let words = 130_usize.div_ceil(64);
514 assert_eq!(layer.words_per_input, words);
515 assert_eq!(layer.packed_weights_flat.len(), 3 * 2 * words);
516
517 for neuron in 0..2 {
518 let mut rng = ChaCha8Rng::seed_from_u64(43 + neuron as u64);
519 for input in 0..3 {
520 let expected =
521 bitstream::bernoulli_packed_simd(layer.weights[neuron][input], 130, &mut rng);
522 assert_eq!(layer.weight_slice(neuron, input), expected.as_slice());
523 }
524 }
525 }
526
527 #[test]
528 fn forward_fused_matches_forward_fast() {
529 let layer = DenseLayer::new(16, 8, 1024, 42);
530 let inputs: Vec<f64> = (0..16).map(|i| (i as f64) / 16.0).collect();
531 let seed = 999_u64;
532
533 let fast = layer
534 .forward_fast(&inputs, seed)
535 .expect("forward_fast should succeed");
536 let fused = layer
537 .forward_fused(&inputs, seed)
538 .expect("forward_fused should succeed");
539 assert_eq!(
540 fast, fused,
541 "forward_fused must be bit-identical to forward_fast"
542 );
543 }
544
545 #[test]
546 fn forward_batch_matches_sequential_fused() {
547 let layer = DenseLayer::new(4, 3, 256, 123);
548 let n_samples = 5;
549 let inputs_flat: Vec<f64> = (0..(n_samples * 4))
550 .map(|i| ((i * 17 + 11) % 100) as f64 / 100.0)
551 .collect();
552 let seed = 77_u64;
553
554 let batch = layer
555 .forward_batch(&inputs_flat, n_samples, seed)
556 .expect("forward_batch should succeed");
557
558 for sample_idx in 0..n_samples {
559 let row = &inputs_flat[sample_idx * 4..(sample_idx + 1) * 4];
560 let sample_seed = seed.wrapping_add((sample_idx as u64).wrapping_mul(1_000_000));
561 let expected = layer
562 .forward_fused(row, sample_seed)
563 .expect("forward_fused should succeed");
564 let got = &batch[sample_idx * 3..(sample_idx + 1) * 3];
565 assert_eq!(got, expected.as_slice(), "sample_idx={sample_idx}");
566 }
567 }
568}