sc_neurocore_engine/
recurrent.rs1use rand::{RngExt, SeedableRng};
14use rand_chacha::ChaCha8Rng;
15
16#[derive(Clone, Debug)]
18pub struct RecurrentLayer {
19 pub n_inputs: usize,
20 pub n_neurons: usize,
21 pub w_in: Vec<f64>,
23 pub w_rec: Vec<f64>,
25 pub state: Vec<f64>,
27}
28
29impl RecurrentLayer {
30 pub fn new(n_inputs: usize, n_neurons: usize, seed: u64) -> Self {
31 let mut rng = ChaCha8Rng::seed_from_u64(seed);
32 let w_in: Vec<f64> = (0..n_neurons * n_inputs)
33 .map(|_| rng.random::<f64>() * 0.5)
34 .collect();
35 let w_rec: Vec<f64> = (0..n_neurons * n_neurons)
36 .map(|_| rng.random::<f64>() * 0.2)
37 .collect();
38 Self {
39 n_inputs,
40 n_neurons,
41 w_in,
42 w_rec,
43 state: vec![0.0; n_neurons],
44 }
45 }
46
47 #[allow(clippy::needless_range_loop)]
49 pub fn step(&mut self, input: &[f64]) -> &[f64] {
50 assert_eq!(input.len(), self.n_inputs);
51 let mut new_state = vec![0.0; self.n_neurons];
52 for i in 0..self.n_neurons {
53 let mut val = 0.0;
54 for j in 0..self.n_inputs {
55 val += self.w_in[i * self.n_inputs + j] * input[j];
56 }
57 for j in 0..self.n_neurons {
58 val += self.w_rec[i * self.n_neurons + j] * self.state[j];
59 }
60 new_state[i] = val.clamp(0.0, 1.0);
61 }
62 self.state = new_state;
63 &self.state
64 }
65
66 pub fn reset(&mut self) {
67 self.state.fill(0.0);
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn state_changes_after_step() {
77 let mut layer = RecurrentLayer::new(3, 5, 42);
78 let input = vec![0.5, 0.3, 0.8];
79 let state = layer.step(&input).to_vec();
80 assert!(state.iter().any(|&s| s > 0.0));
81 }
82
83 #[test]
84 fn reset_clears_state() {
85 let mut layer = RecurrentLayer::new(3, 5, 42);
86 layer.step(&[0.5, 0.3, 0.8]);
87 layer.reset();
88 assert!(layer.state.iter().all(|&s| s == 0.0));
89 }
90
91 #[test]
92 fn state_bounded() {
93 let mut layer = RecurrentLayer::new(2, 4, 99);
94 for _ in 0..100 {
95 layer.step(&[1.0, 1.0]);
96 }
97 assert!(layer.state.iter().all(|&s| (0.0..=1.0).contains(&s)));
98 }
99
100 #[test]
101 fn output_shape() {
102 let mut layer = RecurrentLayer::new(4, 8, 0);
103 let out = layer.step(&[0.1, 0.2, 0.3, 0.4]);
104 assert_eq!(out.len(), 8);
105 }
106}