Skip to main content

sc_neurocore_engine/neurons/
multi_compartment.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Multi-compartment neuron models
8
9//! Multi-compartment neuron models.
10
11#[allow(unused_imports)]
12use super::biophysical::safe_rate;
13
14/// Pinsky-Rinzel 1994 — 2-compartment pyramidal cell.
15#[derive(Clone, Debug)]
16pub struct PinskyRinzelNeuron {
17    pub v_s: f64,
18    pub v_d: f64,
19    pub h: f64,
20    pub n: f64,
21    pub s: f64,
22    pub c: f64,
23    pub q: f64,
24    pub gc: f64,
25    pub p: f64,
26    pub g_na: f64,
27    pub g_kdr: f64,
28    pub g_ca: f64,
29    pub g_kahp: f64,
30    pub g_l: f64,
31    pub e_na: f64,
32    pub e_k: f64,
33    pub e_ca: f64,
34    pub e_l: f64,
35    pub dt: f64,
36    pub v_threshold: f64,
37}
38
39impl PinskyRinzelNeuron {
40    pub fn new() -> Self {
41        Self {
42            v_s: -60.0,
43            v_d: -60.0,
44            h: 0.9,
45            n: 0.1,
46            s: 0.0,
47            c: 0.0,
48            q: 0.0,
49            gc: 2.1,
50            p: 0.5,
51            g_na: 30.0,
52            g_kdr: 15.0,
53            g_ca: 10.0,
54            g_kahp: 0.8,
55            g_l: 0.1,
56            e_na: 60.0,
57            e_k: -75.0,
58            e_ca: 80.0,
59            e_l: -60.0,
60            dt: 0.02,
61            v_threshold: -20.0,
62        }
63    }
64    pub fn step(&mut self, current_soma: f64, current_dend: f64) -> i32 {
65        let v_prev = self.v_s;
66        let am = safe_rate(0.32, 54.0, self.v_s, 4.0, 8.0);
67        let bm = safe_rate(-0.28, 27.0, self.v_s, -5.0, 5.6);
68        let m_inf = am / (am + bm);
69        let ah = 0.128 * (-(self.v_s + 50.0) / 18.0).exp();
70        let bh = 4.0 / (1.0 + (-(self.v_s + 27.0) / 5.0).exp());
71        let an = safe_rate(0.032, 52.0, self.v_s, 5.0, 0.32);
72        let bn = 0.5 * (-(self.v_s + 57.0) / 40.0).exp();
73        let s_inf = 1.0 / (1.0 + (-(self.v_d + 20.0) / 9.0).exp());
74        let i_na = self.g_na * m_inf.powi(2) * self.h * (self.v_s - self.e_na);
75        let i_kdr = self.g_kdr * self.n.powi(2) * (self.v_s - self.e_k);
76        let i_ls = self.g_l * (self.v_s - self.e_l);
77        let i_ds = (self.gc / self.p) * (self.v_s - self.v_d);
78        let i_ca = self.g_ca * self.s.powi(2) * (self.v_d - self.e_ca);
79        let i_kahp = self.g_kahp * self.q * (self.v_d - self.e_k);
80        let i_ld = self.g_l * (self.v_d - self.e_l);
81        let i_sd = (self.gc / (1.0 - self.p)) * (self.v_d - self.v_s);
82        self.v_s += (-i_na - i_kdr - i_ls - i_ds + current_soma / self.p) * self.dt;
83        self.v_d += (-i_ca - i_kahp - i_ld - i_sd + current_dend / (1.0 - self.p)) * self.dt;
84        self.h += (ah * (1.0 - self.h) - bh * self.h) * self.dt;
85        self.n += (an * (1.0 - self.n) - bn * self.n) * self.dt;
86        self.s += ((s_inf - self.s) / 5.0) * self.dt;
87        self.c = (self.c + (-0.13 * i_ca - 0.075 * self.c) * self.dt).max(0.0);
88        let q_inf = (self.c / (self.c + 2.0)).min(1.0);
89        self.q += ((q_inf - self.q) / 100.0) * self.dt;
90        if self.v_s >= self.v_threshold && v_prev < self.v_threshold {
91            1
92        } else {
93            0
94        }
95    }
96    pub fn reset(&mut self) {
97        self.v_s = -60.0;
98        self.v_d = -60.0;
99        self.h = 0.9;
100        self.n = 0.1;
101        self.s = 0.0;
102        self.c = 0.0;
103        self.q = 0.0;
104    }
105}
106impl Default for PinskyRinzelNeuron {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112/// Hay et al. 2011 — Layer 5 thick-tufted pyramidal (3-compartment reduced).
113#[derive(Clone, Debug)]
114pub struct HayL5PyramidalNeuron {
115    pub v_s: f64,
116    pub v_t: f64,
117    pub v_a: f64,
118    pub h_na: f64,
119    pub n_k: f64,
120    pub m_ca: f64,
121    pub h_ca: f64,
122    pub m_ih: f64,
123    pub ca_a: f64,
124    pub g_na: f64,
125    pub g_k: f64,
126    pub g_l_s: f64,
127    pub g_ca_t: f64,
128    pub g_ih: f64,
129    pub g_l_t: f64,
130    pub g_ca_a: f64,
131    pub g_kca: f64,
132    pub g_l_a: f64,
133    pub g_st: f64,
134    pub g_ta: f64,
135    pub p_s: f64,
136    pub p_t: f64,
137    pub p_a: f64,
138    pub e_na: f64,
139    pub e_k: f64,
140    pub e_ca: f64,
141    pub e_ih: f64,
142    pub e_l: f64,
143    pub ca_decay: f64,
144    pub f_ca: f64,
145    pub c_m: f64,
146    pub dt: f64,
147    pub v_threshold: f64,
148}
149
150impl HayL5PyramidalNeuron {
151    pub fn new() -> Self {
152        Self {
153            v_s: -75.0,
154            v_t: -75.0,
155            v_a: -75.0,
156            h_na: 0.9,
157            n_k: 0.1,
158            m_ca: 0.0,
159            h_ca: 1.0,
160            m_ih: 0.0,
161            ca_a: 0.0001,
162            g_na: 300.0,
163            g_k: 40.0,
164            g_l_s: 0.03,
165            g_ca_t: 2.0,
166            g_ih: 0.02,
167            g_l_t: 0.03,
168            g_ca_a: 1.5,
169            g_kca: 2.5,
170            g_l_a: 0.03,
171            g_st: 1.5,
172            g_ta: 0.8,
173            p_s: 0.15,
174            p_t: 0.25,
175            p_a: 0.60,
176            e_na: 50.0,
177            e_k: -85.0,
178            e_ca: 140.0,
179            e_ih: -45.0,
180            e_l: -75.0,
181            ca_decay: 200.0,
182            f_ca: 0.0002,
183            c_m: 1.0,
184            dt: 0.025,
185            v_threshold: -30.0,
186        }
187    }
188    fn valid(&self) -> bool {
189        self.v_s.is_finite()
190            && self.h_na.is_finite()
191            && self.n_k.is_finite()
192            && self.v_t.is_finite()
193            && self.m_ca.is_finite()
194            && self.h_ca.is_finite()
195            && self.m_ih.is_finite()
196            && self.v_a.is_finite()
197            && self.ca_a.is_finite()
198            && self.ca_a >= 0.0
199            && self.g_na.is_finite()
200            && self.g_na >= 0.0
201            && self.g_k.is_finite()
202            && self.g_k >= 0.0
203            && self.g_l_s.is_finite()
204            && self.g_l_s >= 0.0
205            && self.g_ca_t.is_finite()
206            && self.g_ca_t >= 0.0
207            && self.g_ih.is_finite()
208            && self.g_ih >= 0.0
209            && self.g_l_t.is_finite()
210            && self.g_l_t >= 0.0
211            && self.g_ca_a.is_finite()
212            && self.g_ca_a >= 0.0
213            && self.g_kca.is_finite()
214            && self.g_kca >= 0.0
215            && self.g_l_a.is_finite()
216            && self.g_l_a >= 0.0
217            && self.g_st.is_finite()
218            && self.g_st >= 0.0
219            && self.g_ta.is_finite()
220            && self.g_ta >= 0.0
221            && self.p_s.is_finite()
222            && self.p_s > 0.0
223            && self.p_t.is_finite()
224            && self.p_t > 0.0
225            && self.p_a.is_finite()
226            && self.p_a > 0.0
227            && self.e_na.is_finite()
228            && self.e_k.is_finite()
229            && self.e_ca.is_finite()
230            && self.e_ih.is_finite()
231            && self.e_l.is_finite()
232            && self.ca_decay.is_finite()
233            && self.ca_decay > 0.0
234            && self.f_ca.is_finite()
235            && self.f_ca >= 0.0
236            && self.c_m.is_finite()
237            && self.c_m > 0.0
238            && self.dt.is_finite()
239            && self.dt > 0.0
240            && self.v_threshold.is_finite()
241    }
242
243    fn derivatives(&self, s: [f64; 9], current_soma: f64, current_tuft: f64) -> [f64; 9] {
244        let v_s = s[0];
245        let h_na = s[1];
246        let n_k = s[2];
247        let v_t = s[3];
248        let m_ca = s[4];
249        let h_ca = s[5];
250        let m_ih = s[6];
251        let v_a = s[7];
252        let ca_a = s[8].max(0.0);
253
254        let m_na_inf = 1.0 / (1.0 + (-(v_s + 38.0) / 7.0).exp());
255        let h_na_inf = 1.0 / (1.0 + ((v_s + 65.0) / 6.0).exp());
256        let n_k_inf = 1.0 / (1.0 + (-(v_s + 25.0) / 12.0).exp());
257        let tau_h = 0.5 + 14.0 / (1.0 + ((v_s + 35.0) / 10.0).exp());
258        let tau_n = 1.0 + 5.0 / (1.0 + ((v_s + 30.0) / 10.0).exp());
259        let i_na = self.g_na * m_na_inf * m_na_inf * m_na_inf * h_na * (v_s - self.e_na);
260        let i_k = self.g_k * n_k * n_k * n_k * n_k * (v_s - self.e_k);
261        let i_l_s = self.g_l_s * (v_s - self.e_l);
262        let i_st = self.g_st * (v_s - v_t) / self.p_s;
263
264        let m_ca_inf = 1.0 / (1.0 + (-(v_t + 27.0) / 7.0).exp());
265        let h_ca_inf = 1.0 / (1.0 + ((v_t + 52.0) / 5.0).exp());
266        let m_ih_inf = 1.0 / (1.0 + ((v_t + 75.0) / 5.5).exp());
267        let i_ca_t = self.g_ca_t * m_ca * m_ca * h_ca * (v_t - self.e_ca);
268        let i_ih = self.g_ih * m_ih * (v_t - self.e_ih);
269        let i_l_t = self.g_l_t * (v_t - self.e_l);
270        let i_ts = self.g_st * (v_t - v_s) / self.p_t;
271        let i_ta = self.g_ta * (v_t - v_a) / self.p_t;
272
273        let m_ca_a_inf = 1.0 / (1.0 + (-(v_a + 30.0) / 5.0).exp());
274        let kca_act = ca_a / (ca_a + 0.001);
275        let i_ca_a = self.g_ca_a * m_ca_a_inf * m_ca_a_inf * (v_a - self.e_ca);
276        let i_kca = self.g_kca * kca_act * (v_a - self.e_k);
277        let i_l_a = self.g_l_a * (v_a - self.e_l);
278        let i_at = self.g_ta * (v_a - v_t) / self.p_a;
279
280        [
281            (-i_na - i_k - i_l_s - i_st + current_soma / self.p_s) / self.c_m,
282            (h_na_inf - h_na) / tau_h,
283            (n_k_inf - n_k) / tau_n,
284            (-i_ca_t - i_ih - i_l_t - i_ts - i_ta) / self.c_m,
285            m_ca_inf - m_ca,
286            (h_ca_inf - h_ca) / 20.0,
287            (m_ih_inf - m_ih) / 50.0,
288            (-i_ca_a - i_kca - i_l_a - i_at + current_tuft / self.p_a) / self.c_m,
289            -self.f_ca * i_ca_a - ca_a / self.ca_decay,
290        ]
291    }
292
293    fn rk4_substep(&self, s: [f64; 9], current_soma: f64, current_tuft: f64) -> [f64; 9] {
294        let dt = self.dt;
295        let k1 = self.derivatives(s, current_soma, current_tuft);
296        let mut s2 = [0.0; 9];
297        let mut s3 = [0.0; 9];
298        let mut s4 = [0.0; 9];
299        for i in 0..9 {
300            s2[i] = s[i] + 0.5 * dt * k1[i];
301        }
302        let k2 = self.derivatives(s2, current_soma, current_tuft);
303        for i in 0..9 {
304            s3[i] = s[i] + 0.5 * dt * k2[i];
305        }
306        let k3 = self.derivatives(s3, current_soma, current_tuft);
307        for i in 0..9 {
308            s4[i] = s[i] + dt * k3[i];
309        }
310        let k4 = self.derivatives(s4, current_soma, current_tuft);
311        let mut next = [0.0; 9];
312        for i in 0..9 {
313            next[i] = s[i] + dt * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]) / 6.0;
314        }
315        next[8] = next[8].max(0.0);
316        next
317    }
318
319    pub fn step(&mut self, current_soma: f64, current_tuft: f64) -> i32 {
320        if !current_soma.is_finite() || !current_tuft.is_finite() || !self.valid() {
321            return 0;
322        }
323        let v_s_prev = self.v_s;
324        let mut state = [
325            self.v_s, self.h_na, self.n_k, self.v_t, self.m_ca, self.h_ca, self.m_ih, self.v_a,
326            self.ca_a,
327        ];
328        for _ in 0..4 {
329            state = self.rk4_substep(state, current_soma, current_tuft);
330            if !state.iter().all(|value| value.is_finite()) {
331                return 0;
332            }
333        }
334        self.v_s = state[0];
335        self.h_na = state[1];
336        self.n_k = state[2];
337        self.v_t = state[3];
338        self.m_ca = state[4];
339        self.h_ca = state[5];
340        self.m_ih = state[6];
341        self.v_a = state[7];
342        self.ca_a = state[8];
343        if self.v_s >= self.v_threshold && v_s_prev < self.v_threshold {
344            1
345        } else {
346            0
347        }
348    }
349    pub fn reset(&mut self) {
350        self.v_s = -75.0;
351        self.v_t = -75.0;
352        self.v_a = -75.0;
353        self.h_na = 0.9;
354        self.n_k = 0.1;
355        self.m_ca = 0.0;
356        self.h_ca = 1.0;
357        self.m_ih = 0.0;
358        self.ca_a = 0.0001;
359    }
360}
361impl Default for HayL5PyramidalNeuron {
362    fn default() -> Self {
363        Self::new()
364    }
365}
366
367/// Marder STG — stomatogastric ganglion LP neuron with 7 currents.
368#[derive(Clone, Debug)]
369pub struct MarderSTGNeuron {
370    pub v: f64,
371    pub m_na: f64,
372    pub h_na: f64,
373    pub m_cat: f64,
374    pub h_cat: f64,
375    pub m_cas: f64,
376    pub m_a: f64,
377    pub h_a: f64,
378    pub m_kd: f64,
379    pub m_h: f64,
380    pub ca: f64,
381    pub g_na: f64,
382    pub g_cat: f64,
383    pub g_cas: f64,
384    pub g_a: f64,
385    pub g_kd: f64,
386    pub g_kca: f64,
387    pub g_h: f64,
388    pub g_l: f64,
389    pub e_na: f64,
390    pub e_k: f64,
391    pub e_ca: f64,
392    pub e_h: f64,
393    pub e_l: f64,
394    pub dt: f64,
395    pub v_threshold: f64,
396}
397
398impl MarderSTGNeuron {
399    pub fn new() -> Self {
400        Self {
401            v: -60.0,
402            m_na: 0.0,
403            h_na: 0.9,
404            m_cat: 0.0,
405            h_cat: 0.9,
406            m_cas: 0.0,
407            m_a: 0.0,
408            h_a: 0.9,
409            m_kd: 0.0,
410            m_h: 0.0,
411            ca: 0.05,
412            g_na: 400.0,
413            g_cat: 2.5,
414            g_cas: 6.0,
415            g_a: 50.0,
416            g_kd: 100.0,
417            g_kca: 25.0,
418            g_h: 0.02,
419            g_l: 0.01,
420            e_na: 50.0,
421            e_k: -80.0,
422            e_ca: 120.0,
423            e_h: -20.0,
424            e_l: -50.0,
425            dt: 0.05,
426            v_threshold: -20.0,
427        }
428    }
429    pub fn step(&mut self, current: f64) -> i32 {
430        let v_prev = self.v;
431        let b = |v: f64, vh: f64, s: f64| 1.0 / (1.0 + (-(v - vh) / s).exp());
432        self.m_na += (b(self.v, -25.5, 5.29) - self.m_na) / 1.32 * self.dt;
433        self.h_na += (b(self.v, -48.9, -5.18) - self.h_na)
434            / (0.67 * (1.0 + ((self.v + 62.9) / -10.0).exp()) + 1.5)
435            * self.dt;
436        self.m_cat += (b(self.v, -27.1, 7.2) - self.m_cat) / 21.7 * self.dt;
437        self.h_cat += (b(self.v, -32.1, -5.5) - self.h_cat) / 105.0 * self.dt;
438        self.m_cas += (b(self.v, -33.0, 8.1) - self.m_cas) / 14.0 * self.dt;
439        self.m_a += (b(self.v, -27.2, 8.7) - self.m_a) / 11.6 * self.dt;
440        self.h_a += (b(self.v, -56.9, -4.9) - self.h_a) / 38.6 * self.dt;
441        self.m_kd += (b(self.v, -12.3, 11.8) - self.m_kd) / 7.2 * self.dt;
442        self.m_h += (b(self.v, -70.0, -6.0) - self.m_h) / 272.0 * self.dt;
443        let kca_act = (self.ca / (self.ca + 3.0)).min(1.0);
444        let i_na = self.g_na * self.m_na.powi(3) * self.h_na * (self.v - self.e_na);
445        let i_cat = self.g_cat * self.m_cat.powi(3) * self.h_cat * (self.v - self.e_ca);
446        let i_cas = self.g_cas * self.m_cas.powi(3) * (self.v - self.e_ca);
447        let i_a = self.g_a * self.m_a.powi(3) * self.h_a * (self.v - self.e_k);
448        let i_kd = self.g_kd * self.m_kd.powi(4) * (self.v - self.e_k);
449        let i_kca = self.g_kca * kca_act.powi(4) * (self.v - self.e_k);
450        let i_h = self.g_h * self.m_h * (self.v - self.e_h);
451        let i_l = self.g_l * (self.v - self.e_l);
452        self.v += (-i_na - i_cat - i_cas - i_a - i_kd - i_kca - i_h - i_l + current) * self.dt;
453        let i_ca_total = i_cat + i_cas;
454        self.ca = (self.ca + (-0.0001 * i_ca_total - 0.01 * self.ca) * self.dt).max(0.0);
455        if self.v >= self.v_threshold && v_prev < self.v_threshold {
456            1
457        } else {
458            0
459        }
460    }
461    pub fn reset(&mut self) {
462        self.v = -60.0;
463        self.m_na = 0.0;
464        self.h_na = 0.9;
465        self.m_cat = 0.0;
466        self.h_cat = 0.9;
467        self.m_cas = 0.0;
468        self.m_a = 0.0;
469        self.h_a = 0.9;
470        self.m_kd = 0.0;
471        self.m_h = 0.0;
472        self.ca = 0.05;
473    }
474}
475impl Default for MarderSTGNeuron {
476    fn default() -> Self {
477        Self::new()
478    }
479}
480
481/// Rall cable — N-compartment passive dendrite model. Rall 1964.
482#[derive(Clone, Debug)]
483pub struct RallCableNeuron {
484    pub v: Vec<f64>,
485    pub n_comp: usize,
486    pub tau_m: f64,
487    pub v_rest: f64,
488    pub g_ratio: f64,
489    pub v_threshold: f64,
490    pub v_reset: f64,
491    pub dt: f64,
492}
493
494impl RallCableNeuron {
495    pub fn new(n_comp: usize) -> Self {
496        let count = n_comp.max(1);
497        Self {
498            v: vec![-65.0; count],
499            n_comp: count,
500            tau_m: 20.0,
501            v_rest: -65.0,
502            g_ratio: 0.5,
503            v_threshold: -50.0,
504            v_reset: -65.0,
505            dt: 0.1,
506        }
507    }
508    pub fn step(&mut self, current: f64) -> i32 {
509        let Some(mut candidate) = self.candidate(current) else {
510            return -1;
511        };
512        let previous_soma = self.v[0];
513        if candidate[0] >= self.v_threshold && previous_soma < self.v_threshold {
514            candidate[0] = self.v_reset;
515            self.v = candidate;
516            1
517        } else {
518            self.v = candidate;
519            0
520        }
521    }
522    pub fn reset(&mut self) {
523        self.v.fill(self.v_rest);
524    }
525
526    fn valid(&self) -> bool {
527        self.n_comp >= 1
528            && self.v.len() == self.n_comp
529            && self.tau_m.is_finite()
530            && self.tau_m > 0.0
531            && self.v_rest.is_finite()
532            && self.g_ratio.is_finite()
533            && self.g_ratio >= 0.0
534            && self.v_threshold.is_finite()
535            && self.v_reset.is_finite()
536            && self.dt.is_finite()
537            && self.dt > 0.0
538            && self.v.iter().all(|value| value.is_finite())
539    }
540
541    fn candidate(&self, current: f64) -> Option<Vec<f64>> {
542        if !self.valid() || !current.is_finite() {
543            return None;
544        }
545        let alpha = self.dt / self.tau_m;
546        let offdiag = -alpha * self.g_ratio;
547        let mut diagonal = vec![1.0 + alpha + 2.0 * alpha * self.g_ratio; self.n_comp];
548        if self.n_comp == 1 {
549            diagonal[0] = 1.0 + alpha;
550        } else {
551            diagonal[0] = 1.0 + alpha + alpha * self.g_ratio;
552            diagonal[self.n_comp - 1] = 1.0 + alpha + alpha * self.g_ratio;
553        }
554        let lower = vec![offdiag; self.n_comp.saturating_sub(1)];
555        let upper = vec![offdiag; self.n_comp.saturating_sub(1)];
556        let mut rhs: Vec<f64> = self.v.iter().map(|value| value - self.v_rest).collect();
557        rhs[self.n_comp - 1] += alpha * current;
558        let mut solved = solve_rall_tridiagonal(&lower, &diagonal, &upper, &rhs)?;
559        for value in &mut solved {
560            *value += self.v_rest;
561        }
562        Some(solved)
563    }
564}
565
566fn solve_rall_tridiagonal(
567    lower: &[f64],
568    diagonal: &[f64],
569    upper: &[f64],
570    rhs: &[f64],
571) -> Option<Vec<f64>> {
572    let n = diagonal.len();
573    if n == 0
574        || rhs.len() != n
575        || lower.len() != n.saturating_sub(1)
576        || upper.len() != n.saturating_sub(1)
577    {
578        return None;
579    }
580    let mut c_prime = vec![0.0; n.saturating_sub(1)];
581    let mut d_prime = vec![0.0; n];
582    let mut pivot = diagonal[0];
583    if !pivot.is_finite() || pivot == 0.0 {
584        return None;
585    }
586    if n > 1 {
587        c_prime[0] = upper[0] / pivot;
588    }
589    d_prime[0] = rhs[0] / pivot;
590    for i in 1..n {
591        pivot = diagonal[i] - lower[i - 1] * c_prime[i - 1];
592        if !pivot.is_finite() || pivot == 0.0 {
593            return None;
594        }
595        if i < n - 1 {
596            c_prime[i] = upper[i] / pivot;
597        }
598        d_prime[i] = (rhs[i] - lower[i - 1] * d_prime[i - 1]) / pivot;
599    }
600    let mut solution = vec![0.0; n];
601    solution[n - 1] = d_prime[n - 1];
602    for i in (0..n - 1).rev() {
603        solution[i] = d_prime[i] - c_prime[i] * solution[i + 1];
604    }
605    solution
606        .iter()
607        .all(|value| value.is_finite())
608        .then_some(solution)
609}
610
611/// Booth-Rinzel — 2-compartment motoneuron with bistability. Booth et al. 1997.
612#[derive(Clone, Debug)]
613pub struct BoothRinzelNeuron {
614    pub vs: f64,
615    pub vd: f64,
616    pub h: f64,
617    pub n: f64,
618    pub q: f64,
619    pub ca: f64,
620    pub p: f64,
621    pub gc: f64,
622    pub g_na: f64,
623    pub g_k: f64,
624    pub g_ca: f64,
625    pub g_kca: f64,
626    pub g_l: f64,
627    pub e_na: f64,
628    pub e_k: f64,
629    pub e_ca: f64,
630    pub e_l: f64,
631    pub dt: f64,
632    pub v_threshold: f64,
633}
634
635impl BoothRinzelNeuron {
636    pub fn new() -> Self {
637        Self {
638            vs: -65.0,
639            vd: -65.0,
640            h: 0.9,
641            n: 0.0,
642            q: 0.0,
643            ca: 0.0,
644            p: 0.5,
645            gc: 0.1,
646            g_na: 120.0,
647            g_k: 20.0,
648            g_ca: 14.0,
649            g_kca: 5.0,
650            g_l: 0.51,
651            e_na: 55.0,
652            e_k: -80.0,
653            e_ca: 80.0,
654            e_l: -60.0,
655            dt: 0.025,
656            v_threshold: -20.0,
657        }
658    }
659    pub fn step(&mut self, current: f64) -> i32 {
660        let v_prev = self.vs;
661        for _ in 0..4 {
662            let m_inf = 1.0 / (1.0 + (-(self.vs + 35.0) / 7.8).exp());
663            let h_inf = 1.0 / (1.0 + ((self.vs + 55.0) / 7.0).exp());
664            let n_inf = 1.0 / (1.0 + (-(self.vs + 28.0) / 15.0).exp());
665            let s_inf = 1.0 / (1.0 + (-(self.vd + 22.0) / 5.0).exp());
666            let q_inf = 1.0 / (1.0 + (-(self.vd + 35.0) / 2.0).exp());
667            let tau_h = (30.0
668                / (((self.vs + 50.0) / 15.0).exp() + ((-(self.vs + 50.0)) / 16.0).exp() + 1e-12))
669                .max(0.01);
670            let tau_n = (7.0
671                / (((self.vs + 40.0) / 40.0).exp() + ((-(self.vs + 40.0)) / 50.0).exp() + 1e-12))
672                .max(0.01);
673            self.h = (self.h + (h_inf - self.h) / tau_h * self.dt).clamp(0.0, 1.0);
674            self.n = (self.n + (n_inf - self.n) / tau_n * self.dt).clamp(0.0, 1.0);
675            self.q = (self.q + (q_inf - self.q) / 400.0 * self.dt).clamp(0.0, 1.0);
676            let chi = (self.ca / 250.0).min(1.0);
677            let i_na = self.g_na * m_inf.powi(3) * self.h * (self.vs - self.e_na);
678            let i_k = self.g_k * self.n.powi(4) * (self.vs - self.e_k);
679            let i_ls = self.g_l * (self.vs - self.e_l);
680            let i_sd = (self.gc / self.p) * (self.vs - self.vd);
681            let i_ca = self.g_ca * s_inf.powi(2) * (self.vd - self.e_ca);
682            let i_kca = self.g_kca * chi * (self.vd - self.e_k);
683            let i_ld = self.g_l * (self.vd - self.e_l);
684            let i_ds = (self.gc / (1.0 - self.p)) * (self.vd - self.vs);
685            self.vs = (self.vs + (-i_na - i_k - i_ls - i_sd + current / self.p) * self.dt)
686                .clamp(-200.0, 100.0);
687            self.vd = (self.vd + (-i_ca - i_kca - i_ld - i_ds) * self.dt).clamp(-200.0, 100.0);
688            self.ca = (self.ca + (0.0025 * (-0.009 * i_ca) - 0.18 * self.ca) * self.dt).max(0.0);
689        }
690        if self.vs >= self.v_threshold && v_prev < self.v_threshold {
691            1
692        } else {
693            0
694        }
695    }
696    pub fn reset(&mut self) {
697        self.vs = -65.0;
698        self.vd = -65.0;
699        self.h = 0.9;
700        self.n = 0.0;
701        self.q = 0.0;
702        self.ca = 0.0;
703    }
704}
705impl Default for BoothRinzelNeuron {
706    fn default() -> Self {
707        Self::new()
708    }
709}
710
711/// Dendrify — two-compartment with active dendritic spike (NMDA-like plateau).
712#[derive(Clone, Debug)]
713pub struct DendrifyNeuron {
714    pub v_s: f64,
715    pub v_d: f64,
716    pub d_active: bool,
717    pub d_timer: f64,
718    pub tau_s: f64,
719    pub tau_d: f64,
720    pub g_c: f64,
721    pub d_threshold: f64,
722    pub d_amplitude: f64,
723    pub d_duration: f64,
724    pub v_rest: f64,
725    pub v_threshold: f64,
726    pub v_reset: f64,
727    pub dt: f64,
728}
729
730impl DendrifyNeuron {
731    pub fn new() -> Self {
732        Self {
733            v_s: -65.0,
734            v_d: -65.0,
735            d_active: false,
736            d_timer: 0.0,
737            tau_s: 10.0,
738            tau_d: 20.0,
739            g_c: 0.8,
740            d_threshold: -35.0,
741            d_amplitude: 30.0,
742            d_duration: 10.0,
743            v_rest: -65.0,
744            v_threshold: -50.0,
745            v_reset: -65.0,
746            dt: 0.1,
747        }
748    }
749    pub fn step(&mut self, current: f64) -> i32 {
750        let d_input = if self.d_active { self.d_amplitude } else { 0.0 };
751        self.v_d += (-(self.v_d - self.v_rest) + current + d_input
752            - self.g_c * (self.v_d - self.v_s))
753            / self.tau_d
754            * self.dt;
755        self.v_s +=
756            (-(self.v_s - self.v_rest) + self.g_c * (self.v_d - self.v_s)) / self.tau_s * self.dt;
757        if self.d_active {
758            self.d_timer -= self.dt;
759            if self.d_timer <= 0.0 {
760                self.d_active = false;
761            }
762        } else if self.v_d >= self.d_threshold {
763            self.d_active = true;
764            self.d_timer = self.d_duration;
765        }
766        if self.v_s >= self.v_threshold {
767            self.v_s = self.v_reset;
768            1
769        } else {
770            0
771        }
772    }
773    pub fn reset(&mut self) {
774        self.v_s = -65.0;
775        self.v_d = -65.0;
776        self.d_active = false;
777        self.d_timer = 0.0;
778    }
779}
780impl Default for DendrifyNeuron {
781    fn default() -> Self {
782        Self::new()
783    }
784}
785
786/// Two-compartment LIF — soma + dendrite with history-dependent coupling.
787#[derive(Clone, Debug)]
788pub struct TwoCompartmentLIFNeuron {
789    pub v_s: f64,
790    pub v_d: f64,
791    pub v_rest: f64,
792    pub v_reset: f64,
793    pub theta: f64,
794    pub tau_s: f64,
795    pub tau_d: f64,
796    pub kappa: f64,
797    pub dt: f64,
798}
799
800impl TwoCompartmentLIFNeuron {
801    pub fn new() -> Self {
802        Self {
803            v_s: 0.0,
804            v_d: 0.0,
805            v_rest: 0.0,
806            v_reset: 0.0,
807            theta: 1.0,
808            tau_s: 2.0,
809            tau_d: 20.0,
810            kappa: 0.5,
811            dt: 1.0,
812        }
813    }
814    pub fn step(&mut self, i_soma: f64, i_dend: f64) -> i32 {
815        let alpha_s = (-self.dt / self.tau_s).exp();
816        let alpha_d = (-self.dt / self.tau_d).exp();
817        self.v_d = alpha_d * self.v_d + i_dend;
818        self.v_s = alpha_s * self.v_s + i_soma + self.kappa * self.v_d;
819        if self.v_s >= self.theta {
820            self.v_s = self.v_reset;
821            1
822        } else {
823            0
824        }
825    }
826    pub fn reset(&mut self) {
827        self.v_s = self.v_rest;
828        self.v_d = self.v_rest;
829    }
830}
831impl Default for TwoCompartmentLIFNeuron {
832    fn default() -> Self {
833        Self::new()
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    #[test]
842    fn pr_fires() {
843        let mut n = PinskyRinzelNeuron::new();
844        let t: i32 = (0..5000).map(|_| n.step(5.0, 0.0)).sum();
845        assert!(t > 0);
846    }
847    #[test]
848    fn hay_fires() {
849        let mut n = HayL5PyramidalNeuron::new();
850        let t: i32 = (0..500).map(|_| n.step(20.0, 0.0)).sum();
851        assert!(t > 0);
852    }
853    #[test]
854    fn marder_fires() {
855        let mut n = MarderSTGNeuron::new();
856        let t: i32 = (0..2000).map(|_| n.step(5.0)).sum();
857        assert!(t > 0);
858    }
859    #[test]
860    fn rall_fires() {
861        let mut n = RallCableNeuron::new(2);
862        n.g_ratio = 5.0;
863        let t: i32 = (0..5000).map(|_| n.step(500.0)).sum();
864        assert!(t > 0);
865    }
866    #[test]
867    fn booth_fires() {
868        let mut n = BoothRinzelNeuron::new();
869        let t: i32 = (0..2000).map(|_| n.step(5.0)).sum();
870        assert!(t > 0);
871    }
872    #[test]
873    fn dendrify_fires() {
874        let mut n = DendrifyNeuron::new();
875        let t: i32 = (0..2000).map(|_| n.step(50.0)).sum();
876        assert!(t > 0);
877    }
878    #[test]
879    fn tc_lif_fires() {
880        let mut n = TwoCompartmentLIFNeuron::new();
881        let t: i32 = (0..100).map(|_| n.step(0.5, 0.3)).sum();
882        assert!(t > 0);
883    }
884
885    // ── Multi-angle tests for multi-compartment models ──
886
887    // -- PinskyRinzel --
888    #[test]
889    fn pr_reset() {
890        let mut n = PinskyRinzelNeuron::new();
891        for _ in 0..100 {
892            n.step(5.0, 0.0);
893        }
894        n.reset();
895        assert!((n.v_s - (-60.0)).abs() < 1e-10);
896        assert!((n.v_d - (-60.0)).abs() < 1e-10);
897    }
898    #[test]
899    fn pr_bounded() {
900        let mut n = PinskyRinzelNeuron::new();
901        for _ in 0..5000 {
902            n.step(50.0, 0.0);
903        }
904        assert!(n.v_s.is_finite());
905    }
906    #[test]
907    fn pr_dendritic_input() {
908        let mut n = PinskyRinzelNeuron::new();
909        let _t: i32 = (0..5000).map(|_| n.step(0.0, 5.0)).sum();
910        // Dendritic input should also be able to drive spiking
911        assert!(n.v_d.is_finite());
912    }
913    #[test]
914    fn pr_nan_no_panic() {
915        PinskyRinzelNeuron::new().step(f64::NAN, 0.0);
916    }
917
918    // -- HayL5 --
919    #[test]
920    fn hay_reset() {
921        let mut n = HayL5PyramidalNeuron::new();
922        for _ in 0..100 {
923            n.step(20.0, 0.0);
924        }
925        n.reset();
926        assert!((n.v_s - (-75.0)).abs() < 1e-10);
927    }
928    #[test]
929    fn hay_bounded() {
930        let mut n = HayL5PyramidalNeuron::new();
931        for _ in 0..500 {
932            n.step(100.0, 0.0);
933        }
934        assert!(n.v_s.is_finite());
935    }
936    #[test]
937    fn hay_nan_no_panic() {
938        HayL5PyramidalNeuron::new().step(f64::NAN, 0.0);
939    }
940    #[test]
941    fn hay_rk4_somatic_anchor() {
942        let mut n = HayL5PyramidalNeuron::new();
943        let spikes: i32 = (0..20_000).map(|_| n.step(10.0, 0.0)).sum();
944        assert_eq!(spikes, 1);
945        assert!(n.ca_a >= 0.0);
946    }
947    #[test]
948    fn hay_rk4_dual_input_anchor() {
949        let mut n = HayL5PyramidalNeuron::new();
950        let spikes: i32 = (0..20_000).map(|_| n.step(5.0, 5.0)).sum();
951        assert_eq!(spikes, 4);
952    }
953    #[test]
954    fn hay_invalid_input_preserves_state() {
955        let mut n = HayL5PyramidalNeuron::new();
956        for _ in 0..10 {
957            n.step(10.0, 0.0);
958        }
959        let old = [
960            n.v_s, n.h_na, n.n_k, n.v_t, n.m_ca, n.h_ca, n.m_ih, n.v_a, n.ca_a,
961        ];
962        assert_eq!(n.step(f64::INFINITY, 0.0), 0);
963        assert_eq!(
964            [n.v_s, n.h_na, n.n_k, n.v_t, n.m_ca, n.h_ca, n.m_ih, n.v_a, n.ca_a],
965            old
966        );
967    }
968
969    // -- MarderSTG --
970    #[test]
971    fn marder_reset() {
972        let mut n = MarderSTGNeuron::new();
973        for _ in 0..100 {
974            n.step(5.0);
975        }
976        n.reset();
977        assert!((n.v - (-60.0)).abs() < 1e-10);
978    }
979    #[test]
980    fn marder_bounded() {
981        let mut n = MarderSTGNeuron::new();
982        for _ in 0..2000 {
983            n.step(50.0);
984        }
985        assert!(n.v.is_finite());
986    }
987    #[test]
988    fn marder_nan_no_panic() {
989        MarderSTGNeuron::new().step(f64::NAN);
990    }
991
992    // -- RallCable --
993    #[test]
994    fn rall_reset() {
995        let mut n = RallCableNeuron::new(5);
996        for _ in 0..100 {
997            n.step(50.0);
998        }
999        n.reset();
1000        assert!(n.v.iter().all(|&x| (x - n.v_rest).abs() < 1e-10));
1001    }
1002    #[test]
1003    fn rall_bounded() {
1004        let mut n = RallCableNeuron::new(5);
1005        for _ in 0..1000 {
1006            n.step(500.0);
1007        }
1008        assert!(n.v.iter().all(|x| x.is_finite()));
1009    }
1010    #[test]
1011    fn rall_implicit_step_reference() {
1012        let mut n = RallCableNeuron::new(3);
1013        assert_eq!(n.step(100.0), 0);
1014        assert!((n.v[0] - -64.99999695179709).abs() < 1e-12);
1015        assert!((n.v[1] - -64.99877157422763).abs() < 1e-12);
1016        assert!((n.v[2] - -64.50371903616434).abs() < 1e-12);
1017    }
1018    #[test]
1019    fn rall_nan_no_panic() {
1020        let mut n = RallCableNeuron::new(5);
1021        let before = n.v.clone();
1022        assert_eq!(n.step(f64::NAN), -1);
1023        assert_eq!(n.v, before);
1024    }
1025
1026    // -- BoothRinzel --
1027    #[test]
1028    fn booth_reset() {
1029        let mut n = BoothRinzelNeuron::new();
1030        for _ in 0..100 {
1031            n.step(5.0);
1032        }
1033        n.reset();
1034        assert!((n.vs - (-65.0)).abs() < 1e-10);
1035    }
1036    #[test]
1037    fn booth_bounded() {
1038        let mut n = BoothRinzelNeuron::new();
1039        for _ in 0..2000 {
1040            n.step(50.0);
1041        }
1042        assert!(n.vs.is_finite());
1043    }
1044    #[test]
1045    fn booth_nan_no_panic() {
1046        BoothRinzelNeuron::new().step(f64::NAN);
1047    }
1048
1049    // -- Dendrify --
1050    #[test]
1051    fn dendrify_reset() {
1052        let mut n = DendrifyNeuron::new();
1053        for _ in 0..100 {
1054            n.step(50.0);
1055        }
1056        n.reset();
1057        assert!((n.v_s - (-65.0)).abs() < 1e-10);
1058    }
1059    #[test]
1060    fn dendrify_bounded() {
1061        let mut n = DendrifyNeuron::new();
1062        for _ in 0..2000 {
1063            n.step(200.0);
1064        }
1065        assert!(n.v_s.is_finite());
1066    }
1067    #[test]
1068    fn dendrify_nan_no_panic() {
1069        DendrifyNeuron::new().step(f64::NAN);
1070    }
1071
1072    // -- TwoCompartmentLIF --
1073    #[test]
1074    fn tc_lif_reset() {
1075        let mut n = TwoCompartmentLIFNeuron::new();
1076        for _ in 0..50 {
1077            n.step(0.5, 0.3);
1078        }
1079        n.reset();
1080    }
1081    #[test]
1082    fn tc_lif_bounded() {
1083        let mut n = TwoCompartmentLIFNeuron::new();
1084        for _ in 0..1000 {
1085            n.step(100.0, 100.0);
1086        }
1087        assert!(n.v_s.is_finite());
1088    }
1089    #[test]
1090    fn tc_lif_nan_no_panic() {
1091        TwoCompartmentLIFNeuron::new().step(f64::NAN, 0.0);
1092    }
1093}
1094
1095/// Dendritic NMDA spike model.
1096///
1097/// Captures the non-linear voltage-dependent Mg²⁺ block of NMDA receptors
1098/// in dendritic branches. NMDA current has a sigmoidal voltage dependence:
1099///
1100///   I_NMDA = g_NMDA · B(V) · (V - E_NMDA)
1101///   B(V) = 1 / (1 + [Mg²⁺]/3.57 · exp(-0.062 · V))
1102///
1103/// This enables coincidence detection: the dendrite only passes current
1104/// when both presynaptic glutamate AND postsynaptic depolarisation are present.
1105///
1106/// Reference: Jahr & Stevens (1990), Schiller et al. (2000).
1107#[derive(Clone, Debug)]
1108pub struct DendriticNMDANeuron {
1109    pub v_soma: f64,
1110    pub v_dend: f64,
1111    pub g_nmda: f64,
1112    pub e_nmda: f64,
1113    pub mg_conc: f64,
1114    pub g_coupling: f64,
1115    pub tau_soma: f64,
1116    pub tau_dend: f64,
1117    pub theta: f64,
1118    pub dt: f64,
1119}
1120
1121impl DendriticNMDANeuron {
1122    pub fn new() -> Self {
1123        Self {
1124            v_soma: -65.0,
1125            v_dend: -65.0,
1126            g_nmda: 1.5,
1127            e_nmda: 0.0,
1128            mg_conc: 1.0,
1129            g_coupling: 0.5,
1130            tau_soma: 20.0,
1131            tau_dend: 50.0,
1132            theta: -50.0,
1133            dt: 0.1,
1134        }
1135    }
1136
1137    /// Mg²⁺ block factor (Jahr & Stevens 1990).
1138    fn mg_block(&self, v: f64) -> f64 {
1139        1.0 / (1.0 + (self.mg_conc / 3.57) * (-0.062 * v).exp())
1140    }
1141
1142    fn valid(&self) -> bool {
1143        self.v_soma.is_finite()
1144            && self.v_dend.is_finite()
1145            && self.g_nmda.is_finite()
1146            && self.g_nmda >= 0.0
1147            && self.e_nmda.is_finite()
1148            && self.mg_conc.is_finite()
1149            && self.mg_conc >= 0.0
1150            && self.g_coupling.is_finite()
1151            && self.g_coupling >= 0.0
1152            && self.tau_soma.is_finite()
1153            && self.tau_soma > 0.0
1154            && self.tau_dend.is_finite()
1155            && self.tau_dend > 0.0
1156            && self.theta.is_finite()
1157            && self.dt.is_finite()
1158            && self.dt > 0.0
1159    }
1160
1161    fn derivatives(&self, v_soma: f64, v_dend: f64, i_soma: f64, glutamate: f64) -> (f64, f64) {
1162        let b = self.mg_block(v_dend);
1163        let i_nmda = self.g_nmda * glutamate * b * (v_dend - self.e_nmda);
1164        let dv_soma =
1165            (-v_soma - 65.0 + i_soma + self.g_coupling * (v_dend - v_soma)) / self.tau_soma;
1166        let dv_dend =
1167            (-v_dend - 65.0 + i_nmda + self.g_coupling * (v_soma - v_dend)) / self.tau_dend;
1168        (dv_soma, dv_dend)
1169    }
1170
1171    fn rk4_substep(&self, v_soma: f64, v_dend: f64, i_soma: f64, glutamate: f64) -> (f64, f64) {
1172        let dt = self.dt;
1173        let (k1s, k1d) = self.derivatives(v_soma, v_dend, i_soma, glutamate);
1174        let (k2s, k2d) = self.derivatives(
1175            v_soma + 0.5 * dt * k1s,
1176            v_dend + 0.5 * dt * k1d,
1177            i_soma,
1178            glutamate,
1179        );
1180        let (k3s, k3d) = self.derivatives(
1181            v_soma + 0.5 * dt * k2s,
1182            v_dend + 0.5 * dt * k2d,
1183            i_soma,
1184            glutamate,
1185        );
1186        let (k4s, k4d) = self.derivatives(v_soma + dt * k3s, v_dend + dt * k3d, i_soma, glutamate);
1187        (
1188            v_soma + dt * (k1s + 2.0 * k2s + 2.0 * k3s + k4s) / 6.0,
1189            v_dend + dt * (k1d + 2.0 * k2d + 2.0 * k3d + k4d) / 6.0,
1190        )
1191    }
1192
1193    /// Step with somatic input and dendritic glutamate.
1194    pub fn step(&mut self, i_soma: f64, glutamate: f64) -> i32 {
1195        if !i_soma.is_finite() || !glutamate.is_finite() || glutamate < 0.0 || !self.valid() {
1196            return 0;
1197        }
1198        let (next_v_soma, next_v_dend) =
1199            self.rk4_substep(self.v_soma, self.v_dend, i_soma, glutamate);
1200        if !next_v_soma.is_finite() || !next_v_dend.is_finite() {
1201            return 0;
1202        }
1203        self.v_dend = next_v_dend;
1204        if next_v_soma >= self.theta {
1205            self.v_soma = -65.0;
1206            1
1207        } else {
1208            self.v_soma = next_v_soma;
1209            0
1210        }
1211    }
1212
1213    pub fn reset(&mut self) {
1214        self.v_soma = -65.0;
1215        self.v_dend = -65.0;
1216    }
1217}
1218
1219impl Default for DendriticNMDANeuron {
1220    fn default() -> Self {
1221        Self::new()
1222    }
1223}
1224
1225/// Multi-compartment neuron (MCN) matching the Spiking-WM architecture.
1226///
1227/// Dual-dendrite model with basal and apical compartments. The apical dendrite
1228/// gates how strongly basal information influences the soma, enabling
1229/// nonlinear integration for long-term temporal memory in RL tasks. The engine
1230/// uses candidate-first RK4 over `(u, v_basal, v_apical)` so all compartments
1231/// are advanced from one consistent state before the reset is committed.
1232///
1233/// Exact equations from arXiv:2503.00713 (Spiking-WM, PNAS 2025):
1234///
1235///   τ_b dV_b/dt = -V_b + x_b                                  (basal)
1236///   τ_a dV_a/dt = -V_a + x_a                                  (apical)
1237///   τ   dU/dt   = -U + σ(V_a)·[g_B/g_L·(V_b - U) + W_s·I]   (soma)
1238///   S[t] = Θ(U[t] - V_th)                                     (spike)
1239///   U[t] ← U[t]·(1 - S[t])                                    (soft reset)
1240///
1241/// Default parameters from Table II: τ = τ_a = τ_b = 2.0, g_B/g_L = 1.0,
1242/// β = 1.0 (sigmoid steepness), V_th = 1.0.
1243///
1244/// Reference: Brain-Cog-Lab, arXiv:2503.00713, PNAS 2025.
1245#[derive(Clone, Debug)]
1246pub struct MulticompartmentMCNNeuron {
1247    /// Somatic membrane potential.
1248    pub u: f64,
1249    /// Basal dendrite potential.
1250    pub v_basal: f64,
1251    /// Apical dendrite potential.
1252    pub v_apical: f64,
1253    /// Soma time constant.
1254    pub tau: f64,
1255    /// Basal dendrite time constant.
1256    pub tau_b: f64,
1257    /// Apical dendrite time constant.
1258    pub tau_a: f64,
1259    /// Basal-to-soma conductance ratio (g_B/g_L).
1260    pub g_ratio: f64,
1261    /// Sigmoid steepness for apical gating.
1262    pub beta: f64,
1263    /// Spike threshold.
1264    pub v_th: f64,
1265    /// Time step.
1266    pub dt: f64,
1267}
1268
1269impl MulticompartmentMCNNeuron {
1270    pub fn new() -> Self {
1271        Self {
1272            u: 0.0,
1273            v_basal: 0.0,
1274            v_apical: 0.0,
1275            tau: 2.0,
1276            tau_b: 2.0,
1277            tau_a: 2.0,
1278            g_ratio: 1.0,
1279            beta: 1.0,
1280            v_th: 1.0,
1281            dt: 1.0,
1282        }
1283    }
1284
1285    /// Sigmoid gating function σ(x) = 1/(1 + exp(-βx)).
1286    fn sigma(&self, x: f64) -> f64 {
1287        1.0 / (1.0 + (-self.beta * x).exp())
1288    }
1289
1290    fn valid(&self) -> bool {
1291        self.tau.is_finite()
1292            && self.tau > 0.0
1293            && self.tau_b.is_finite()
1294            && self.tau_b > 0.0
1295            && self.tau_a.is_finite()
1296            && self.tau_a > 0.0
1297            && self.g_ratio.is_finite()
1298            && self.g_ratio >= 0.0
1299            && self.beta.is_finite()
1300            && self.beta > 0.0
1301            && self.v_th.is_finite()
1302            && self.v_th > 0.0
1303            && self.dt.is_finite()
1304            && self.dt > 0.0
1305            && self.u.is_finite()
1306            && self.v_basal.is_finite()
1307            && self.v_apical.is_finite()
1308    }
1309
1310    fn derivatives(
1311        &self,
1312        u: f64,
1313        v_basal: f64,
1314        v_apical: f64,
1315        x_basal: f64,
1316        x_apical: f64,
1317        i_soma: f64,
1318    ) -> [f64; 3] {
1319        let gate = self.sigma(v_apical);
1320        let du = (-u + gate * (self.g_ratio * (v_basal - u) + i_soma)) / self.tau;
1321        let dv_basal = (-v_basal + x_basal) / self.tau_b;
1322        let dv_apical = (-v_apical + x_apical) / self.tau_a;
1323        [du, dv_basal, dv_apical]
1324    }
1325
1326    fn rk4_substep(&self, state: [f64; 3], x_basal: f64, x_apical: f64, i_soma: f64) -> [f64; 3] {
1327        let dt = self.dt;
1328        let k1 = self.derivatives(state[0], state[1], state[2], x_basal, x_apical, i_soma);
1329        let k2 = self.derivatives(
1330            state[0] + 0.5 * dt * k1[0],
1331            state[1] + 0.5 * dt * k1[1],
1332            state[2] + 0.5 * dt * k1[2],
1333            x_basal,
1334            x_apical,
1335            i_soma,
1336        );
1337        let k3 = self.derivatives(
1338            state[0] + 0.5 * dt * k2[0],
1339            state[1] + 0.5 * dt * k2[1],
1340            state[2] + 0.5 * dt * k2[2],
1341            x_basal,
1342            x_apical,
1343            i_soma,
1344        );
1345        let k4 = self.derivatives(
1346            state[0] + dt * k3[0],
1347            state[1] + dt * k3[1],
1348            state[2] + dt * k3[2],
1349            x_basal,
1350            x_apical,
1351            i_soma,
1352        );
1353        [
1354            state[0] + dt * (k1[0] + 2.0 * k2[0] + 2.0 * k3[0] + k4[0]) / 6.0,
1355            state[1] + dt * (k1[1] + 2.0 * k2[1] + 2.0 * k3[1] + k4[1]) / 6.0,
1356            state[2] + dt * (k1[2] + 2.0 * k2[2] + 2.0 * k3[2] + k4[2]) / 6.0,
1357        ]
1358    }
1359
1360    fn threshold_reached(&self, candidate_u: f64) -> bool {
1361        let margin = 16.0 * f64::EPSILON * self.v_th.abs().max(1.0);
1362        candidate_u >= self.v_th || (candidate_u - self.v_th).abs() <= margin
1363    }
1364
1365    /// Step with basal input (x_b), apical input (x_a), and direct somatic input.
1366    pub fn step_compartments(&mut self, x_basal: f64, x_apical: f64, i_soma: f64) -> i32 {
1367        if !x_basal.is_finite() || !x_apical.is_finite() || !i_soma.is_finite() || !self.valid() {
1368            return 0;
1369        }
1370        let next = self.rk4_substep(
1371            [self.u, self.v_basal, self.v_apical],
1372            x_basal,
1373            x_apical,
1374            i_soma,
1375        );
1376        if !next.iter().all(|value| value.is_finite()) {
1377            return 0;
1378        }
1379        let spike = self.threshold_reached(next[0]);
1380        self.u = if spike { 0.0 } else { next[0] };
1381        self.v_basal = next[1];
1382        self.v_apical = next[2];
1383        i32::from(spike)
1384    }
1385
1386    /// Simple step: input goes to basal dendrite only.
1387    pub fn step(&mut self, current: f64) -> i32 {
1388        self.step_compartments(current, 0.0, 0.0)
1389    }
1390
1391    pub fn reset(&mut self) {
1392        self.u = 0.0;
1393        self.v_basal = 0.0;
1394        self.v_apical = 0.0;
1395    }
1396}
1397
1398impl Default for MulticompartmentMCNNeuron {
1399    fn default() -> Self {
1400        Self::new()
1401    }
1402}
1403
1404/// Astrocyte-LIF hybrid unit with calcium wave feedback.
1405///
1406/// Models the tripartite synapse: a glial astrocyte monitors extracellular
1407/// glutamate from a paired LIF neuron and provides slow homeostatic feedback
1408/// via calcium-dependent gliotransmitter release.
1409///
1410///   dCa/dt = -Ca/τ_ca + δ · S_pre(t)        (calcium rise on presynaptic spike)
1411///   I_glio = g_glio · H(Ca - Ca_thresh)      (gliotransmitter release)
1412///   dV/dt = -(V - E_L)/τ_m + I_ext + I_glio  (LIF with glial feedback)
1413///
1414/// Reference: Perea, Navarrete & Araque, "Tripartite synapses" (2009).
1415#[derive(Clone, Debug)]
1416pub struct AstrocyteLIFNeuron {
1417    pub v: f64,
1418    pub ca: f64,
1419    pub tau_m: f64,
1420    pub tau_ca: f64,
1421    pub e_l: f64,
1422    pub theta: f64,
1423    pub v_reset: f64,
1424    pub ca_delta: f64,
1425    pub ca_thresh: f64,
1426    pub g_glio: f64,
1427    pub dt: f64,
1428}
1429
1430impl AstrocyteLIFNeuron {
1431    pub fn new() -> Self {
1432        Self {
1433            v: -65.0,
1434            ca: 0.0,
1435            tau_m: 20.0,
1436            tau_ca: 500.0,
1437            e_l: -65.0,
1438            theta: -50.0,
1439            v_reset: -65.0,
1440            ca_delta: 0.1,
1441            ca_thresh: 0.5,
1442            g_glio: 2.0,
1443            dt: 0.1,
1444        }
1445    }
1446
1447    /// Step with external current and presynaptic spike indicator.
1448    pub fn step_with_pre(&mut self, i_ext: f64, pre_spike: bool) -> i32 {
1449        // Astrocyte calcium dynamics.
1450        let dca = -self.ca / self.tau_ca
1451            + if pre_spike {
1452                self.ca_delta / self.dt
1453            } else {
1454                0.0
1455            };
1456        self.ca += dca * self.dt;
1457        self.ca = self.ca.max(0.0);
1458
1459        // Gliotransmitter release (Heaviside on calcium).
1460        let i_glio = if self.ca > self.ca_thresh {
1461            self.g_glio
1462        } else {
1463            0.0
1464        };
1465
1466        // LIF membrane dynamics with glial feedback.
1467        let dv = (-(self.v - self.e_l) + i_ext + i_glio) / self.tau_m;
1468        self.v += dv * self.dt;
1469
1470        if self.v >= self.theta {
1471            self.v = self.v_reset;
1472            1
1473        } else {
1474            0
1475        }
1476    }
1477
1478    /// Simple step (no presynaptic spike).
1479    pub fn step(&mut self, current: f64) -> i32 {
1480        self.step_with_pre(current, false)
1481    }
1482
1483    pub fn reset(&mut self) {
1484        self.v = self.e_l;
1485        self.ca = 0.0;
1486    }
1487}
1488
1489impl Default for AstrocyteLIFNeuron {
1490    fn default() -> Self {
1491        Self::new()
1492    }
1493}
1494
1495// ---- Tests for new multi-compartment / glial models ----
1496
1497#[cfg(test)]
1498mod gap_mc_tests {
1499    use super::*;
1500
1501    #[test]
1502    fn nmda_coincidence_detection() {
1503        let mut n = DendriticNMDANeuron::new();
1504        // Only soma input — dendrite contributes little.
1505        let mut spikes_soma_only = 0;
1506        for _ in 0..2000 {
1507            spikes_soma_only += n.step(8.0, 0.0);
1508        }
1509        n.reset();
1510        // Soma + glutamate — NMDA amplifies.
1511        let mut spikes_both = 0;
1512        for _ in 0..2000 {
1513            spikes_both += n.step(8.0, 1.0);
1514        }
1515        // Coincidence: both inputs together should fire more.
1516        assert!(
1517            spikes_both >= spikes_soma_only,
1518            "NMDA coincidence: both={spikes_both} must >= soma_only={spikes_soma_only}"
1519        );
1520    }
1521
1522    #[test]
1523    fn nmda_mg_block_voltage_dependent() {
1524        let n = DendriticNMDANeuron::new();
1525        let b_hyper = n.mg_block(-80.0);
1526        let b_depol = n.mg_block(-20.0);
1527        assert!(
1528            b_depol > b_hyper,
1529            "Mg block must relieve at depolarised potentials: B(-20)={b_depol:.3} > B(-80)={b_hyper:.3}"
1530        );
1531    }
1532
1533    #[test]
1534    fn nmda_zero_glutamate_no_nmda_current() {
1535        let mut n = DendriticNMDANeuron::new();
1536        let spikes: i32 = (0..500).map(|_| n.step(0.0, 0.0)).sum();
1537        assert_eq!(spikes, 0, "No input → no spikes");
1538    }
1539
1540    #[test]
1541    fn nmda_rk4_cross_backend_anchor() {
1542        let mut n = DendriticNMDANeuron::new();
1543        let spikes: i32 = (0..20_000).map(|_| n.step(50.0, 0.5)).sum();
1544        assert_eq!(spikes, 253);
1545        assert!(n.v_soma.is_finite());
1546        assert!(n.v_dend.is_finite());
1547    }
1548
1549    #[test]
1550    fn nmda_invalid_input_preserves_state() {
1551        let mut n = DendriticNMDANeuron::new();
1552        for _ in 0..10 {
1553            let _ = n.step(50.0, 0.5);
1554        }
1555        let old = (n.v_soma, n.v_dend);
1556        assert_eq!(n.step(f64::INFINITY, 0.5), 0);
1557        assert_eq!((n.v_soma, n.v_dend), old);
1558        assert_eq!(n.step(50.0, -1.0), 0);
1559        assert_eq!((n.v_soma, n.v_dend), old);
1560    }
1561
1562    #[test]
1563    fn nmda_invalid_configuration_preserves_state() {
1564        let mut n = DendriticNMDANeuron::new();
1565        for _ in 0..10 {
1566            let _ = n.step(50.0, 0.5);
1567        }
1568        let old = (n.v_soma, n.v_dend);
1569        n.tau_dend = 0.0;
1570        assert_eq!(n.step(50.0, 0.5), 0);
1571        assert_eq!((n.v_soma, n.v_dend), old);
1572    }
1573
1574    #[test]
1575    fn mcn_apical_gating() {
1576        // Without apical input, gate = σ(0) = 0.5, moderate drive.
1577        let mut n_no_apical = MulticompartmentMCNNeuron::new();
1578        let mut spikes_no = 0;
1579        for _ in 0..1000 {
1580            spikes_no += n_no_apical.step_compartments(2.5, 0.0, 0.0);
1581        }
1582        // With strong apical input, gate ≈ 1.0, full basal→soma coupling.
1583        let mut n_apical = MulticompartmentMCNNeuron::new();
1584        let mut spikes_yes = 0;
1585        for _ in 0..1000 {
1586            spikes_yes += n_apical.step_compartments(2.5, 5.0, 0.0);
1587        }
1588        assert!(
1589            spikes_yes >= spikes_no && spikes_yes > 0,
1590            "Apical gating should boost firing: apical={spikes_yes} >= none={spikes_no}"
1591        );
1592    }
1593
1594    #[test]
1595    fn mcn_rk4_cross_backend_anchor() {
1596        let mut n = MulticompartmentMCNNeuron::new();
1597        let mut spikes = 0;
1598        for _ in 0..200_000 {
1599            spikes += n.step(3.2);
1600        }
1601        assert_eq!(spikes, 49_999);
1602    }
1603
1604    #[test]
1605    fn mcn_threshold_boundary_accepts_one_ulp_roundoff() {
1606        let n = MulticompartmentMCNNeuron::new();
1607        let one_ulp_below = f64::from_bits(n.v_th.to_bits() - 1);
1608        assert!(n.threshold_reached(one_ulp_below));
1609        assert!(!n.threshold_reached(n.v_th - 1.0e-9));
1610    }
1611
1612    #[test]
1613    fn mcn_invalid_input_preserves_state() {
1614        let mut n = MulticompartmentMCNNeuron::new();
1615        for _ in 0..5 {
1616            let _ = n.step(3.2);
1617        }
1618        let old = (n.u, n.v_basal, n.v_apical);
1619        assert_eq!(n.step(f64::INFINITY), 0);
1620        assert_eq!((n.u, n.v_basal, n.v_apical), old);
1621    }
1622
1623    #[test]
1624    fn mcn_basal_dendrite_memory() {
1625        // τ_b = 2.0, dt = 1.0: V_b decays by factor (1 - dt/τ) = 0.5 per step.
1626        let mut n = MulticompartmentMCNNeuron::new();
1627        n.step_compartments(5.0, 0.0, 0.0);
1628        let v_after = n.v_basal;
1629        n.step_compartments(0.0, 0.0, 0.0);
1630        let v_decay = n.v_basal;
1631        assert!(
1632            v_decay.abs() > 0.1 * v_after.abs(),
1633            "Basal dendrite retains memory: {v_decay:.3} vs {v_after:.3}"
1634        );
1635    }
1636
1637    #[test]
1638    fn mcn_reset_clears_all() {
1639        let mut n = MulticompartmentMCNNeuron::new();
1640        for _ in 0..50 {
1641            n.step(2.0);
1642        }
1643        n.reset();
1644        assert_eq!(n.u, 0.0);
1645        assert_eq!(n.v_basal, 0.0);
1646        assert_eq!(n.v_apical, 0.0);
1647    }
1648
1649    #[test]
1650    fn astrocyte_calcium_rises_on_pre_spikes() {
1651        let mut n = AstrocyteLIFNeuron::new();
1652        let ca_before = n.ca;
1653        for _ in 0..100 {
1654            n.step_with_pre(0.0, true);
1655        }
1656        assert!(
1657            n.ca > ca_before,
1658            "Calcium must rise with presynaptic spikes"
1659        );
1660    }
1661
1662    #[test]
1663    fn astrocyte_gliotransmitter_boosts_firing() {
1664        let mut n_no_glio = AstrocyteLIFNeuron::new();
1665        let mut n_glio = AstrocyteLIFNeuron::new();
1666
1667        let mut spikes_no = 0;
1668        let mut spikes_yes = 0;
1669        for _ in 0..5000 {
1670            spikes_no += n_no_glio.step_with_pre(10.0, false);
1671            spikes_yes += n_glio.step_with_pre(10.0, true); // pre spikes → Ca → glio
1672        }
1673        assert!(
1674            spikes_yes >= spikes_no,
1675            "Gliotransmitter should boost firing: with={spikes_yes} >= without={spikes_no}"
1676        );
1677    }
1678
1679    #[test]
1680    fn astrocyte_calcium_decays() {
1681        let mut n = AstrocyteLIFNeuron::new();
1682        // Build up calcium.
1683        for _ in 0..200 {
1684            n.step_with_pre(0.0, true);
1685        }
1686        let ca_peak = n.ca;
1687        // Let it decay.
1688        for _ in 0..5000 {
1689            n.step_with_pre(0.0, false);
1690        }
1691        assert!(
1692            n.ca < ca_peak * 0.5,
1693            "Calcium must decay: current={:.4} < peak={:.4}*0.5",
1694            n.ca,
1695            ca_peak
1696        );
1697    }
1698}