1use fusion_types::error::{FusionError, FusionResult};
8use ndarray::Array1;
9use std::collections::HashMap;
10
11const MIN_DT_S: f64 = 1e-9;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum PlasmaRegime {
16 LMode,
17 HMode,
18 RampUp,
19 RampDown,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
24pub struct RegimeObservation {
25 pub beta_n: f64,
26 pub q95: f64,
27 pub density_line_avg_1e20_m3: f64,
28 pub current_ramp_ma_s: f64,
29}
30
31impl Default for RegimeObservation {
32 fn default() -> Self {
33 Self {
34 beta_n: 1.5,
35 q95: 4.5,
36 density_line_avg_1e20_m3: 0.8,
37 current_ramp_ma_s: 0.0,
38 }
39 }
40}
41
42pub fn detect_regime(observation: &RegimeObservation) -> PlasmaRegime {
44 if observation.current_ramp_ma_s > 0.2 {
45 PlasmaRegime::RampUp
46 } else if observation.current_ramp_ma_s < -0.2 {
47 PlasmaRegime::RampDown
48 } else if observation.beta_n >= 2.2
49 || (observation.beta_n >= 2.0 && observation.q95 <= 3.7)
50 || observation.density_line_avg_1e20_m3 >= 1.0
51 {
52 PlasmaRegime::HMode
53 } else {
54 PlasmaRegime::LMode
55 }
56}
57
58fn validate_observation(observation: &RegimeObservation) -> FusionResult<()> {
59 if !observation.beta_n.is_finite() {
60 return Err(FusionError::ConfigError(
61 "jit observation beta_n must be finite".to_string(),
62 ));
63 }
64 if !observation.q95.is_finite() {
65 return Err(FusionError::ConfigError(
66 "jit observation q95 must be finite".to_string(),
67 ));
68 }
69 if !observation.density_line_avg_1e20_m3.is_finite() {
70 return Err(FusionError::ConfigError(
71 "jit observation density_line_avg_1e20_m3 must be finite".to_string(),
72 ));
73 }
74 if !observation.current_ramp_ma_s.is_finite() {
75 return Err(FusionError::ConfigError(
76 "jit observation current_ramp_ma_s must be finite".to_string(),
77 ));
78 }
79 Ok(())
80}
81
82#[derive(Debug, Clone, Copy, PartialEq)]
84pub struct KernelCompileSpec {
85 pub n_state: usize,
86 pub n_control: usize,
87 pub dt_s: f64,
88 pub unroll_factor: usize,
89}
90
91impl KernelCompileSpec {
92 pub fn validated(self) -> FusionResult<Self> {
93 if self.n_state == 0 {
94 return Err(FusionError::ConfigError(
95 "jit kernel n_state must be > 0".to_string(),
96 ));
97 }
98 if self.n_control == 0 {
99 return Err(FusionError::ConfigError(
100 "jit kernel n_control must be > 0".to_string(),
101 ));
102 }
103 if !self.dt_s.is_finite() || self.dt_s < MIN_DT_S {
104 return Err(FusionError::ConfigError(format!(
105 "jit kernel dt_s must be finite and >= {MIN_DT_S}"
106 )));
107 }
108 if self.unroll_factor == 0 {
109 return Err(FusionError::ConfigError(
110 "jit kernel unroll_factor must be > 0".to_string(),
111 ));
112 }
113 Ok(self)
114 }
115}
116
117impl Default for KernelCompileSpec {
118 fn default() -> Self {
119 Self {
120 n_state: 8,
121 n_control: 4,
122 dt_s: 1e-3,
123 unroll_factor: 4,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
129pub struct CompiledKernel {
130 pub regime: PlasmaRegime,
131 pub generation: u64,
132 pub spec: KernelCompileSpec,
133 nonlinear_gain: f64,
134 control_gain: f64,
135 bias: f64,
136}
137
138impl CompiledKernel {
139 fn from_regime(regime: PlasmaRegime, spec: KernelCompileSpec, generation: u64) -> Self {
140 let (nonlinear_gain, control_gain, bias) = match regime {
141 PlasmaRegime::LMode => (0.22, 0.08, -0.01),
142 PlasmaRegime::HMode => (0.30, 0.12, 0.02),
143 PlasmaRegime::RampUp => (0.26, 0.14, 0.04),
144 PlasmaRegime::RampDown => (0.24, 0.13, -0.04),
145 };
146 Self {
147 regime,
148 generation,
149 spec,
150 nonlinear_gain,
151 control_gain,
152 bias,
153 }
154 }
155
156 pub fn step(&self, state: &Array1<f64>, control: &Array1<f64>) -> FusionResult<Array1<f64>> {
158 if state.len() != self.spec.n_state {
159 return Err(FusionError::ConfigError(format!(
160 "jit step state length mismatch: expected {}, got {}",
161 self.spec.n_state,
162 state.len()
163 )));
164 }
165 if control.len() != self.spec.n_control {
166 return Err(FusionError::ConfigError(format!(
167 "jit step control length mismatch: expected {}, got {}",
168 self.spec.n_control,
169 control.len()
170 )));
171 }
172 if state.iter().any(|v| !v.is_finite()) {
173 return Err(FusionError::ConfigError(
174 "jit step state vector must contain only finite values".to_string(),
175 ));
176 }
177 if control.iter().any(|v| !v.is_finite()) {
178 return Err(FusionError::ConfigError(
179 "jit step control vector must contain only finite values".to_string(),
180 ));
181 }
182
183 let control_mean = control.iter().copied().sum::<f64>() / self.spec.n_control as f64;
184 if !control_mean.is_finite() {
185 return Err(FusionError::ConfigError(
186 "jit step control mean became non-finite".to_string(),
187 ));
188 }
189
190 let forcing = self.control_gain * control_mean + self.bias;
191 let dt = self.spec.dt_s;
192 let mut next = Array1::zeros(self.spec.n_state);
193
194 if let (Some(state_slice), Some(next_slice)) = (state.as_slice(), next.as_slice_mut()) {
195 for start in (0..self.spec.n_state).step_by(self.spec.unroll_factor) {
196 let end = (start + self.spec.unroll_factor).min(self.spec.n_state);
197 for idx in start..end {
198 let x = state_slice[idx];
199 let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
200 next_slice[idx] = x + dt * (drift + forcing);
201 }
202 }
203 if next_slice.iter().any(|v| !v.is_finite()) {
204 return Err(FusionError::ConfigError(
205 "jit step produced non-finite state output".to_string(),
206 ));
207 }
208 return Ok(next);
209 }
210
211 for (idx, out) in next.iter_mut().enumerate() {
212 let x = state[idx];
213 let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
214 *out = x + dt * (drift + forcing);
215 }
216 if next.iter().any(|v| !v.is_finite()) {
217 return Err(FusionError::ConfigError(
218 "jit step produced non-finite state output".to_string(),
219 ));
220 }
221 Ok(next)
222 }
223}
224
225#[derive(Debug, Default, Clone)]
227pub struct RuntimeKernelJit {
228 kernels: HashMap<PlasmaRegime, CompiledKernel>,
229 active: Option<PlasmaRegime>,
230 compile_events: u64,
231}
232
233impl RuntimeKernelJit {
234 pub fn new() -> Self {
235 Self::default()
236 }
237
238 pub fn compile_for_regime(
240 &mut self,
241 regime: PlasmaRegime,
242 spec: KernelCompileSpec,
243 ) -> FusionResult<u64> {
244 let spec = spec.validated()?;
245 if let Some(existing) = self.kernels.get(®ime) {
246 if existing.spec == spec {
247 self.active = Some(regime);
248 return Ok(existing.generation);
249 }
250 }
251
252 self.compile_events += 1;
253 let generation = self.compile_events;
254 let compiled = CompiledKernel::from_regime(regime, spec, generation);
255 self.kernels.insert(regime, compiled);
256 self.active = Some(regime);
257 Ok(generation)
258 }
259
260 pub fn refresh_for_observation(
262 &mut self,
263 observation: &RegimeObservation,
264 spec: KernelCompileSpec,
265 ) -> FusionResult<(PlasmaRegime, u64)> {
266 validate_observation(observation)?;
267 let regime = detect_regime(observation);
268 let generation = self.compile_for_regime(regime, spec)?;
269 Ok((regime, generation))
270 }
271
272 pub fn active_regime(&self) -> Option<PlasmaRegime> {
273 self.active
274 }
275
276 pub fn cache_size(&self) -> usize {
277 self.kernels.len()
278 }
279
280 pub fn compile_events(&self) -> u64 {
281 self.compile_events
282 }
283
284 pub fn step_active(
286 &self,
287 state: &Array1<f64>,
288 control: &Array1<f64>,
289 ) -> FusionResult<Array1<f64>> {
290 let Some(regime) = self.active else {
291 return Err(FusionError::ConfigError(
292 "jit step_active requires an active regime; compile or refresh first".to_string(),
293 ));
294 };
295 let Some(kernel) = self.kernels.get(®ime) else {
296 return Err(FusionError::ConfigError(
297 "jit step_active active regime has no compiled kernel".to_string(),
298 ));
299 };
300 kernel.step(state, control)
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_detect_regime_routing() {
310 let ramp_up = RegimeObservation {
311 current_ramp_ma_s: 0.4,
312 ..RegimeObservation::default()
313 };
314 let ramp_down = RegimeObservation {
315 current_ramp_ma_s: -0.5,
316 ..RegimeObservation::default()
317 };
318 let h_mode = RegimeObservation {
319 beta_n: 2.4,
320 ..RegimeObservation::default()
321 };
322 let l_mode = RegimeObservation::default();
323
324 assert_eq!(detect_regime(&ramp_up), PlasmaRegime::RampUp);
325 assert_eq!(detect_regime(&ramp_down), PlasmaRegime::RampDown);
326 assert_eq!(detect_regime(&h_mode), PlasmaRegime::HMode);
327 assert_eq!(detect_regime(&l_mode), PlasmaRegime::LMode);
328 }
329
330 #[test]
331 fn test_compile_cache_reuses_generation() {
332 let mut jit = RuntimeKernelJit::new();
333 let spec = KernelCompileSpec::default();
334 let gen1 = jit
335 .compile_for_regime(PlasmaRegime::LMode, spec)
336 .expect("valid compile spec");
337 let gen2 = jit
338 .compile_for_regime(PlasmaRegime::LMode, spec)
339 .expect("valid compile spec");
340 assert_eq!(gen1, gen2);
341 assert_eq!(jit.compile_events(), 1);
342 assert_eq!(jit.cache_size(), 1);
343 }
344
345 #[test]
346 fn test_hot_swap_changes_active_regime_and_response() {
347 let mut jit = RuntimeKernelJit::new();
348 let spec = KernelCompileSpec::default();
349 let state = Array1::from_vec(vec![0.7; spec.n_state]);
350 let control = Array1::from_vec(vec![0.2; spec.n_control]);
351
352 jit.compile_for_regime(PlasmaRegime::LMode, spec)
353 .expect("valid compile spec");
354 let l_step = jit
355 .step_active(&state, &control)
356 .expect("valid active-kernel step inputs");
357
358 jit.compile_for_regime(PlasmaRegime::HMode, spec)
359 .expect("valid compile spec");
360 let h_step = jit
361 .step_active(&state, &control)
362 .expect("valid active-kernel step inputs");
363
364 assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
365 let delta_sum = (&h_step - &l_step).iter().map(|v| v.abs()).sum::<f64>();
366 assert!(
367 delta_sum > 1e-9,
368 "Expected regime hot-swap to alter response"
369 );
370 }
371
372 #[test]
373 fn test_refresh_for_observation_compiles_once_per_regime() {
374 let mut jit = RuntimeKernelJit::new();
375 let spec = KernelCompileSpec::default();
376
377 let obs_l = RegimeObservation::default();
378 let obs_h = RegimeObservation {
379 beta_n: 2.6,
380 ..RegimeObservation::default()
381 };
382 jit.refresh_for_observation(&obs_l, spec)
383 .expect("valid compile spec");
384 jit.refresh_for_observation(&obs_l, spec)
385 .expect("valid compile spec");
386 jit.refresh_for_observation(&obs_h, spec)
387 .expect("valid compile spec");
388 jit.refresh_for_observation(&obs_h, spec)
389 .expect("valid compile spec");
390
391 assert_eq!(jit.compile_events(), 2);
392 assert_eq!(jit.cache_size(), 2);
393 assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
394 }
395
396 #[test]
397 fn test_step_active_without_kernel_rejects_missing_active_kernel() {
398 let jit = RuntimeKernelJit::new();
399 let state = Array1::from_vec(vec![1.0, -2.0, 0.5]);
400 let control = Array1::from_vec(vec![0.1, 0.1]);
401 assert!(jit.step_active(&state, &control).is_err());
402 }
403
404 #[test]
405 fn test_compile_for_regime_rejects_invalid_compile_specs() {
406 let mut jit = RuntimeKernelJit::new();
407 let bad_n_state = KernelCompileSpec {
408 n_state: 0,
409 ..KernelCompileSpec::default()
410 };
411 let bad_n_control = KernelCompileSpec {
412 n_control: 0,
413 ..KernelCompileSpec::default()
414 };
415 let bad_dt = KernelCompileSpec {
416 dt_s: f64::NAN,
417 ..KernelCompileSpec::default()
418 };
419 let bad_unroll = KernelCompileSpec {
420 unroll_factor: 0,
421 ..KernelCompileSpec::default()
422 };
423
424 assert!(jit
425 .compile_for_regime(PlasmaRegime::LMode, bad_n_state)
426 .is_err());
427 assert!(jit
428 .compile_for_regime(PlasmaRegime::LMode, bad_n_control)
429 .is_err());
430 assert!(jit.compile_for_regime(PlasmaRegime::LMode, bad_dt).is_err());
431 assert!(jit
432 .compile_for_regime(PlasmaRegime::LMode, bad_unroll)
433 .is_err());
434 assert_eq!(jit.compile_events(), 0);
435 assert_eq!(jit.cache_size(), 0);
436 }
437
438 #[test]
439 fn test_refresh_for_observation_rejects_non_finite_inputs() {
440 let mut jit = RuntimeKernelJit::new();
441 let spec = KernelCompileSpec::default();
442 let bad = RegimeObservation {
443 beta_n: f64::NAN,
444 ..RegimeObservation::default()
445 };
446 assert!(jit.refresh_for_observation(&bad, spec).is_err());
447 assert_eq!(jit.compile_events(), 0);
448 assert_eq!(jit.cache_size(), 0);
449 }
450
451 #[test]
452 fn test_step_active_rejects_invalid_runtime_vectors() {
453 let mut jit = RuntimeKernelJit::new();
454 let spec = KernelCompileSpec::default();
455 jit.compile_for_regime(PlasmaRegime::LMode, spec)
456 .expect("valid compile spec");
457
458 let bad_state = Array1::from_vec(vec![0.0; spec.n_state - 1]);
459 let good_control = Array1::from_vec(vec![0.1; spec.n_control]);
460 assert!(jit.step_active(&bad_state, &good_control).is_err());
461
462 let good_state = Array1::from_vec(vec![0.0; spec.n_state]);
463 let bad_control = Array1::from_vec(vec![0.1; spec.n_control - 1]);
464 assert!(jit.step_active(&good_state, &bad_control).is_err());
465
466 let nan_state = Array1::from_vec(vec![f64::NAN; spec.n_state]);
467 assert!(jit.step_active(&nan_state, &good_control).is_err());
468 }
469}