1use numpy::{IntoPyArray, PyReadonlyArray1};
12use pyo3::exceptions::PyValueError;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15
16const IZH_SPIKE_THRESHOLD: f64 = 30.0;
17
18#[derive(Clone, Debug)]
19pub struct IzhikevichRk4 {
20 pub v: f64,
21 pub u: f64,
22 pub a: f64,
23 pub b: f64,
24 pub c: f64,
25 pub d: f64,
26 pub dt: f64,
27}
28
29impl IzhikevichRk4 {
30 pub fn new(dt: f64) -> Self {
31 let c = -65.0;
32 let b = 0.2;
33 Self {
34 v: c,
35 u: b * c,
36 a: 0.02,
37 b,
38 c,
39 d: 8.0,
40 dt,
41 }
42 }
43
44 fn rhs(&self, v: f64, u: f64, current: f64) -> (f64, f64) {
45 let dv = 0.04 * v.powi(2) + 5.0 * v + 140.0 - u + current;
46 let du = self.a * (self.b * v - u);
47 (dv, du)
48 }
49
50 pub fn step(&mut self, current: f64) -> i32 {
51 let (k1_v, k1_u) = self.rhs(self.v, self.u, current);
52 let (k2_v, k2_u) = self.rhs(
53 self.v + 0.5 * self.dt * k1_v,
54 self.u + 0.5 * self.dt * k1_u,
55 current,
56 );
57 let (k3_v, k3_u) = self.rhs(
58 self.v + 0.5 * self.dt * k2_v,
59 self.u + 0.5 * self.dt * k2_u,
60 current,
61 );
62 let (k4_v, k4_u) = self.rhs(self.v + self.dt * k3_v, self.u + self.dt * k3_u, current);
63
64 self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
65 self.u += (self.dt / 6.0) * (k1_u + 2.0 * k2_u + 2.0 * k3_u + k4_u);
66
67 if self.v >= IZH_SPIKE_THRESHOLD {
68 self.v = self.c;
69 self.u += self.d;
70 1
71 } else {
72 0
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
78pub struct AdExRk4 {
79 pub v: f64,
80 pub w: f64,
81 pub v_rest: f64,
82 pub v_reset: f64,
83 pub v_threshold: f64,
84 pub v_rh: f64,
85 pub delta_t: f64,
86 pub tau: f64,
87 pub tau_w: f64,
88 pub a: f64,
89 pub b: f64,
90 pub c_m: f64,
91 pub dt: f64,
92}
93
94impl AdExRk4 {
95 pub fn new(dt: f64) -> Self {
96 Self {
97 v: -65.0,
98 w: 0.0,
99 v_rest: -65.0,
100 v_reset: -68.0,
101 v_threshold: -50.0,
102 v_rh: -55.0,
103 delta_t: 2.0,
104 tau: 20.0,
105 tau_w: 100.0,
106 a: 0.5,
107 b: 7.0,
108 c_m: 200.0,
109 dt,
110 }
111 }
112
113 fn rhs(&self, v: f64, w: f64, current: f64) -> (f64, f64) {
114 let exp_arg = ((v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
115 let exp_term = self.delta_t * exp_arg.exp();
116 let dv = (-(v - self.v_rest) + exp_term) / self.tau + (-w + current) / self.c_m;
117 let dw = (self.a * (v - self.v_rest) - w) / self.tau_w;
118 (dv, dw)
119 }
120
121 pub fn step(&mut self, current: f64) -> i32 {
122 let (k1_v, k1_w) = self.rhs(self.v, self.w, current);
123 let (k2_v, k2_w) = self.rhs(
124 self.v + 0.5 * self.dt * k1_v,
125 self.w + 0.5 * self.dt * k1_w,
126 current,
127 );
128 let (k3_v, k3_w) = self.rhs(
129 self.v + 0.5 * self.dt * k2_v,
130 self.w + 0.5 * self.dt * k2_w,
131 current,
132 );
133 let (k4_v, k4_w) = self.rhs(self.v + self.dt * k3_v, self.w + self.dt * k3_w, current);
134
135 self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
136 self.w += (self.dt / 6.0) * (k1_w + 2.0 * k2_w + 2.0 * k3_w + k4_w);
137
138 if self.v >= self.v_threshold {
139 self.v = self.v_reset;
140 self.w += self.b;
141 1
142 } else {
143 0
144 }
145 }
146}
147
148#[derive(Clone, Debug)]
149pub struct HodgkinHuxleyRk4 {
150 pub v: f64,
151 pub m: f64,
152 pub h: f64,
153 pub n: f64,
154 pub c_m: f64,
155 pub g_na: f64,
156 pub g_k: f64,
157 pub g_l: f64,
158 pub e_na: f64,
159 pub e_k: f64,
160 pub e_l: f64,
161 pub dt: f64,
162 pub v_threshold: f64,
163}
164
165impl HodgkinHuxleyRk4 {
166 pub fn new(dt: f64) -> Self {
167 Self {
168 v: -65.0,
169 m: 0.05,
170 h: 0.6,
171 n: 0.32,
172 c_m: 1.0,
173 g_na: 120.0,
174 g_k: 36.0,
175 g_l: 0.3,
176 e_na: 50.0,
177 e_k: -77.0,
178 e_l: -54.4,
179 dt,
180 v_threshold: 0.0,
181 }
182 }
183
184 fn alpha_m(v: f64) -> f64 {
185 let d = v + 40.0;
186 if d.abs() < 1e-7 {
187 1.0
188 } else {
189 0.1 * d / (1.0 - (-d / 10.0).exp())
190 }
191 }
192
193 fn beta_m(v: f64) -> f64 {
194 4.0 * (-(v + 65.0) / 18.0).exp()
195 }
196
197 fn alpha_h(v: f64) -> f64 {
198 0.07 * (-(v + 65.0) / 20.0).exp()
199 }
200
201 fn beta_h(v: f64) -> f64 {
202 1.0 / (1.0 + (-(v + 35.0) / 10.0).exp())
203 }
204
205 fn alpha_n(v: f64) -> f64 {
206 let d = v + 55.0;
207 if d.abs() < 1e-7 {
208 0.1
209 } else {
210 0.01 * d / (1.0 - (-d / 10.0).exp())
211 }
212 }
213
214 fn beta_n(v: f64) -> f64 {
215 0.125 * (-(v + 65.0) / 80.0).exp()
216 }
217
218 fn rhs(&self, state: [f64; 4], current: f64) -> [f64; 4] {
219 let [v, m, h, n] = state;
220 let am = Self::alpha_m(v);
221 let bm = Self::beta_m(v);
222 let ah = Self::alpha_h(v);
223 let bh = Self::beta_h(v);
224 let an = Self::alpha_n(v);
225 let bn = Self::beta_n(v);
226
227 let dm = am * (1.0 - m) - bm * m;
228 let dh = ah * (1.0 - h) - bh * h;
229 let dn = an * (1.0 - n) - bn * n;
230 let i_na = self.g_na * m.powi(3) * h * (v - self.e_na);
231 let i_k = self.g_k * n.powi(4) * (v - self.e_k);
232 let i_l = self.g_l * (v - self.e_l);
233 let dv = (-i_na - i_k - i_l + current) / self.c_m;
234 [dv, dm, dh, dn]
235 }
236
237 pub fn step(&mut self, current: f64) -> i32 {
238 let v_prev = self.v;
239 let mut state = [self.v, self.m, self.h, self.n];
240 let substeps = (1.0 / self.dt).round() as usize;
241 for _ in 0..substeps {
242 let k1 = self.rhs(state, current);
243 let k2 = self.rhs(add_scaled(state, k1, 0.5 * self.dt), current);
244 let k3 = self.rhs(add_scaled(state, k2, 0.5 * self.dt), current);
245 let k4 = self.rhs(add_scaled(state, k3, self.dt), current);
246 for idx in 0..4 {
247 state[idx] += (self.dt / 6.0) * (k1[idx] + 2.0 * k2[idx] + 2.0 * k3[idx] + k4[idx]);
248 }
249 }
250 self.v = state[0];
251 self.m = state[1];
252 self.h = state[2];
253 self.n = state[3];
254
255 if self.v >= self.v_threshold && v_prev < self.v_threshold {
256 1
257 } else {
258 0
259 }
260 }
261}
262
263fn add_scaled(state: [f64; 4], deriv: [f64; 4], scale: f64) -> [f64; 4] {
264 [
265 state[0] + scale * deriv[0],
266 state[1] + scale * deriv[1],
267 state[2] + scale * deriv[2],
268 state[3] + scale * deriv[3],
269 ]
270}
271
272#[pyfunction]
273#[pyo3(signature = (model_name, current_trace, dt=None))]
274pub fn py_rk4_neuron_simulate<'py>(
275 py: Python<'py>,
276 model_name: &str,
277 current_trace: PyReadonlyArray1<'py, f64>,
278 dt: Option<f64>,
279) -> PyResult<Py<PyAny>> {
280 let currents = current_trace.as_slice()?;
281 match normalise_model_name(model_name).as_str() {
282 "izhikevich" | "scizhikevichneuron" | "izhikevichneuron" => {
283 let dt = validate_trace_dt(currents, dt.unwrap_or(1.0))?;
284 simulate_izhikevich(py, currents, dt)
285 }
286 "hodgkinhuxley" | "hodgkinhuxleyneuron" => {
287 let dt = validate_trace_dt(currents, dt.unwrap_or(0.01))?;
288 simulate_hodgkin_huxley(py, currents, dt)
289 }
290 "adex" | "adexneuron" => {
291 let dt = validate_trace_dt(currents, dt.unwrap_or(0.1))?;
292 simulate_adex(py, currents, dt)
293 }
294 _ => Err(PyValueError::new_err(format!(
295 "unsupported RK4 neuron model {model_name:?}"
296 ))),
297 }
298}
299
300fn validate_trace_dt(currents: &[f64], dt: f64) -> PyResult<f64> {
301 if !dt.is_finite() || dt <= 0.0 {
302 return Err(PyValueError::new_err("dt must be a positive finite scalar"));
303 }
304 if currents.is_empty() {
305 return Err(PyValueError::new_err("current_trace must be non-empty"));
306 }
307 if currents.iter().any(|current| !current.is_finite()) {
308 return Err(PyValueError::new_err(
309 "current_trace must contain only finite values",
310 ));
311 }
312 Ok(dt)
313}
314
315fn normalise_model_name(name: &str) -> String {
316 name.chars()
317 .filter(|ch| ch.is_ascii_alphanumeric())
318 .flat_map(char::to_lowercase)
319 .collect()
320}
321
322fn simulate_izhikevich<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
323 let mut neuron = IzhikevichRk4::new(dt);
324 let mut v = Vec::with_capacity(currents.len());
325 let mut u = Vec::with_capacity(currents.len());
326 let mut spikes = Vec::new();
327 for (idx, ¤t) in currents.iter().enumerate() {
328 if neuron.step(current) != 0 {
329 spikes.push(idx as u64);
330 }
331 v.push(neuron.v);
332 u.push(neuron.u);
333 }
334 let d = PyDict::new(py);
335 d.set_item("v", v.into_pyarray(py))?;
336 d.set_item("u", u.into_pyarray(py))?;
337 d.set_item("spikes", spikes.into_pyarray(py))?;
338 d.set_item("n_steps", currents.len())?;
339 Ok(d.into_any().unbind())
340}
341
342fn simulate_adex<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
343 let mut neuron = AdExRk4::new(dt);
344 let mut v = Vec::with_capacity(currents.len());
345 let mut w = Vec::with_capacity(currents.len());
346 let mut spikes = Vec::new();
347 for (idx, ¤t) in currents.iter().enumerate() {
348 if neuron.step(current) != 0 {
349 spikes.push(idx as u64);
350 }
351 v.push(neuron.v);
352 w.push(neuron.w);
353 }
354 let d = PyDict::new(py);
355 d.set_item("v", v.into_pyarray(py))?;
356 d.set_item("w", w.into_pyarray(py))?;
357 d.set_item("spikes", spikes.into_pyarray(py))?;
358 d.set_item("n_steps", currents.len())?;
359 Ok(d.into_any().unbind())
360}
361
362fn simulate_hodgkin_huxley<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
363 let mut neuron = HodgkinHuxleyRk4::new(dt);
364 let mut v = Vec::with_capacity(currents.len());
365 let mut m = Vec::with_capacity(currents.len());
366 let mut h = Vec::with_capacity(currents.len());
367 let mut n = Vec::with_capacity(currents.len());
368 let mut spikes = Vec::new();
369 for (idx, ¤t) in currents.iter().enumerate() {
370 if neuron.step(current) != 0 {
371 spikes.push(idx as u64);
372 }
373 v.push(neuron.v);
374 m.push(neuron.m);
375 h.push(neuron.h);
376 n.push(neuron.n);
377 }
378 let d = PyDict::new(py);
379 d.set_item("v", v.into_pyarray(py))?;
380 d.set_item("m", m.into_pyarray(py))?;
381 d.set_item("h", h.into_pyarray(py))?;
382 d.set_item("n", n.into_pyarray(py))?;
383 d.set_item("spikes", spikes.into_pyarray(py))?;
384 d.set_item("n_steps", currents.len())?;
385 Ok(d.into_any().unbind())
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn izhikevich_rk4_is_deterministic_and_spikes() {
394 let mut a = IzhikevichRk4::new(1.0);
395 let mut b = IzhikevichRk4::new(1.0);
396 let mut spikes = 0;
397 for _ in 0..100 {
398 spikes += a.step(10.0);
399 b.step(10.0);
400 }
401 assert!(spikes > 0);
402 assert_eq!(a.v, b.v);
403 assert_eq!(a.u, b.u);
404 }
405
406 #[test]
407 fn adex_rk4_remains_finite_under_sustained_current() {
408 let mut neuron = AdExRk4::new(0.1);
409 let mut spikes = 0;
410 for _ in 0..3000 {
411 spikes += neuron.step(500.0);
412 }
413 assert!(spikes > 0);
414 assert!(neuron.v.is_finite());
415 assert!(neuron.w.is_finite());
416 }
417
418 #[test]
419 fn hodgkin_huxley_rk4_keeps_gates_bounded() {
420 let mut neuron = HodgkinHuxleyRk4::new(0.01);
421 let mut spikes = 0;
422 for _ in 0..1000 {
423 spikes += neuron.step(10.0);
424 }
425 assert!(spikes > 0);
426 assert!(neuron.v.is_finite());
427 assert!((0.0..=1.0).contains(&neuron.m));
428 assert!((0.0..=1.0).contains(&neuron.h));
429 assert!((0.0..=1.0).contains(&neuron.n));
430 }
431
432 #[test]
433 fn model_name_normalisation_accepts_common_aliases() {
434 assert_eq!(
435 normalise_model_name("Hodgkin-HuxleyNeuron"),
436 "hodgkinhuxleyneuron"
437 );
438 assert_eq!(normalise_model_name("AdEx"), "adex");
439 }
440}