# SPDX-License-Identifier: AGPL-3.0-or-later
# Commercial license available
# © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
# © Code 2020–2026 Miroslav Šotek. All rights reserved.
# ORCID: 0009-0009-3560-0851
# Contact: www.anulum.li | protoscience@anulum.li
"""Backward-compatibility layer for the optional Rust acceleration backend.
Imports from ``scpn_fusion_rs`` when available, falling back to pure-Python
implementations otherwise.
Usage:
from scpn_fusion.core._rust_compat import FusionKernel, RUST_BACKEND
"""
from __future__ import annotations
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from numpy.typing import NDArray
from scpn_fusion.io.safe_loaders import checked_json_load
logger = logging.getLogger(__name__)
FloatArray = NDArray[np.float64]
try:
from scpn_fusion_rs import (
PyFusionKernel,
PyEquilibriumResult,
PyThermodynamicsResult,
shafranov_bv,
solve_coil_currents,
measure_magnetics,
simulate_tearing_mode,
)
_RUST_AVAILABLE = True
except ImportError:
_RUST_AVAILABLE = False
def _require_monotonic_axis(name: str, values: FloatArray, expected_len: int) -> FloatArray:
arr = np.asarray(values, dtype=np.float64)
if arr.ndim != 1 or arr.size != int(expected_len):
raise ValueError(f"{name} must be 1-D with length {expected_len}, got shape {arr.shape}")
if not np.all(np.isfinite(arr)):
raise ValueError(f"{name} must contain finite values")
delta = np.diff(arr)
if delta.size == 0 or not np.all(delta > 0.0):
raise ValueError(f"{name} must be strictly increasing")
return arr
def _require_state_grid(
name: str,
values: FloatArray,
*,
nz: int,
nr: int,
require_finite: bool,
) -> FloatArray:
arr = np.asarray(values, dtype=np.float64)
expected = (int(nz), int(nr))
if arr.ndim != 2 or tuple(arr.shape) != expected:
raise ValueError(f"{name} must have shape {expected}, got {arr.shape}")
if require_finite and not np.all(np.isfinite(arr)):
raise ValueError(f"{name} must contain finite values")
return arr
def _rust_available() -> bool:
"""Check if the Rust backend is loadable."""
return _RUST_AVAILABLE
[docs]
class RustAcceleratedKernel:
"""Drop-in wrapper around the Rust PyFusionKernel.
Mirrors the Python FusionKernel attribute interface (.Psi, .R, .Z, .RR, .ZZ,
.cfg, etc.) and delegates the equilibrium solve to Rust for ~20x speedup while
keeping all attribute accesses compatible with downstream code.
"""
def __init__(self, config_path: str | Path) -> None:
self._config_path = str(config_path)
self.state_sync_failures = 0
self.last_state_sync_error: Optional[str] = None
# Load via Rust (PyO3 expects str, not Path)
self._rust = PyFusionKernel(self._config_path)
# Also load JSON config for attribute access (bridges read .cfg directly)
self.cfg = checked_json_load(config_path)
# Mirror grid attributes
nr, nz = self._rust.grid_shape()
self.NR = int(nr)
self.NZ = int(nz)
if self.NR < 2 or self.NZ < 2:
raise ValueError(f"Rust grid shape must be >= 2x2, got {(self.NR, self.NZ)}")
self.R = _require_monotonic_axis("R", np.asarray(self._rust.get_r()), self.NR)
self.Z = _require_monotonic_axis("Z", np.asarray(self._rust.get_z()), self.NZ)
self.dR = float(self.R[1] - self.R[0])
self.dZ = float(self.Z[1] - self.Z[0])
self.RR, self.ZZ = np.meshgrid(self.R, self.Z)
# Initialize and validate state from Rust arrays.
self.Psi: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64)
self.J_phi: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64)
self.B_R: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64)
self.B_Z: FloatArray = np.zeros((self.NZ, self.NR), dtype=np.float64)
self._sync_state_from_rust(context="init", require_finite=True)
self.compute_b_field()
[docs]
def solve_equilibrium(self) -> Any:
"""Solve Grad-Shafranov equilibrium via Rust backend."""
result = self._rust.solve_equilibrium()
# Sync arrays back to Python attributes
self._sync_state_from_rust(context="solve_equilibrium", require_finite=True)
# Compute B-field from Psi (matching Python FusionKernel.compute_b_field)
self.compute_b_field()
return result
[docs]
def compute_b_field(self) -> None:
"""Compute magnetic field components from Psi gradient."""
if tuple(self.Psi.shape) != (self.NZ, self.NR):
raise ValueError(
f"Psi shape mismatch for B-field computation: expected {(self.NZ, self.NR)}, "
f"got {self.Psi.shape}"
)
if not np.all(np.isfinite(self.Psi)):
raise ValueError("Psi must contain finite values before B-field computation")
# Psi is indexed (Z, R): axis-0 = Z, axis-1 = R
dPsi_dZ, dPsi_dR = np.gradient(self.Psi, self.dZ, self.dR)
R_safe = np.maximum(self.RR, 1e-6)
self.B_R = -(1.0 / R_safe) * dPsi_dZ
self.B_Z = (1.0 / R_safe) * dPsi_dR
def _sync_state_from_rust(self, *, context: str, require_finite: bool) -> None:
"""Synchronize Psi/J_phi arrays from Rust and enforce shape/finite contracts."""
try:
psi = _require_state_grid(
"Psi",
np.asarray(self._rust.get_psi()),
nz=self.NZ,
nr=self.NR,
require_finite=require_finite,
)
j_phi = _require_state_grid(
"J_phi",
np.asarray(self._rust.get_j_phi()),
nz=self.NZ,
nr=self.NR,
require_finite=require_finite,
)
except ValueError as exc:
self.state_sync_failures += 1
self.last_state_sync_error = str(exc)
logger.warning("Rust state sync failed during %s: %s", context, exc)
raise RuntimeError(f"Rust state sync failed during {context}: {exc}") from exc
self.Psi = psi
self.J_phi = j_phi
[docs]
def find_x_point(self, Psi: FloatArray) -> tuple[tuple[float, float], float]:
"""Locate the null point (B=0) using local minimisation.
Matches the Python ``FusionKernel.find_x_point()`` interface.
"""
# Psi is indexed (Z, R): axis-0 = Z, axis-1 = R
dPsi_dZ, dPsi_dR = np.gradient(Psi, self.dZ, self.dR)
B_mag = np.sqrt(dPsi_dR**2 + dPsi_dZ**2)
mask_divertor = (self.cfg["dimensions"]["Z_min"] * 0.5) > self.ZZ
if np.any(mask_divertor):
masked_B = np.where(mask_divertor, B_mag, 1e9)
idx_min = int(np.argmin(masked_B))
iz, ir = np.unravel_index(idx_min, Psi.shape)
return (float(self.R[ir]), float(self.Z[iz])), float(Psi[iz, ir])
return (0.0, 0.0), float(np.min(Psi))
[docs]
def calculate_thermodynamics(self, p_aux_mw: float) -> Any:
"""Calculate thermodynamics via Rust backend."""
return self._rust.calculate_thermodynamics(p_aux_mw)
[docs]
def calculate_vacuum_field(self) -> FloatArray:
"""Compute vacuum field with Python reference implementation."""
from scpn_fusion.core.fusion_kernel import FusionKernel as _PyFusionKernel
fk = _PyFusionKernel(self._config_path)
return fk.calculate_vacuum_field()
[docs]
def set_solver_method(self, method: str) -> None:
"""Set inner linear solver: 'sor' or 'multigrid'."""
self._rust.set_solver_method(method)
[docs]
def solver_method(self) -> str:
"""Get current solver method name."""
return str(self._rust.solver_method())
[docs]
def save_results(self, filename: str = "equilibrium_nonlinear.npz") -> None:
"""Save current state to .npz file."""
np.savez(filename, R=self.R, Z=self.Z, Psi=self.Psi, J_phi=self.J_phi)
if _RUST_AVAILABLE:
FusionKernel = RustAcceleratedKernel
RUST_BACKEND = True
else:
RUST_BACKEND = False
# Re-export Rust-only helpers (with compatibility shims where needed)
if _RUST_AVAILABLE:
def rust_shafranov_bv(*args: Any, **kwargs: Any) -> Any:
"""Compatibility wrapper for legacy config-path invocation.
Supported call forms:
- rust_shafranov_bv(r_geo, a_min, ip_ma) -> tuple[float, float, float]
- rust_shafranov_bv(config_path) -> vacuum Psi array
"""
if len(args) == 1 and not kwargs and isinstance(args[0], (str, os.PathLike)):
from scpn_fusion.core.fusion_kernel import FusionKernel as _PyFusionKernel
fk = _PyFusionKernel(str(args[0]))
return fk.calculate_vacuum_field()
return shafranov_bv(*args, **kwargs)
rust_solve_coil_currents = solve_coil_currents
rust_measure_magnetics = measure_magnetics
def rust_simulate_tearing_mode(steps: int, seed: Optional[int] = None) -> Any:
"""Rust tearing mode with optional deterministic seed compatibility."""
if seed is None:
return simulate_tearing_mode(int(steps))
from scpn_fusion.control.disruption_risk_runtime import (
simulate_tearing_mode as _py_tearing,
)
rng = np.random.default_rng(seed=int(seed))
return cast("Any", _py_tearing)(steps=int(steps), rng=rng)
else:
[docs]
def rust_shafranov_bv(*args: Any, **kwargs: Any) -> Any:
raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs]
def rust_solve_coil_currents(*args: Any, **kwargs: Any) -> Any:
raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs]
def rust_measure_magnetics(*args: Any, **kwargs: Any) -> Any:
raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs]
def rust_simulate_tearing_mode(steps: int, seed: Optional[int] = None) -> Any:
raise ImportError("scpn_fusion_rs not installed. Run: maturin develop")
[docs]
class RustSnnPool:
"""Compatibility wrapper for Rust SpikingControllerPool.
Uses the Rust implementation when available and falls back to a
deterministic NumPy LIF population otherwise.
Parameters
----------
n_neurons : int
Number of LIF neurons per sub-population (positive/negative).
gain : float
Output scaling factor.
window_size : int
Sliding window length for rate-code averaging.
allow_numpy_fallback : bool
When ``False``, raise :class:`ImportError` if Rust extension is
unavailable.
seed : int
Seed used by deterministic NumPy compatibility backend.
"""
def __init__(
self,
n_neurons: int = 50,
gain: float = 10.0,
window_size: int = 20,
*,
allow_numpy_fallback: bool = True,
seed: int = 42,
):
self._backend = "rust"
if _RUST_AVAILABLE:
from scpn_fusion_rs import PySnnPool
self._inner = PySnnPool(n_neurons, gain, window_size)
return
if not allow_numpy_fallback:
raise ImportError("scpn_fusion_rs not installed and allow_numpy_fallback=False.")
self._backend = "numpy_fallback"
self._inner = _NumpySnnPoolFallback(
n_neurons=n_neurons,
gain=gain,
window_size=window_size,
seed=seed,
)
[docs]
def step(self, error: float) -> float:
"""Process *error* through SNN pool and return scalar control output."""
return float(self._inner.step(error))
@property
def n_neurons(self) -> int:
"""Number of neurons in the active pool backend."""
return int(self._inner.n_neurons)
@property
def gain(self) -> float:
"""Controller gain used by the active pool backend."""
return float(self._inner.gain)
@property
def backend(self) -> str:
"""Name of the active SNN pool backend."""
return self._backend
def __repr__(self) -> str:
return (
f"RustSnnPool(n_neurons={self.n_neurons}, gain={self.gain}, backend='{self.backend}')"
)
[docs]
class RustSnnController:
"""Compatibility wrapper for Rust NeuroCyberneticController.
Uses the Rust implementation when available and falls back to paired
deterministic NumPy LIF pools otherwise.
Parameters
----------
target_r : float
Target major-radius position [m].
target_z : float
Target vertical position [m].
allow_numpy_fallback : bool
When ``False``, raise :class:`ImportError` if Rust extension is
unavailable.
seed : int
Seed used by deterministic NumPy compatibility backend.
"""
def __init__(
self,
target_r: float = 6.2,
target_z: float = 0.0,
*,
allow_numpy_fallback: bool = True,
seed: int = 42,
):
self._backend = "rust"
if _RUST_AVAILABLE:
from scpn_fusion_rs import PySnnController
self._inner = PySnnController(target_r, target_z)
return
if not allow_numpy_fallback:
raise ImportError("scpn_fusion_rs not installed and allow_numpy_fallback=False.")
self._backend = "numpy_fallback"
self._inner = _NumpySnnControllerFallback(
target_r=target_r,
target_z=target_z,
seed=seed,
)
[docs]
def step(self, measured_r: float, measured_z: float) -> tuple[float, float]:
"""Process measured (R, Z) position and return (ctrl_R, ctrl_Z)."""
ctrl_r, ctrl_z = self._inner.step(measured_r, measured_z)
return float(ctrl_r), float(ctrl_z)
@property
def target_r(self) -> float:
"""Target major-radius position passed to the active backend."""
return float(self._inner.target_r)
@property
def target_z(self) -> float:
"""Target vertical position passed to the active backend."""
return float(self._inner.target_z)
@property
def backend(self) -> str:
"""Name of the active SNN controller backend."""
return self._backend
def __repr__(self) -> str:
return (
f"RustSnnController(target_r={self.target_r}, target_z={self.target_z}, "
f"backend='{self.backend}')"
)
class _NumpySnnPoolFallback:
"""Deterministic local compatibility path matching the Rust SNN pool interface."""
def __init__(
self,
n_neurons: int,
gain: float,
window_size: int,
*,
seed: int,
) -> None:
self.n_neurons = int(n_neurons)
self.gain = float(gain)
self.window_size = int(window_size)
if self.n_neurons < 1:
raise ValueError("n_neurons must be >= 1.")
if not np.isfinite(self.gain):
raise ValueError("gain must be finite.")
if self.window_size < 1:
raise ValueError("window_size must be >= 1.")
self._rng_pos = np.random.default_rng(int(seed))
self._rng_neg = np.random.default_rng(int(seed) + 100003)
self._v_pos = np.zeros(self.n_neurons, dtype=np.float64)
self._v_neg = np.zeros(self.n_neurons, dtype=np.float64)
self._history_pos: deque[int] = deque([0] * self.window_size, maxlen=self.window_size)
self._history_neg: deque[int] = deque([0] * self.window_size, maxlen=self.window_size)
self._alpha = 1.0e-3 / 15.0e-3
self._noise_std = 0.02
self._i_scale = 5.0
self._i_bias = 0.1
self._v_threshold = 0.35
self._v_reset = 0.0
def _step_pop(self, v: FloatArray, rng: np.random.Generator, input_current: float) -> int:
noise = rng.normal(0.0, self._noise_std, size=v.shape)
v += self._alpha * (-v + float(input_current) + noise)
fired = v >= self._v_threshold
n_fired = int(np.count_nonzero(fired))
if n_fired > 0:
v[fired] = self._v_reset
return n_fired
def step(self, error_signal: float) -> float:
err = float(error_signal)
if not np.isfinite(err):
raise ValueError("error_signal must be finite.")
input_pos = max(0.0, err) * self._i_scale
input_neg = max(0.0, -err) * self._i_scale
spikes_pos = self._step_pop(self._v_pos, self._rng_pos, self._i_bias + input_pos)
spikes_neg = self._step_pop(self._v_neg, self._rng_neg, self._i_bias + input_neg)
self._history_pos.append(spikes_pos)
self._history_neg.append(spikes_neg)
rate_pos = float(sum(self._history_pos) / (self.window_size * self.n_neurons))
rate_neg = float(sum(self._history_neg) / (self.window_size * self.n_neurons))
return float((rate_pos - rate_neg) * self.gain)
class _NumpySnnControllerFallback:
"""Deterministic local compatibility path matching the Rust SNN controller interface."""
def __init__(self, target_r: float, target_z: float, *, seed: int) -> None:
self.target_r = float(target_r)
self.target_z = float(target_z)
if not np.isfinite(self.target_r) or not np.isfinite(self.target_z):
raise ValueError("target_r and target_z must be finite.")
self._pool_r = _NumpySnnPoolFallback(50, 10.0, 20, seed=int(seed) + 1)
self._pool_z = _NumpySnnPoolFallback(50, 20.0, 20, seed=int(seed) + 2)
def step(self, measured_r: float, measured_z: float) -> tuple[float, float]:
mr = float(measured_r)
mz = float(measured_z)
if not np.isfinite(mr) or not np.isfinite(mz):
raise ValueError("measured_r and measured_z must be finite.")
err_r = self.target_r - mr
err_z = self.target_z - mz
return self._pool_r.step(err_r), self._pool_z.step(err_z)
[docs]
def rust_multigrid_vcycle(
source: FloatArray,
psi_bc: FloatArray,
r_min: float,
r_max: float,
z_min: float,
z_max: float,
nr: int,
nz: int,
tol: float = 1e-6,
max_cycles: int = 500,
) -> tuple[FloatArray, float, int, bool] | None:
"""Call Rust multigrid V-cycle if available, else return None.
Returns
-------
tuple of (psi, residual, n_cycles, converged), or None when Rust is unavailable.
"""
if not _RUST_AVAILABLE:
logger.warning("scpn_fusion_rs not installed — falling back to Python multigrid.")
return None
try:
from scpn_fusion_rs import multigrid_vcycle as _rust_mg
result = _rust_mg(source, psi_bc, r_min, r_max, z_min, z_max, nr, nz, tol, max_cycles)
return cast("tuple[FloatArray, float, int, bool]", result)
except ImportError:
logger.warning("Rust multigrid_vcycle not exposed via PyO3 — falling back to Python.")
return None