Skip to main content

sc_neurocore_engine/
rk4_neurons.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — RK4 neuron integrator ports
8
9//! Explicit RK4 ports for the priority neuron integrator paths.
10
11use numpy::{IntoPyArray, PyReadonlyArray1};
12use pyo3::exceptions::PyValueError;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15
16const IZH_SPIKE_THRESHOLD: f64 = 30.0;
17
18#[derive(Clone, Debug)]
19pub struct IzhikevichRk4 {
20    pub v: f64,
21    pub u: f64,
22    pub a: f64,
23    pub b: f64,
24    pub c: f64,
25    pub d: f64,
26    pub dt: f64,
27}
28
29impl IzhikevichRk4 {
30    pub fn new(dt: f64) -> Self {
31        let c = -65.0;
32        let b = 0.2;
33        Self {
34            v: c,
35            u: b * c,
36            a: 0.02,
37            b,
38            c,
39            d: 8.0,
40            dt,
41        }
42    }
43
44    fn rhs(&self, v: f64, u: f64, current: f64) -> (f64, f64) {
45        let dv = 0.04 * v.powi(2) + 5.0 * v + 140.0 - u + current;
46        let du = self.a * (self.b * v - u);
47        (dv, du)
48    }
49
50    pub fn step(&mut self, current: f64) -> i32 {
51        let (k1_v, k1_u) = self.rhs(self.v, self.u, current);
52        let (k2_v, k2_u) = self.rhs(
53            self.v + 0.5 * self.dt * k1_v,
54            self.u + 0.5 * self.dt * k1_u,
55            current,
56        );
57        let (k3_v, k3_u) = self.rhs(
58            self.v + 0.5 * self.dt * k2_v,
59            self.u + 0.5 * self.dt * k2_u,
60            current,
61        );
62        let (k4_v, k4_u) = self.rhs(self.v + self.dt * k3_v, self.u + self.dt * k3_u, current);
63
64        self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
65        self.u += (self.dt / 6.0) * (k1_u + 2.0 * k2_u + 2.0 * k3_u + k4_u);
66
67        if self.v >= IZH_SPIKE_THRESHOLD {
68            self.v = self.c;
69            self.u += self.d;
70            1
71        } else {
72            0
73        }
74    }
75}
76
77#[derive(Clone, Debug)]
78pub struct AdExRk4 {
79    pub v: f64,
80    pub w: f64,
81    pub v_rest: f64,
82    pub v_reset: f64,
83    pub v_threshold: f64,
84    pub v_rh: f64,
85    pub delta_t: f64,
86    pub tau: f64,
87    pub tau_w: f64,
88    pub a: f64,
89    pub b: f64,
90    pub c_m: f64,
91    pub dt: f64,
92}
93
94impl AdExRk4 {
95    pub fn new(dt: f64) -> Self {
96        Self {
97            v: -65.0,
98            w: 0.0,
99            v_rest: -65.0,
100            v_reset: -68.0,
101            v_threshold: -50.0,
102            v_rh: -55.0,
103            delta_t: 2.0,
104            tau: 20.0,
105            tau_w: 100.0,
106            a: 0.5,
107            b: 7.0,
108            c_m: 200.0,
109            dt,
110        }
111    }
112
113    fn rhs(&self, v: f64, w: f64, current: f64) -> (f64, f64) {
114        let exp_arg = ((v - self.v_rh) / self.delta_t).clamp(-20.0, 20.0);
115        let exp_term = self.delta_t * exp_arg.exp();
116        let dv = (-(v - self.v_rest) + exp_term) / self.tau + (-w + current) / self.c_m;
117        let dw = (self.a * (v - self.v_rest) - w) / self.tau_w;
118        (dv, dw)
119    }
120
121    pub fn step(&mut self, current: f64) -> i32 {
122        let (k1_v, k1_w) = self.rhs(self.v, self.w, current);
123        let (k2_v, k2_w) = self.rhs(
124            self.v + 0.5 * self.dt * k1_v,
125            self.w + 0.5 * self.dt * k1_w,
126            current,
127        );
128        let (k3_v, k3_w) = self.rhs(
129            self.v + 0.5 * self.dt * k2_v,
130            self.w + 0.5 * self.dt * k2_w,
131            current,
132        );
133        let (k4_v, k4_w) = self.rhs(self.v + self.dt * k3_v, self.w + self.dt * k3_w, current);
134
135        self.v += (self.dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v);
136        self.w += (self.dt / 6.0) * (k1_w + 2.0 * k2_w + 2.0 * k3_w + k4_w);
137
138        if self.v >= self.v_threshold {
139            self.v = self.v_reset;
140            self.w += self.b;
141            1
142        } else {
143            0
144        }
145    }
146}
147
148#[derive(Clone, Debug)]
149pub struct HodgkinHuxleyRk4 {
150    pub v: f64,
151    pub m: f64,
152    pub h: f64,
153    pub n: f64,
154    pub c_m: f64,
155    pub g_na: f64,
156    pub g_k: f64,
157    pub g_l: f64,
158    pub e_na: f64,
159    pub e_k: f64,
160    pub e_l: f64,
161    pub dt: f64,
162    pub v_threshold: f64,
163}
164
165impl HodgkinHuxleyRk4 {
166    pub fn new(dt: f64) -> Self {
167        Self {
168            v: -65.0,
169            m: 0.05,
170            h: 0.6,
171            n: 0.32,
172            c_m: 1.0,
173            g_na: 120.0,
174            g_k: 36.0,
175            g_l: 0.3,
176            e_na: 50.0,
177            e_k: -77.0,
178            e_l: -54.4,
179            dt,
180            v_threshold: 0.0,
181        }
182    }
183
184    fn alpha_m(v: f64) -> f64 {
185        let d = v + 40.0;
186        if d.abs() < 1e-7 {
187            1.0
188        } else {
189            0.1 * d / (1.0 - (-d / 10.0).exp())
190        }
191    }
192
193    fn beta_m(v: f64) -> f64 {
194        4.0 * (-(v + 65.0) / 18.0).exp()
195    }
196
197    fn alpha_h(v: f64) -> f64 {
198        0.07 * (-(v + 65.0) / 20.0).exp()
199    }
200
201    fn beta_h(v: f64) -> f64 {
202        1.0 / (1.0 + (-(v + 35.0) / 10.0).exp())
203    }
204
205    fn alpha_n(v: f64) -> f64 {
206        let d = v + 55.0;
207        if d.abs() < 1e-7 {
208            0.1
209        } else {
210            0.01 * d / (1.0 - (-d / 10.0).exp())
211        }
212    }
213
214    fn beta_n(v: f64) -> f64 {
215        0.125 * (-(v + 65.0) / 80.0).exp()
216    }
217
218    fn rhs(&self, state: [f64; 4], current: f64) -> [f64; 4] {
219        let [v, m, h, n] = state;
220        let am = Self::alpha_m(v);
221        let bm = Self::beta_m(v);
222        let ah = Self::alpha_h(v);
223        let bh = Self::beta_h(v);
224        let an = Self::alpha_n(v);
225        let bn = Self::beta_n(v);
226
227        let dm = am * (1.0 - m) - bm * m;
228        let dh = ah * (1.0 - h) - bh * h;
229        let dn = an * (1.0 - n) - bn * n;
230        let i_na = self.g_na * m.powi(3) * h * (v - self.e_na);
231        let i_k = self.g_k * n.powi(4) * (v - self.e_k);
232        let i_l = self.g_l * (v - self.e_l);
233        let dv = (-i_na - i_k - i_l + current) / self.c_m;
234        [dv, dm, dh, dn]
235    }
236
237    pub fn step(&mut self, current: f64) -> i32 {
238        let v_prev = self.v;
239        let mut state = [self.v, self.m, self.h, self.n];
240        let substeps = (1.0 / self.dt).round() as usize;
241        for _ in 0..substeps {
242            let k1 = self.rhs(state, current);
243            let k2 = self.rhs(add_scaled(state, k1, 0.5 * self.dt), current);
244            let k3 = self.rhs(add_scaled(state, k2, 0.5 * self.dt), current);
245            let k4 = self.rhs(add_scaled(state, k3, self.dt), current);
246            for idx in 0..4 {
247                state[idx] += (self.dt / 6.0) * (k1[idx] + 2.0 * k2[idx] + 2.0 * k3[idx] + k4[idx]);
248            }
249        }
250        self.v = state[0];
251        self.m = state[1];
252        self.h = state[2];
253        self.n = state[3];
254
255        if self.v >= self.v_threshold && v_prev < self.v_threshold {
256            1
257        } else {
258            0
259        }
260    }
261}
262
263fn add_scaled(state: [f64; 4], deriv: [f64; 4], scale: f64) -> [f64; 4] {
264    [
265        state[0] + scale * deriv[0],
266        state[1] + scale * deriv[1],
267        state[2] + scale * deriv[2],
268        state[3] + scale * deriv[3],
269    ]
270}
271
272#[pyfunction]
273#[pyo3(signature = (model_name, current_trace, dt=None))]
274pub fn py_rk4_neuron_simulate<'py>(
275    py: Python<'py>,
276    model_name: &str,
277    current_trace: PyReadonlyArray1<'py, f64>,
278    dt: Option<f64>,
279) -> PyResult<Py<PyAny>> {
280    let currents = current_trace.as_slice()?;
281    match normalise_model_name(model_name).as_str() {
282        "izhikevich" | "scizhikevichneuron" | "izhikevichneuron" => {
283            let dt = validate_trace_dt(currents, dt.unwrap_or(1.0))?;
284            simulate_izhikevich(py, currents, dt)
285        }
286        "hodgkinhuxley" | "hodgkinhuxleyneuron" => {
287            let dt = validate_trace_dt(currents, dt.unwrap_or(0.01))?;
288            simulate_hodgkin_huxley(py, currents, dt)
289        }
290        "adex" | "adexneuron" => {
291            let dt = validate_trace_dt(currents, dt.unwrap_or(0.1))?;
292            simulate_adex(py, currents, dt)
293        }
294        _ => Err(PyValueError::new_err(format!(
295            "unsupported RK4 neuron model {model_name:?}"
296        ))),
297    }
298}
299
300fn validate_trace_dt(currents: &[f64], dt: f64) -> PyResult<f64> {
301    if !dt.is_finite() || dt <= 0.0 {
302        return Err(PyValueError::new_err("dt must be a positive finite scalar"));
303    }
304    if currents.is_empty() {
305        return Err(PyValueError::new_err("current_trace must be non-empty"));
306    }
307    if currents.iter().any(|current| !current.is_finite()) {
308        return Err(PyValueError::new_err(
309            "current_trace must contain only finite values",
310        ));
311    }
312    Ok(dt)
313}
314
315fn normalise_model_name(name: &str) -> String {
316    name.chars()
317        .filter(|ch| ch.is_ascii_alphanumeric())
318        .flat_map(char::to_lowercase)
319        .collect()
320}
321
322fn simulate_izhikevich<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
323    let mut neuron = IzhikevichRk4::new(dt);
324    let mut v = Vec::with_capacity(currents.len());
325    let mut u = Vec::with_capacity(currents.len());
326    let mut spikes = Vec::new();
327    for (idx, &current) in currents.iter().enumerate() {
328        if neuron.step(current) != 0 {
329            spikes.push(idx as u64);
330        }
331        v.push(neuron.v);
332        u.push(neuron.u);
333    }
334    let d = PyDict::new(py);
335    d.set_item("v", v.into_pyarray(py))?;
336    d.set_item("u", u.into_pyarray(py))?;
337    d.set_item("spikes", spikes.into_pyarray(py))?;
338    d.set_item("n_steps", currents.len())?;
339    Ok(d.into_any().unbind())
340}
341
342fn simulate_adex<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
343    let mut neuron = AdExRk4::new(dt);
344    let mut v = Vec::with_capacity(currents.len());
345    let mut w = Vec::with_capacity(currents.len());
346    let mut spikes = Vec::new();
347    for (idx, &current) in currents.iter().enumerate() {
348        if neuron.step(current) != 0 {
349            spikes.push(idx as u64);
350        }
351        v.push(neuron.v);
352        w.push(neuron.w);
353    }
354    let d = PyDict::new(py);
355    d.set_item("v", v.into_pyarray(py))?;
356    d.set_item("w", w.into_pyarray(py))?;
357    d.set_item("spikes", spikes.into_pyarray(py))?;
358    d.set_item("n_steps", currents.len())?;
359    Ok(d.into_any().unbind())
360}
361
362fn simulate_hodgkin_huxley<'py>(py: Python<'py>, currents: &[f64], dt: f64) -> PyResult<Py<PyAny>> {
363    let mut neuron = HodgkinHuxleyRk4::new(dt);
364    let mut v = Vec::with_capacity(currents.len());
365    let mut m = Vec::with_capacity(currents.len());
366    let mut h = Vec::with_capacity(currents.len());
367    let mut n = Vec::with_capacity(currents.len());
368    let mut spikes = Vec::new();
369    for (idx, &current) in currents.iter().enumerate() {
370        if neuron.step(current) != 0 {
371            spikes.push(idx as u64);
372        }
373        v.push(neuron.v);
374        m.push(neuron.m);
375        h.push(neuron.h);
376        n.push(neuron.n);
377    }
378    let d = PyDict::new(py);
379    d.set_item("v", v.into_pyarray(py))?;
380    d.set_item("m", m.into_pyarray(py))?;
381    d.set_item("h", h.into_pyarray(py))?;
382    d.set_item("n", n.into_pyarray(py))?;
383    d.set_item("spikes", spikes.into_pyarray(py))?;
384    d.set_item("n_steps", currents.len())?;
385    Ok(d.into_any().unbind())
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn izhikevich_rk4_is_deterministic_and_spikes() {
394        let mut a = IzhikevichRk4::new(1.0);
395        let mut b = IzhikevichRk4::new(1.0);
396        let mut spikes = 0;
397        for _ in 0..100 {
398            spikes += a.step(10.0);
399            b.step(10.0);
400        }
401        assert!(spikes > 0);
402        assert_eq!(a.v, b.v);
403        assert_eq!(a.u, b.u);
404    }
405
406    #[test]
407    fn adex_rk4_remains_finite_under_sustained_current() {
408        let mut neuron = AdExRk4::new(0.1);
409        let mut spikes = 0;
410        for _ in 0..3000 {
411            spikes += neuron.step(500.0);
412        }
413        assert!(spikes > 0);
414        assert!(neuron.v.is_finite());
415        assert!(neuron.w.is_finite());
416    }
417
418    #[test]
419    fn hodgkin_huxley_rk4_keeps_gates_bounded() {
420        let mut neuron = HodgkinHuxleyRk4::new(0.01);
421        let mut spikes = 0;
422        for _ in 0..1000 {
423            spikes += neuron.step(10.0);
424        }
425        assert!(spikes > 0);
426        assert!(neuron.v.is_finite());
427        assert!((0.0..=1.0).contains(&neuron.m));
428        assert!((0.0..=1.0).contains(&neuron.h));
429        assert!((0.0..=1.0).contains(&neuron.n));
430    }
431
432    #[test]
433    fn model_name_normalisation_accepts_common_aliases() {
434        assert_eq!(
435            normalise_model_name("Hodgkin-HuxleyNeuron"),
436            "hodgkinhuxleyneuron"
437        );
438        assert_eq!(normalise_model_name("AdEx"), "adex");
439    }
440}