1pub struct CorticalColumnRust {
14 pub n: usize,
15 decay: f64,
16 threshold: f64,
17 dt_over_tau: f64,
18 pub v_l4: Vec<f64>,
20 pub v_l23e: Vec<f64>,
21 pub v_l23i: Vec<f64>,
22 pub v_l5: Vec<f64>,
23 pub v_l6: Vec<f64>,
24 w_thal_l4: Vec<f64>,
26 w_l4_l23e: Vec<f64>,
27 w_l23e_l23i: Vec<f64>,
28 w_l23i_l23e: Vec<f64>,
29 w_l23e_l5: Vec<f64>,
30 w_l5_l6: Vec<f64>,
31 w_l6_l4: Vec<f64>,
32}
33
34impl CorticalColumnRust {
35 pub fn new(
36 n: usize,
37 tau: f64,
38 dt: f64,
39 threshold: f64,
40 w_exc: f64,
41 w_inh: f64,
42 seed: u64,
43 ) -> Self {
44 let decay = (-dt / tau).exp();
45 let dt_over_tau = dt / tau;
46 let mut rng = SimpleRng::new(seed);
47
48 let make_weights = |rng: &mut SimpleRng, strength: f64, prob: f64| -> Vec<f64> {
49 let mut w = vec![0.0f64; n * n];
50 for i in 0..n * n {
51 if rng.next_f64() < prob {
52 w[i] = rng.next_f64() * strength.abs();
53 if strength < 0.0 {
54 w[i] = -w[i];
55 }
56 }
57 }
58 w
59 };
60
61 Self {
62 n,
63 decay,
64 threshold,
65 dt_over_tau,
66 v_l4: vec![0.0; n],
67 v_l23e: vec![0.0; n],
68 v_l23i: vec![0.0; n],
69 v_l5: vec![0.0; n],
70 v_l6: vec![0.0; n],
71 w_thal_l4: make_weights(&mut rng, w_exc, 0.5),
72 w_l4_l23e: make_weights(&mut rng, w_exc, 0.4),
73 w_l23e_l23i: make_weights(&mut rng, w_exc, 0.3),
74 w_l23i_l23e: make_weights(&mut rng, w_inh, 0.3),
75 w_l23e_l5: make_weights(&mut rng, w_exc, 0.3),
76 w_l5_l6: make_weights(&mut rng, w_exc, 0.3),
77 w_l6_l4: make_weights(&mut rng, w_exc * 0.5, 0.2),
78 }
79 }
80
81 pub fn step(&mut self, thalamic_input: &[f64]) -> [Vec<f64>; 5] {
83 let n = self.n;
84
85 let matvec = |w: &[f64], x: &[f64]| -> Vec<f64> {
87 let mut out = vec![0.0; n];
88 for i in 0..n {
89 let mut sum = 0.0;
90 for j in 0..n {
91 sum += w[i * n + j] * x[j];
92 }
93 out[i] = sum;
94 }
95 out
96 };
97
98 let thresh_vec = |v: &[f64]| -> Vec<f64> {
99 v.iter()
100 .map(|&vi| if vi > self.threshold { 1.0 } else { 0.0 })
101 .collect()
102 };
103
104 let l6_spk = thresh_vec(&self.v_l6);
106
107 let i_l4_thal = matvec(&self.w_thal_l4, thalamic_input);
109 let i_l4_fb = matvec(&self.w_l6_l4, &l6_spk);
110 for i in 0..n {
111 self.v_l4[i] =
112 self.decay * self.v_l4[i] + (i_l4_thal[i] + i_l4_fb[i]) * self.dt_over_tau;
113 }
114 let spk_l4 = thresh_vec(&self.v_l4);
115 for i in 0..n {
116 self.v_l4[i] -= spk_l4[i] * self.threshold;
117 }
118
119 let i_l23e_ff = matvec(&self.w_l4_l23e, &spk_l4);
121 let l23i_spk = thresh_vec(&self.v_l23i);
122 let i_l23e_inh = matvec(&self.w_l23i_l23e, &l23i_spk);
123 for i in 0..n {
124 self.v_l23e[i] =
125 self.decay * self.v_l23e[i] + (i_l23e_ff[i] + i_l23e_inh[i]) * self.dt_over_tau;
126 }
127 let spk_l23e = thresh_vec(&self.v_l23e);
128 for i in 0..n {
129 self.v_l23e[i] -= spk_l23e[i] * self.threshold;
130 }
131
132 let i_l23i = matvec(&self.w_l23e_l23i, &spk_l23e);
134 for i in 0..n {
135 self.v_l23i[i] = self.decay * self.v_l23i[i] + i_l23i[i] * self.dt_over_tau;
136 }
137 let spk_l23i = thresh_vec(&self.v_l23i);
138 for i in 0..n {
139 self.v_l23i[i] -= spk_l23i[i] * self.threshold;
140 }
141
142 let i_l5 = matvec(&self.w_l23e_l5, &spk_l23e);
144 for i in 0..n {
145 self.v_l5[i] = self.decay * self.v_l5[i] + i_l5[i] * self.dt_over_tau;
146 }
147 let spk_l5 = thresh_vec(&self.v_l5);
148 for i in 0..n {
149 self.v_l5[i] -= spk_l5[i] * self.threshold;
150 }
151
152 let i_l6 = matvec(&self.w_l5_l6, &spk_l5);
154 for i in 0..n {
155 self.v_l6[i] = self.decay * self.v_l6[i] + i_l6[i] * self.dt_over_tau;
156 }
157 let spk_l6_new = thresh_vec(&self.v_l6);
158 for i in 0..n {
159 self.v_l6[i] -= spk_l6_new[i] * self.threshold;
160 }
161
162 [spk_l4, spk_l23e, spk_l23i, spk_l5, spk_l6_new]
163 }
164
165 pub fn reset(&mut self) {
166 self.v_l4.fill(0.0);
167 self.v_l23e.fill(0.0);
168 self.v_l23i.fill(0.0);
169 self.v_l5.fill(0.0);
170 self.v_l6.fill(0.0);
171 }
172}
173
174struct SimpleRng {
176 state: u64,
177}
178
179impl SimpleRng {
180 fn new(seed: u64) -> Self {
181 Self {
182 state: seed.wrapping_add(1),
183 }
184 }
185
186 fn next_u64(&mut self) -> u64 {
187 self.state ^= self.state << 13;
188 self.state ^= self.state >> 7;
189 self.state ^= self.state << 17;
190 self.state
191 }
192
193 fn next_f64(&mut self) -> f64 {
194 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_column_step_output_lengths() {
204 let mut col = CorticalColumnRust::new(10, 10.0, 1.0, 1.0, 0.5, -0.3, 42);
205 let input = vec![5.0; 10];
206 let spikes = col.step(&input);
207 assert_eq!(spikes.len(), 5);
208 for pop in &spikes {
209 assert_eq!(pop.len(), 10);
210 }
211 }
212
213 #[test]
214 fn test_column_produces_spikes() {
215 let mut col = CorticalColumnRust::new(20, 10.0, 1.0, 0.5, 1.0, -0.3, 42);
216 let input = vec![10.0; 20];
217 let mut total = 0.0;
218 for _ in 0..50 {
219 let spikes = col.step(&input);
220 total += spikes[0].iter().sum::<f64>(); }
222 assert!(total > 0.0, "Expected L4 spikes");
223 }
224
225 #[test]
226 fn test_column_reset() {
227 let mut col = CorticalColumnRust::new(5, 10.0, 1.0, 1.0, 0.5, -0.3, 42);
228 let input = vec![10.0; 5];
229 col.step(&input);
230 col.reset();
231 assert!(col.v_l4.iter().all(|&v| v == 0.0));
232 assert!(col.v_l5.iter().all(|&v| v == 0.0));
233 }
234
235 #[test]
236 fn test_column_deterministic() {
237 let mut col_a = CorticalColumnRust::new(5, 10.0, 1.0, 1.0, 0.5, -0.3, 99);
238 let mut col_b = CorticalColumnRust::new(5, 10.0, 1.0, 1.0, 0.5, -0.3, 99);
239 let input = vec![3.0; 5];
240 let sa = col_a.step(&input);
241 let sb = col_b.step(&input);
242 assert_eq!(sa, sb);
243 }
244}