fusion_core/
jit.rs

1//! Reduced runtime kernel specialization and hot-swap lane.
2//!
3//! This module provides a deterministic stand-in for regime-triggered runtime
4//! compilation. It keeps compilation/cache semantics explicit without external
5//! LLVM/JIT dependencies so CI remains bounded.
6
7use fusion_types::error::{FusionError, FusionResult};
8use ndarray::Array1;
9use std::collections::HashMap;
10
11const MIN_DT_S: f64 = 1e-9;
12
13/// Plasma operation regime used for runtime kernel specialization.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum PlasmaRegime {
16    LMode,
17    HMode,
18    RampUp,
19    RampDown,
20}
21
22/// Minimal observation bundle used for regime routing.
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub struct RegimeObservation {
25    pub beta_n: f64,
26    pub q95: f64,
27    pub density_line_avg_1e20_m3: f64,
28    pub current_ramp_ma_s: f64,
29}
30
31impl Default for RegimeObservation {
32    fn default() -> Self {
33        Self {
34            beta_n: 1.5,
35            q95: 4.5,
36            density_line_avg_1e20_m3: 0.8,
37            current_ramp_ma_s: 0.0,
38        }
39    }
40}
41
42/// Routing heuristic for reduced runtime specialization.
43pub fn detect_regime(observation: &RegimeObservation) -> PlasmaRegime {
44    if observation.current_ramp_ma_s > 0.2 {
45        PlasmaRegime::RampUp
46    } else if observation.current_ramp_ma_s < -0.2 {
47        PlasmaRegime::RampDown
48    } else if observation.beta_n >= 2.2
49        || (observation.beta_n >= 2.0 && observation.q95 <= 3.7)
50        || observation.density_line_avg_1e20_m3 >= 1.0
51    {
52        PlasmaRegime::HMode
53    } else {
54        PlasmaRegime::LMode
55    }
56}
57
58fn validate_observation(observation: &RegimeObservation) -> FusionResult<()> {
59    if !observation.beta_n.is_finite() {
60        return Err(FusionError::ConfigError(
61            "jit observation beta_n must be finite".to_string(),
62        ));
63    }
64    if !observation.q95.is_finite() {
65        return Err(FusionError::ConfigError(
66            "jit observation q95 must be finite".to_string(),
67        ));
68    }
69    if !observation.density_line_avg_1e20_m3.is_finite() {
70        return Err(FusionError::ConfigError(
71            "jit observation density_line_avg_1e20_m3 must be finite".to_string(),
72        ));
73    }
74    if !observation.current_ramp_ma_s.is_finite() {
75        return Err(FusionError::ConfigError(
76            "jit observation current_ramp_ma_s must be finite".to_string(),
77        ));
78    }
79    Ok(())
80}
81
82/// Compile-time shape metadata for generated kernels.
83#[derive(Debug, Clone, Copy, PartialEq)]
84pub struct KernelCompileSpec {
85    pub n_state: usize,
86    pub n_control: usize,
87    pub dt_s: f64,
88    pub unroll_factor: usize,
89}
90
91impl KernelCompileSpec {
92    pub fn validated(self) -> FusionResult<Self> {
93        if self.n_state == 0 {
94            return Err(FusionError::ConfigError(
95                "jit kernel n_state must be > 0".to_string(),
96            ));
97        }
98        if self.n_control == 0 {
99            return Err(FusionError::ConfigError(
100                "jit kernel n_control must be > 0".to_string(),
101            ));
102        }
103        if !self.dt_s.is_finite() || self.dt_s < MIN_DT_S {
104            return Err(FusionError::ConfigError(format!(
105                "jit kernel dt_s must be finite and >= {MIN_DT_S}"
106            )));
107        }
108        if self.unroll_factor == 0 {
109            return Err(FusionError::ConfigError(
110                "jit kernel unroll_factor must be > 0".to_string(),
111            ));
112        }
113        Ok(self)
114    }
115}
116
117impl Default for KernelCompileSpec {
118    fn default() -> Self {
119        Self {
120            n_state: 8,
121            n_control: 4,
122            dt_s: 1e-3,
123            unroll_factor: 4,
124        }
125    }
126}
127
128#[derive(Debug, Clone)]
129pub struct CompiledKernel {
130    pub regime: PlasmaRegime,
131    pub generation: u64,
132    pub spec: KernelCompileSpec,
133    nonlinear_gain: f64,
134    control_gain: f64,
135    bias: f64,
136}
137
138impl CompiledKernel {
139    fn from_regime(regime: PlasmaRegime, spec: KernelCompileSpec, generation: u64) -> Self {
140        let (nonlinear_gain, control_gain, bias) = match regime {
141            PlasmaRegime::LMode => (0.22, 0.08, -0.01),
142            PlasmaRegime::HMode => (0.30, 0.12, 0.02),
143            PlasmaRegime::RampUp => (0.26, 0.14, 0.04),
144            PlasmaRegime::RampDown => (0.24, 0.13, -0.04),
145        };
146        Self {
147            regime,
148            generation,
149            spec,
150            nonlinear_gain,
151            control_gain,
152            bias,
153        }
154    }
155
156    /// Execute one reduced control step for this specialized kernel.
157    pub fn step(&self, state: &Array1<f64>, control: &Array1<f64>) -> FusionResult<Array1<f64>> {
158        if state.len() != self.spec.n_state {
159            return Err(FusionError::ConfigError(format!(
160                "jit step state length mismatch: expected {}, got {}",
161                self.spec.n_state,
162                state.len()
163            )));
164        }
165        if control.len() != self.spec.n_control {
166            return Err(FusionError::ConfigError(format!(
167                "jit step control length mismatch: expected {}, got {}",
168                self.spec.n_control,
169                control.len()
170            )));
171        }
172        if state.iter().any(|v| !v.is_finite()) {
173            return Err(FusionError::ConfigError(
174                "jit step state vector must contain only finite values".to_string(),
175            ));
176        }
177        if control.iter().any(|v| !v.is_finite()) {
178            return Err(FusionError::ConfigError(
179                "jit step control vector must contain only finite values".to_string(),
180            ));
181        }
182
183        let control_mean = control.iter().copied().sum::<f64>() / self.spec.n_control as f64;
184        if !control_mean.is_finite() {
185            return Err(FusionError::ConfigError(
186                "jit step control mean became non-finite".to_string(),
187            ));
188        }
189
190        let forcing = self.control_gain * control_mean + self.bias;
191        let dt = self.spec.dt_s;
192        let mut next = Array1::zeros(self.spec.n_state);
193
194        if let (Some(state_slice), Some(next_slice)) = (state.as_slice(), next.as_slice_mut()) {
195            for start in (0..self.spec.n_state).step_by(self.spec.unroll_factor) {
196                let end = (start + self.spec.unroll_factor).min(self.spec.n_state);
197                for idx in start..end {
198                    let x = state_slice[idx];
199                    let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
200                    next_slice[idx] = x + dt * (drift + forcing);
201                }
202            }
203            if next_slice.iter().any(|v| !v.is_finite()) {
204                return Err(FusionError::ConfigError(
205                    "jit step produced non-finite state output".to_string(),
206                ));
207            }
208            return Ok(next);
209        }
210
211        for (idx, out) in next.iter_mut().enumerate() {
212            let x = state[idx];
213            let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
214            *out = x + dt * (drift + forcing);
215        }
216        if next.iter().any(|v| !v.is_finite()) {
217            return Err(FusionError::ConfigError(
218                "jit step produced non-finite state output".to_string(),
219            ));
220        }
221        Ok(next)
222    }
223}
224
225/// Runtime kernel specialization manager with cache + hot-swap semantics.
226#[derive(Debug, Default, Clone)]
227pub struct RuntimeKernelJit {
228    kernels: HashMap<PlasmaRegime, CompiledKernel>,
229    active: Option<PlasmaRegime>,
230    compile_events: u64,
231}
232
233impl RuntimeKernelJit {
234    pub fn new() -> Self {
235        Self::default()
236    }
237
238    /// Compile or reuse a kernel for the requested regime and activate it.
239    pub fn compile_for_regime(
240        &mut self,
241        regime: PlasmaRegime,
242        spec: KernelCompileSpec,
243    ) -> FusionResult<u64> {
244        let spec = spec.validated()?;
245        if let Some(existing) = self.kernels.get(&regime) {
246            if existing.spec == spec {
247                self.active = Some(regime);
248                return Ok(existing.generation);
249            }
250        }
251
252        self.compile_events += 1;
253        let generation = self.compile_events;
254        let compiled = CompiledKernel::from_regime(regime, spec, generation);
255        self.kernels.insert(regime, compiled);
256        self.active = Some(regime);
257        Ok(generation)
258    }
259
260    /// Detect regime, compile if needed, and activate specialized kernel.
261    pub fn refresh_for_observation(
262        &mut self,
263        observation: &RegimeObservation,
264        spec: KernelCompileSpec,
265    ) -> FusionResult<(PlasmaRegime, u64)> {
266        validate_observation(observation)?;
267        let regime = detect_regime(observation);
268        let generation = self.compile_for_regime(regime, spec)?;
269        Ok((regime, generation))
270    }
271
272    pub fn active_regime(&self) -> Option<PlasmaRegime> {
273        self.active
274    }
275
276    pub fn cache_size(&self) -> usize {
277        self.kernels.len()
278    }
279
280    pub fn compile_events(&self) -> u64 {
281        self.compile_events
282    }
283
284    /// Execute one step with the active specialized kernel.
285    pub fn step_active(
286        &self,
287        state: &Array1<f64>,
288        control: &Array1<f64>,
289    ) -> FusionResult<Array1<f64>> {
290        let Some(regime) = self.active else {
291            return Err(FusionError::ConfigError(
292                "jit step_active requires an active regime; compile or refresh first".to_string(),
293            ));
294        };
295        let Some(kernel) = self.kernels.get(&regime) else {
296            return Err(FusionError::ConfigError(
297                "jit step_active active regime has no compiled kernel".to_string(),
298            ));
299        };
300        kernel.step(state, control)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_detect_regime_routing() {
310        let ramp_up = RegimeObservation {
311            current_ramp_ma_s: 0.4,
312            ..RegimeObservation::default()
313        };
314        let ramp_down = RegimeObservation {
315            current_ramp_ma_s: -0.5,
316            ..RegimeObservation::default()
317        };
318        let h_mode = RegimeObservation {
319            beta_n: 2.4,
320            ..RegimeObservation::default()
321        };
322        let l_mode = RegimeObservation::default();
323
324        assert_eq!(detect_regime(&ramp_up), PlasmaRegime::RampUp);
325        assert_eq!(detect_regime(&ramp_down), PlasmaRegime::RampDown);
326        assert_eq!(detect_regime(&h_mode), PlasmaRegime::HMode);
327        assert_eq!(detect_regime(&l_mode), PlasmaRegime::LMode);
328    }
329
330    #[test]
331    fn test_compile_cache_reuses_generation() {
332        let mut jit = RuntimeKernelJit::new();
333        let spec = KernelCompileSpec::default();
334        let gen1 = jit
335            .compile_for_regime(PlasmaRegime::LMode, spec)
336            .expect("valid compile spec");
337        let gen2 = jit
338            .compile_for_regime(PlasmaRegime::LMode, spec)
339            .expect("valid compile spec");
340        assert_eq!(gen1, gen2);
341        assert_eq!(jit.compile_events(), 1);
342        assert_eq!(jit.cache_size(), 1);
343    }
344
345    #[test]
346    fn test_hot_swap_changes_active_regime_and_response() {
347        let mut jit = RuntimeKernelJit::new();
348        let spec = KernelCompileSpec::default();
349        let state = Array1::from_vec(vec![0.7; spec.n_state]);
350        let control = Array1::from_vec(vec![0.2; spec.n_control]);
351
352        jit.compile_for_regime(PlasmaRegime::LMode, spec)
353            .expect("valid compile spec");
354        let l_step = jit
355            .step_active(&state, &control)
356            .expect("valid active-kernel step inputs");
357
358        jit.compile_for_regime(PlasmaRegime::HMode, spec)
359            .expect("valid compile spec");
360        let h_step = jit
361            .step_active(&state, &control)
362            .expect("valid active-kernel step inputs");
363
364        assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
365        let delta_sum = (&h_step - &l_step).iter().map(|v| v.abs()).sum::<f64>();
366        assert!(
367            delta_sum > 1e-9,
368            "Expected regime hot-swap to alter response"
369        );
370    }
371
372    #[test]
373    fn test_refresh_for_observation_compiles_once_per_regime() {
374        let mut jit = RuntimeKernelJit::new();
375        let spec = KernelCompileSpec::default();
376
377        let obs_l = RegimeObservation::default();
378        let obs_h = RegimeObservation {
379            beta_n: 2.6,
380            ..RegimeObservation::default()
381        };
382        jit.refresh_for_observation(&obs_l, spec)
383            .expect("valid compile spec");
384        jit.refresh_for_observation(&obs_l, spec)
385            .expect("valid compile spec");
386        jit.refresh_for_observation(&obs_h, spec)
387            .expect("valid compile spec");
388        jit.refresh_for_observation(&obs_h, spec)
389            .expect("valid compile spec");
390
391        assert_eq!(jit.compile_events(), 2);
392        assert_eq!(jit.cache_size(), 2);
393        assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
394    }
395
396    #[test]
397    fn test_step_active_without_kernel_rejects_missing_active_kernel() {
398        let jit = RuntimeKernelJit::new();
399        let state = Array1::from_vec(vec![1.0, -2.0, 0.5]);
400        let control = Array1::from_vec(vec![0.1, 0.1]);
401        assert!(jit.step_active(&state, &control).is_err());
402    }
403
404    #[test]
405    fn test_compile_for_regime_rejects_invalid_compile_specs() {
406        let mut jit = RuntimeKernelJit::new();
407        let bad_n_state = KernelCompileSpec {
408            n_state: 0,
409            ..KernelCompileSpec::default()
410        };
411        let bad_n_control = KernelCompileSpec {
412            n_control: 0,
413            ..KernelCompileSpec::default()
414        };
415        let bad_dt = KernelCompileSpec {
416            dt_s: f64::NAN,
417            ..KernelCompileSpec::default()
418        };
419        let bad_unroll = KernelCompileSpec {
420            unroll_factor: 0,
421            ..KernelCompileSpec::default()
422        };
423
424        assert!(jit
425            .compile_for_regime(PlasmaRegime::LMode, bad_n_state)
426            .is_err());
427        assert!(jit
428            .compile_for_regime(PlasmaRegime::LMode, bad_n_control)
429            .is_err());
430        assert!(jit.compile_for_regime(PlasmaRegime::LMode, bad_dt).is_err());
431        assert!(jit
432            .compile_for_regime(PlasmaRegime::LMode, bad_unroll)
433            .is_err());
434        assert_eq!(jit.compile_events(), 0);
435        assert_eq!(jit.cache_size(), 0);
436    }
437
438    #[test]
439    fn test_refresh_for_observation_rejects_non_finite_inputs() {
440        let mut jit = RuntimeKernelJit::new();
441        let spec = KernelCompileSpec::default();
442        let bad = RegimeObservation {
443            beta_n: f64::NAN,
444            ..RegimeObservation::default()
445        };
446        assert!(jit.refresh_for_observation(&bad, spec).is_err());
447        assert_eq!(jit.compile_events(), 0);
448        assert_eq!(jit.cache_size(), 0);
449    }
450
451    #[test]
452    fn test_step_active_rejects_invalid_runtime_vectors() {
453        let mut jit = RuntimeKernelJit::new();
454        let spec = KernelCompileSpec::default();
455        jit.compile_for_regime(PlasmaRegime::LMode, spec)
456            .expect("valid compile spec");
457
458        let bad_state = Array1::from_vec(vec![0.0; spec.n_state - 1]);
459        let good_control = Array1::from_vec(vec![0.1; spec.n_control]);
460        assert!(jit.step_active(&bad_state, &good_control).is_err());
461
462        let good_state = Array1::from_vec(vec![0.0; spec.n_state]);
463        let bad_control = Array1::from_vec(vec![0.1; spec.n_control - 1]);
464        assert!(jit.step_active(&good_state, &bad_control).is_err());
465
466        let nan_state = Array1::from_vec(vec![f64::NAN; spec.n_state]);
467        assert!(jit.step_active(&nan_state, &good_control).is_err());
468    }
469}