1use crate::neuron::{mask, FixedPointLif};
15use rand::SeedableRng;
16use rand_xoshiro::Xoshiro256PlusPlus;
17
18pub struct BrunelNetwork {
20 neurons: Vec<FixedPointLif>,
21 prev_spikes: Vec<bool>,
22 w_row_offsets: Vec<usize>,
24 w_col_indices: Vec<usize>,
25 w_values: Vec<i16>,
26 n_neurons: usize,
27 leak_k: i16,
28 gain_k: i16,
29 ext_lambda: f64,
30 ext_weight_fp: i16,
31 rng: Xoshiro256PlusPlus,
32}
33
34impl BrunelNetwork {
35 #[allow(clippy::too_many_arguments)]
40 pub fn new(
41 n_neurons: usize,
42 w_row_offsets: Vec<usize>,
43 w_col_indices: Vec<usize>,
44 w_values: Vec<i16>,
45 data_width: u32,
46 fraction: u32,
47 v_rest: i16,
48 v_reset: i16,
49 v_threshold: i16,
50 refractory_period: i32,
51 leak_k: i16,
52 gain_k: i16,
53 ext_lambda: f64,
54 ext_weight_fp: i16,
55 seed: u64,
56 ) -> Result<Self, String> {
57 if w_row_offsets.len() != n_neurons + 1 {
58 return Err(format!(
59 "w_row_offsets length {} != n_neurons+1={}",
60 w_row_offsets.len(),
61 n_neurons + 1
62 ));
63 }
64 if w_col_indices.len() != w_values.len() {
65 return Err(format!(
66 "w_col_indices len {} != w_values len {}",
67 w_col_indices.len(),
68 w_values.len()
69 ));
70 }
71 let neurons: Vec<FixedPointLif> = (0..n_neurons)
72 .map(|_| {
73 FixedPointLif::new(
74 data_width,
75 fraction,
76 v_rest,
77 v_reset,
78 v_threshold,
79 refractory_period,
80 )
81 })
82 .collect();
83
84 Ok(Self {
85 neurons,
86 prev_spikes: vec![false; n_neurons],
87 w_row_offsets,
88 w_col_indices,
89 w_values,
90 n_neurons,
91 leak_k,
92 gain_k,
93 ext_lambda,
94 ext_weight_fp,
95 rng: Xoshiro256PlusPlus::seed_from_u64(seed),
96 })
97 }
98
99 pub fn run(&mut self, n_steps: usize) -> Vec<u32> {
101 let n = self.n_neurons;
102 let mut i_syn = vec![0i32; n];
103 let mut counts = Vec::with_capacity(n_steps);
104
105 for _ in 0..n_steps {
106 i_syn.iter_mut().for_each(|x| *x = 0);
109 for pre in 0..n {
110 if !self.prev_spikes[pre] {
111 continue;
112 }
113 let start = self.w_row_offsets[pre];
114 let end = self.w_row_offsets[pre + 1];
115 for idx in start..end {
116 let post = self.w_col_indices[idx];
117 i_syn[post] += self.w_values[idx] as i32;
118 }
119 }
120
121 let mut step_spikes = 0u32;
123 #[allow(clippy::needless_range_loop)]
124 for i in 0..n {
125 let ext_count = poisson_sample(&mut self.rng, self.ext_lambda);
126 let ext_current = (ext_count as i32) * (self.ext_weight_fp as i32);
127 let total_current = i_syn[i] + ext_current;
128 let dw = self.neurons[i].data_width;
129 let i_t = mask(total_current, dw);
130
131 let (spike, _) = self.neurons[i].step(self.leak_k, self.gain_k, i_t, 0);
132 self.prev_spikes[i] = spike > 0;
133 if spike > 0 {
134 step_spikes += 1;
135 }
136 }
137 counts.push(step_spikes);
138 }
139 counts
140 }
141
142 pub fn total_spikes(&self, counts: &[u32]) -> u64 {
143 counts.iter().map(|&c| c as u64).sum()
144 }
145}
146
147fn poisson_sample(rng: &mut Xoshiro256PlusPlus, lambda: f64) -> u32 {
149 if lambda <= 0.0 {
150 return 0;
151 }
152 use rand::RngExt;
153 let l = (-lambda).exp();
154 let mut k = 0u32;
155 let mut p = 1.0f64;
156 loop {
157 k += 1;
158 p *= rng.random::<f64>();
159 if p <= l {
160 return k - 1;
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 fn make_small_network() -> BrunelNetwork {
170 let n = 4;
172 let mut row_offsets = vec![0usize; n + 1];
173 let mut col_indices = Vec::new();
174 let mut values = Vec::new();
175 for i in 0..n {
176 for j in 0..n {
177 if i != j {
178 col_indices.push(j);
179 values.push(26i16); }
181 }
182 row_offsets[i + 1] = col_indices.len();
183 }
184
185 BrunelNetwork::new(
186 n,
187 row_offsets,
188 col_indices,
189 values,
190 16,
191 8, 0,
193 0,
194 256, 2, 1, 256, 5.0, 26, 42,
201 )
202 .unwrap()
203 }
204
205 #[test]
206 fn brunel_produces_spikes() {
207 let mut net = make_small_network();
208 let counts = net.run(100);
209 let total: u64 = net.total_spikes(&counts);
210 assert!(total > 0, "network must produce spikes");
211 }
212
213 #[test]
214 fn brunel_empty_network() {
215 let mut net = BrunelNetwork::new(
216 0,
217 vec![0],
218 vec![],
219 vec![],
220 16,
221 8,
222 0,
223 0,
224 256,
225 2,
226 1,
227 256,
228 0.0,
229 0,
230 42,
231 )
232 .unwrap();
233 let counts = net.run(10);
234 assert!(counts.iter().all(|&c| c == 0));
235 }
236
237 #[test]
238 fn brunel_csr_validation() {
239 let result = BrunelNetwork::new(
240 4,
241 vec![0, 1],
242 vec![0],
243 vec![10],
244 16,
245 8,
246 0,
247 0,
248 256,
249 2,
250 1,
251 256,
252 0.0,
253 0,
254 42,
255 );
256 assert!(result.is_err());
257 }
258}