1use fusion_types::error::{FusionError, FusionResult};
13use ndarray::Array1;
14use std::collections::HashMap;
15
16const MIN_DT_S: f64 = 1e-9;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum PlasmaRegime {
21 LMode,
22 HMode,
23 RampUp,
24 RampDown,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq)]
29pub struct RegimeObservation {
30 pub beta_n: f64,
31 pub q95: f64,
32 pub density_line_avg_1e20_m3: f64,
33 pub current_ramp_ma_s: f64,
34}
35
36impl Default for RegimeObservation {
37 fn default() -> Self {
38 Self {
39 beta_n: 1.5,
40 q95: 4.5,
41 density_line_avg_1e20_m3: 0.8,
42 current_ramp_ma_s: 0.0,
43 }
44 }
45}
46
47pub fn detect_regime(observation: &RegimeObservation) -> PlasmaRegime {
49 if observation.current_ramp_ma_s > 0.2 {
50 PlasmaRegime::RampUp
51 } else if observation.current_ramp_ma_s < -0.2 {
52 PlasmaRegime::RampDown
53 } else if observation.beta_n >= 2.2
54 || (observation.beta_n >= 2.0 && observation.q95 <= 3.7)
55 || observation.density_line_avg_1e20_m3 >= 1.0
56 {
57 PlasmaRegime::HMode
58 } else {
59 PlasmaRegime::LMode
60 }
61}
62
63fn validate_observation(observation: &RegimeObservation) -> FusionResult<()> {
64 if !observation.beta_n.is_finite() {
65 return Err(FusionError::ConfigError(
66 "jit observation beta_n must be finite".to_string(),
67 ));
68 }
69 if !observation.q95.is_finite() {
70 return Err(FusionError::ConfigError(
71 "jit observation q95 must be finite".to_string(),
72 ));
73 }
74 if !observation.density_line_avg_1e20_m3.is_finite() {
75 return Err(FusionError::ConfigError(
76 "jit observation density_line_avg_1e20_m3 must be finite".to_string(),
77 ));
78 }
79 if !observation.current_ramp_ma_s.is_finite() {
80 return Err(FusionError::ConfigError(
81 "jit observation current_ramp_ma_s must be finite".to_string(),
82 ));
83 }
84 Ok(())
85}
86
87#[derive(Debug, Clone, Copy, PartialEq)]
89pub struct KernelCompileSpec {
90 pub n_state: usize,
91 pub n_control: usize,
92 pub dt_s: f64,
93 pub unroll_factor: usize,
94}
95
96impl KernelCompileSpec {
97 pub fn validated(self) -> FusionResult<Self> {
98 if self.n_state == 0 {
99 return Err(FusionError::ConfigError(
100 "jit kernel n_state must be > 0".to_string(),
101 ));
102 }
103 if self.n_control == 0 {
104 return Err(FusionError::ConfigError(
105 "jit kernel n_control must be > 0".to_string(),
106 ));
107 }
108 if !self.dt_s.is_finite() || self.dt_s < MIN_DT_S {
109 return Err(FusionError::ConfigError(format!(
110 "jit kernel dt_s must be finite and >= {MIN_DT_S}"
111 )));
112 }
113 if self.unroll_factor == 0 {
114 return Err(FusionError::ConfigError(
115 "jit kernel unroll_factor must be > 0".to_string(),
116 ));
117 }
118 Ok(self)
119 }
120}
121
122impl Default for KernelCompileSpec {
123 fn default() -> Self {
124 Self {
125 n_state: 8,
126 n_control: 4,
127 dt_s: 1e-3,
128 unroll_factor: 4,
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct CompiledKernel {
135 pub regime: PlasmaRegime,
136 pub generation: u64,
137 pub spec: KernelCompileSpec,
138 nonlinear_gain: f64,
139 control_gain: f64,
140 bias: f64,
141}
142
143impl CompiledKernel {
144 fn from_regime(regime: PlasmaRegime, spec: KernelCompileSpec, generation: u64) -> Self {
145 let (nonlinear_gain, control_gain, bias) = match regime {
146 PlasmaRegime::LMode => (0.22, 0.08, -0.01),
147 PlasmaRegime::HMode => (0.30, 0.12, 0.02),
148 PlasmaRegime::RampUp => (0.26, 0.14, 0.04),
149 PlasmaRegime::RampDown => (0.24, 0.13, -0.04),
150 };
151 Self {
152 regime,
153 generation,
154 spec,
155 nonlinear_gain,
156 control_gain,
157 bias,
158 }
159 }
160
161 pub fn step(&self, state: &Array1<f64>, control: &Array1<f64>) -> FusionResult<Array1<f64>> {
163 if state.len() != self.spec.n_state {
164 return Err(FusionError::ConfigError(format!(
165 "jit step state length mismatch: expected {}, got {}",
166 self.spec.n_state,
167 state.len()
168 )));
169 }
170 if control.len() != self.spec.n_control {
171 return Err(FusionError::ConfigError(format!(
172 "jit step control length mismatch: expected {}, got {}",
173 self.spec.n_control,
174 control.len()
175 )));
176 }
177 if state.iter().any(|v| !v.is_finite()) {
178 return Err(FusionError::ConfigError(
179 "jit step state vector must contain only finite values".to_string(),
180 ));
181 }
182 if control.iter().any(|v| !v.is_finite()) {
183 return Err(FusionError::ConfigError(
184 "jit step control vector must contain only finite values".to_string(),
185 ));
186 }
187
188 let control_mean = control.iter().copied().sum::<f64>() / self.spec.n_control as f64;
189 if !control_mean.is_finite() {
190 return Err(FusionError::ConfigError(
191 "jit step control mean became non-finite".to_string(),
192 ));
193 }
194
195 let forcing = self.control_gain * control_mean + self.bias;
196 let dt = self.spec.dt_s;
197 let mut next = Array1::zeros(self.spec.n_state);
198
199 if let (Some(state_slice), Some(next_slice)) = (state.as_slice(), next.as_slice_mut()) {
200 for start in (0..self.spec.n_state).step_by(self.spec.unroll_factor) {
201 let end = (start + self.spec.unroll_factor).min(self.spec.n_state);
202 for idx in start..end {
203 let x = state_slice[idx];
204 let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
205 next_slice[idx] = x + dt * (drift + forcing);
206 }
207 }
208 if next_slice.iter().any(|v| !v.is_finite()) {
209 return Err(FusionError::ConfigError(
210 "jit step produced non-finite state output".to_string(),
211 ));
212 }
213 return Ok(next);
214 }
215
216 for (idx, out) in next.iter_mut().enumerate() {
217 let x = state[idx];
218 let drift = (self.nonlinear_gain * x).tanh() - 0.10 * x;
219 *out = x + dt * (drift + forcing);
220 }
221 if next.iter().any(|v| !v.is_finite()) {
222 return Err(FusionError::ConfigError(
223 "jit step produced non-finite state output".to_string(),
224 ));
225 }
226 Ok(next)
227 }
228}
229
230#[derive(Debug, Default, Clone)]
232pub struct RuntimeKernelJit {
233 kernels: HashMap<PlasmaRegime, CompiledKernel>,
234 active: Option<PlasmaRegime>,
235 compile_events: u64,
236}
237
238impl RuntimeKernelJit {
239 pub fn new() -> Self {
240 Self::default()
241 }
242
243 pub fn compile_for_regime(
245 &mut self,
246 regime: PlasmaRegime,
247 spec: KernelCompileSpec,
248 ) -> FusionResult<u64> {
249 let spec = spec.validated()?;
250 if let Some(existing) = self.kernels.get(®ime) {
251 if existing.spec == spec {
252 self.active = Some(regime);
253 return Ok(existing.generation);
254 }
255 }
256
257 self.compile_events += 1;
258 let generation = self.compile_events;
259 let compiled = CompiledKernel::from_regime(regime, spec, generation);
260 self.kernels.insert(regime, compiled);
261 self.active = Some(regime);
262 Ok(generation)
263 }
264
265 pub fn refresh_for_observation(
267 &mut self,
268 observation: &RegimeObservation,
269 spec: KernelCompileSpec,
270 ) -> FusionResult<(PlasmaRegime, u64)> {
271 validate_observation(observation)?;
272 let regime = detect_regime(observation);
273 let generation = self.compile_for_regime(regime, spec)?;
274 Ok((regime, generation))
275 }
276
277 pub fn active_regime(&self) -> Option<PlasmaRegime> {
278 self.active
279 }
280
281 pub fn cache_size(&self) -> usize {
282 self.kernels.len()
283 }
284
285 pub fn compile_events(&self) -> u64 {
286 self.compile_events
287 }
288
289 pub fn step_active(
291 &self,
292 state: &Array1<f64>,
293 control: &Array1<f64>,
294 ) -> FusionResult<Array1<f64>> {
295 let Some(regime) = self.active else {
296 return Err(FusionError::ConfigError(
297 "jit step_active requires an active regime; compile or refresh first".to_string(),
298 ));
299 };
300 let Some(kernel) = self.kernels.get(®ime) else {
301 return Err(FusionError::ConfigError(
302 "jit step_active active regime has no compiled kernel".to_string(),
303 ));
304 };
305 kernel.step(state, control)
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_detect_regime_routing() {
315 let ramp_up = RegimeObservation {
316 current_ramp_ma_s: 0.4,
317 ..RegimeObservation::default()
318 };
319 let ramp_down = RegimeObservation {
320 current_ramp_ma_s: -0.5,
321 ..RegimeObservation::default()
322 };
323 let h_mode = RegimeObservation {
324 beta_n: 2.4,
325 ..RegimeObservation::default()
326 };
327 let l_mode = RegimeObservation::default();
328
329 assert_eq!(detect_regime(&ramp_up), PlasmaRegime::RampUp);
330 assert_eq!(detect_regime(&ramp_down), PlasmaRegime::RampDown);
331 assert_eq!(detect_regime(&h_mode), PlasmaRegime::HMode);
332 assert_eq!(detect_regime(&l_mode), PlasmaRegime::LMode);
333 }
334
335 #[test]
336 fn test_compile_cache_reuses_generation() {
337 let mut jit = RuntimeKernelJit::new();
338 let spec = KernelCompileSpec::default();
339 let gen1 = jit
340 .compile_for_regime(PlasmaRegime::LMode, spec)
341 .expect("valid compile spec");
342 let gen2 = jit
343 .compile_for_regime(PlasmaRegime::LMode, spec)
344 .expect("valid compile spec");
345 assert_eq!(gen1, gen2);
346 assert_eq!(jit.compile_events(), 1);
347 assert_eq!(jit.cache_size(), 1);
348 }
349
350 #[test]
351 fn test_hot_swap_changes_active_regime_and_response() {
352 let mut jit = RuntimeKernelJit::new();
353 let spec = KernelCompileSpec::default();
354 let state = Array1::from_vec(vec![0.7; spec.n_state]);
355 let control = Array1::from_vec(vec![0.2; spec.n_control]);
356
357 jit.compile_for_regime(PlasmaRegime::LMode, spec)
358 .expect("valid compile spec");
359 let l_step = jit
360 .step_active(&state, &control)
361 .expect("valid active-kernel step inputs");
362
363 jit.compile_for_regime(PlasmaRegime::HMode, spec)
364 .expect("valid compile spec");
365 let h_step = jit
366 .step_active(&state, &control)
367 .expect("valid active-kernel step inputs");
368
369 assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
370 let delta_sum = (&h_step - &l_step).iter().map(|v| v.abs()).sum::<f64>();
371 assert!(
372 delta_sum > 1e-9,
373 "Expected regime hot-swap to alter response"
374 );
375 }
376
377 #[test]
378 fn test_refresh_for_observation_compiles_once_per_regime() {
379 let mut jit = RuntimeKernelJit::new();
380 let spec = KernelCompileSpec::default();
381
382 let obs_l = RegimeObservation::default();
383 let obs_h = RegimeObservation {
384 beta_n: 2.6,
385 ..RegimeObservation::default()
386 };
387 jit.refresh_for_observation(&obs_l, spec)
388 .expect("valid compile spec");
389 jit.refresh_for_observation(&obs_l, spec)
390 .expect("valid compile spec");
391 jit.refresh_for_observation(&obs_h, spec)
392 .expect("valid compile spec");
393 jit.refresh_for_observation(&obs_h, spec)
394 .expect("valid compile spec");
395
396 assert_eq!(jit.compile_events(), 2);
397 assert_eq!(jit.cache_size(), 2);
398 assert_eq!(jit.active_regime(), Some(PlasmaRegime::HMode));
399 }
400
401 #[test]
402 fn test_step_active_without_kernel_rejects_missing_active_kernel() {
403 let jit = RuntimeKernelJit::new();
404 let state = Array1::from_vec(vec![1.0, -2.0, 0.5]);
405 let control = Array1::from_vec(vec![0.1, 0.1]);
406 assert!(jit.step_active(&state, &control).is_err());
407 }
408
409 #[test]
410 fn test_compile_for_regime_rejects_invalid_compile_specs() {
411 let mut jit = RuntimeKernelJit::new();
412 let bad_n_state = KernelCompileSpec {
413 n_state: 0,
414 ..KernelCompileSpec::default()
415 };
416 let bad_n_control = KernelCompileSpec {
417 n_control: 0,
418 ..KernelCompileSpec::default()
419 };
420 let bad_dt = KernelCompileSpec {
421 dt_s: f64::NAN,
422 ..KernelCompileSpec::default()
423 };
424 let bad_unroll = KernelCompileSpec {
425 unroll_factor: 0,
426 ..KernelCompileSpec::default()
427 };
428
429 assert!(jit
430 .compile_for_regime(PlasmaRegime::LMode, bad_n_state)
431 .is_err());
432 assert!(jit
433 .compile_for_regime(PlasmaRegime::LMode, bad_n_control)
434 .is_err());
435 assert!(jit.compile_for_regime(PlasmaRegime::LMode, bad_dt).is_err());
436 assert!(jit
437 .compile_for_regime(PlasmaRegime::LMode, bad_unroll)
438 .is_err());
439 assert_eq!(jit.compile_events(), 0);
440 assert_eq!(jit.cache_size(), 0);
441 }
442
443 #[test]
444 fn test_refresh_for_observation_rejects_non_finite_inputs() {
445 let mut jit = RuntimeKernelJit::new();
446 let spec = KernelCompileSpec::default();
447 let bad = RegimeObservation {
448 beta_n: f64::NAN,
449 ..RegimeObservation::default()
450 };
451 assert!(jit.refresh_for_observation(&bad, spec).is_err());
452 assert_eq!(jit.compile_events(), 0);
453 assert_eq!(jit.cache_size(), 0);
454 }
455
456 #[test]
457 fn test_step_active_rejects_invalid_runtime_vectors() {
458 let mut jit = RuntimeKernelJit::new();
459 let spec = KernelCompileSpec::default();
460 jit.compile_for_regime(PlasmaRegime::LMode, spec)
461 .expect("valid compile spec");
462
463 let bad_state = Array1::from_vec(vec![0.0; spec.n_state - 1]);
464 let good_control = Array1::from_vec(vec![0.1; spec.n_control]);
465 assert!(jit.step_active(&bad_state, &good_control).is_err());
466
467 let good_state = Array1::from_vec(vec![0.0; spec.n_state]);
468 let bad_control = Array1::from_vec(vec![0.1; spec.n_control - 1]);
469 assert!(jit.step_active(&good_state, &bad_control).is_err());
470
471 let nan_state = Array1::from_vec(vec![f64::NAN; spec.n_state]);
472 assert!(jit.step_active(&nan_state, &good_control).is_err());
473 }
474}