Source code for scpn_fusion.control.jax_traceable_runtime

# SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
# © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
# © Code 2020–2026 Miroslav Šotek. All rights reserved.
# ORCID: 0009-0009-3560-0851
# Contact: www.anulum.li | protoscience@anulum.li
# SCPN Fusion Core — JAX Traceable Runtime
"""Optional JAX-traceable control-loop utilities with NumPy fallback."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

try:
    import jax
    import jax.numpy as jnp

    # Enable float64 for high-precision control analysis
    jax.config.update("jax_enable_x64", True)

    _HAS_JAX = True
except Exception:
    jax = None  # type: ignore[assignment]
    jnp = None  # type: ignore[assignment]
    _HAS_JAX = False

try:
    import torch

    _HAS_TORCH = True
except Exception:
    torch = None  # type: ignore[assignment]
    _HAS_TORCH = False


FloatArray = NDArray[np.float64]


[docs] @dataclass(frozen=True) class TraceableRuntimeSpec: """Configuration for reduced traceable first-order actuator dynamics.""" dt_s: float = 1.0e-3 tau_s: float = 5.0e-3 gain: float = 1.0 command_limit: float = 1.0
[docs] @dataclass(frozen=True) class TraceableRuntimeResult: """Result of a traceable control-loop rollout.""" state_history: FloatArray backend_used: str compiled: bool
[docs] @dataclass(frozen=True) class TraceableRuntimeBatchResult: """Result of batched traceable control-loop rollout.""" state_history: FloatArray backend_used: str compiled: bool
[docs] @dataclass(frozen=True) class TraceableBackendParityReport: """Parity metrics against NumPy reference backend.""" backend: str single_max_abs_err: float batch_max_abs_err: float single_within_tol: bool batch_within_tol: bool
def _validate_spec(spec: TraceableRuntimeSpec) -> None: if not np.isfinite(spec.dt_s) or spec.dt_s <= 0.0: raise ValueError("dt_s must be finite and > 0.") if not np.isfinite(spec.tau_s) or spec.tau_s <= 0.0: raise ValueError("tau_s must be finite and > 0.") if not np.isfinite(spec.gain): raise ValueError("gain must be finite.") if not np.isfinite(spec.command_limit) or spec.command_limit <= 0.0: raise ValueError("command_limit must be finite and > 0.") def _validate_commands(commands: FloatArray) -> None: if commands.ndim != 1 or commands.size == 0: raise ValueError("commands must be a non-empty 1D array.") if not np.all(np.isfinite(commands)): raise ValueError("commands must contain only finite values.") def _validate_batch_commands(commands: FloatArray) -> None: if commands.ndim != 2 or commands.shape[0] == 0 or commands.shape[1] == 0: raise ValueError("commands must have shape (batch, steps) with non-zero sizes.") if not np.all(np.isfinite(commands)): raise ValueError("commands must contain only finite values.") def _resolve_backend(backend: str) -> str: b = str(backend).strip().lower() if b not in {"auto", "numpy", "jax", "torchscript"}: raise ValueError("backend must be one of: auto, numpy, jax, torchscript.") if b == "auto": if _HAS_JAX: return "jax" if _HAS_TORCH: return "torchscript" return "numpy" return b
[docs] def available_traceable_backends() -> list[str]: """Return available runtime backends on this machine.""" out = ["numpy"] if _HAS_JAX: out.append("jax") if _HAS_TORCH: out.append("torchscript") return out
def _resolve_backend_set(backends: list[str] | tuple[str, ...] | None) -> list[str]: available = available_traceable_backends() if backends is None: return available out: list[str] = [] seen: set[str] = set() for raw in backends: name = str(raw).strip().lower() if name not in {"numpy", "jax", "torchscript"}: raise ValueError(f"Unsupported backend '{raw}'. Allowed: numpy, jax, torchscript.") if name not in available: raise ValueError(f"Requested backend '{name}' is not available on this host.") if name not in seen: out.append(name) seen.add(name) if not out: raise ValueError("backends must contain at least one backend when provided.") return out def _simulate_numpy( commands: FloatArray, initial_state: float, spec: TraceableRuntimeSpec ) -> FloatArray: alpha = float(spec.dt_s / (spec.tau_s + spec.dt_s)) state = float(initial_state) out = np.empty_like(commands, dtype=np.float64) for i, cmd in enumerate(commands): cmd_clipped = float(np.clip(cmd, -spec.command_limit, spec.command_limit)) state = state + alpha * ((spec.gain * cmd_clipped) - state) out[i] = state return out def _simulate_jax( commands: FloatArray, initial_state: float, spec: TraceableRuntimeSpec ) -> FloatArray: if not _HAS_JAX: raise RuntimeError("JAX backend requested but JAX is not installed.") if jnp is None or jax is None: raise RuntimeError("JAX backend requested but JAX imports are unavailable.") cmd = jnp.asarray(commands, dtype=jnp.float64) alpha = jnp.asarray(spec.dt_s / (spec.tau_s + spec.dt_s), dtype=jnp.float64) gain = jnp.asarray(spec.gain, dtype=jnp.float64) limit = jnp.asarray(spec.command_limit, dtype=jnp.float64) def _step(state, u): u_clip = jnp.clip(u, -limit, limit) next_state = state + alpha * ((gain * u_clip) - state) return next_state, next_state @jax.jit def _rollout(x0, u): _, hist = jax.lax.scan(_step, x0, u) return hist hist = _rollout(jnp.asarray(initial_state, dtype=jnp.float64), cmd) return np.asarray(hist, dtype=np.float64) if _HAS_TORCH: @torch.jit.script def _torchscript_rollout( cmd: torch.Tensor, initial_state: float, alpha: float, gain: float, limit: float, ) -> torch.Tensor: n = cmd.numel() out = torch.empty((n,), dtype=cmd.dtype, device=cmd.device) state = torch.tensor(initial_state, dtype=cmd.dtype, device=cmd.device) for i in range(n): u = torch.clamp(cmd[i], -limit, limit) state = state + alpha * ((gain * u) - state) out[i] = state return out @torch.jit.script def _torchscript_rollout_batch( cmd: torch.Tensor, initial_state: torch.Tensor, alpha: float, gain: float, limit: float, ) -> torch.Tensor: batch = cmd.size(0) steps = cmd.size(1) out = torch.empty((batch, steps), dtype=cmd.dtype, device=cmd.device) state = initial_state.clone() for t in range(steps): u = torch.clamp(cmd[:, t], -limit, limit) state = state + alpha * ((gain * u) - state) out[:, t] = state return out else: _torchscript_rollout = None _torchscript_rollout_batch = None def _simulate_torchscript( commands: FloatArray, initial_state: float, spec: TraceableRuntimeSpec ) -> FloatArray: if not _HAS_TORCH or _torchscript_rollout is None: raise RuntimeError("TorchScript backend requested but torch is not installed.") if torch is None: raise RuntimeError("TorchScript backend requested but torch imports are unavailable.") cmd = torch.as_tensor(commands, dtype=torch.float64) alpha = float(spec.dt_s / (spec.tau_s + spec.dt_s)) hist = _torchscript_rollout( cmd, float(initial_state), alpha, float(spec.gain), float(spec.command_limit), ) return np.asarray(hist.detach().cpu().numpy(), dtype=np.float64) def _simulate_numpy_batch( commands: FloatArray, initial_state: FloatArray, spec: TraceableRuntimeSpec ) -> FloatArray: alpha = float(spec.dt_s / (spec.tau_s + spec.dt_s)) state = np.asarray(initial_state, dtype=np.float64).copy() out = np.empty_like(commands, dtype=np.float64) for t in range(commands.shape[1]): u = np.clip(commands[:, t], -spec.command_limit, spec.command_limit) state = state + alpha * ((spec.gain * u) - state) out[:, t] = state return out def _simulate_jax_batch( commands: FloatArray, initial_state: FloatArray, spec: TraceableRuntimeSpec ) -> FloatArray: if not _HAS_JAX: raise RuntimeError("JAX backend requested but JAX is not installed.") if jnp is None or jax is None: raise RuntimeError("JAX backend requested but JAX imports are unavailable.") cmd = jnp.asarray(commands, dtype=jnp.float64) x0 = jnp.asarray(initial_state, dtype=jnp.float64) alpha = jnp.asarray(spec.dt_s / (spec.tau_s + spec.dt_s), dtype=jnp.float64) gain = jnp.asarray(spec.gain, dtype=jnp.float64) limit = jnp.asarray(spec.command_limit, dtype=jnp.float64) def _step(state, u_t): u_clip = jnp.clip(u_t, -limit, limit) next_state = state + alpha * ((gain * u_clip) - state) return next_state, next_state @jax.jit def _rollout_batch(batch_x0, batch_u): _, hist_tb = jax.lax.scan(_step, batch_x0, jnp.swapaxes(batch_u, 0, 1)) return jnp.swapaxes(hist_tb, 0, 1) hist = _rollout_batch(x0, cmd) return np.asarray(hist, dtype=np.float64) def _simulate_torchscript_batch( commands: FloatArray, initial_state: FloatArray, spec: TraceableRuntimeSpec ) -> FloatArray: if not _HAS_TORCH or _torchscript_rollout_batch is None: raise RuntimeError("TorchScript backend requested but torch is not installed.") if torch is None: raise RuntimeError("TorchScript backend requested but torch imports are unavailable.") cmd = torch.as_tensor(commands, dtype=torch.float64) x0 = torch.as_tensor(initial_state, dtype=torch.float64) alpha = float(spec.dt_s / (spec.tau_s + spec.dt_s)) hist = _torchscript_rollout_batch( cmd, x0, alpha, float(spec.gain), float(spec.command_limit), ) return np.asarray(hist.detach().cpu().numpy(), dtype=np.float64)
[docs] def run_traceable_control_loop( commands: FloatArray, *, initial_state: float = 0.0, spec: TraceableRuntimeSpec | None = None, backend: str = "auto", ) -> TraceableRuntimeResult: """ Run a reduced control loop suitable for optional JAX tracing/JIT. `backend` can be `auto`, `numpy`, `jax`, or `torchscript`. """ cmd_arr = np.asarray(commands, dtype=np.float64).reshape(-1) _validate_commands(cmd_arr) if not np.isfinite(initial_state): raise ValueError("initial_state must be finite.") runtime_spec = spec if spec is not None else TraceableRuntimeSpec() _validate_spec(runtime_spec) b = _resolve_backend(backend) if b == "jax": return TraceableRuntimeResult( state_history=_simulate_jax(cmd_arr, float(initial_state), runtime_spec), backend_used="jax", compiled=True, ) if b == "torchscript": return TraceableRuntimeResult( state_history=_simulate_torchscript(cmd_arr, float(initial_state), runtime_spec), backend_used="torchscript", compiled=True, ) return TraceableRuntimeResult( state_history=_simulate_numpy(cmd_arr, float(initial_state), runtime_spec), backend_used="numpy", compiled=False, )
[docs] def run_traceable_control_batch( commands: FloatArray, *, initial_state: FloatArray | float | None = None, spec: TraceableRuntimeSpec | None = None, backend: str = "auto", ) -> TraceableRuntimeBatchResult: """ Run batched reduced control loops with optional JAX/TorchScript backends. `commands` shape: (batch, steps) """ cmd_arr = np.asarray(commands, dtype=np.float64) _validate_batch_commands(cmd_arr) batch = int(cmd_arr.shape[0]) if initial_state is None: x0 = np.zeros(batch, dtype=np.float64) else: arr = np.asarray(initial_state, dtype=np.float64) if arr.ndim == 0: x0 = np.full(batch, float(arr), dtype=np.float64) else: x0 = arr.reshape(-1) if x0.size != batch: raise ValueError("initial_state length must match commands batch dimension.") if not np.all(np.isfinite(x0)): raise ValueError("initial_state must contain only finite values.") runtime_spec = spec if spec is not None else TraceableRuntimeSpec() _validate_spec(runtime_spec) b = _resolve_backend(backend) if b == "jax": return TraceableRuntimeBatchResult( state_history=_simulate_jax_batch(cmd_arr, x0, runtime_spec), backend_used="jax", compiled=True, ) if b == "torchscript": return TraceableRuntimeBatchResult( state_history=_simulate_torchscript_batch(cmd_arr, x0, runtime_spec), backend_used="torchscript", compiled=True, ) return TraceableRuntimeBatchResult( state_history=_simulate_numpy_batch(cmd_arr, x0, runtime_spec), backend_used="numpy", compiled=False, )
[docs] def validate_traceable_backend_parity( *, steps: int = 64, batch: int = 8, seed: int = 42, spec: TraceableRuntimeSpec | None = None, atol: float = 1e-8, backends: list[str] | tuple[str, ...] | None = None, ) -> dict[str, TraceableBackendParityReport]: """ Compare available compiled backends to NumPy for single and batch rollouts. """ if steps <= 0: raise ValueError("steps must be > 0.") if batch <= 0: raise ValueError("batch must be > 0.") if not np.isfinite(atol) or atol < 0.0: raise ValueError("atol must be finite and >= 0.") runtime_spec = spec if spec is not None else TraceableRuntimeSpec() _validate_spec(runtime_spec) rng = np.random.default_rng(int(seed)) single_cmd = np.asarray(rng.normal(0.0, 1.0, size=steps), dtype=np.float64) batch_cmd = np.asarray(rng.normal(0.0, 1.0, size=(batch, steps)), dtype=np.float64) batch_x0 = np.asarray(rng.normal(0.0, 0.2, size=batch), dtype=np.float64) x0 = float(rng.normal(0.0, 0.2)) ref_single = run_traceable_control_loop( single_cmd, initial_state=x0, spec=runtime_spec, backend="numpy" ).state_history ref_batch = run_traceable_control_batch( batch_cmd, initial_state=batch_x0, spec=runtime_spec, backend="numpy" ).state_history reports: dict[str, TraceableBackendParityReport] = {} backend_list = _resolve_backend_set(backends) for backend in backend_list: out_single = run_traceable_control_loop( single_cmd, initial_state=x0, spec=runtime_spec, backend=backend ).state_history out_batch = run_traceable_control_batch( batch_cmd, initial_state=batch_x0, spec=runtime_spec, backend=backend ).state_history s_err = float(np.max(np.abs(out_single - ref_single))) b_err = float(np.max(np.abs(out_batch - ref_batch))) reports[backend] = TraceableBackendParityReport( backend=backend, single_max_abs_err=s_err, batch_max_abs_err=b_err, single_within_tol=bool(s_err <= atol), batch_within_tol=bool(b_err <= atol), ) return reports