# SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
# © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
# © Code 2020–2026 Miroslav Šotek. All rights reserved.
# ORCID: 0009-0009-3560-0851
# Contact: www.anulum.li | protoscience@anulum.li
# SCPN Fusion Core — Neuro-Symbolic Logic Compiler
"""
Neuro-Symbolic Controller — oracle + SC dual paths.
Loads a ``.scpnctl.json`` artifact and provides deterministic
``step(obs, k) → ControlAction`` with JSONL logging.
"""
from __future__ import annotations
import json
import time
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, cast
import numpy as np
from numpy.typing import NDArray
from .artifact import Artifact
from .contracts import (
ControlAction,
FeatureAxisSpec,
ControlScales,
ControlTargets,
)
from .controller_backend_mixin import NeuroSymbolicControllerBackendMixin
from .controller_features_mixin import NeuroSymbolicControllerFeaturesMixin
from .controller_runtime_backend import probe_rust_runtime_bindings
from scpn_fusion.fallback_telemetry import record_fallback_event
FloatArray = NDArray[np.float64]
_HAS_RUST_SCPN_RUNTIME = False
_rust_dense_activations: Optional[Callable[[FloatArray, FloatArray], object]] = None
_rust_marking_update: Optional[
Callable[[FloatArray, FloatArray, FloatArray, FloatArray], object]
] = None
_rust_sample_firing: Optional[Callable[[FloatArray, int, int, bool], object]] = None
(
_HAS_RUST_SCPN_RUNTIME,
_rust_dense_activations,
_rust_marking_update,
_rust_sample_firing,
) = probe_rust_runtime_bindings()
[docs]
class NeuroSymbolicController(
NeuroSymbolicControllerFeaturesMixin,
NeuroSymbolicControllerBackendMixin,
):
"""Reference controller with oracle float and stochastic paths.
Parameters
----------
artifact : loaded ``.scpnctl.json`` artifact.
seed_base : 64-bit base seed for deterministic stochastic execution.
targets : control setpoint targets.
scales : normalisation scales.
"""
def __init__(
self,
artifact: Artifact,
seed_base: int,
targets: ControlTargets,
scales: ControlScales,
sc_n_passes: int = 8,
sc_bitflip_rate: float = 0.0,
sc_binary_margin: Optional[float] = None,
sc_antithetic: bool = True,
enable_oracle_diagnostics: bool = True,
feature_axes: Optional[Sequence[FeatureAxisSpec]] = None,
runtime_profile: str = "adaptive",
runtime_backend: str = "auto",
rust_backend_min_problem_size: int = 1,
sc_antithetic_chunk_size: int = 2048,
) -> None:
def _require_int_ge(name: str, value: object, minimum: int) -> int:
if isinstance(value, bool) or not isinstance(value, (int, np.integer)):
raise ValueError(f"{name} must be an integer >= {minimum}.")
parsed = int(value)
if parsed < minimum:
raise ValueError(f"{name} must be an integer >= {minimum}.")
return parsed
self.artifact = artifact
self.seed_base = int(seed_base)
self.targets = targets
self.scales = scales
self._sc_n_passes = _require_int_ge("sc_n_passes", sc_n_passes, 1)
self._sc_bitflip_rate = float(sc_bitflip_rate)
if (
not np.isfinite(self._sc_bitflip_rate)
or self._sc_bitflip_rate < 0.0
or self._sc_bitflip_rate > 1.0
):
raise ValueError("sc_bitflip_rate must be finite and in [0, 1].")
self._runtime_profile = runtime_profile.strip().lower()
if self._runtime_profile not in {"adaptive", "deterministic", "traceable"}:
raise ValueError("runtime_profile must be 'adaptive', 'deterministic', or 'traceable'")
self._sc_antithetic = bool(sc_antithetic)
self._enable_oracle_diagnostics = bool(enable_oracle_diagnostics)
self._feature_axes = list(feature_axes) if feature_axes is not None else None
self._runtime_backend_request = runtime_backend.strip().lower()
if self._runtime_backend_request not in {"auto", "numpy", "rust"}:
raise ValueError("runtime_backend must be 'auto', 'numpy', or 'rust'")
self._rust_backend_min_problem_size = _require_int_ge(
"rust_backend_min_problem_size", rust_backend_min_problem_size, 1
)
self._sc_antithetic_chunk_size = _require_int_ge(
"sc_antithetic_chunk_size", sc_antithetic_chunk_size, 1
)
if self._feature_axes is not None:
axes = list(self._feature_axes)
else:
axes = [
FeatureAxisSpec(
obs_key="R_axis_m",
target=self.targets.R_target_m,
scale=self.scales.R_scale_m,
pos_key="x_R_pos",
neg_key="x_R_neg",
),
FeatureAxisSpec(
obs_key="Z_axis_m",
target=self.targets.Z_target_m,
scale=self.scales.Z_scale_m,
pos_key="x_Z_pos",
neg_key="x_Z_neg",
),
]
self._feature_axes_effective = axes
self._axis_count = len(axes)
self._axis_obs_keys = [axis.obs_key for axis in axes]
self._axis_targets = np.asarray([axis.target for axis in axes], dtype=np.float64)
self._axis_scales = np.asarray(
[axis.scale if abs(axis.scale) > 1e-12 else 1e-12 for axis in axes],
dtype=np.float64,
)
self._axis_pos_keys = [axis.pos_key for axis in axes]
self._axis_neg_keys = [axis.neg_key for axis in axes]
self._empty = np.zeros(0, dtype=np.float64)
self._tmp_obs_vals = np.zeros(self._axis_count, dtype=np.float64)
self._tmp_feature_err = np.zeros(self._axis_count, dtype=np.float64)
self._tmp_feature_pos = np.zeros(self._axis_count, dtype=np.float64)
self._tmp_feature_neg = np.zeros(self._axis_count, dtype=np.float64)
# Flatten weight matrices for fast indexing
self._w_in = artifact.weights.w_in.data[:]
self._w_out = artifact.weights.w_out.data[:]
self._nP = artifact.nP
self._nT = artifact.nT
self._W_in = np.asarray(self._w_in, dtype=np.float64).reshape(self._nT, self._nP)
self._W_out = np.asarray(self._w_out, dtype=np.float64).reshape(self._nP, self._nT)
self._W_in_t = self._W_in.T
self._tmp_activations = np.zeros(self._nT, dtype=np.float64)
self._tmp_consumption = np.zeros(self._nP, dtype=np.float64)
self._tmp_production = np.zeros(self._nP, dtype=np.float64)
self._tmp_marking_oracle = np.zeros(self._nP, dtype=np.float64)
self._tmp_marking_sc = np.zeros(self._nP, dtype=np.float64)
self._tmp_marking_input = np.zeros(self._nP, dtype=np.float64)
self._tmp_sc_counts = np.zeros(self._nT, dtype=np.int64)
self._thresholds = np.asarray(
[tr.threshold for tr in artifact.topology.transitions], dtype=np.float64
)
self._delay_ticks = np.asarray(
[max(int(getattr(tr, "delay_ticks", 0)), 0) for tr in artifact.topology.transitions],
dtype=np.int64,
)
self._delay_immediate_idx = np.flatnonzero(self._delay_ticks == 0).astype(
np.int64, copy=False
)
self._delay_delayed_idx = np.flatnonzero(self._delay_ticks > 0).astype(np.int64, copy=False)
if self._delay_delayed_idx.size:
self._delay_delayed_offsets = np.asarray(
self._delay_ticks[self._delay_delayed_idx], dtype=np.int64
)
self._tmp_delay_slots = np.zeros(self._delay_delayed_idx.size, dtype=np.int64)
else:
self._delay_delayed_offsets = np.asarray([], dtype=np.int64)
self._tmp_delay_slots = np.asarray([], dtype=np.int64)
self._max_delay_ticks = int(np.max(self._delay_ticks)) if self._delay_ticks.size else 0
pending_len = self._max_delay_ticks + 1
self._oracle_pending = np.zeros((pending_len, self._nT), dtype=np.float64)
self._sc_pending = np.zeros((pending_len, self._nT), dtype=np.float64)
self._oracle_cursor = 0
self._sc_cursor = 0
self._firing_mode = artifact.meta.firing_mode
default_margin = float(getattr(artifact.meta, "firing_margin", 0.05) or 0.05)
self._margins = np.asarray(
[
float((tr.margin if tr.margin is not None else default_margin) or default_margin)
for tr in artifact.topology.transitions
],
dtype=np.float64,
)
if sc_binary_margin is None:
if self._runtime_profile == "adaptive":
self._sc_binary_margin = 0.05
else:
self._sc_binary_margin = 0.0
else:
self._sc_binary_margin = float(sc_binary_margin)
if not np.isfinite(self._sc_binary_margin) or self._sc_binary_margin < 0.0:
raise ValueError("sc_binary_margin must be finite and >= 0.")
problem_size = int(self._nP * self._nT)
rust_eligible = _HAS_RUST_SCPN_RUNTIME and (
problem_size >= self._rust_backend_min_problem_size
)
if self._runtime_backend_request == "numpy":
self._runtime_backend = "numpy"
elif self._runtime_backend_request == "rust":
if _HAS_RUST_SCPN_RUNTIME:
self._runtime_backend = "rust"
else:
self._runtime_backend = "numpy"
record_fallback_event(
"scpn_controller",
"rust_backend_unavailable",
context={"runtime_backend_request": "rust"},
)
else:
self._runtime_backend = "rust" if rust_eligible else "numpy"
if self._runtime_backend == "numpy" and not _HAS_RUST_SCPN_RUNTIME:
record_fallback_event(
"scpn_controller",
"auto_backend_numpy_due_to_missing_rust",
context={"problem_size": int(problem_size)},
)
produced_feature_keys = set(self._axis_pos_keys)
produced_feature_keys.update(self._axis_neg_keys)
passthrough_sources: list[str] = []
for inj in self.artifact.initial_state.place_injections:
src = inj.source
if src not in produced_feature_keys and src not in passthrough_sources:
passthrough_sources.append(src)
self._passthrough_sources = passthrough_sources
self._traceable_ready = len(self._passthrough_sources) == 0
key_to_axis: Dict[str, Tuple[int, bool]] = {}
for i, key in enumerate(self._axis_pos_keys):
key_to_axis[key] = (i, True)
for i, key in enumerate(self._axis_neg_keys):
key_to_axis[key] = (i, False)
# Live state
self._marking = np.asarray(artifact.initial_state.marking, dtype=np.float64).copy()
injections = artifact.initial_state.place_injections
self._inj_sources = [inj.source for inj in injections]
self._inj_count = len(self._inj_sources)
self._inj_place_ids = np.asarray([inj.place_id for inj in injections], dtype=np.int64)
self._inj_scales = np.asarray([inj.scale for inj in injections], dtype=np.float64)
self._inj_offsets = np.asarray([inj.offset for inj in injections], dtype=np.float64)
self._inj_clamp_mask = np.asarray(
[bool(inj.clamp_0_1) for inj in injections], dtype=np.bool_
)
self._inj_clamp_idx = np.flatnonzero(self._inj_clamp_mask)
self._inj_has_clamp = bool(self._inj_clamp_idx.size)
self._inj_source_axis_idx = np.full(self._inj_count, -1, dtype=np.int64)
self._inj_source_axis_pos = np.zeros(self._inj_count, dtype=np.bool_)
self._tmp_inj_values = np.zeros(self._inj_count, dtype=np.float64)
passthrough_pairs: list[Tuple[int, str]] = []
for i, src in enumerate(self._inj_sources):
axis_info = key_to_axis.get(src)
if axis_info is not None:
axis_idx, is_pos = axis_info
self._inj_source_axis_idx[i] = int(axis_idx)
self._inj_source_axis_pos[i] = bool(is_pos)
else:
passthrough_pairs.append((i, src))
self._inj_passthrough_pairs = passthrough_pairs
self._action_names = [a.name for a in artifact.readout.actions]
self._action_pos_idx = np.asarray(
[a.pos_place for a in artifact.readout.actions], dtype=np.int64
)
self._action_neg_idx = np.asarray(
[a.neg_place for a in artifact.readout.actions], dtype=np.int64
)
self._action_gains = np.asarray(artifact.readout.gains, dtype=np.float64)
self._action_abs_max = np.asarray(artifact.readout.abs_max, dtype=np.float64)
self._action_slew_per_s = np.asarray(artifact.readout.slew_per_s, dtype=np.float64)
self._action_count = len(self._action_names)
self._dt = float(artifact.meta.dt_control_s)
self._action_max_delta = self._action_slew_per_s * self._dt
self._prev_actions = np.zeros(self._action_count, dtype=np.float64)
self._tmp_actions = np.zeros(self._action_count, dtype=np.float64)
self.last_oracle_firing: List[float] = []
self.last_sc_firing: List[float] = []
self.last_oracle_marking: List[float] = self._marking.tolist()
self.last_sc_marking: List[float] = self._marking.tolist()
# ── Public API ───────────────────────────────────────────────────────
[docs]
def reset(self) -> None:
"""Restore initial marking and zero previous actions."""
np.copyto(
self._marking,
np.asarray(self.artifact.initial_state.marking, dtype=np.float64),
)
self._prev_actions.fill(0.0)
self._oracle_pending.fill(0.0)
self._sc_pending.fill(0.0)
self._oracle_cursor = 0
self._sc_cursor = 0
self.last_oracle_firing = []
self.last_sc_firing = []
self.last_oracle_marking = self._marking.tolist() if self._enable_oracle_diagnostics else []
self.last_sc_marking = self._marking.tolist()
@property
def runtime_backend_name(self) -> str:
return self._runtime_backend
@property
def runtime_profile_name(self) -> str:
return self._runtime_profile
@property
def marking(self) -> List[float]:
return cast(List[float], self._marking.tolist())
@marking.setter
def marking(self, values: Sequence[float]) -> None:
arr = np.asarray(list(values), dtype=np.float64)
if arr.shape != (self._nP,):
raise ValueError(f"marking must have length {self._nP}, got {arr.size}")
self._marking = np.clip(arr, 0.0, 1.0)
[docs]
def step(
self,
obs: Mapping[str, float],
k: int,
log_path: Optional[str] = None,
) -> ControlAction:
"""Execute one control tick.
Steps:
1. ``extract_features(obs)`` → 4 unipolar features
2. ``_inject_places(features)``
3. ``_oracle_step()`` — float path (optional)
4. ``_sc_step(k)`` — deterministic stochastic path
5. ``_decode_actions()`` — gain × differencing, slew + abs clamp
6. Optional JSONL logging
"""
t0 = time.perf_counter()
# 1. Feature extraction (fast compiled mapping)
pos_vals, neg_vals = self._compute_feature_components(obs)
feats = self._build_feature_dict(obs, pos_vals, neg_vals) if log_path is not None else None
# 2. Inject features into marking
m = self._tmp_marking_input
np.copyto(m, self._marking)
self._inject_places(m, obs, pos_vals, neg_vals)
# 3. Oracle float path (optional)
if self._enable_oracle_diagnostics:
f_oracle, m_oracle = self._oracle_step(m)
else:
f_oracle = np.asarray([], dtype=np.float64)
m_oracle = np.asarray([], dtype=np.float64)
# 4. Stochastic path
f_sc, m_sc = self._sc_step(m, k)
# Diagnostics (used by deterministic benchmark gates)
self.last_oracle_firing = f_oracle.tolist()
self.last_sc_firing = f_sc.tolist()
self.last_oracle_marking = m_oracle.tolist() if self._enable_oracle_diagnostics else []
self.last_sc_marking = m_sc.tolist()
# Commit SC state
np.copyto(self._marking, m_sc)
# 5. Decode actions
actions_dict = self._decode_actions(m_sc)
t1 = time.perf_counter()
# 6. Optional JSONL logging
if log_path is not None:
rec = {
"k": int(k),
"obs": dict(obs),
"features": feats,
"f_oracle": f_oracle.tolist(),
"f_sc": f_sc.tolist(),
"marking": m_sc.tolist(),
"actions": actions_dict,
"timing_ms": (t1 - t0) * 1000.0,
}
with open(log_path, "a", encoding="utf-8") as fh:
fh.write(json.dumps(rec) + "\n")
# Preserve all decoded action channels from the artifact readout.
return cast(ControlAction, dict(actions_dict))
[docs]
def step_traceable(
self,
obs_vector: Sequence[float],
k: int,
log_path: Optional[str] = None,
) -> FloatArray:
"""Execute one control tick from a fixed-order observation vector.
The vector order is ``self._axis_obs_keys``. This avoids per-step key
lookups/dict allocation in tight control loops.
"""
if not self._traceable_ready:
raise RuntimeError(
"step_traceable requires axis-only injections (no passthrough sources)"
)
t0 = time.perf_counter()
pos_vals, neg_vals = self._compute_feature_components_vector(obs_vector)
m = self._tmp_marking_input
np.copyto(m, self._marking)
self._inject_places(m, {}, pos_vals, neg_vals)
if self._enable_oracle_diagnostics:
f_oracle, m_oracle = self._oracle_step(m)
else:
f_oracle = np.asarray([], dtype=np.float64)
m_oracle = np.asarray([], dtype=np.float64)
f_sc, m_sc = self._sc_step(m, k)
self.last_oracle_firing = f_oracle.tolist()
self.last_sc_firing = f_sc.tolist()
self.last_oracle_marking = m_oracle.tolist() if self._enable_oracle_diagnostics else []
self.last_sc_marking = m_sc.tolist()
np.copyto(self._marking, m_sc)
actions_vec = np.asarray(self._decode_actions_vector(m_sc), dtype=np.float64).copy()
t1 = time.perf_counter()
if log_path is not None:
obs_payload = {key: float(value) for key, value in zip(self._axis_obs_keys, obs_vector)}
rec = {
"k": int(k),
"obs": obs_payload,
"features": self._build_feature_dict(obs_payload, pos_vals, neg_vals),
"f_oracle": f_oracle.tolist(),
"f_sc": f_sc.tolist(),
"marking": m_sc.tolist(),
"actions": {
name: float(actions_vec[i]) for i, name in enumerate(self._action_names)
},
"timing_ms": (t1 - t0) * 1000.0,
}
with open(log_path, "a", encoding="utf-8") as fh:
fh.write(json.dumps(rec) + "\n")
return actions_vec