Source code for scpn_fusion.core._rust_compat

# SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
# © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
# © Code 2020–2026 Miroslav Šotek. All rights reserved.
# ORCID: 0009-0009-3560-0851
# Contact: www.anulum.li | protoscience@anulum.li
# SCPN Fusion Core — Rust Compat
from __future__ import annotations

"""
Backward compatibility layer: imports from Rust (scpn_fusion_rs) if available,
falls back to pure-Python implementations.

Usage:
    from scpn_fusion.core._rust_compat import FusionKernel, RUST_BACKEND
"""
import os
import logging
from collections import deque
from typing import Optional
import numpy as np

logger = logging.getLogger(__name__)

try:
    from scpn_fusion_rs import (
        PyFusionKernel,
        PyEquilibriumResult,
        PyThermodynamicsResult,
        shafranov_bv,
        solve_coil_currents,
        measure_magnetics,
        simulate_tearing_mode,
    )

    _RUST_AVAILABLE = True
except ImportError:
    _RUST_AVAILABLE = False


def _require_monotonic_axis(name: str, values: np.ndarray, expected_len: int) -> np.ndarray:
    arr = np.asarray(values, dtype=np.float64)
    if arr.ndim != 1 or arr.size != int(expected_len):
        raise ValueError(f"{name} must be 1-D with length {expected_len}, got shape {arr.shape}")
    if not np.all(np.isfinite(arr)):
        raise ValueError(f"{name} must contain finite values")
    delta = np.diff(arr)
    if delta.size == 0 or not np.all(delta > 0.0):
        raise ValueError(f"{name} must be strictly increasing")
    return arr


def _require_state_grid(
    name: str,
    values: np.ndarray,
    *,
    nz: int,
    nr: int,
    require_finite: bool,
) -> np.ndarray:
    arr = np.asarray(values, dtype=np.float64)
    expected = (int(nz), int(nr))
    if arr.ndim != 2 or tuple(arr.shape) != expected:
        raise ValueError(f"{name} must have shape {expected}, got {arr.shape}")
    if require_finite and not np.all(np.isfinite(arr)):
        raise ValueError(f"{name} must contain finite values")
    return arr


def _rust_available():
    """Check if the Rust backend is loadable."""
    return _RUST_AVAILABLE


[docs] class RustAcceleratedKernel: """ Drop-in wrapper around Rust PyFusionKernel that mirrors the Python FusionKernel attribute interface (.Psi, .R, .Z, .RR, .ZZ, .cfg, etc.). Delegates equilibrium solve to Rust for ~20x speedup while keeping all attribute accesses compatible with downstream code. """ def __init__(self, config_path): self._config_path = str(config_path) self.state_sync_failures = 0 self.last_state_sync_error: Optional[str] = None # Load via Rust (PyO3 expects str, not Path) self._rust = PyFusionKernel(self._config_path) # Also load JSON config for attribute access (bridges read .cfg directly) import json with open(config_path, "r", encoding="utf-8") as f: self.cfg = json.load(f) # Mirror grid attributes nr, nz = self._rust.grid_shape() self.NR = int(nr) self.NZ = int(nz) if self.NR < 2 or self.NZ < 2: raise ValueError(f"Rust grid shape must be >= 2x2, got {(self.NR, self.NZ)}") self.R = _require_monotonic_axis("R", np.asarray(self._rust.get_r()), self.NR) self.Z = _require_monotonic_axis("Z", np.asarray(self._rust.get_z()), self.NZ) self.dR = float(self.R[1] - self.R[0]) self.dZ = float(self.Z[1] - self.Z[0]) self.RR, self.ZZ = np.meshgrid(self.R, self.Z) # Initialize and validate state from Rust arrays. self.Psi = np.zeros((self.NZ, self.NR), dtype=np.float64) self.J_phi = np.zeros((self.NZ, self.NR), dtype=np.float64) self.B_R = np.zeros((self.NZ, self.NR), dtype=np.float64) self.B_Z = np.zeros((self.NZ, self.NR), dtype=np.float64) self._sync_state_from_rust(context="init", require_finite=True) self.compute_b_field()
[docs] def solve_equilibrium(self): """Solve Grad-Shafranov equilibrium via Rust backend.""" result = self._rust.solve_equilibrium() # Sync arrays back to Python attributes self._sync_state_from_rust(context="solve_equilibrium", require_finite=True) # Compute B-field from Psi (matching Python FusionKernel.compute_b_field) self.compute_b_field() return result
[docs] def compute_b_field(self): """Compute magnetic field components from Psi gradient.""" if tuple(self.Psi.shape) != (self.NZ, self.NR): raise ValueError( f"Psi shape mismatch for B-field computation: expected {(self.NZ, self.NR)}, " f"got {self.Psi.shape}" ) if not np.all(np.isfinite(self.Psi)): raise ValueError("Psi must contain finite values before B-field computation") # Psi is indexed (Z, R): axis-0 = Z, axis-1 = R dPsi_dZ, dPsi_dR = np.gradient(self.Psi, self.dZ, self.dR) R_safe = np.maximum(self.RR, 1e-6) self.B_R = -(1.0 / R_safe) * dPsi_dZ self.B_Z = (1.0 / R_safe) * dPsi_dR
def _sync_state_from_rust(self, *, context: str, require_finite: bool) -> None: """Synchronize Psi/J_phi arrays from Rust and enforce shape/finite contracts.""" try: psi = _require_state_grid( "Psi", np.asarray(self._rust.get_psi()), nz=self.NZ, nr=self.NR, require_finite=require_finite, ) j_phi = _require_state_grid( "J_phi", np.asarray(self._rust.get_j_phi()), nz=self.NZ, nr=self.NR, require_finite=require_finite, ) except ValueError as exc: self.state_sync_failures += 1 self.last_state_sync_error = str(exc) logger.warning("Rust state sync failed during %s: %s", context, exc) raise RuntimeError(f"Rust state sync failed during {context}: {exc}") from exc self.Psi = psi self.J_phi = j_phi
[docs] def find_x_point(self, Psi): """ Locate the null point (B=0) using local minimization. Matches Python FusionKernel.find_x_point() interface. """ # Psi is indexed (Z, R): axis-0 = Z, axis-1 = R dPsi_dZ, dPsi_dR = np.gradient(Psi, self.dZ, self.dR) B_mag = np.sqrt(dPsi_dR**2 + dPsi_dZ**2) mask_divertor = (self.cfg["dimensions"]["Z_min"] * 0.5) > self.ZZ if np.any(mask_divertor): masked_B = np.where(mask_divertor, B_mag, 1e9) idx_min = np.argmin(masked_B) iz, ir = np.unravel_index(idx_min, Psi.shape) return (self.R[ir], self.Z[iz]), Psi[iz, ir] else: return (0, 0), np.min(Psi)
[docs] def calculate_thermodynamics(self, p_aux_mw): """Calculate thermodynamics via Rust backend.""" return self._rust.calculate_thermodynamics(p_aux_mw)
[docs] def calculate_vacuum_field(self): """Compute vacuum field with Python reference implementation.""" from scpn_fusion.core.fusion_kernel import FusionKernel as _PyFusionKernel fk = _PyFusionKernel(self._config_path) return fk.calculate_vacuum_field()
[docs] def set_solver_method(self, method: str) -> None: """Set inner linear solver: 'sor' or 'multigrid'.""" self._rust.set_solver_method(method)
[docs] def solver_method(self) -> str: """Get current solver method name.""" return self._rust.solver_method()
[docs] def save_results(self, filename="equilibrium_nonlinear.npz"): """Save current state to .npz file.""" np.savez(filename, R=self.R, Z=self.Z, Psi=self.Psi, J_phi=self.J_phi)
if _RUST_AVAILABLE: FusionKernel = RustAcceleratedKernel RUST_BACKEND = True else: RUST_BACKEND = False # Re-export Rust-only helpers (with compatibility shims where needed) if _RUST_AVAILABLE: def rust_shafranov_bv(*args, **kwargs): """Compatibility wrapper for legacy config-path invocation. Supported call forms: - rust_shafranov_bv(r_geo, a_min, ip_ma) -> tuple[float, float, float] - rust_shafranov_bv(config_path) -> vacuum Psi array """ if len(args) == 1 and not kwargs and isinstance(args[0], (str, os.PathLike)): from scpn_fusion.core.fusion_kernel import FusionKernel as _PyFusionKernel fk = _PyFusionKernel(str(args[0])) return fk.calculate_vacuum_field() return shafranov_bv(*args, **kwargs) rust_solve_coil_currents = solve_coil_currents rust_measure_magnetics = measure_magnetics def rust_simulate_tearing_mode(steps: int, seed: Optional[int] = None): """Rust tearing mode with optional deterministic seed compatibility.""" if seed is None: return simulate_tearing_mode(int(steps)) from scpn_fusion.control.disruption_predictor import ( simulate_tearing_mode as _py_tearing, ) rng = np.random.default_rng(seed=int(seed)) return _py_tearing(steps=int(steps), rng=rng) else:
[docs] def rust_shafranov_bv(*args, **kwargs): raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] def rust_solve_coil_currents(*args, **kwargs): raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] def rust_measure_magnetics(*args, **kwargs): raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] def rust_simulate_tearing_mode(*args, **kwargs): raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs] class RustSnnPool: """Compatibility wrapper for Rust SpikingControllerPool. Uses the Rust implementation when available and falls back to a deterministic NumPy LIF population otherwise. Parameters ---------- n_neurons : int Number of LIF neurons per sub-population (positive/negative). gain : float Output scaling factor. window_size : int Sliding window length for rate-code averaging. allow_numpy_fallback : bool When ``False``, raise :class:`ImportError` if Rust extension is unavailable. seed : int Seed used by deterministic NumPy compatibility backend. """ def __init__( self, n_neurons: int = 50, gain: float = 10.0, window_size: int = 20, *, allow_numpy_fallback: bool = True, seed: int = 42, ): self._backend = "rust" if _RUST_AVAILABLE: from scpn_fusion_rs import PySnnPool # type: ignore[import-untyped] self._inner = PySnnPool(n_neurons, gain, window_size) return if not allow_numpy_fallback: raise ImportError("scpn_fusion_rs not installed and allow_numpy_fallback=False.") self._backend = "numpy_fallback" self._inner = _NumpySnnPoolFallback( n_neurons=n_neurons, gain=gain, window_size=window_size, seed=seed, )
[docs] def step(self, error: float) -> float: """Process *error* through SNN pool and return scalar control output.""" return self._inner.step(error)
@property def n_neurons(self) -> int: return self._inner.n_neurons @property def gain(self) -> float: return self._inner.gain @property def backend(self) -> str: return self._backend def __repr__(self) -> str: return ( f"RustSnnPool(n_neurons={self.n_neurons}, gain={self.gain}, backend='{self.backend}')" )
[docs] class RustSnnController: """Compatibility wrapper for Rust NeuroCyberneticController. Uses the Rust implementation when available and falls back to paired deterministic NumPy LIF pools otherwise. Parameters ---------- target_r : float Target major-radius position [m]. target_z : float Target vertical position [m]. allow_numpy_fallback : bool When ``False``, raise :class:`ImportError` if Rust extension is unavailable. seed : int Seed used by deterministic NumPy compatibility backend. """ def __init__( self, target_r: float = 6.2, target_z: float = 0.0, *, allow_numpy_fallback: bool = True, seed: int = 42, ): self._backend = "rust" if _RUST_AVAILABLE: from scpn_fusion_rs import PySnnController # type: ignore[import-untyped] self._inner = PySnnController(target_r, target_z) return if not allow_numpy_fallback: raise ImportError("scpn_fusion_rs not installed and allow_numpy_fallback=False.") self._backend = "numpy_fallback" self._inner = _NumpySnnControllerFallback( target_r=target_r, target_z=target_z, seed=seed, )
[docs] def step(self, measured_r: float, measured_z: float) -> tuple[float, float]: """Process measured (R, Z) position and return (ctrl_R, ctrl_Z).""" return self._inner.step(measured_r, measured_z)
@property def target_r(self) -> float: return self._inner.target_r @property def target_z(self) -> float: return self._inner.target_z @property def backend(self) -> str: return self._backend def __repr__(self) -> str: return ( f"RustSnnController(target_r={self.target_r}, target_z={self.target_z}, " f"backend='{self.backend}')" )
class _NumpySnnPoolFallback: """Deterministic local compatibility path matching the Rust SNN pool interface.""" def __init__( self, n_neurons: int, gain: float, window_size: int, *, seed: int, ) -> None: self.n_neurons = int(n_neurons) self.gain = float(gain) self.window_size = int(window_size) if self.n_neurons < 1: raise ValueError("n_neurons must be >= 1.") if not np.isfinite(self.gain): raise ValueError("gain must be finite.") if self.window_size < 1: raise ValueError("window_size must be >= 1.") self._rng_pos = np.random.default_rng(int(seed)) self._rng_neg = np.random.default_rng(int(seed) + 100003) self._v_pos = np.zeros(self.n_neurons, dtype=np.float64) self._v_neg = np.zeros(self.n_neurons, dtype=np.float64) self._history_pos: deque[int] = deque([0] * self.window_size, maxlen=self.window_size) self._history_neg: deque[int] = deque([0] * self.window_size, maxlen=self.window_size) self._alpha = 1.0e-3 / 15.0e-3 self._noise_std = 0.02 self._i_scale = 5.0 self._i_bias = 0.1 self._v_threshold = 0.35 self._v_reset = 0.0 def _step_pop(self, v: np.ndarray, rng: np.random.Generator, input_current: float) -> int: noise = rng.normal(0.0, self._noise_std, size=v.shape) v += self._alpha * (-v + float(input_current) + noise) fired = v >= self._v_threshold n_fired = int(np.count_nonzero(fired)) if n_fired > 0: v[fired] = self._v_reset return n_fired def step(self, error_signal: float) -> float: err = float(error_signal) if not np.isfinite(err): raise ValueError("error_signal must be finite.") input_pos = max(0.0, err) * self._i_scale input_neg = max(0.0, -err) * self._i_scale spikes_pos = self._step_pop(self._v_pos, self._rng_pos, self._i_bias + input_pos) spikes_neg = self._step_pop(self._v_neg, self._rng_neg, self._i_bias + input_neg) self._history_pos.append(spikes_pos) self._history_neg.append(spikes_neg) rate_pos = float(sum(self._history_pos) / (self.window_size * self.n_neurons)) rate_neg = float(sum(self._history_neg) / (self.window_size * self.n_neurons)) return float((rate_pos - rate_neg) * self.gain) class _NumpySnnControllerFallback: """Deterministic local compatibility path matching the Rust SNN controller interface.""" def __init__(self, target_r: float, target_z: float, *, seed: int) -> None: self.target_r = float(target_r) self.target_z = float(target_z) if not np.isfinite(self.target_r) or not np.isfinite(self.target_z): raise ValueError("target_r and target_z must be finite.") self._pool_r = _NumpySnnPoolFallback(50, 10.0, 20, seed=int(seed) + 1) self._pool_z = _NumpySnnPoolFallback(50, 20.0, 20, seed=int(seed) + 2) def step(self, measured_r: float, measured_z: float) -> tuple[float, float]: mr = float(measured_r) mz = float(measured_z) if not np.isfinite(mr) or not np.isfinite(mz): raise ValueError("measured_r and measured_z must be finite.") err_r = self.target_r - mr err_z = self.target_z - mz return self._pool_r.step(err_r), self._pool_z.step(err_z)
[docs] def rust_multigrid_vcycle( source: np.ndarray, psi_bc: np.ndarray, r_min: float, r_max: float, z_min: float, z_max: float, nr: int, nz: int, tol: float = 1e-6, max_cycles: int = 500, ) -> tuple[np.ndarray, float, int, bool] | None: """Call Rust multigrid V-cycle if available, else return None. Returns ------- tuple of (psi, residual, n_cycles, converged), or None when Rust is unavailable. """ if not _RUST_AVAILABLE: logger.warning("scpn_fusion_rs not installed — falling back to Python multigrid.") return None try: from scpn_fusion_rs import multigrid_vcycle as _rust_mg # type: ignore return _rust_mg(source, psi_bc, r_min, r_max, z_min, z_max, nr, nz, tol, max_cycles) except ImportError: logger.warning("Rust multigrid_vcycle not exposed via PyO3 — falling back to Python.") return None