fusion_core/
jit.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6//! Reduced runtime kernel specialization and hot-swap lane.
7//!
8//! This module provides a deterministic stand-in for regime-triggered runtime
9//! compilation. It keeps compilation/cache semantics explicit without external
10//! LLVM/JIT dependencies so CI remains bounded.
11
12use fusion_types::error::{FusionError, FusionResult};
13use ndarray::Array1;
14use std::collections::HashMap;
15
16const MIN_DT_S: f64 = 1e-9;
17
18/// Plasma operation regime used for runtime kernel specialization.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum PlasmaRegime {
21    LMode,
22    HMode,
23    RampUp,
24    RampDown,
25}
26
27/// Minimal observation bundle used for regime routing.
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub struct RegimeObservation {
30    pub beta_n: f64,
31    pub q95: f64,
32    pub density_line_avg_1e20_m3: f64,
33    pub current_ramp_ma_s: f64,
34}
35
36impl Default for RegimeObservation {
37    fn default() -> Self {
38        Self {
39            beta_n: 1.5,
40            q95: 4.5,
41            density_line_avg_1e20_m3: 0.8,
42            current_ramp_ma_s: 0.0,
43        }
44    }
45}
46
47/// Routing heuristic for reduced runtime specialization.
48pub fn detect_regime(observation: &RegimeObservation) -> PlasmaRegime {
49    if observation.current_ramp_ma_s > 0.2 {
50        PlasmaRegime::RampUp
51    } else if observation.current_ramp_ma_s < -0.2 {
52        PlasmaRegime::RampDown
53    } else if observation.beta_n >= 2.2
54        || (observation.beta_n >= 2.0 && observation.q95 <= 3.7)
55        || observation.density_line_avg_1e20_m3 >= 1.0
56    {
57        PlasmaRegime::HMode
58    } else {
59        PlasmaRegime::LMode
60    }
61}
62
63fn validate_observation(observation: &RegimeObservation) -> FusionResult<()> {
64    if !observation.beta_n.is_finite() {
65        return Err(FusionError::ConfigError(
66            "jit observation beta_n must be finite".to_string(),
67        ));
68    }
69    if !observation.q95.is_finite() {
70        return Err(FusionError::ConfigError(
71            "jit observation q95 must be finite".to_string(),
72        ));
73    }
74    if !observation.density_line_avg_1e20_m3.is_finite() {
75        return Err(FusionError::ConfigError(
76            "jit observation density_line_avg_1e20_m3 must be finite".to_string(),
77        ));
78    }
79    if !observation.current_ramp_ma_s.is_finite() {
80        return Err(FusionError::ConfigError(
81            "jit observation current_ramp_ma_s must be finite".to_string(),
82        ));
83    }
84    Ok(())
85}
86
87/// Compile-time shape metadata for generated kernels.
88#[derive(Debug, Clone, Copy, PartialEq)]
89pub struct KernelCompileSpec {
90    pub n_state: usize,
91    pub n_control: usize,
92    pub dt_s: f64,
93    pub unroll_factor: usize,
94}
95
96impl KernelCompileSpec {
97    pub fn validated(self) -> FusionResult<Self> {
98        if self.n_state == 0 {
99            return Err(FusionError::ConfigError(
100                "jit kernel n_state must be > 0".to_string(),
101            ));
102        }
103        if self.n_control == 0 {
104            return Err(FusionError::ConfigError(
105                "jit kernel n_control must be > 0".to_string(),
106            ));
107        }
108        if !self.dt_s.is_finite() || self.dt_s < MIN_DT_S {
109            return Err(FusionError::ConfigError(format!(
110                "jit kernel dt_s must be finite and >= {MIN_DT_S}"
111            )));
112        }
113        if self.unroll_factor == 0 {
114            return Err(FusionError::ConfigError(
115                "jit kernel unroll_factor must be > 0".to_string(),
116            ));
117        }
118        Ok(self)
119    }
120}
121
122impl Default for KernelCompileSpec {
123    fn default() -> Self {
124        Self {
125            n_state: 8,
126            n_control: 4,
127            dt_s: 1e-3,
128            unroll_factor: 4,
129        }
130    }
131}
132
133#[derive(Debug, Clone)]
134pub struct CompiledKernel {
135    pub regime: PlasmaRegime,
136    pub generation: u64,
137    pub spec: KernelCompileSpec,
138    nonlinear_gain: f64,
139    control_gain: f64,
140    bias: f64,
141}
142
143impl CompiledKernel {
144    fn from_regime(regime: PlasmaRegime, spec: KernelCompileSpec, generation: u64) -> Self {
145        let (nonlinear_gain, control_gain, bias) = match regime {
146            PlasmaRegime::LMode => (0.22, 0.08, -0.01),
147            PlasmaRegime::HMode => (0.30, 0.12, 0.02),
148            PlasmaRegime::RampUp => (0.26, 0.14, 0.04),
149            PlasmaRegime::RampDown => (0.24, 0.13, -0.04),
150        };
151        Self {
152            regime,
153            generation,
154            spec,
155            nonlinear_gain,
156            control_gain,
157            bias,
158        }
159    }
160
161    /// Execute one reduced control step for this specialized kernel.
162    pub fn step(&self, state: &Array1<f64>, control: &Array1<f64>) -> FusionResult<Array1<f64>> {
163        if state.len() != self.spec.n_state {
164            return Err(FusionError::ConfigError(format!(
165                "jit step state length mismatch: expected {}, got {}",
166                self.spec.n_state,
167                state.len()
168            )));
169        }
170        if control.len() != self.spec.n_control {
171            return Err(FusionError::ConfigError(format!(
172                "jit step control length mismatch: expected {}, got {}",
173                self.spec.n_control,
174                control.len()
175            )));
176        }
177        if state.iter().any(|v| !v.is_finite()) {
178            return Err(FusionError::ConfigError(
179                "jit step state vector must contain only finite values".to_string(),
180            ));
181        }
182        if control.iter().any(|v| !v.is_finite()) {
183            return Err(FusionError::ConfigError(
184                "jit step control vector must contain only finite values".to_string(),
185            ));
186        }
187
188        let control_mean = control.iter().copied().sum::<f64>() / self.spec.n_control as f64;
189        if !control_mean.is_finite() {
190            return Err(FusionError::ConfigError(
191                "jit step control mean became non-finite".to_string(),
192            ));
193        }
194
195        let forcing = self.control_gain * control_mean + self.bias;
196        let dt = self.spec.dt_s;
197        let mut next = Array1::zeros(self.spec.n_state);
198
199        if let (Some(state_slice), Some(next_slice)) = (state.as_slice(), next.as_slice_mut()) {
200            for start in (0..self.spec.n_state).step_by(self.spec.unroll_factor) {
201                let end = (start + self.spec.unroll_factor).min(self.spec.n_state);
202                for idx in start..end {
203                    let x = state_slice[idx];
204                    let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
205                    next_slice[idx] = x + dt * (drift + forcing);
206                }
207            }
208            if next_slice.iter().any(|v| !v.is_finite()) {
209                return Err(FusionError::ConfigError(
210                    "jit step produced non-finite state output".to_string(),
211                ));
212            }
213            return Ok(next);
214        }
215
216        for (idx, out) in next.iter_mut().enumerate() {
217            let x = state[idx];
218            let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
219            *out = x + dt * (drift + forcing);
220        }
221        if next.iter().any(|v| !v.is_finite()) {
222            return Err(FusionError::ConfigError(
223                "jit step produced non-finite state output".to_string(),
224            ));
225        }
226        Ok(next)
227    }
228}
229
230/// Runtime kernel specialization manager with cache + hot-swap semantics.
231#[derive(Debug, Default, Clone)]
232pub struct RuntimeKernelJit {
233    kernels: HashMap<PlasmaRegime, CompiledKernel>,
234    active: Option<PlasmaRegime>,
235    compile_events: u64,
236}
237
238impl RuntimeKernelJit {
239    pub fn new() -> Self {
240        Self::default()
241    }
242
243    /// Compile or reuse a kernel for the requested regime and activate it.
244    pub fn compile_for_regime(
245        &mut self,
246        regime: PlasmaRegime,
247        spec: KernelCompileSpec,
248    ) -> FusionResult<u64> {
249        let spec = spec.validated()?;
250        if let Some(existing) = self.kernels.get(&regime) {
251            if existing.spec == spec {
252                self.active = Some(regime);
253                return Ok(existing.generation);
254            }
255        }
256
257        self.compile_events += 1;
258        let generation = self.compile_events;
259        let compiled = CompiledKernel::from_regime(regime, spec, generation);
260        self.kernels.insert(regime, compiled);
261        self.active = Some(regime);
262        Ok(generation)
263    }
264
265    /// Detect regime, compile if needed, and activate specialized kernel.
266    pub fn refresh_for_observation(
267        &mut self,
268        observation: &RegimeObservation,
269        spec: KernelCompileSpec,
270    ) -> FusionResult<(PlasmaRegime, u64)> {
271        validate_observation(observation)?;
272        let regime = detect_regime(observation);
273        let generation = self.compile_for_regime(regime, spec)?;
274        Ok((regime, generation))
275    }
276
277    pub fn active_regime(&self) -> Option<PlasmaRegime> {
278        self.active
279    }
280
281    pub fn cache_size(&self) -> usize {
282        self.kernels.len()
283    }
284
285    pub fn compile_events(&self) -> u64 {
286        self.compile_events
287    }
288
289    /// Execute one step with the active specialized kernel.
290    pub fn step_active(
291        &self,
292        state: &Array1<f64>,
293        control: &Array1<f64>,
294    ) -> FusionResult<Array1<f64>> {
295        let Some(regime) = self.active else {
296            return Err(FusionError::ConfigError(
297                "jit step_active requires an active regime; compile or refresh first".to_string(),
298            ));
299        };
300        let Some(kernel) = self.kernels.get(&regime) else {
301            return Err(FusionError::ConfigError(
302                "jit step_active active regime has no compiled kernel".to_string(),
303            ));
304        };
305        kernel.step(state, control)
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_detect_regime_routing() {
315        let ramp_up = RegimeObservation {
316            current_ramp_ma_s: 0.4,
317            ..RegimeObservation::default()
318        };
319        let ramp_down = RegimeObservation {
320            current_ramp_ma_s: -0.5,
321            ..RegimeObservation::default()
322        };
323        let h_mode = RegimeObservation {
324            beta_n: 2.4,
325            ..RegimeObservation::default()
326        };
327        let l_mode = RegimeObservation::default();
328
329        assert_eq!(detect_regime(&ramp_up), PlasmaRegime::RampUp);
330        assert_eq!(detect_regime(&ramp_down), PlasmaRegime::RampDown);
331        assert_eq!(detect_regime(&h_mode), PlasmaRegime::HMode);
332        assert_eq!(detect_regime(&l_mode), PlasmaRegime::LMode);
333    }
334
335    #[test]
336    fn test_compile_cache_reuses_generation() {
337        let mut jit = RuntimeKernelJit::new();
338        let spec = KernelCompileSpec::default();
339        let gen1 = jit
340            .compile_for_regime(PlasmaRegime::LMode, spec)
341            .expect("valid compile spec");
342        let gen2 = jit
343            .compile_for_regime(PlasmaRegime::LMode, spec)
344            .expect("valid compile spec");
345        assert_eq!(gen1, gen2);
346        assert_eq!(jit.compile_events(), 1);
347        assert_eq!(jit.cache_size(), 1);
348    }
349
350    #[test]
351    fn test_hot_swap_changes_active_regime_and_response() {
352        let mut jit = RuntimeKernelJit::new();
353        let spec = KernelCompileSpec::default();
354        let state = Array1::from_vec(vec![0.7; spec.n_state]);
355        let control = Array1::from_vec(vec![0.2; spec.n_control]);
356
357        jit.compile_for_regime(PlasmaRegime::LMode, spec)
358            .expect("valid compile spec");
359        let l_step = jit
360            .step_active(&state, &control)
361            .expect("valid active-kernel step inputs");
362
363        jit.compile_for_regime(PlasmaRegime::HMode, spec)
364            .expect("valid compile spec");
365        let h_step = jit
366            .step_active(&state, &control)
367            .expect("valid active-kernel step inputs");
368
369        assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
370        let delta_sum = (&h_step - &l_step).iter().map(|v| v.abs()).sum::<f64>();
371        assert!(
372            delta_sum > 1e-9,
373            "Expected regime hot-swap to alter response"
374        );
375    }
376
377    #[test]
378    fn test_refresh_for_observation_compiles_once_per_regime() {
379        let mut jit = RuntimeKernelJit::new();
380        let spec = KernelCompileSpec::default();
381
382        let obs_l = RegimeObservation::default();
383        let obs_h = RegimeObservation {
384            beta_n: 2.6,
385            ..RegimeObservation::default()
386        };
387        jit.refresh_for_observation(&obs_l, spec)
388            .expect("valid compile spec");
389        jit.refresh_for_observation(&obs_l, spec)
390            .expect("valid compile spec");
391        jit.refresh_for_observation(&obs_h, spec)
392            .expect("valid compile spec");
393        jit.refresh_for_observation(&obs_h, spec)
394            .expect("valid compile spec");
395
396        assert_eq!(jit.compile_events(), 2);
397        assert_eq!(jit.cache_size(), 2);
398        assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
399    }
400
401    #[test]
402    fn test_step_active_without_kernel_rejects_missing_active_kernel() {
403        let jit = RuntimeKernelJit::new();
404        let state = Array1::from_vec(vec![1.0, -2.0, 0.5]);
405        let control = Array1::from_vec(vec![0.1, 0.1]);
406        assert!(jit.step_active(&state, &control).is_err());
407    }
408
409    #[test]
410    fn test_compile_for_regime_rejects_invalid_compile_specs() {
411        let mut jit = RuntimeKernelJit::new();
412        let bad_n_state = KernelCompileSpec {
413            n_state: 0,
414            ..KernelCompileSpec::default()
415        };
416        let bad_n_control = KernelCompileSpec {
417            n_control: 0,
418            ..KernelCompileSpec::default()
419        };
420        let bad_dt = KernelCompileSpec {
421            dt_s: f64::NAN,
422            ..KernelCompileSpec::default()
423        };
424        let bad_unroll = KernelCompileSpec {
425            unroll_factor: 0,
426            ..KernelCompileSpec::default()
427        };
428
429        assert!(jit
430            .compile_for_regime(PlasmaRegime::LMode, bad_n_state)
431            .is_err());
432        assert!(jit
433            .compile_for_regime(PlasmaRegime::LMode, bad_n_control)
434            .is_err());
435        assert!(jit.compile_for_regime(PlasmaRegime::LMode, bad_dt).is_err());
436        assert!(jit
437            .compile_for_regime(PlasmaRegime::LMode, bad_unroll)
438            .is_err());
439        assert_eq!(jit.compile_events(), 0);
440        assert_eq!(jit.cache_size(), 0);
441    }
442
443    #[test]
444    fn test_refresh_for_observation_rejects_non_finite_inputs() {
445        let mut jit = RuntimeKernelJit::new();
446        let spec = KernelCompileSpec::default();
447        let bad = RegimeObservation {
448            beta_n: f64::NAN,
449            ..RegimeObservation::default()
450        };
451        assert!(jit.refresh_for_observation(&bad, spec).is_err());
452        assert_eq!(jit.compile_events(), 0);
453        assert_eq!(jit.cache_size(), 0);
454    }
455
456    #[test]
457    fn test_step_active_rejects_invalid_runtime_vectors() {
458        let mut jit = RuntimeKernelJit::new();
459        let spec = KernelCompileSpec::default();
460        jit.compile_for_regime(PlasmaRegime::LMode, spec)
461            .expect("valid compile spec");
462
463        let bad_state = Array1::from_vec(vec![0.0; spec.n_state - 1]);
464        let good_control = Array1::from_vec(vec![0.1; spec.n_control]);
465        assert!(jit.step_active(&bad_state, &good_control).is_err());
466
467        let good_state = Array1::from_vec(vec![0.0; spec.n_state]);
468        let bad_control = Array1::from_vec(vec![0.1; spec.n_control - 1]);
469        assert!(jit.step_active(&good_state, &bad_control).is_err());
470
471        let nan_state = Array1::from_vec(vec![f64::NAN; spec.n_state]);
472        assert!(jit.step_active(&nan_state, &good_control).is_err());
473    }
474}