sc_neurocore_engine/
recurrent.rs1use rand::{RngExt, SeedableRng};
13use rand_chacha::ChaCha8Rng;
14
15#[derive(Clone, Debug)]
17pub struct RecurrentLayer {
18 pub n_inputs: usize,
19 pub n_neurons: usize,
20 pub w_in: Vec<f64>,
22 pub w_rec: Vec<f64>,
24 pub state: Vec<f64>,
26}
27
28impl RecurrentLayer {
29 pub fn new(n_inputs: usize, n_neurons: usize, seed: u64) -> Self {
30 let mut rng = ChaCha8Rng::seed_from_u64(seed);
31 let w_in: Vec<f64> = (0..n_neurons * n_inputs)
32 .map(|_| rng.random::<f64>() * 0.5)
33 .collect();
34 let w_rec: Vec<f64> = (0..n_neurons * n_neurons)
35 .map(|_| rng.random::<f64>() * 0.2)
36 .collect();
37 Self {
38 n_inputs,
39 n_neurons,
40 w_in,
41 w_rec,
42 state: vec![0.0; n_neurons],
43 }
44 }
45
46 #[allow(clippy::needless_range_loop)]
48 pub fn step(&mut self, input: &[f64]) -> &[f64] {
49 assert_eq!(input.len(), self.n_inputs);
50 let mut new_state = vec![0.0; self.n_neurons];
51 for i in 0..self.n_neurons {
52 let mut val = 0.0;
53 for j in 0..self.n_inputs {
54 val += self.w_in[i * self.n_inputs + j] * input[j];
55 }
56 for j in 0..self.n_neurons {
57 val += self.w_rec[i * self.n_neurons + j] * self.state[j];
58 }
59 new_state[i] = val.clamp(0.0, 1.0);
60 }
61 self.state = new_state;
62 &self.state
63 }
64
65 pub fn reset(&mut self) {
66 self.state.fill(0.0);
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn state_changes_after_step() {
76 let mut layer = RecurrentLayer::new(3, 5, 42);
77 let input = vec![0.5, 0.3, 0.8];
78 let state = layer.step(&input).to_vec();
79 assert!(state.iter().any(|&s| s > 0.0));
80 }
81
82 #[test]
83 fn reset_clears_state() {
84 let mut layer = RecurrentLayer::new(3, 5, 42);
85 layer.step(&[0.5, 0.3, 0.8]);
86 layer.reset();
87 assert!(layer.state.iter().all(|&s| s == 0.0));
88 }
89
90 #[test]
91 fn state_bounded() {
92 let mut layer = RecurrentLayer::new(2, 4, 99);
93 for _ in 0..100 {
94 layer.step(&[1.0, 1.0]);
95 }
96 assert!(layer.state.iter().all(|&s| (0.0..=1.0).contains(&s)));
97 }
98
99 #[test]
100 fn output_shape() {
101 let mut layer = RecurrentLayer::new(4, 8, 0);
102 let out = layer.step(&[0.1, 0.2, 0.3, 0.4]);
103 assert_eq!(out.len(), 8);
104 }
105}