1pub struct CorticalColumnRust {
13 pub n: usize,
14 decay: f64,
15 threshold: f64,
16 dt_over_tau: f64,
17 pub v_l4: Vec<f64>,
19 pub v_l23e: Vec<f64>,
20 pub v_l23i: Vec<f64>,
21 pub v_l5: Vec<f64>,
22 pub v_l6: Vec<f64>,
23 w_thal_l4: Vec<f64>,
25 w_l4_l23e: Vec<f64>,
26 w_l23e_l23i: Vec<f64>,
27 w_l23i_l23e: Vec<f64>,
28 w_l23e_l5: Vec<f64>,
29 w_l5_l6: Vec<f64>,
30 w_l6_l4: Vec<f64>,
31}
32
33impl CorticalColumnRust {
34 pub fn new(
35 n: usize,
36 tau: f64,
37 dt: f64,
38 threshold: f64,
39 w_exc: f64,
40 w_inh: f64,
41 seed: u64,
42 ) -> Self {
43 let decay = (-dt / tau).exp();
44 let dt_over_tau = dt / tau;
45 let mut rng = SimpleRng::new(seed);
46
47 let make_weights = |rng: &mut SimpleRng, strength: f64, prob: f64| -> Vec<f64> {
48 let mut w = vec![0.0f64; n * n];
49 for i in 0..n * n {
50 if rng.next_f64() < prob {
51 w[i] = rng.next_f64() * strength.abs();
52 if strength < 0.0 {
53 w[i] = -w[i];
54 }
55 }
56 }
57 w
58 };
59
60 Self {
61 n,
62 decay,
63 threshold,
64 dt_over_tau,
65 v_l4: vec![0.0; n],
66 v_l23e: vec![0.0; n],
67 v_l23i: vec![0.0; n],
68 v_l5: vec![0.0; n],
69 v_l6: vec![0.0; n],
70 w_thal_l4: make_weights(&mut rng, w_exc, 0.5),
71 w_l4_l23e: make_weights(&mut rng, w_exc, 0.4),
72 w_l23e_l23i: make_weights(&mut rng, w_exc, 0.3),
73 w_l23i_l23e: make_weights(&mut rng, w_inh, 0.3),
74 w_l23e_l5: make_weights(&mut rng, w_exc, 0.3),
75 w_l5_l6: make_weights(&mut rng, w_exc, 0.3),
76 w_l6_l4: make_weights(&mut rng, w_exc * 0.5, 0.2),
77 }
78 }
79
80 pub fn step(&mut self, thalamic_input: &[f64]) -> [Vec<f64>; 5] {
82 let n = self.n;
83
84 let matvec = |w: &[f64], x: &[f64]| -> Vec<f64> {
86 let mut out = vec![0.0; n];
87 for i in 0..n {
88 let mut sum = 0.0;
89 for j in 0..n {
90 sum += w[i * n + j] * x[j];
91 }
92 out[i] = sum;
93 }
94 out
95 };
96
97 let thresh_vec = |v: &[f64]| -> Vec<f64> {
98 v.iter()
99 .map(|&vi| if vi > self.threshold { 1.0 } else { 0.0 })
100 .collect()
101 };
102
103 let l6_spk = thresh_vec(&self.v_l6);
105
106 let i_l4_thal = matvec(&self.w_thal_l4, thalamic_input);
108 let i_l4_fb = matvec(&self.w_l6_l4, &l6_spk);
109 for i in 0..n {
110 self.v_l4[i] =
111 self.decay * self.v_l4[i] + (i_l4_thal[i] + i_l4_fb[i]) * self.dt_over_tau;
112 }
113 let spk_l4 = thresh_vec(&self.v_l4);
114 for i in 0..n {
115 self.v_l4[i] -= spk_l4[i] * self.threshold;
116 }
117
118 let i_l23e_ff = matvec(&self.w_l4_l23e, &spk_l4);
120 let l23i_spk = thresh_vec(&self.v_l23i);
121 let i_l23e_inh = matvec(&self.w_l23i_l23e, &l23i_spk);
122 for i in 0..n {
123 self.v_l23e[i] =
124 self.decay * self.v_l23e[i] + (i_l23e_ff[i] + i_l23e_inh[i]) * self.dt_over_tau;
125 }
126 let spk_l23e = thresh_vec(&self.v_l23e);
127 for i in 0..n {
128 self.v_l23e[i] -= spk_l23e[i] * self.threshold;
129 }
130
131 let i_l23i = matvec(&self.w_l23e_l23i, &spk_l23e);
133 for i in 0..n {
134 self.v_l23i[i] = self.decay * self.v_l23i[i] + i_l23i[i] * self.dt_over_tau;
135 }
136 let spk_l23i = thresh_vec(&self.v_l23i);
137 for i in 0..n {
138 self.v_l23i[i] -= spk_l23i[i] * self.threshold;
139 }
140
141 let i_l5 = matvec(&self.w_l23e_l5, &spk_l23e);
143 for i in 0..n {
144 self.v_l5[i] = self.decay * self.v_l5[i] + i_l5[i] * self.dt_over_tau;
145 }
146 let spk_l5 = thresh_vec(&self.v_l5);
147 for i in 0..n {
148 self.v_l5[i] -= spk_l5[i] * self.threshold;
149 }
150
151 let i_l6 = matvec(&self.w_l5_l6, &spk_l5);
153 for i in 0..n {
154 self.v_l6[i] = self.decay * self.v_l6[i] + i_l6[i] * self.dt_over_tau;
155 }
156 let spk_l6_new = thresh_vec(&self.v_l6);
157 for i in 0..n {
158 self.v_l6[i] -= spk_l6_new[i] * self.threshold;
159 }
160
161 [spk_l4, spk_l23e, spk_l23i, spk_l5, spk_l6_new]
162 }
163
164 pub fn reset(&mut self) {
165 self.v_l4.fill(0.0);
166 self.v_l23e.fill(0.0);
167 self.v_l23i.fill(0.0);
168 self.v_l5.fill(0.0);
169 self.v_l6.fill(0.0);
170 }
171}
172
173struct SimpleRng {
175 state: u64,
176}
177
178impl SimpleRng {
179 fn new(seed: u64) -> Self {
180 Self {
181 state: seed.wrapping_add(1),
182 }
183 }
184
185 fn next_u64(&mut self) -> u64 {
186 self.state ^= self.state << 13;
187 self.state ^= self.state >> 7;
188 self.state ^= self.state << 17;
189 self.state
190 }
191
192 fn next_f64(&mut self) -> f64 {
193 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_column_step_output_lengths() {
203 let mut col = CorticalColumnRust::new(10, 10.0, 1.0, 1.0, 0.5, -0.3, 42);
204 let input = vec![5.0; 10];
205 let spikes = col.step(&input);
206 assert_eq!(spikes.len(), 5);
207 for pop in &spikes {
208 assert_eq!(pop.len(), 10);
209 }
210 }
211
212 #[test]
213 fn test_column_produces_spikes() {
214 let mut col = CorticalColumnRust::new(20, 10.0, 1.0, 0.5, 1.0, -0.3, 42);
215 let input = vec![10.0; 20];
216 let mut total = 0.0;
217 for _ in 0..50 {
218 let spikes = col.step(&input);
219 total += spikes[0].iter().sum::<f64>(); }
221 assert!(total > 0.0, "Expected L4 spikes");
222 }
223
224 #[test]
225 fn test_column_reset() {
226 let mut col = CorticalColumnRust::new(5, 10.0, 1.0, 1.0, 0.5, -0.3, 42);
227 let input = vec![10.0; 5];
228 col.step(&input);
229 col.reset();
230 assert!(col.v_l4.iter().all(|&v| v == 0.0));
231 assert!(col.v_l5.iter().all(|&v| v == 0.0));
232 }
233
234 #[test]
235 fn test_column_deterministic() {
236 let mut col_a = CorticalColumnRust::new(5, 10.0, 1.0, 1.0, 0.5, -0.3, 99);
237 let mut col_b = CorticalColumnRust::new(5, 10.0, 1.0, 1.0, 0.5, -0.3, 99);
238 let input = vec![3.0; 5];
239 let sa = col_a.step(&input);
240 let sb = col_b.step(&input);
241 assert_eq!(sa, sb);
242 }
243}