1#[inline]
25fn logistic(z: f64) -> f64 {
26 if z >= 0.0 {
27 1.0 / (1.0 + (-z).exp())
28 } else {
29 let exp_z = z.exp();
30 exp_z / (1.0 + exp_z)
31 }
32}
33
34#[inline]
35fn sigmoid(a: f64, theta: f64, x: f64) -> f64 {
36 logistic(a * (x - theta)) - logistic(-a * theta)
40}
41
42#[inline]
43fn derivatives(
44 e: f64,
45 i: f64,
46 ext: f64,
47 params: (f64, f64, f64, f64, f64, f64, f64, f64),
48) -> (f64, f64) {
49 let (w_ee, w_ei, w_ie, w_ii, tau_e, tau_i, a, theta) = params;
50 let s_e = sigmoid(a, theta, w_ee * e - w_ei * i + ext);
51 let s_i = sigmoid(a, theta, w_ie * e - w_ii * i);
52 ((-e + s_e) / tau_e, (-i + s_i) / tau_i)
53}
54
55#[expect(
59 clippy::too_many_arguments,
60 reason = "Python extension parity surface passes canonical scalar parameters"
61)]
62pub fn simulate(
63 mut e: f64,
64 mut i: f64,
65 w_ee: f64,
66 w_ei: f64,
67 w_ie: f64,
68 w_ii: f64,
69 tau_e: f64,
70 tau_i: f64,
71 a: f64,
72 theta: f64,
73 dt: f64,
74 ext_input: &[f64],
75 e_out: &mut [f64],
76 i_out: &mut [f64],
77) -> (f64, f64) {
78 let n = ext_input.len();
79 assert_eq!(e_out.len(), n, "e_out length mismatch");
80 assert_eq!(i_out.len(), n, "i_out length mismatch");
81 let params = (w_ee, w_ei, w_ie, w_ii, tau_e, tau_i, a, theta);
82
83 for t in 0..n {
84 let ext = ext_input[t];
85 let (k1_e, k1_i) = derivatives(e, i, ext, params);
86 let (k2_e, k2_i) = derivatives(e + 0.5 * dt * k1_e, i + 0.5 * dt * k1_i, ext, params);
87 let (k3_e, k3_i) = derivatives(e + 0.5 * dt * k2_e, i + 0.5 * dt * k2_i, ext, params);
88 let (k4_e, k4_i) = derivatives(e + dt * k3_e, i + dt * k3_i, ext, params);
89 e += dt * (k1_e + 2.0 * k2_e + 2.0 * k3_e + k4_e) / 6.0;
90 i += dt * (k1_i + 2.0 * k2_i + 2.0 * k3_i + k4_i) / 6.0;
91 e_out[t] = e;
92 i_out[t] = i;
93 }
94 (e, i)
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 fn defaults() -> (f64, f64, f64, f64, f64, f64, f64, f64, f64) {
102 (10.0, 6.0, 10.0, 1.0, 1.0, 2.0, 1.2, 4.0, 0.1)
104 }
105
106 #[test]
107 fn sigmoid_monotone_increasing() {
108 let (_, _, _, _, _, _, a, theta, _) = defaults();
109 let lo = sigmoid(a, theta, 0.0);
110 let mid = sigmoid(a, theta, 4.0);
111 let hi = sigmoid(a, theta, 10.0);
112 assert!(lo < mid && mid < hi);
113 }
114
115 #[test]
116 fn sigmoid_at_zero_is_zero() {
117 let (_, _, _, _, _, _, a, theta, _) = defaults();
119 assert!(sigmoid(a, theta, 0.0).abs() < 1e-12);
120 }
121
122 #[test]
123 fn sigmoid_at_theta_equals_half_minus_baseline() {
124 let (_, _, _, _, _, _, a, theta, _) = defaults();
125 let baseline = 1.0 / (1.0 + (a * theta).exp());
126 let r = sigmoid(a, theta, theta);
127 assert!((r - (0.5 - baseline)).abs() < 1e-12);
128 }
129
130 #[test]
131 fn sigmoid_asymptotes_respect_baseline() {
132 let (_, _, _, _, _, _, a, theta, _) = defaults();
134 let baseline = 1.0 / (1.0 + (a * theta).exp());
135 assert!((sigmoid(a, theta, 1e6) - (1.0 - baseline)).abs() < 1e-50);
136 assert!((sigmoid(a, theta, -1e6) - (-baseline)).abs() < 1e-50);
137 }
138
139 #[test]
140 fn quiescent_converges() {
141 let (w_ee, w_ei, w_ie, w_ii, tau_e, tau_i, a, theta, dt) = defaults();
142 let n = 20_000;
143 let ext = vec![0.0_f64; n];
144 let mut e_out = vec![0.0_f64; n];
145 let mut i_out = vec![0.0_f64; n];
146 let (e_f, i_f) = simulate(
147 0.1, 0.05, w_ee, w_ei, w_ie, w_ii, tau_e, tau_i, a, theta, dt, &ext, &mut e_out,
148 &mut i_out,
149 );
150 assert!(e_f.is_finite() && i_f.is_finite());
151 assert!(e_f < 0.2, "quiescent E must stay low, got {e_f}");
152 assert!(i_f < 0.2, "quiescent I must stay low, got {i_f}");
153 }
154
155 #[test]
156 fn high_drive_elevates_activity() {
157 let (w_ee, w_ei, w_ie, w_ii, tau_e, tau_i, a, theta, dt) = defaults();
158 let n = 10_000;
159 let ext = vec![10.0_f64; n];
160 let mut e_out = vec![0.0_f64; n];
161 let mut i_out = vec![0.0_f64; n];
162 let (e_f, _) = simulate(
163 0.1, 0.05, w_ee, w_ei, w_ie, w_ii, tau_e, tau_i, a, theta, dt, &ext, &mut e_out,
164 &mut i_out,
165 );
166 assert!(e_f > 0.3, "high external drive must elevate E, got {e_f}");
167 }
168
169 #[test]
170 fn rk4_step_matches_reference_and_separates_from_euler() {
171 let mut e_out = vec![0.0_f64; 1];
172 let mut i_out = vec![0.0_f64; 1];
173 let ext = vec![3.0_f64; 1];
174 simulate(
175 0.24, 0.11, 10.0, 6.0, 10.0, 1.0, 1.0, 2.0, 1.2, 4.0, 0.35, &ext, &mut e_out,
176 &mut i_out,
177 );
178 let euler_e = 0.40111014473980233_f64;
179 let euler_i = 0.10924537850891547_f64;
180 assert!((e_out[0] - 0.42143718680097664_f64).abs() < 1e-15);
181 assert!((i_out[0] - 0.13798020053932203_f64).abs() < 1e-15);
182 assert!((e_out[0] - euler_e).abs() > 1e-2);
183 assert!((i_out[0] - euler_i).abs() > 1e-2);
184 }
185
186 #[test]
187 fn output_trace_shape_matches_input() {
188 let n = 64;
189 let ext = vec![1.0_f64; n];
190 let mut e_out = vec![f64::NAN; n];
191 let mut i_out = vec![f64::NAN; n];
192 simulate(
193 0.1, 0.05, 10.0, 6.0, 10.0, 1.0, 1.0, 2.0, 1.2, 4.0, 0.1, &ext, &mut e_out, &mut i_out,
194 );
195 assert!(e_out.iter().all(|v| v.is_finite()));
196 assert!(i_out.iter().all(|v| v.is_finite()));
197 }
198
199 #[test]
200 #[should_panic(expected = "e_out length mismatch")]
201 fn mismatched_e_out_panics() {
202 let n = 10;
203 let ext = vec![0.0_f64; n];
204 let mut e_out = vec![0.0_f64; n + 1];
205 let mut i_out = vec![0.0_f64; n];
206 simulate(
207 0.1, 0.05, 10.0, 6.0, 10.0, 1.0, 1.0, 2.0, 1.2, 4.0, 0.1, &ext, &mut e_out, &mut i_out,
208 );
209 }
210}