1use numpy::{IntoPyArray, PyReadonlyArray1};
12use pyo3::exceptions::PyValueError;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15
16const IZH_SPIKE_THRESHOLD: f64 = 30.0;
17
18#[derive(Clone, Debug)]
19pub struct IzhikevichRk4 {
20 pub v: f64,
21 pub u: f64,
22 pub a: f64,
23 pub b: f64,
24 pub c: f64,
25 pub d: f64,
26 pub dt: f64,
27}
28
29impl IzhikevichRk4 {
30 pub fn new(dt: f64) -> Self {
31 let c = -65.0;
32 let b = 0.2;
33 Self {
34 v: c,
35 u: b * c,
36 a: 0.02,
37 b,
38 c,
39 d: 8.0,
40 dt,
41 }
42 }
43
44 fn rhs(&self, v: f64, u: f64, current: f64) -> (f64, f64) {
45 let dv = 0.04 * v.powi(2) + 5.0 * v + 140.0 - u + current;
46 let du = self.a * (self.b * v - u);
47 (dv, du)
48 }
49
50 pub fn step(&mut self, current: f64) -> i32 {
51 let (k1_v, k1_u) = self.rhs(self.v, self.u, current);
52 let (k2_v, k2_u) = self.rhs(
53 self.v + 0.5 * self.dt * k1_v,
54 self.u + 0.5 * self.dt * k1_u,
55 current,
56 );
57 let (k3_v, k3_u) = self.rhs(
58 self.v + 0.5 * self.dt * k2_v,
59 self.u + 0.5 * self.dt * k2_u,
60 current,
61 );
62 let (k4_v, k4_u) = self.rhs(self.v + self.dt * k3_v, self.u + self.dt * k3_u, current);
63
64 self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
65 self.u += (self.dt / 6.0) * (k1_u + 2.0 * k2_u + 2.0 * k3_u + k4_u);
66
67 if self.v >= IZH_SPIKE_THRESHOLD {
68 self.v = self.c;
69 self.u += self.d;
70 1
71 } else {
72 0
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
83pub struct Izhikevich2007Rk4 {
84 pub v: f64,
85 pub u: f64,
86 pub cap: f64,
87 pub k: f64,
88 pub vr: f64,
89 pub vt: f64,
90 pub vpeak: f64,
91 pub a: f64,
92 pub b: f64,
93 pub c: f64,
94 pub d: f64,
95 pub dt: f64,
96}
97
98impl Izhikevich2007Rk4 {
99 fn rhs(&self, v: f64, u: f64, current: f64) -> (f64, f64) {
100 let dv = (self.k * (v - self.vr) * (v - self.vt) - u + current) / self.cap;
101 let du = self.a * (self.b * (v - self.vr) - u);
102 (dv, du)
103 }
104
105 pub fn step(&mut self, current: f64) -> i32 {
106 let (k1v, k1u) = self.rhs(self.v, self.u, current);
107 let (k2v, k2u) = self.rhs(
108 self.v + 0.5 * self.dt * k1v,
109 self.u + 0.5 * self.dt * k1u,
110 current,
111 );
112 let (k3v, k3u) = self.rhs(
113 self.v + 0.5 * self.dt * k2v,
114 self.u + 0.5 * self.dt * k2u,
115 current,
116 );
117 let (k4v, k4u) = self.rhs(self.v + self.dt * k3v, self.u + self.dt * k3u, current);
118 let dt6 = self.dt / 6.0;
119 self.v += dt6 * (k1v + 2.0 * k2v + 2.0 * k3v + k4v);
120 self.u += dt6 * (k1u + 2.0 * k2u + 2.0 * k3u + k4u);
121 if self.v >= self.vpeak {
122 self.v = self.c;
123 self.u += self.d;
124 1
125 } else {
126 0
127 }
128 }
129
130 pub fn simulate(&mut self, n_steps: usize, current: f64) -> (Vec<f64>, i64) {
136 let mut trace = Vec::with_capacity(n_steps);
137 let mut spikes: i64 = 0;
138 for _ in 0..n_steps {
139 let spiked = self.step(current);
140 trace.push(self.v);
141 if spiked == 1 {
142 spikes += 1;
143 }
144 }
145 (trace, spikes)
146 }
147}
148
149#[derive(Clone, Debug)]
150pub struct AdExRk4 {
151 pub v: f64,
152 pub w: f64,
153 pub v_rest: f64,
154 pub v_reset: f64,
155 pub v_threshold: f64,
156 pub v_rh: f64,
157 pub delta_t: f64,
158 pub tau: f64,
159 pub tau_w: f64,
160 pub a: f64,
161 pub b: f64,
162 pub c_m: f64,
163 pub dt: f64,
164}
165
166impl AdExRk4 {
167 pub fn new(dt: f64) -> Self {
168 Self {
169 v: -65.0,
170 w: 0.0,
171 v_rest: -65.0,
172 v_reset: -68.0,
173 v_threshold: -50.0,
174 v_rh: -55.0,
175 delta_t: 2.0,
176 tau: 20.0,
177 tau_w: 100.0,
178 a: 0.5,
179 b: 7.0,
180 c_m: 200.0,
181 dt,
182 }
183 }
184
185 fn rhs(&self, v: f64, w: f64, current: f64) -> (f64, f64) {
186 let exp_arg = ((v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
187 let exp_term = self.delta_t * exp_arg.exp();
188 let dv = (-(v - self.v_rest) + exp_term) / self.tau + (-w + current) / self.c_m;
189 let dw = (self.a * (v - self.v_rest) - w) / self.tau_w;
190 (dv, dw)
191 }
192
193 pub fn step(&mut self, current: f64) -> i32 {
194 let (k1_v, k1_w) = self.rhs(self.v, self.w, current);
195 let (k2_v, k2_w) = self.rhs(
196 self.v + 0.5 * self.dt * k1_v,
197 self.w + 0.5 * self.dt * k1_w,
198 current,
199 );
200 let (k3_v, k3_w) = self.rhs(
201 self.v + 0.5 * self.dt * k2_v,
202 self.w + 0.5 * self.dt * k2_w,
203 current,
204 );
205 let (k4_v, k4_w) = self.rhs(self.v + self.dt * k3_v, self.w + self.dt * k3_w, current);
206
207 self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
208 self.w += (self.dt / 6.0) * (k1_w + 2.0 * k2_w + 2.0 * k3_w + k4_w);
209
210 if self.v >= self.v_threshold {
211 self.v = self.v_reset;
212 self.w += self.b;
213 1
214 } else {
215 0
216 }
217 }
218}
219
220#[derive(Clone, Debug)]
221pub struct HodgkinHuxleyRk4 {
222 pub v: f64,
223 pub m: f64,
224 pub h: f64,
225 pub n: f64,
226 pub c_m: f64,
227 pub g_na: f64,
228 pub g_k: f64,
229 pub g_l: f64,
230 pub e_na: f64,
231 pub e_k: f64,
232 pub e_l: f64,
233 pub dt: f64,
234 pub v_threshold: f64,
235}
236
237impl HodgkinHuxleyRk4 {
238 pub fn new(dt: f64) -> Self {
239 Self {
240 v: -65.0,
241 m: 0.05,
242 h: 0.6,
243 n: 0.32,
244 c_m: 1.0,
245 g_na: 120.0,
246 g_k: 36.0,
247 g_l: 0.3,
248 e_na: 50.0,
249 e_k: -77.0,
250 e_l: -54.4,
251 dt,
252 v_threshold: 0.0,
253 }
254 }
255
256 fn alpha_m(v: f64) -> f64 {
257 let d = v + 40.0;
258 if d.abs() < 1e-7 {
259 1.0
260 } else {
261 0.1 * d / (1.0 - (-d / 10.0).exp())
262 }
263 }
264
265 fn beta_m(v: f64) -> f64 {
266 4.0 * (-(v + 65.0) / 18.0).exp()
267 }
268
269 fn alpha_h(v: f64) -> f64 {
270 0.07 * (-(v + 65.0) / 20.0).exp()
271 }
272
273 fn beta_h(v: f64) -> f64 {
274 1.0 / (1.0 + (-(v + 35.0) / 10.0).exp())
275 }
276
277 fn alpha_n(v: f64) -> f64 {
278 let d = v + 55.0;
279 if d.abs() < 1e-7 {
280 0.1
281 } else {
282 0.01 * d / (1.0 - (-d / 10.0).exp())
283 }
284 }
285
286 fn beta_n(v: f64) -> f64 {
287 0.125 * (-(v + 65.0) / 80.0).exp()
288 }
289
290 fn rhs(&self, state: [f64; 4], current: f64) -> [f64; 4] {
291 let [v, m, h, n] = state;
292 let am = Self::alpha_m(v);
293 let bm = Self::beta_m(v);
294 let ah = Self::alpha_h(v);
295 let bh = Self::beta_h(v);
296 let an = Self::alpha_n(v);
297 let bn = Self::beta_n(v);
298
299 let dm = am * (1.0 - m) - bm * m;
300 let dh = ah * (1.0 - h) - bh * h;
301 let dn = an * (1.0 - n) - bn * n;
302 let i_na = self.g_na * m.powi(3) * h * (v - self.e_na);
303 let i_k = self.g_k * n.powi(4) * (v - self.e_k);
304 let i_l = self.g_l * (v - self.e_l);
305 let dv = (-i_na - i_k - i_l + current) / self.c_m;
306 [dv, dm, dh, dn]
307 }
308
309 pub fn step(&mut self, current: f64) -> i32 {
310 let v_prev = self.v;
311 let mut state = [self.v, self.m, self.h, self.n];
312 let substeps = (1.0 / self.dt).round() as usize;
313 for _ in 0..substeps {
314 let k1 = self.rhs(state, current);
315 let k2 = self.rhs(add_scaled(state, k1, 0.5 * self.dt), current);
316 let k3 = self.rhs(add_scaled(state, k2, 0.5 * self.dt), current);
317 let k4 = self.rhs(add_scaled(state, k3, self.dt), current);
318 for idx in 0..4 {
319 state[idx] += (self.dt / 6.0) * (k1[idx] + 2.0 * k2[idx] + 2.0 * k3[idx] + k4[idx]);
320 }
321 }
322 self.v = state[0];
323 self.m = state[1];
324 self.h = state[2];
325 self.n = state[3];
326
327 if self.v >= self.v_threshold && v_prev < self.v_threshold {
328 1
329 } else {
330 0
331 }
332 }
333}
334
335fn add_scaled(state: [f64; 4], deriv: [f64; 4], scale: f64) -> [f64; 4] {
336 [
337 state[0] + scale * deriv[0],
338 state[1] + scale * deriv[1],
339 state[2] + scale * deriv[2],
340 state[3] + scale * deriv[3],
341 ]
342}
343
344#[pyfunction]
345#[pyo3(signature = (model_name, current_trace, dt=None))]
346pub fn py_rk4_neuron_simulate<'py>(
347 py: Python<'py>,
348 model_name: &str,
349 current_trace: PyReadonlyArray1<'py, f64>,
350 dt: Option<f64>,
351) -> PyResult<Py<PyAny>> {
352 let currents = current_trace.as_slice()?;
353 match normalise_model_name(model_name).as_str() {
354 "izhikevich" | "scizhikevichneuron" | "izhikevichneuron" => {
355 let dt = validate_trace_dt(currents, dt.unwrap_or(1.0))?;
356 simulate_izhikevich(py, currents, dt)
357 }
358 "hodgkinhuxley" | "hodgkinhuxleyneuron" => {
359 let dt = validate_trace_dt(currents, dt.unwrap_or(0.01))?;
360 simulate_hodgkin_huxley(py, currents, dt)
361 }
362 "adex" | "adexneuron" => {
363 let dt = validate_trace_dt(currents, dt.unwrap_or(0.1))?;
364 simulate_adex(py, currents, dt)
365 }
366 _ => Err(PyValueError::new_err(format!(
367 "unsupported RK4 neuron model {model_name:?}"
368 ))),
369 }
370}
371
372fn validate_trace_dt(currents: &[f64], dt: f64) -> PyResult<f64> {
373 if !dt.is_finite() || dt <= 0.0 {
374 return Err(PyValueError::new_err("dt must be a positive finite scalar"));
375 }
376 if currents.is_empty() {
377 return Err(PyValueError::new_err("current_trace must be non-empty"));
378 }
379 if currents.iter().any(|current| !current.is_finite()) {
380 return Err(PyValueError::new_err(
381 "current_trace must contain only finite values",
382 ));
383 }
384 Ok(dt)
385}
386
387fn normalise_model_name(name: &str) -> String {
388 name.chars()
389 .filter(|ch| ch.is_ascii_alphanumeric())
390 .flat_map(char::to_lowercase)
391 .collect()
392}
393
394fn simulate_izhikevich<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
395 let mut neuron = IzhikevichRk4::new(dt);
396 let mut v = Vec::with_capacity(currents.len());
397 let mut u = Vec::with_capacity(currents.len());
398 let mut spikes = Vec::new();
399 for (idx, ¤t) in currents.iter().enumerate() {
400 if neuron.step(current) != 0 {
401 spikes.push(idx as u64);
402 }
403 v.push(neuron.v);
404 u.push(neuron.u);
405 }
406 let d = PyDict::new(py);
407 d.set_item("v", v.into_pyarray(py))?;
408 d.set_item("u", u.into_pyarray(py))?;
409 d.set_item("spikes", spikes.into_pyarray(py))?;
410 d.set_item("n_steps", currents.len())?;
411 Ok(d.into_any().unbind())
412}
413
414fn simulate_adex<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
415 let mut neuron = AdExRk4::new(dt);
416 let mut v = Vec::with_capacity(currents.len());
417 let mut w = Vec::with_capacity(currents.len());
418 let mut spikes = Vec::new();
419 for (idx, ¤t) in currents.iter().enumerate() {
420 if neuron.step(current) != 0 {
421 spikes.push(idx as u64);
422 }
423 v.push(neuron.v);
424 w.push(neuron.w);
425 }
426 let d = PyDict::new(py);
427 d.set_item("v", v.into_pyarray(py))?;
428 d.set_item("w", w.into_pyarray(py))?;
429 d.set_item("spikes", spikes.into_pyarray(py))?;
430 d.set_item("n_steps", currents.len())?;
431 Ok(d.into_any().unbind())
432}
433
434fn simulate_hodgkin_huxley<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
435 let mut neuron = HodgkinHuxleyRk4::new(dt);
436 let mut v = Vec::with_capacity(currents.len());
437 let mut m = Vec::with_capacity(currents.len());
438 let mut h = Vec::with_capacity(currents.len());
439 let mut n = Vec::with_capacity(currents.len());
440 let mut spikes = Vec::new();
441 for (idx, ¤t) in currents.iter().enumerate() {
442 if neuron.step(current) != 0 {
443 spikes.push(idx as u64);
444 }
445 v.push(neuron.v);
446 m.push(neuron.m);
447 h.push(neuron.h);
448 n.push(neuron.n);
449 }
450 let d = PyDict::new(py);
451 d.set_item("v", v.into_pyarray(py))?;
452 d.set_item("m", m.into_pyarray(py))?;
453 d.set_item("h", h.into_pyarray(py))?;
454 d.set_item("n", n.into_pyarray(py))?;
455 d.set_item("spikes", spikes.into_pyarray(py))?;
456 d.set_item("n_steps", currents.len())?;
457 Ok(d.into_any().unbind())
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn izhikevich_rk4_is_deterministic_and_spikes() {
466 let mut a = IzhikevichRk4::new(1.0);
467 let mut b = IzhikevichRk4::new(1.0);
468 let mut spikes = 0;
469 for _ in 0..100 {
470 spikes += a.step(10.0);
471 b.step(10.0);
472 }
473 assert!(spikes > 0);
474 assert_eq!(a.v, b.v);
475 assert_eq!(a.u, b.u);
476 }
477
478 #[test]
479 fn adex_rk4_remains_finite_under_sustained_current() {
480 let mut neuron = AdExRk4::new(0.1);
481 let mut spikes = 0;
482 for _ in 0..3000 {
483 spikes += neuron.step(500.0);
484 }
485 assert!(spikes > 0);
486 assert!(neuron.v.is_finite());
487 assert!(neuron.w.is_finite());
488 }
489
490 #[test]
491 fn hodgkin_huxley_rk4_keeps_gates_bounded() {
492 let mut neuron = HodgkinHuxleyRk4::new(0.01);
493 let mut spikes = 0;
494 for _ in 0..1000 {
495 spikes += neuron.step(10.0);
496 }
497 assert!(spikes > 0);
498 assert!(neuron.v.is_finite());
499 assert!((0.0..=1.0).contains(&neuron.m));
500 assert!((0.0..=1.0).contains(&neuron.h));
501 assert!((0.0..=1.0).contains(&neuron.n));
502 }
503
504 #[test]
505 fn model_name_normalisation_accepts_common_aliases() {
506 assert_eq!(
507 normalise_model_name("Hodgkin-HuxleyNeuron"),
508 "hodgkinhuxleyneuron"
509 );
510 assert_eq!(normalise_model_name("AdEx"), "adex");
511 }
512}