Source code for scpn_fusion.control.torax_hybrid_loop

# SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
# © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
# © Code 2020–2026 Miroslav Šotek. All rights reserved.
# ORCID: 0009-0009-3560-0851
# Contact: www.anulum.li | protoscience@anulum.li
# SCPN Fusion Core — TORAX Hybrid Realtime Loop (GAI-02)
"""Synthetic TORAX-hybrid realtime control lane for NSTX-U-like scenarios."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, cast

import numpy as np

from scpn_fusion.control.disruption_predictor import predict_disruption_risk
from scpn_fusion.scpn.compiler import FusionCompiler
from scpn_fusion.scpn.contracts import (
    ControlObservation,
    ControlScales,
    ControlTargets,
)
from scpn_fusion.scpn.controller import NeuroSymbolicController
from scpn_fusion.scpn.structure import StochasticPetriNet

_PredictRiskFn = Callable[[list[float], dict[str, float]], float]
_predict_disruption_risk = cast(_PredictRiskFn, predict_disruption_risk)


[docs] @dataclass(frozen=True) class ToraxPlasmaState: beta_n: float q95: float li3: float w_thermal_mj: float
[docs] @dataclass(frozen=True) class ToraxHybridCampaignResult: episodes: int steps_per_episode: int disruption_avoidance_rate: float torax_parity_pct: float p95_loop_latency_ms: float mean_risk: float passes_thresholds: bool
def _estimated_loop_latency_ms(disturbance: float, snn_corr: float) -> float: """Return deterministic hardware-normalized loop latency proxy. This synthetic campaign tracks control-loop complexity rather than host CPU wall-clock jitter, so CI and local environments remain comparable. """ base = 0.24 return float(base + 0.12 * float(np.clip(disturbance, 0.0, 1.0)) + 0.08 * abs(snn_corr)) def _build_hybrid_controller() -> NeuroSymbolicController: net = StochasticPetriNet() net.add_place("x_R_pos", initial_tokens=0.0) net.add_place("x_R_neg", initial_tokens=0.0) net.add_place("a_R_pos", initial_tokens=0.0) net.add_place("a_R_neg", initial_tokens=0.0) net.add_transition("T_Rp", threshold=0.1) net.add_transition("T_Rn", threshold=0.1) net.add_arc("x_R_pos", "T_Rp", weight=1.0) net.add_arc("x_R_neg", "T_Rn", weight=1.0) net.add_arc("T_Rp", "a_R_pos", weight=1.0) net.add_arc("T_Rn", "a_R_neg", weight=1.0) net.compile() compiled = FusionCompiler.with_reactor_lif_defaults( bitstream_length=1024, seed=211, ).compile(net, firing_mode="binary") artifact = compiled.export_artifact( name="gai02_torax_hybrid", dt_control_s=0.001, readout_config={ "actions": [{"name": "dI_PF3_A", "pos_place": 2, "neg_place": 3}], "gains": [2200.0], "abs_max": [4500.0], "slew_per_s": [1e6], }, injection_config=[ {"place_id": 0, "source": "x_R_pos", "scale": 1.0, "offset": 0.0, "clamp_0_1": True}, {"place_id": 1, "source": "x_R_neg", "scale": 1.0, "offset": 0.0, "clamp_0_1": True}, ], ) return NeuroSymbolicController( artifact=artifact, seed_base=314159265, targets=ControlTargets(R_target_m=1.85, Z_target_m=0.0), scales=ControlScales(R_scale_m=0.8, Z_scale_m=1.0), ) def _torax_policy(state: ToraxPlasmaState) -> float: """Reduced TORAX-like policy head for beta/q tracking.""" beta_err = 1.85 - state.beta_n q_err = state.q95 - 4.9 cmd = 1.10 * beta_err - 0.32 * q_err return float(np.clip(cmd, -1.6, 1.6)) def _torax_step( state: ToraxPlasmaState, command: float, disturbance: float, rng: np.random.Generator, ) -> ToraxPlasmaState: """Reduced TORAX-like transport/equilibrium state update.""" command = float(np.clip(command, -2.0, 2.0)) beta_n = state.beta_n + 0.045 * ( 0.85 * command - (state.beta_n - 1.85) - 0.52 * disturbance + rng.normal(0.0, 0.004) ) q95 = state.q95 + 0.060 * ( 0.18 - 0.33 * command + 0.62 * disturbance - 0.16 * (state.q95 - 4.9) + rng.normal(0.0, 0.006) ) li3 = state.li3 + 0.050 * (0.06 * command - 0.11 * disturbance - 0.09 * (state.li3 - 0.95)) w_thermal = state.w_thermal_mj + 0.110 * ( 10.0 * command - 5.0 * disturbance - 0.06 * (state.w_thermal_mj - 140.0) ) return ToraxPlasmaState( beta_n=float(np.clip(beta_n, 0.6, 3.2)), q95=float(np.clip(q95, 2.8, 7.5)), li3=float(np.clip(li3, 0.45, 1.8)), w_thermal_mj=float(np.clip(w_thermal, 50.0, 260.0)), ) def _risk_signal(state: ToraxPlasmaState, disturbance: float) -> float: return float( 0.40 + 0.42 * max(state.beta_n - 2.05, 0.0) + 0.38 * max(4.4 - state.q95, 0.0) + 0.22 * max(state.li3 - 1.25, 0.0) + 0.30 * disturbance )
[docs] def run_nstxu_torax_hybrid_campaign( *, seed: int = 42, episodes: int = 16, steps_per_episode: int = 220, ) -> ToraxHybridCampaignResult: """Run deterministic NSTX-U-like realtime hybrid control campaign.""" rng = np.random.default_rng(int(seed)) controller = _build_hybrid_controller() episodes = int(episodes) if episodes < 1: raise ValueError("episodes must be >= 1.") steps = int(steps_per_episode) if steps < 32: raise ValueError("steps_per_episode must be >= 32.") disruptions = 0 parity_scores = [] latencies_ms = [] all_risks = [] for ep in range(episodes): base = ToraxPlasmaState( beta_n=float(rng.uniform(1.65, 1.95)), q95=float(rng.uniform(4.6, 5.2)), li3=float(rng.uniform(0.85, 1.05)), w_thermal_mj=float(rng.uniform(120.0, 170.0)), ) torax_state = base hybrid_state = base signal_history = [] streak_high_risk = 0 beta_delta_sq = [] beta_ref_sq = [] for k in range(steps): phase = k / max(steps - 1, 1) disturbance = 0.0 if 0.35 <= phase <= 0.58: disturbance = float(0.22 + 0.15 * np.sin(np.pi * (phase - 0.35) / 0.23)) # TORAX-only baseline branch torax_cmd = _torax_policy(torax_state) torax_state = _torax_step(torax_state, torax_cmd, disturbance, rng) # Hybrid branch = TORAX command + SNN correction base_cmd = _torax_policy(hybrid_state) obs: ControlObservation = {"R_axis_m": hybrid_state.beta_n, "Z_axis_m": 0.0} action = controller.step(obs, ep * steps + k) snn_corr = float(np.clip(action["dI_PF3_A"] / 4500.0, -0.45, 0.45)) cmd = float(np.clip(base_cmd + 0.30 * snn_corr, -2.0, 2.0)) hybrid_state = _torax_step(hybrid_state, cmd, disturbance, rng) latencies_ms.append(_estimated_loop_latency_ms(disturbance, snn_corr)) sig = _risk_signal(hybrid_state, disturbance) signal_history.append(sig) toroidal = { "toroidal_n1_amp": 0.04 + 0.40 * disturbance, "toroidal_n2_amp": 0.03 + 0.25 * disturbance, "toroidal_n3_amp": 0.02 + 0.12 * disturbance, "toroidal_asymmetry_index": 0.05 + 0.48 * disturbance, "toroidal_radial_spread": 0.02 + 0.08 * disturbance, } risk = float(_predict_disruption_risk(signal_history, toroidal)) all_risks.append(risk) if risk > 0.93: streak_high_risk += 1 else: streak_high_risk = 0 if streak_high_risk >= 3: disruptions += 1 break beta_delta_sq.append((hybrid_state.beta_n - torax_state.beta_n) ** 2) beta_ref_sq.append(torax_state.beta_n**2) if beta_ref_sq: rmse = float(np.sqrt(np.mean(beta_delta_sq))) scale = float(np.sqrt(np.mean(beta_ref_sq))) parity = float(np.clip(100.0 * (1.0 - rmse / max(scale, 1e-9)), 0.0, 100.0)) parity_scores.append(parity) avoidance_rate = float(1.0 - disruptions / episodes) torax_parity = float(np.mean(parity_scores) if parity_scores else 0.0) p95_latency = float(np.percentile(latencies_ms, 95) if latencies_ms else 0.0) mean_risk = float(np.mean(all_risks) if all_risks else 0.0) passes = bool(avoidance_rate >= 0.90 and torax_parity >= 95.0 and p95_latency <= 1.0) return ToraxHybridCampaignResult( episodes=episodes, steps_per_episode=steps, disruption_avoidance_rate=avoidance_rate, torax_parity_pct=torax_parity, p95_loop_latency_ms=p95_latency, mean_risk=mean_risk, passes_thresholds=passes, )