Skip to main content

sc_neurocore_engine/
rk4_neurons.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — RK4 neuron integrator ports
8
9//! Explicit RK4 ports for the priority neuron integrator paths.
10
11use numpy::{IntoPyArray, PyReadonlyArray1};
12use pyo3::exceptions::PyValueError;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15
16const IZH_SPIKE_THRESHOLD: f64 = 30.0;
17
18#[derive(Clone, Debug)]
19pub struct IzhikevichRk4 {
20    pub v: f64,
21    pub u: f64,
22    pub a: f64,
23    pub b: f64,
24    pub c: f64,
25    pub d: f64,
26    pub dt: f64,
27}
28
29impl IzhikevichRk4 {
30    pub fn new(dt: f64) -> Self {
31        let c = -65.0;
32        let b = 0.2;
33        Self {
34            v: c,
35            u: b * c,
36            a: 0.02,
37            b,
38            c,
39            d: 8.0,
40            dt,
41        }
42    }
43
44    fn rhs(&self, v: f64, u: f64, current: f64) -> (f64, f64) {
45        let dv = 0.04 * v.powi(2) + 5.0 * v + 140.0 - u + current;
46        let du = self.a * (self.b * v - u);
47        (dv, du)
48    }
49
50    pub fn step(&mut self, current: f64) -> i32 {
51        let (k1_v, k1_u) = self.rhs(self.v, self.u, current);
52        let (k2_v, k2_u) = self.rhs(
53            self.v + 0.5 * self.dt * k1_v,
54            self.u + 0.5 * self.dt * k1_u,
55            current,
56        );
57        let (k3_v, k3_u) = self.rhs(
58            self.v + 0.5 * self.dt * k2_v,
59            self.u + 0.5 * self.dt * k2_u,
60            current,
61        );
62        let (k4_v, k4_u) = self.rhs(self.v + self.dt * k3_v, self.u + self.dt * k3_u, current);
63
64        self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
65        self.u += (self.dt / 6.0) * (k1_u + 2.0 * k2_u + 2.0 * k3_u + k4_u);
66
67        if self.v >= IZH_SPIKE_THRESHOLD {
68            self.v = self.c;
69            self.u += self.d;
70            1
71        } else {
72            0
73        }
74    }
75}
76
77/// Izhikevich 2007 biophysical parameterisation (NeuroML `izhikevich2007Cell`):
78/// `C dv/dt = k (v - vr)(v - vt) - u + I`, `du/dt = a (b (v - vr) - u)`, with a
79/// `v >= vpeak -> v = c, u += d` reset. RK4 over the coupled ODE. The right-hand
80/// side is exact arithmetic (products, a sum, a division — no transcendental
81/// functions), so `simulate` matches the Python reference bit-for-bit.
82#[derive(Clone, Debug)]
83pub struct Izhikevich2007Rk4 {
84    pub v: f64,
85    pub u: f64,
86    pub cap: f64,
87    pub k: f64,
88    pub vr: f64,
89    pub vt: f64,
90    pub vpeak: f64,
91    pub a: f64,
92    pub b: f64,
93    pub c: f64,
94    pub d: f64,
95    pub dt: f64,
96}
97
98impl Izhikevich2007Rk4 {
99    fn rhs(&self, v: f64, u: f64, current: f64) -> (f64, f64) {
100        let dv = (self.k * (v - self.vr) * (v - self.vt) - u + current) / self.cap;
101        let du = self.a * (self.b * (v - self.vr) - u);
102        (dv, du)
103    }
104
105    pub fn step(&mut self, current: f64) -> i32 {
106        let (k1v, k1u) = self.rhs(self.v, self.u, current);
107        let (k2v, k2u) = self.rhs(
108            self.v + 0.5 * self.dt * k1v,
109            self.u + 0.5 * self.dt * k1u,
110            current,
111        );
112        let (k3v, k3u) = self.rhs(
113            self.v + 0.5 * self.dt * k2v,
114            self.u + 0.5 * self.dt * k2u,
115            current,
116        );
117        let (k4v, k4u) = self.rhs(self.v + self.dt * k3v, self.u + self.dt * k3u, current);
118        let dt6 = self.dt / 6.0;
119        self.v += dt6 * (k1v + 2.0 * k2v + 2.0 * k3v + k4v);
120        self.u += dt6 * (k1u + 2.0 * k2u + 2.0 * k3u + k4u);
121        if self.v >= self.vpeak {
122            self.v = self.c;
123            self.u += self.d;
124            1
125        } else {
126            0
127        }
128    }
129
130    /// Run `n_steps` RK4 updates under a constant input, returning the `v` trace
131    /// (already reset to `c` on spiking steps) and the spike count. Reuses
132    /// `step`, so the trace is bit-identical to the per-step path and to the
133    /// Python reference (the right-hand side is exact arithmetic). The final
134    /// state is left in `self.v` / `self.u`.
135    pub fn simulate(&mut self, n_steps: usize, current: f64) -> (Vec<f64>, i64) {
136        let mut trace = Vec::with_capacity(n_steps);
137        let mut spikes: i64 = 0;
138        for _ in 0..n_steps {
139            let spiked = self.step(current);
140            trace.push(self.v);
141            if spiked == 1 {
142                spikes += 1;
143            }
144        }
145        (trace, spikes)
146    }
147}
148
149#[derive(Clone, Debug)]
150pub struct AdExRk4 {
151    pub v: f64,
152    pub w: f64,
153    pub v_rest: f64,
154    pub v_reset: f64,
155    pub v_threshold: f64,
156    pub v_rh: f64,
157    pub delta_t: f64,
158    pub tau: f64,
159    pub tau_w: f64,
160    pub a: f64,
161    pub b: f64,
162    pub c_m: f64,
163    pub dt: f64,
164}
165
166impl AdExRk4 {
167    pub fn new(dt: f64) -> Self {
168        Self {
169            v: -65.0,
170            w: 0.0,
171            v_rest: -65.0,
172            v_reset: -68.0,
173            v_threshold: -50.0,
174            v_rh: -55.0,
175            delta_t: 2.0,
176            tau: 20.0,
177            tau_w: 100.0,
178            a: 0.5,
179            b: 7.0,
180            c_m: 200.0,
181            dt,
182        }
183    }
184
185    fn rhs(&self, v: f64, w: f64, current: f64) -> (f64, f64) {
186        let exp_arg = ((v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
187        let exp_term = self.delta_t * exp_arg.exp();
188        let dv = (-(v - self.v_rest) + exp_term) / self.tau + (-w + current) / self.c_m;
189        let dw = (self.a * (v - self.v_rest) - w) / self.tau_w;
190        (dv, dw)
191    }
192
193    pub fn step(&mut self, current: f64) -> i32 {
194        let (k1_v, k1_w) = self.rhs(self.v, self.w, current);
195        let (k2_v, k2_w) = self.rhs(
196            self.v + 0.5 * self.dt * k1_v,
197            self.w + 0.5 * self.dt * k1_w,
198            current,
199        );
200        let (k3_v, k3_w) = self.rhs(
201            self.v + 0.5 * self.dt * k2_v,
202            self.w + 0.5 * self.dt * k2_w,
203            current,
204        );
205        let (k4_v, k4_w) = self.rhs(self.v + self.dt * k3_v, self.w + self.dt * k3_w, current);
206
207        self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
208        self.w += (self.dt / 6.0) * (k1_w + 2.0 * k2_w + 2.0 * k3_w + k4_w);
209
210        if self.v >= self.v_threshold {
211            self.v = self.v_reset;
212            self.w += self.b;
213            1
214        } else {
215            0
216        }
217    }
218}
219
220#[derive(Clone, Debug)]
221pub struct HodgkinHuxleyRk4 {
222    pub v: f64,
223    pub m: f64,
224    pub h: f64,
225    pub n: f64,
226    pub c_m: f64,
227    pub g_na: f64,
228    pub g_k: f64,
229    pub g_l: f64,
230    pub e_na: f64,
231    pub e_k: f64,
232    pub e_l: f64,
233    pub dt: f64,
234    pub v_threshold: f64,
235}
236
237impl HodgkinHuxleyRk4 {
238    pub fn new(dt: f64) -> Self {
239        Self {
240            v: -65.0,
241            m: 0.05,
242            h: 0.6,
243            n: 0.32,
244            c_m: 1.0,
245            g_na: 120.0,
246            g_k: 36.0,
247            g_l: 0.3,
248            e_na: 50.0,
249            e_k: -77.0,
250            e_l: -54.4,
251            dt,
252            v_threshold: 0.0,
253        }
254    }
255
256    fn alpha_m(v: f64) -> f64 {
257        let d = v + 40.0;
258        if d.abs() < 1e-7 {
259            1.0
260        } else {
261            0.1 * d / (1.0 - (-d / 10.0).exp())
262        }
263    }
264
265    fn beta_m(v: f64) -> f64 {
266        4.0 * (-(v + 65.0) / 18.0).exp()
267    }
268
269    fn alpha_h(v: f64) -> f64 {
270        0.07 * (-(v + 65.0) / 20.0).exp()
271    }
272
273    fn beta_h(v: f64) -> f64 {
274        1.0 / (1.0 + (-(v + 35.0) / 10.0).exp())
275    }
276
277    fn alpha_n(v: f64) -> f64 {
278        let d = v + 55.0;
279        if d.abs() < 1e-7 {
280            0.1
281        } else {
282            0.01 * d / (1.0 - (-d / 10.0).exp())
283        }
284    }
285
286    fn beta_n(v: f64) -> f64 {
287        0.125 * (-(v + 65.0) / 80.0).exp()
288    }
289
290    fn rhs(&self, state: [f64; 4], current: f64) -> [f64; 4] {
291        let [v, m, h, n] = state;
292        let am = Self::alpha_m(v);
293        let bm = Self::beta_m(v);
294        let ah = Self::alpha_h(v);
295        let bh = Self::beta_h(v);
296        let an = Self::alpha_n(v);
297        let bn = Self::beta_n(v);
298
299        let dm = am * (1.0 - m) - bm * m;
300        let dh = ah * (1.0 - h) - bh * h;
301        let dn = an * (1.0 - n) - bn * n;
302        let i_na = self.g_na * m.powi(3) * h * (v - self.e_na);
303        let i_k = self.g_k * n.powi(4) * (v - self.e_k);
304        let i_l = self.g_l * (v - self.e_l);
305        let dv = (-i_na - i_k - i_l + current) / self.c_m;
306        [dv, dm, dh, dn]
307    }
308
309    pub fn step(&mut self, current: f64) -> i32 {
310        let v_prev = self.v;
311        let mut state = [self.v, self.m, self.h, self.n];
312        let substeps = (1.0 / self.dt).round() as usize;
313        for _ in 0..substeps {
314            let k1 = self.rhs(state, current);
315            let k2 = self.rhs(add_scaled(state, k1, 0.5 * self.dt), current);
316            let k3 = self.rhs(add_scaled(state, k2, 0.5 * self.dt), current);
317            let k4 = self.rhs(add_scaled(state, k3, self.dt), current);
318            for idx in 0..4 {
319                state[idx] += (self.dt / 6.0) * (k1[idx] + 2.0 * k2[idx] + 2.0 * k3[idx] + k4[idx]);
320            }
321        }
322        self.v = state[0];
323        self.m = state[1];
324        self.h = state[2];
325        self.n = state[3];
326
327        if self.v >= self.v_threshold && v_prev < self.v_threshold {
328            1
329        } else {
330            0
331        }
332    }
333}
334
335fn add_scaled(state: [f64; 4], deriv: [f64; 4], scale: f64) -> [f64; 4] {
336    [
337        state[0] + scale * deriv[0],
338        state[1] + scale * deriv[1],
339        state[2] + scale * deriv[2],
340        state[3] + scale * deriv[3],
341    ]
342}
343
344#[pyfunction]
345#[pyo3(signature = (model_name, current_trace, dt=None))]
346pub fn py_rk4_neuron_simulate<'py>(
347    py: Python<'py>,
348    model_name: &str,
349    current_trace: PyReadonlyArray1<'py, f64>,
350    dt: Option<f64>,
351) -> PyResult<Py<PyAny>> {
352    let currents = current_trace.as_slice()?;
353    match normalise_model_name(model_name).as_str() {
354        "izhikevich" | "scizhikevichneuron" | "izhikevichneuron" => {
355            let dt = validate_trace_dt(currents, dt.unwrap_or(1.0))?;
356            simulate_izhikevich(py, currents, dt)
357        }
358        "hodgkinhuxley" | "hodgkinhuxleyneuron" => {
359            let dt = validate_trace_dt(currents, dt.unwrap_or(0.01))?;
360            simulate_hodgkin_huxley(py, currents, dt)
361        }
362        "adex" | "adexneuron" => {
363            let dt = validate_trace_dt(currents, dt.unwrap_or(0.1))?;
364            simulate_adex(py, currents, dt)
365        }
366        _ => Err(PyValueError::new_err(format!(
367            "unsupported RK4 neuron model {model_name:?}"
368        ))),
369    }
370}
371
372fn validate_trace_dt(currents: &[f64], dt: f64) -> PyResult<f64> {
373    if !dt.is_finite() || dt <= 0.0 {
374        return Err(PyValueError::new_err("dt must be a positive finite scalar"));
375    }
376    if currents.is_empty() {
377        return Err(PyValueError::new_err("current_trace must be non-empty"));
378    }
379    if currents.iter().any(|current| !current.is_finite()) {
380        return Err(PyValueError::new_err(
381            "current_trace must contain only finite values",
382        ));
383    }
384    Ok(dt)
385}
386
387fn normalise_model_name(name: &str) -> String {
388    name.chars()
389        .filter(|ch| ch.is_ascii_alphanumeric())
390        .flat_map(char::to_lowercase)
391        .collect()
392}
393
394fn simulate_izhikevich<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
395    let mut neuron = IzhikevichRk4::new(dt);
396    let mut v = Vec::with_capacity(currents.len());
397    let mut u = Vec::with_capacity(currents.len());
398    let mut spikes = Vec::new();
399    for (idx, &current) in currents.iter().enumerate() {
400        if neuron.step(current) != 0 {
401            spikes.push(idx as u64);
402        }
403        v.push(neuron.v);
404        u.push(neuron.u);
405    }
406    let d = PyDict::new(py);
407    d.set_item("v", v.into_pyarray(py))?;
408    d.set_item("u", u.into_pyarray(py))?;
409    d.set_item("spikes", spikes.into_pyarray(py))?;
410    d.set_item("n_steps", currents.len())?;
411    Ok(d.into_any().unbind())
412}
413
414fn simulate_adex<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
415    let mut neuron = AdExRk4::new(dt);
416    let mut v = Vec::with_capacity(currents.len());
417    let mut w = Vec::with_capacity(currents.len());
418    let mut spikes = Vec::new();
419    for (idx, &current) in currents.iter().enumerate() {
420        if neuron.step(current) != 0 {
421            spikes.push(idx as u64);
422        }
423        v.push(neuron.v);
424        w.push(neuron.w);
425    }
426    let d = PyDict::new(py);
427    d.set_item("v", v.into_pyarray(py))?;
428    d.set_item("w", w.into_pyarray(py))?;
429    d.set_item("spikes", spikes.into_pyarray(py))?;
430    d.set_item("n_steps", currents.len())?;
431    Ok(d.into_any().unbind())
432}
433
434fn simulate_hodgkin_huxley<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
435    let mut neuron = HodgkinHuxleyRk4::new(dt);
436    let mut v = Vec::with_capacity(currents.len());
437    let mut m = Vec::with_capacity(currents.len());
438    let mut h = Vec::with_capacity(currents.len());
439    let mut n = Vec::with_capacity(currents.len());
440    let mut spikes = Vec::new();
441    for (idx, &current) in currents.iter().enumerate() {
442        if neuron.step(current) != 0 {
443            spikes.push(idx as u64);
444        }
445        v.push(neuron.v);
446        m.push(neuron.m);
447        h.push(neuron.h);
448        n.push(neuron.n);
449    }
450    let d = PyDict::new(py);
451    d.set_item("v", v.into_pyarray(py))?;
452    d.set_item("m", m.into_pyarray(py))?;
453    d.set_item("h", h.into_pyarray(py))?;
454    d.set_item("n", n.into_pyarray(py))?;
455    d.set_item("spikes", spikes.into_pyarray(py))?;
456    d.set_item("n_steps", currents.len())?;
457    Ok(d.into_any().unbind())
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn izhikevich_rk4_is_deterministic_and_spikes() {
466        let mut a = IzhikevichRk4::new(1.0);
467        let mut b = IzhikevichRk4::new(1.0);
468        let mut spikes = 0;
469        for _ in 0..100 {
470            spikes += a.step(10.0);
471            b.step(10.0);
472        }
473        assert!(spikes > 0);
474        assert_eq!(a.v, b.v);
475        assert_eq!(a.u, b.u);
476    }
477
478    #[test]
479    fn adex_rk4_remains_finite_under_sustained_current() {
480        let mut neuron = AdExRk4::new(0.1);
481        let mut spikes = 0;
482        for _ in 0..3000 {
483            spikes += neuron.step(500.0);
484        }
485        assert!(spikes > 0);
486        assert!(neuron.v.is_finite());
487        assert!(neuron.w.is_finite());
488    }
489
490    #[test]
491    fn hodgkin_huxley_rk4_keeps_gates_bounded() {
492        let mut neuron = HodgkinHuxleyRk4::new(0.01);
493        let mut spikes = 0;
494        for _ in 0..1000 {
495            spikes += neuron.step(10.0);
496        }
497        assert!(spikes > 0);
498        assert!(neuron.v.is_finite());
499        assert!((0.0..=1.0).contains(&neuron.m));
500        assert!((0.0..=1.0).contains(&neuron.h));
501        assert!((0.0..=1.0).contains(&neuron.n));
502    }
503
504    #[test]
505    fn model_name_normalisation_accepts_common_aliases() {
506        assert_eq!(
507            normalise_model_name("Hodgkin-HuxleyNeuron"),
508            "hodgkinhuxleyneuron"
509        );
510        assert_eq!(normalise_model_name("AdEx"), "adex");
511    }
512}