Source code for scpn_fusion.core._rust_compat

# SPDX-License-Identifier: AGPL-3.0-or-later
# Commercial license available
# © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
# © Code 2020–2026 Miroslav Šotek. All rights reserved.
# ORCID: 0009-0009-3560-0851
# Contact: www.anulum.li | protoscience@anulum.li
"""Backward-compatibility layer for the optional Rust acceleration backend.

Imports from ``scpn_fusion_rs`` when available, falling back to pure-Python
implementations otherwise.

Usage:
    from scpn_fusion.core._rust_compat import FusionKernel, RUST_BACKEND
"""

from __future__ import annotations

import logging
import os
from collections import deque
from pathlib import Path
from typing import Any, Optional, cast

import numpy as np
from numpy.typing import NDArray

from scpn_fusion.io.safe_loaders import checked_json_load

logger = logging.getLogger(__name__)

FloatArray = NDArray[np.float64]

try:
    from scpn_fusion_rs import (
        PyFusionKernel,
        PyEquilibriumResult,
        PyThermodynamicsResult,
        shafranov_bv,
        solve_coil_currents,
        measure_magnetics,
        simulate_tearing_mode,
    )

    _RUST_AVAILABLE = True
except ImportError:
    _RUST_AVAILABLE = False


def _require_monotonic_axis(name: str, values: FloatArray, expected_len: int) -> FloatArray:
    arr = np.asarray(values, dtype=np.float64)
    if arr.ndim != 1 or arr.size != int(expected_len):
        raise ValueError(f"{name} must be 1-D with length {expected_len}, got shape {arr.shape}")
    if not np.all(np.isfinite(arr)):
        raise ValueError(f"{name} must contain finite values")
    delta = np.diff(arr)
    if delta.size == 0 or not np.all(delta > 0.0):
        raise ValueError(f"{name} must be strictly increasing")
    return arr


def _require_state_grid(
    name: str,
    values: FloatArray,
    *,
    nz: int,
    nr: int,
    require_finite: bool,
) -> FloatArray:
    arr = np.asarray(values, dtype=np.float64)
    expected = (int(nz), int(nr))
    if arr.ndim != 2 or tuple(arr.shape) != expected:
        raise ValueError(f"{name} must have shape {expected}, got {arr.shape}")
    if require_finite and not np.all(np.isfinite(arr)):
        raise ValueError(f"{name} must contain finite values")
    return arr


def _rust_available() -> bool:
    """Check if the Rust backend is loadable."""
    return _RUST_AVAILABLE


[docs] class RustAcceleratedKernel: """Drop-in wrapper around the Rust PyFusionKernel. Mirrors the Python FusionKernel attribute interface (.Psi, .R, .Z, .RR, .ZZ, .cfg, etc.) and delegates the equilibrium solve to Rust for ~20x speedup while keeping all attribute accesses compatible with downstream code. """ def __init__(self, config_path: str | Path) -> None: self._config_path = str(config_path) self.state_sync_failures = 0 self.last_state_sync_error: Optional[str] = None # Load via Rust (PyO3 expects str, not Path) self._rust = PyFusionKernel(self._config_path) # Also load JSON config for attribute access (bridges read .cfg directly) self.cfg = checked_json_load(config_path) # Mirror grid attributes nr, nz = self._rust.grid_shape() self.NR = int(nr) self.NZ = int(nz) if self.NR < 2 or self.NZ < 2: raise ValueError(f"Rust grid shape must be >= 2x2, got {(self.NR, self.NZ)}") self.R = _require_monotonic_axis("R", np.asarray(self._rust.get_r()), self.NR) self.Z = _require_monotonic_axis("Z", np.asarray(self._rust.get_z()), self.NZ) self.dR = float(self.R[1] - self.R[0]) self.dZ = float(self.Z[1] - self.Z[0]) self.RR, self.ZZ = np.meshgrid(self.R, self.Z) # Initialize and validate state from Rust arrays. self.Psi: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64) self.J_phi: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64) self.B_R: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64) self.B_Z: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64) self._sync_state_from_rust(context="init", require_finite=True) self.compute_b_field()
[docs] def solve_equilibrium(self) -> Any: """Solve Grad-Shafranov equilibrium via Rust backend.""" result = self._rust.solve_equilibrium() # Sync arrays back to Python attributes self._sync_state_from_rust(context="solve_equilibrium", require_finite=True) # Compute B-field from Psi (matching Python FusionKernel.compute_b_field) self.compute_b_field() return result
[docs] def compute_b_field(self) -> None: """Compute magnetic field components from Psi gradient.""" if tuple(self.Psi.shape) != (self.NZ, self.NR): raise ValueError( f"Psi shape mismatch for B-field computation: expected {(self.NZ, self.NR)}, " f"got {self.Psi.shape}" ) if not np.all(np.isfinite(self.Psi)): raise ValueError("Psi must contain finite values before B-field computation") # Psi is indexed (Z, R): axis-0 = Z, axis-1 = R dPsi_dZ, dPsi_dR = np.gradient(self.Psi, self.dZ, self.dR) R_safe = np.maximum(self.RR, 1e-6) self.B_R = -(1.0 / R_safe) * dPsi_dZ self.B_Z = (1.0 / R_safe) * dPsi_dR
def _sync_state_from_rust(self, *, context: str, require_finite: bool) -> None: """Synchronize Psi/J_phi arrays from Rust and enforce shape/finite contracts.""" try: psi = _require_state_grid( "Psi", np.asarray(self._rust.get_psi()), nz=self.NZ, nr=self.NR, require_finite=require_finite, ) j_phi = _require_state_grid( "J_phi", np.asarray(self._rust.get_j_phi()), nz=self.NZ, nr=self.NR, require_finite=require_finite, ) except ValueError as exc: self.state_sync_failures += 1 self.last_state_sync_error = str(exc) logger.warning("Rust state sync failed during %s: %s", context, exc) raise RuntimeError(f"Rust state sync failed during {context}: {exc}") from exc self.Psi = psi self.J_phi = j_phi
[docs] def find_x_point(self, Psi: FloatArray) -> tuple[tuple[float, float], float]: """Locate the null point (B=0) using local minimisation. Matches the Python ``FusionKernel.find_x_point()`` interface. """ # Psi is indexed (Z, R): axis-0 = Z, axis-1 = R dPsi_dZ, dPsi_dR = np.gradient(Psi, self.dZ, self.dR) B_mag = np.sqrt(dPsi_dR**2 + dPsi_dZ**2) mask_divertor = (self.cfg["dimensions"]["Z_min"] * 0.5) > self.ZZ if np.any(mask_divertor): masked_B = np.where(mask_divertor, B_mag, 1e9) idx_min = int(np.argmin(masked_B)) iz, ir = np.unravel_index(idx_min, Psi.shape) return (float(self.R[ir]), float(self.Z[iz])), float(Psi[iz, ir]) return (0.0, 0.0), float(np.min(Psi))
[docs] def calculate_thermodynamics(self, p_aux_mw: float) -> Any: """Calculate thermodynamics via Rust backend.""" return self._rust.calculate_thermodynamics(p_aux_mw)
[docs] def calculate_vacuum_field(self) -> FloatArray: """Compute vacuum field with Python reference implementation.""" from scpn_fusion.core.fusion_kernel import FusionKernel as _PyFusionKernel fk = _PyFusionKernel(self._config_path) return fk.calculate_vacuum_field()
[docs] def set_solver_method(self, method: str) -> None: """Set inner linear solver: 'sor' or 'multigrid'.""" self._rust.set_solver_method(method)
[docs] def solver_method(self) -> str: """Get current solver method name.""" return str(self._rust.solver_method())
[docs] def save_results(self, filename: str = "equilibrium_nonlinear.npz") -> None: """Save current state to .npz file.""" np.savez(filename, R=self.R, Z=self.Z, Psi=self.Psi, J_phi=self.J_phi)
if _RUST_AVAILABLE: FusionKernel = RustAcceleratedKernel RUST_BACKEND = True else: RUST_BACKEND = False # Re-export Rust-only helpers (with compatibility shims where needed) if _RUST_AVAILABLE: def rust_shafranov_bv(*args: Any, **kwargs: Any) -> Any: """Compatibility wrapper for legacy config-path invocation. Supported call forms: - rust_shafranov_bv(r_geo, a_min, ip_ma) -> tuple[float, float, float] - rust_shafranov_bv(config_path) -> vacuum Psi array """ if len(args) == 1 and not kwargs and isinstance(args[0], (str, os.PathLike)): from scpn_fusion.core.fusion_kernel import FusionKernel as _PyFusionKernel fk = _PyFusionKernel(str(args[0])) return fk.calculate_vacuum_field() return shafranov_bv(*args, **kwargs) rust_solve_coil_currents = solve_coil_currents rust_measure_magnetics = measure_magnetics def rust_simulate_tearing_mode(steps: int, seed: Optional[int] = None) -> Any: """Rust tearing mode with optional deterministic seed compatibility.""" if seed is None: return simulate_tearing_mode(int(steps)) from scpn_fusion.control.disruption_risk_runtime import ( simulate_tearing_mode as _py_tearing, ) rng = np.random.default_rng(seed=int(seed)) return cast("Any", _py_tearing)(steps=int(steps), rng=rng) else:
[docs] def rust_shafranov_bv(*args: Any, **kwargs: Any) -> Any: raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] def rust_solve_coil_currents(*args: Any, **kwargs: Any) -> Any: raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] def rust_measure_magnetics(*args: Any, **kwargs: Any) -> Any: raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] def rust_simulate_tearing_mode(steps: int, seed: Optional[int] = None) -> Any: raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] class RustSnnPool: """Compatibility wrapper for Rust SpikingControllerPool. Uses the Rust implementation when available and falls back to a deterministic NumPy LIF population otherwise. Parameters ---------- n_neurons : int Number of LIF neurons per sub-population (positive/negative). gain : float Output scaling factor. window_size : int Sliding window length for rate-code averaging. allow_numpy_fallback : bool When ``False``, raise :class:`ImportError` if Rust extension is unavailable. seed : int Seed used by deterministic NumPy compatibility backend. """ def __init__( self, n_neurons: int = 50, gain: float = 10.0, window_size: int = 20, *, allow_numpy_fallback: bool = True, seed: int = 42, ): self._backend = "rust" if _RUST_AVAILABLE: from scpn_fusion_rs import PySnnPool self._inner = PySnnPool(n_neurons, gain, window_size) return if not allow_numpy_fallback: raise ImportError("scpn_fusion_rs not installed and allow_numpy_fallback=False.") self._backend = "numpy_fallback" self._inner = _NumpySnnPoolFallback( n_neurons=n_neurons, gain=gain, window_size=window_size, seed=seed, )
[docs] def step(self, error: float) -> float: """Process *error* through SNN pool and return scalar control output.""" return float(self._inner.step(error))
@property def n_neurons(self) -> int: """Number of neurons in the active pool backend.""" return int(self._inner.n_neurons) @property def gain(self) -> float: """Controller gain used by the active pool backend.""" return float(self._inner.gain) @property def backend(self) -> str: """Name of the active SNN pool backend.""" return self._backend def __repr__(self) -> str: return ( f"RustSnnPool(n_neurons={self.n_neurons}, gain={self.gain}, backend='{self.backend}')" )
[docs] class RustSnnController: """Compatibility wrapper for Rust NeuroCyberneticController. Uses the Rust implementation when available and falls back to paired deterministic NumPy LIF pools otherwise. Parameters ---------- target_r : float Target major-radius position [m]. target_z : float Target vertical position [m]. allow_numpy_fallback : bool When ``False``, raise :class:`ImportError` if Rust extension is unavailable. seed : int Seed used by deterministic NumPy compatibility backend. """ def __init__( self, target_r: float = 6.2, target_z: float = 0.0, *, allow_numpy_fallback: bool = True, seed: int = 42, ): self._backend = "rust" if _RUST_AVAILABLE: from scpn_fusion_rs import PySnnController self._inner = PySnnController(target_r, target_z) return if not allow_numpy_fallback: raise ImportError("scpn_fusion_rs not installed and allow_numpy_fallback=False.") self._backend = "numpy_fallback" self._inner = _NumpySnnControllerFallback( target_r=target_r, target_z=target_z, seed=seed, )
[docs] def step(self, measured_r: float, measured_z: float) -> tuple[float, float]: """Process measured (R, Z) position and return (ctrl_R, ctrl_Z).""" ctrl_r, ctrl_z = self._inner.step(measured_r, measured_z) return float(ctrl_r), float(ctrl_z)
@property def target_r(self) -> float: """Target major-radius position passed to the active backend.""" return float(self._inner.target_r) @property def target_z(self) -> float: """Target vertical position passed to the active backend.""" return float(self._inner.target_z) @property def backend(self) -> str: """Name of the active SNN controller backend.""" return self._backend def __repr__(self) -> str: return ( f"RustSnnController(target_r={self.target_r}, target_z={self.target_z}, " f"backend='{self.backend}')" )
class _NumpySnnPoolFallback: """Deterministic local compatibility path matching the Rust SNN pool interface.""" def __init__( self, n_neurons: int, gain: float, window_size: int, *, seed: int, ) -> None: self.n_neurons = int(n_neurons) self.gain = float(gain) self.window_size = int(window_size) if self.n_neurons < 1: raise ValueError("n_neurons must be >= 1.") if not np.isfinite(self.gain): raise ValueError("gain must be finite.") if self.window_size < 1: raise ValueError("window_size must be >= 1.") self._rng_pos = np.random.default_rng(int(seed)) self._rng_neg = np.random.default_rng(int(seed) + 100003) self._v_pos = np.zeros(self.n_neurons, dtype=np.float64) self._v_neg = np.zeros(self.n_neurons, dtype=np.float64) self._history_pos: deque[int] = deque([0] * self.window_size, maxlen=self.window_size) self._history_neg: deque[int] = deque([0] * self.window_size, maxlen=self.window_size) self._alpha = 1.0e-3 / 15.0e-3 self._noise_std = 0.02 self._i_scale = 5.0 self._i_bias = 0.1 self._v_threshold = 0.35 self._v_reset = 0.0 def _step_pop(self, v: FloatArray, rng: np.random.Generator, input_current: float) -> int: noise = rng.normal(0.0, self._noise_std, size=v.shape) v += self._alpha * (-v + float(input_current) + noise) fired = v >= self._v_threshold n_fired = int(np.count_nonzero(fired)) if n_fired > 0: v[fired] = self._v_reset return n_fired def step(self, error_signal: float) -> float: err = float(error_signal) if not np.isfinite(err): raise ValueError("error_signal must be finite.") input_pos = max(0.0, err) * self._i_scale input_neg = max(0.0, -err) * self._i_scale spikes_pos = self._step_pop(self._v_pos, self._rng_pos, self._i_bias + input_pos) spikes_neg = self._step_pop(self._v_neg, self._rng_neg, self._i_bias + input_neg) self._history_pos.append(spikes_pos) self._history_neg.append(spikes_neg) rate_pos = float(sum(self._history_pos) / (self.window_size * self.n_neurons)) rate_neg = float(sum(self._history_neg) / (self.window_size * self.n_neurons)) return float((rate_pos - rate_neg) * self.gain) class _NumpySnnControllerFallback: """Deterministic local compatibility path matching the Rust SNN controller interface.""" def __init__(self, target_r: float, target_z: float, *, seed: int) -> None: self.target_r = float(target_r) self.target_z = float(target_z) if not np.isfinite(self.target_r) or not np.isfinite(self.target_z): raise ValueError("target_r and target_z must be finite.") self._pool_r = _NumpySnnPoolFallback(50, 10.0, 20, seed=int(seed) + 1) self._pool_z = _NumpySnnPoolFallback(50, 20.0, 20, seed=int(seed) + 2) def step(self, measured_r: float, measured_z: float) -> tuple[float, float]: mr = float(measured_r) mz = float(measured_z) if not np.isfinite(mr) or not np.isfinite(mz): raise ValueError("measured_r and measured_z must be finite.") err_r = self.target_r - mr err_z = self.target_z - mz return self._pool_r.step(err_r), self._pool_z.step(err_z)
[docs] def rust_multigrid_vcycle( source: FloatArray, psi_bc: FloatArray, r_min: float, r_max: float, z_min: float, z_max: float, nr: int, nz: int, tol: float = 1e-6, max_cycles: int = 500, ) -> tuple[FloatArray, float, int, bool] | None: """Call Rust multigrid V-cycle if available, else return None. Returns ------- tuple of (psi, residual, n_cycles, converged), or None when Rust is unavailable. """ if not _RUST_AVAILABLE: logger.warning("scpn_fusion_rs not installed — falling back to Python multigrid.") return None try: from scpn_fusion_rs import multigrid_vcycle as _rust_mg result = _rust_mg(source, psi_bc, r_min, r_max, z_min, z_max, nr, nz, tol, max_cycles) return cast("tuple[FloatArray, float, int, bool]", result) except ImportError: logger.warning("Rust multigrid_vcycle not exposed via PyO3 — falling back to Python.") return None