1use crate::neuron::{mask, FixedPointLif};
16use rand::SeedableRng;
17use rand_xoshiro::Xoshiro256PlusPlus;
18
19pub struct BrunelNetwork {
21 neurons: Vec<FixedPointLif>,
22 prev_spikes: Vec<bool>,
23 w_row_offsets: Vec<usize>,
25 w_col_indices: Vec<usize>,
26 w_values: Vec<i16>,
27 n_neurons: usize,
28 leak_k: i16,
29 gain_k: i16,
30 ext_lambda: f64,
31 ext_weight_fp: i16,
32 rng: Xoshiro256PlusPlus,
33}
34
35impl BrunelNetwork {
36 #[allow(clippy::too_many_arguments)]
41 pub fn new(
42 n_neurons: usize,
43 w_row_offsets: Vec<usize>,
44 w_col_indices: Vec<usize>,
45 w_values: Vec<i16>,
46 data_width: u32,
47 fraction: u32,
48 v_rest: i16,
49 v_reset: i16,
50 v_threshold: i16,
51 refractory_period: i32,
52 leak_k: i16,
53 gain_k: i16,
54 ext_lambda: f64,
55 ext_weight_fp: i16,
56 seed: u64,
57 ) -> Result<Self, String> {
58 if w_row_offsets.len() != n_neurons + 1 {
59 return Err(format!(
60 "w_row_offsets length {} != n_neurons+1={}",
61 w_row_offsets.len(),
62 n_neurons + 1
63 ));
64 }
65 if w_col_indices.len() != w_values.len() {
66 return Err(format!(
67 "w_col_indices len {} != w_values len {}",
68 w_col_indices.len(),
69 w_values.len()
70 ));
71 }
72 let neurons: Vec<FixedPointLif> = (0..n_neurons)
73 .map(|_| {
74 FixedPointLif::new(
75 data_width,
76 fraction,
77 v_rest,
78 v_reset,
79 v_threshold,
80 refractory_period,
81 )
82 })
83 .collect();
84
85 Ok(Self {
86 neurons,
87 prev_spikes: vec![false; n_neurons],
88 w_row_offsets,
89 w_col_indices,
90 w_values,
91 n_neurons,
92 leak_k,
93 gain_k,
94 ext_lambda,
95 ext_weight_fp,
96 rng: Xoshiro256PlusPlus::seed_from_u64(seed),
97 })
98 }
99
100 pub fn run(&mut self, n_steps: usize) -> Vec<u32> {
102 let n = self.n_neurons;
103 let mut i_syn = vec![0i32; n];
104 let mut counts = Vec::with_capacity(n_steps);
105
106 for _ in 0..n_steps {
107 i_syn.iter_mut().for_each(|x| *x = 0);
110 for pre in 0..n {
111 if !self.prev_spikes[pre] {
112 continue;
113 }
114 let start = self.w_row_offsets[pre];
115 let end = self.w_row_offsets[pre + 1];
116 for idx in start..end {
117 let post = self.w_col_indices[idx];
118 i_syn[post] += self.w_values[idx] as i32;
119 }
120 }
121
122 let mut step_spikes = 0u32;
124 #[allow(clippy::needless_range_loop)]
125 for i in 0..n {
126 let ext_count = poisson_sample(&mut self.rng, self.ext_lambda);
127 let ext_current = (ext_count as i32) * (self.ext_weight_fp as i32);
128 let total_current = i_syn[i] + ext_current;
129 let dw = self.neurons[i].data_width;
130 let i_t = mask(total_current, dw);
131
132 let (spike, _) = self.neurons[i].step(self.leak_k, self.gain_k, i_t, 0);
133 self.prev_spikes[i] = spike > 0;
134 if spike > 0 {
135 step_spikes += 1;
136 }
137 }
138 counts.push(step_spikes);
139 }
140 counts
141 }
142
143 pub fn total_spikes(&self, counts: &[u32]) -> u64 {
144 counts.iter().map(|&c| c as u64).sum()
145 }
146}
147
148fn poisson_sample(rng: &mut Xoshiro256PlusPlus, lambda: f64) -> u32 {
150 if lambda <= 0.0 {
151 return 0;
152 }
153 use rand::RngExt;
154 let l = (-lambda).exp();
155 let mut k = 0u32;
156 let mut p = 1.0f64;
157 loop {
158 k += 1;
159 p *= rng.random::<f64>();
160 if p <= l {
161 return k - 1;
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn make_small_network() -> BrunelNetwork {
171 let n = 4;
173 let mut row_offsets = vec![0usize; n + 1];
174 let mut col_indices = Vec::new();
175 let mut values = Vec::new();
176 for i in 0..n {
177 for j in 0..n {
178 if i != j {
179 col_indices.push(j);
180 values.push(26i16); }
182 }
183 row_offsets[i + 1] = col_indices.len();
184 }
185
186 BrunelNetwork::new(
187 n,
188 row_offsets,
189 col_indices,
190 values,
191 16,
192 8, 0,
194 0,
195 256, 2, 1, 256, 5.0, 26, 42,
202 )
203 .unwrap()
204 }
205
206 #[test]
207 fn brunel_produces_spikes() {
208 let mut net = make_small_network();
209 let counts = net.run(100);
210 let total: u64 = net.total_spikes(&counts);
211 assert!(total > 0, "network must produce spikes");
212 }
213
214 #[test]
215 fn brunel_empty_network() {
216 let mut net = BrunelNetwork::new(
217 0,
218 vec![0],
219 vec![],
220 vec![],
221 16,
222 8,
223 0,
224 0,
225 256,
226 2,
227 1,
228 256,
229 0.0,
230 0,
231 42,
232 )
233 .unwrap();
234 let counts = net.run(10);
235 assert!(counts.iter().all(|&c| c == 0));
236 }
237
238 #[test]
239 fn brunel_csr_validation() {
240 let result = BrunelNetwork::new(
241 4,
242 vec![0, 1],
243 vec![0],
244 vec![10],
245 16,
246 8,
247 0,
248 0,
249 256,
250 2,
251 1,
252 256,
253 0.0,
254 0,
255 42,
256 );
257 assert!(result.is_err());
258 }
259}