Skip to main content

sc_neurocore_engine/
wong_wang.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Commercial license available
3// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
4// © Code 2020–2026 Miroslav Šotek. All rights reserved.
5// ORCID: 0009-0009-3560-0851
6// Contact: www.anulum.li | protoscience@anulum.li
7// SC-NeuroCore — Rust N-step simulator for the Wong-Wang 2006 decision unit
8
9//! Batch parity with `WongWangUnit.step` in
10//! `src/sc_neurocore/neurons/models/wong_wang.py` (Wong & Wang 2006,
11//! J. Neurosci. 26:1314–1328).
12//!
13//! Per step:
14//!   1. `i_k = j_n * s_k - j_cross * s_(3-k) + i_0 + stim_k + sigma * xi`
15//!   2. `r_k = phi(i_k)` where
16//!      `phi(i) = (a*i - b) / (1 - exp(-d*(a*i - b)))`
17//!      with singularity guard `|a*i - b| < 1e-6 -> 1/d`.
18//!   3. integrate the coupled `s1`, `s2` ODE with fixed-step RK4 while
19//!      holding the step's pre-drawn noise sample constant.
20//!   4. clamp `s_k` into `[0, 1]` after the candidate state is finite.
21//!
22//! The Python primary draws `np.random.randn()` twice per step; the
23//! Rust simulator takes `xi` pre-drawn from the Python RNG so trajectories
24//! are bit-exact for matching seeds. This mirrors the ping.rs +
25//! PINGCircuit pattern: Python owns the RNG, Rust owns the inner loop.
26
27const A: f64 = 270.0;
28const B: f64 = 108.0;
29const D: f64 = 0.154;
30
31#[inline]
32fn phi(i_syn: f64) -> f64 {
33    let x = A * i_syn - B;
34    if x.abs() < 1e-6 {
35        1.0 / D
36    } else {
37        x / (1.0 - (-D * x).exp())
38    }
39}
40
41#[inline]
42fn derivatives(
43    s1: f64,
44    s2: f64,
45    stim1: f64,
46    stim2: f64,
47    xi1: f64,
48    xi2: f64,
49    tau_s: f64,
50    gamma: f64,
51    j_n: f64,
52    j_cross: f64,
53    i_0: f64,
54    sigma: f64,
55) -> (f64, f64, f64, f64) {
56    let i1 = j_n * s1 - j_cross * s2 + i_0 + stim1 + sigma * xi1;
57    let i2 = j_n * s2 - j_cross * s1 + i_0 + stim2 + sigma * xi2;
58    let r1 = phi(i1);
59    let r2 = phi(i2);
60    (
61        -s1 / tau_s + (1.0 - s1) * gamma * r1,
62        -s2 / tau_s + (1.0 - s2) * gamma * r2,
63        r1,
64        r2,
65    )
66}
67
68/// Simulate `n_steps` Wong-Wang iterations; write per-step state +
69/// firing-rate traces. `xi` must be length `2 * n_steps` (two noise
70/// samples per step, consumed in `i1, i2` order).
71pub fn simulate(
72    mut s1: f64,
73    mut s2: f64,
74    tau_s: f64,
75    gamma: f64,
76    j_n: f64,
77    j_cross: f64,
78    i_0: f64,
79    sigma: f64,
80    dt: f64,
81    stim1: &[f64],
82    stim2: &[f64],
83    xi: &[f64],
84    s1_out: &mut [f64],
85    s2_out: &mut [f64],
86    r1_out: &mut [f64],
87    r2_out: &mut [f64],
88) -> (f64, f64) {
89    let n = stim1.len();
90    assert_eq!(stim2.len(), n, "stim2 length mismatch");
91    assert_eq!(xi.len(), 2 * n, "xi length must be 2 * n_steps");
92    assert_eq!(s1_out.len(), n, "s1_out length mismatch");
93    assert_eq!(s2_out.len(), n, "s2_out length mismatch");
94    assert_eq!(r1_out.len(), n, "r1_out length mismatch");
95    assert_eq!(r2_out.len(), n, "r2_out length mismatch");
96
97    for t in 0..n {
98        let xi1 = xi[2 * t];
99        let xi2 = xi[2 * t + 1];
100        let (k1_s1, k1_s2, r1, r2) = derivatives(
101            s1, s2, stim1[t], stim2[t], xi1, xi2, tau_s, gamma, j_n, j_cross, i_0, sigma,
102        );
103        let (k2_s1, k2_s2, _, _) = derivatives(
104            s1 + 0.5 * dt * k1_s1,
105            s2 + 0.5 * dt * k1_s2,
106            stim1[t],
107            stim2[t],
108            xi1,
109            xi2,
110            tau_s,
111            gamma,
112            j_n,
113            j_cross,
114            i_0,
115            sigma,
116        );
117        let (k3_s1, k3_s2, _, _) = derivatives(
118            s1 + 0.5 * dt * k2_s1,
119            s2 + 0.5 * dt * k2_s2,
120            stim1[t],
121            stim2[t],
122            xi1,
123            xi2,
124            tau_s,
125            gamma,
126            j_n,
127            j_cross,
128            i_0,
129            sigma,
130        );
131        let (k4_s1, k4_s2, _, _) = derivatives(
132            s1 + dt * k3_s1,
133            s2 + dt * k3_s2,
134            stim1[t],
135            stim2[t],
136            xi1,
137            xi2,
138            tau_s,
139            gamma,
140            j_n,
141            j_cross,
142            i_0,
143            sigma,
144        );
145        s1 = (s1 + dt * (k1_s1 + 2.0 * k2_s1 + 2.0 * k3_s1 + k4_s1) / 6.0).clamp(0.0, 1.0);
146        s2 = (s2 + dt * (k1_s2 + 2.0 * k2_s2 + 2.0 * k3_s2 + k4_s2) / 6.0).clamp(0.0, 1.0);
147        s1_out[t] = s1;
148        s2_out[t] = s2;
149        r1_out[t] = r1;
150        r2_out[t] = r2;
151    }
152    (s1, s2)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    fn params() -> (f64, f64, f64, f64, f64, f64, f64) {
160        // tau_s, gamma, j_n, j_cross, i_0, sigma, dt
161        (0.1, 0.641, 0.2609, 0.0497, 0.3255, 0.0, 0.001)
162    }
163
164    #[test]
165    fn phi_singularity_guard_returns_finite() {
166        // a*i - b == 0 at i = b/a = 0.4
167        let r = phi(B / A);
168        assert!(r.is_finite());
169        assert!((r - 1.0 / D).abs() < 1e-6);
170    }
171
172    #[test]
173    fn phi_monotone_increasing() {
174        let lo = phi(0.5);
175        let hi = phi(1.0);
176        assert!(hi > lo);
177    }
178
179    #[test]
180    fn rk4_state_differs_from_forward_euler() {
181        let n = 1;
182        let stim1 = vec![0.17_f64; n];
183        let stim2 = vec![0.03_f64; n];
184        let xi = vec![0.0_f64; 2 * n];
185        let mut s1o = vec![0.0_f64; n];
186        let mut s2o = vec![0.0_f64; n];
187        let mut r1o = vec![0.0_f64; n];
188        let mut r2o = vec![0.0_f64; n];
189        let (s1_f, s2_f) = simulate(
190            0.24, 0.11, 0.1, 0.641, 0.2609, 0.0497, 0.3255, 0.0, 0.02, &stim1, &stim2, &xi,
191            &mut s1o, &mut s2o, &mut r1o, &mut r2o,
192        );
193        let r1 = phi(0.2609 * 0.24 - 0.0497 * 0.11 + 0.3255 + 0.17);
194        let r2 = phi(0.2609 * 0.11 - 0.0497 * 0.24 + 0.3255 + 0.03);
195        let euler_s1 = (0.24 + (-0.24 / 0.1 + (1.0 - 0.24) * 0.641 * r1) * 0.02).clamp(0.0, 1.0);
196        let euler_s2 = (0.11 + (-0.11 / 0.1 + (1.0 - 0.11) * 0.641 * r2) * 0.02).clamp(0.0, 1.0);
197        assert!((s1_f - euler_s1).abs() > 1e-5);
198        assert!((s2_f - euler_s2).abs() > 1e-5);
199    }
200
201    #[test]
202    fn zero_noise_zero_stim_converges_to_fixed_point() {
203        // With sigma=0, identical initial conditions, balanced dynamics:
204        // s1 and s2 should stay very close and bounded in [0, 1].
205        let (tau_s, gamma, j_n, j_cross, i_0, _, dt) = params();
206        let n = 10_000;
207        let stim = vec![0.0_f64; n];
208        let xi = vec![0.0_f64; 2 * n];
209        let mut s1o = vec![0.0_f64; n];
210        let mut s2o = vec![0.0_f64; n];
211        let mut r1o = vec![0.0_f64; n];
212        let mut r2o = vec![0.0_f64; n];
213        let (s1_f, s2_f) = simulate(
214            0.1, 0.1, tau_s, gamma, j_n, j_cross, i_0, 0.0, dt, &stim, &stim, &xi, &mut s1o,
215            &mut s2o, &mut r1o, &mut r2o,
216        );
217        assert!(s1_f.is_finite() && s2_f.is_finite());
218        assert!((0.0..=1.0).contains(&s1_f));
219        assert!((0.0..=1.0).contains(&s2_f));
220        assert!(
221            (s1_f - s2_f).abs() < 1e-9,
222            "symmetric init must stay symmetric under zero noise"
223        );
224    }
225
226    #[test]
227    fn biased_stimulus_drives_winner() {
228        // Strong stim1, no stim2, no noise → s1 > s2 within a few time constants.
229        let (tau_s, gamma, j_n, j_cross, i_0, _, dt) = params();
230        let n = 50_000;
231        let stim1 = vec![0.2_f64; n];
232        let stim2 = vec![0.0_f64; n];
233        let xi = vec![0.0_f64; 2 * n];
234        let mut s1o = vec![0.0_f64; n];
235        let mut s2o = vec![0.0_f64; n];
236        let mut r1o = vec![0.0_f64; n];
237        let mut r2o = vec![0.0_f64; n];
238        let (s1_f, s2_f) = simulate(
239            0.1, 0.1, tau_s, gamma, j_n, j_cross, i_0, 0.0, dt, &stim1, &stim2, &xi, &mut s1o,
240            &mut s2o, &mut r1o, &mut r2o,
241        );
242        assert!(s1_f > 0.5, "winner s1 should reach attractor; got {s1_f}");
243        assert!(s2_f < 0.2, "loser s2 should be suppressed; got {s2_f}");
244    }
245
246    #[test]
247    fn output_trace_shape_matches_input() {
248        let n = 128;
249        let stim = vec![0.1_f64; n];
250        let xi = vec![0.0_f64; 2 * n];
251        let mut s1o = vec![0.0_f64; n];
252        let mut s2o = vec![0.0_f64; n];
253        let mut r1o = vec![0.0_f64; n];
254        let mut r2o = vec![0.0_f64; n];
255        simulate(
256            0.1, 0.1, 0.1, 0.641, 0.2609, 0.0497, 0.3255, 0.0, 0.001, &stim, &stim, &xi, &mut s1o,
257            &mut s2o, &mut r1o, &mut r2o,
258        );
259        // Every step wrote; nothing still at sentinel zero (for r_out —
260        // s_out might genuinely be near 0, so check r_out which is phi(I) ≥ some positive
261        // floor when I > 0 on this parameter set).
262        assert!(r1o.iter().all(|&r| r > 0.0));
263        assert!(r2o.iter().all(|&r| r > 0.0));
264    }
265
266    #[test]
267    #[should_panic(expected = "xi length must be 2 * n_steps")]
268    fn mismatched_xi_length_panics() {
269        let n = 10;
270        let stim = vec![0.0_f64; n];
271        let xi = vec![0.0_f64; n]; // wrong — should be 2*n
272        let mut s1o = vec![0.0_f64; n];
273        let mut s2o = vec![0.0_f64; n];
274        let mut r1o = vec![0.0_f64; n];
275        let mut r2o = vec![0.0_f64; n];
276        simulate(
277            0.1, 0.1, 0.1, 0.641, 0.2609, 0.0497, 0.3255, 0.0, 0.001, &stim, &stim, &xi, &mut s1o,
278            &mut s2o, &mut r1o, &mut r2o,
279        );
280    }
281}