1use core_affinity;
11use crossbeam_channel::{bounded, Receiver, Sender, TrySendError};
12use pyo3::exceptions::PyRuntimeError;
13use pyo3::prelude::*;
14use std::sync::{
15 atomic::{AtomicBool, Ordering},
16 Arc,
17};
18use std::thread;
19use std::time::Duration;
20use std::{error::Error, fmt};
21use z3::{
22 ast::{Bool, Int},
23 SatResult, Solver,
24};
25
26const NUM_PLACES: usize = 4;
27const NUM_TRANSITIONS: usize = 3;
28const VERIFICATION_DEPTH: usize = 4;
29const SAFETY_THRESHOLD_P3: i64 = 100;
30const DEFAULT_SNAPSHOT_CAPACITY: usize = 2;
31const DEFAULT_SNAPSHOT_PERIOD: u64 = 30;
32const DEFAULT_STEP_INTERVAL_NS: u64 = 0;
33
34const W_IN: [[i64; NUM_PLACES]; NUM_TRANSITIONS] = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]];
36
37const W_OUT: [[i64; NUM_PLACES]; NUM_TRANSITIONS] = [[0, 1, 0, 0], [0, 0, 1, 1], [1, 0, 0, 0]];
38
39#[derive(Clone, Debug)]
41pub struct PetriNetSnapshot {
42 pub step_index: u64,
43 pub active_markings: Vec<i64>,
44 pub transition_rates: Vec<f64>,
45}
46
47#[derive(Clone, Debug)]
49pub struct SupervisorState {
50 pub safe_shutdown_flag: Arc<AtomicBool>,
51 pub tx_snapshot: Sender<PetriNetSnapshot>,
52}
53
54#[derive(Clone, Copy, Debug, Eq, PartialEq)]
55pub enum SupervisorExecutionError {
56 InvalidNeuronCount,
57 SafetyViolation,
58}
59
60impl fmt::Display for SupervisorExecutionError {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 Self::InvalidNeuronCount => f.write_str("n_neurons must be > 0"),
64 Self::SafetyViolation => f.write_str("safety contract violation detected"),
65 }
66 }
67}
68
69impl Error for SupervisorExecutionError {}
70
71#[derive(Debug)]
72struct LightweightSnnPool {
73 step_index: u64,
74 n_neurons: usize,
75 rng: u64,
76 transition_rates: [f64; NUM_TRANSITIONS],
77 markings: [i64; NUM_PLACES],
78}
79
80impl LightweightSnnPool {
81 fn new(n_neurons: usize, seed: u64) -> Self {
82 Self {
83 step_index: 0,
84 n_neurons,
85 rng: seed ^ 0xA3BF_0000_1234_5678u64,
86 transition_rates: [0.0; NUM_TRANSITIONS],
87 markings: [12, 15, 8, 0],
88 }
89 }
90
91 fn step(&mut self) -> f64 {
92 self.step_index = self.step_index.saturating_add(1);
93 self.rng = self
94 .rng
95 .wrapping_mul(6_364_136_223_846_793_005)
96 .wrapping_add(1);
97
98 let drift = ((self.rng >> 32) as i64) & 0xF;
99 let drift = drift.saturating_sub(6);
100
101 for (idx, mark) in self.markings.iter_mut().enumerate() {
103 let local = ((drift + idx as i64) as f64) / 10.0;
104 *mark = (*mark + drift)
105 .clamp(0, 200)
106 .saturating_add((self.n_neurons as i64) % 2);
107 self.transition_rates[idx % NUM_TRANSITIONS] = local.abs();
108 }
109
110 self.markings[3] = (self.markings[3] + ((self.step_index as i64) / 8).min(2)) % 210;
112
113 (self.rng as f64) / (u64::MAX as f64)
114 }
115
116 fn snapshot(&self, snapshot_step: u64, control_output: f64) -> PetriNetSnapshot {
117 let mut transition_rates = self.transition_rates.to_vec();
118 transition_rates.push(control_output);
119
120 PetriNetSnapshot {
121 step_index: snapshot_step,
122 active_markings: self.markings.to_vec(),
123 transition_rates,
124 }
125 }
126}
127
128fn bind_core(core_index: usize) {
129 let Some(core_ids) = core_affinity::get_core_ids() else {
130 return;
131 };
132 if let Some(core_id) = core_ids.get(core_index) {
133 let _ = core_affinity::set_for_current(*core_id);
134 }
135}
136
137pub fn verify_bounds_at_depth(snapshot: &PetriNetSnapshot, depth: usize) -> bool {
138 let solver = Solver::new();
141
142 let mut markings = Vec::with_capacity(depth + 1);
143 for step in 0..=depth {
144 let mut step_markings = Vec::with_capacity(NUM_PLACES);
145 for place in 0..NUM_PLACES {
146 let initial = i64::from(*snapshot.active_markings.get(place).unwrap_or(&0));
147 if step == 0 {
148 step_markings.push(Int::from_i64(initial));
149 } else {
150 step_markings.push(Int::new_const(format!("mark_{step}_{place}")));
151 }
152 }
153 markings.push(step_markings);
154 }
155
156 let mut firings = Vec::with_capacity(depth);
157 for step in 0..depth {
158 let mut step_firings = Vec::with_capacity(NUM_TRANSITIONS);
159 for transition in 0..NUM_TRANSITIONS {
160 step_firings.push(Bool::new_const(format!("fire_{step}_{transition}")));
161 }
162 firings.push(step_firings);
163 }
164
165 for place in 0..NUM_PLACES {
166 solver.assert(markings[0][place].ge(Int::from_i64(0)));
167 }
168
169 for step in 0..depth {
170 for place in 0..NUM_PLACES {
171 let mut next_value = markings[step][place].clone();
172
173 for transition in 0..NUM_TRANSITIONS {
174 let fire = &firings[step][transition];
175 let as_int = fire.ite(&Int::from_i64(1), &Int::from_i64(0));
176
177 let win = Int::from_i64(W_IN[transition][place]);
178 let wout = Int::from_i64(W_OUT[transition][place]);
179 if W_IN[transition][place] != 0 {
180 next_value -= &win * &as_int;
181 }
182 if W_OUT[transition][place] != 0 {
183 next_value += &wout * &as_int;
184 }
185 }
186
187 solver.assert(markings[step + 1][place]._eq(&next_value));
188 solver.assert(markings[step + 1][place].ge(Int::from_i64(0)));
189 }
190 }
191
192 let threshold = Int::from_i64(SAFETY_THRESHOLD_P3);
194 let mut violation_conditions: Vec<Bool> = Vec::with_capacity(depth);
195 for step in 1..=depth {
196 violation_conditions.push(markings[step][3].gt(&threshold));
197 }
198 if !violation_conditions.is_empty() {
199 let violation = Bool::or(&violation_conditions);
200 solver.assert(&violation);
201 } else {
202 return true;
203 }
204
205 match solver.check() {
206 SatResult::Unsat => true,
207 SatResult::Sat => false,
208 SatResult::Unknown => false,
209 }
210}
211
212pub fn spawn_z3_verification_worker(
213 rx_snapshot: Receiver<PetriNetSnapshot>,
214 shutdown_flag: Arc<AtomicBool>,
215 target_core: usize,
216) -> thread::JoinHandle<()> {
217 thread::spawn(move || {
218 bind_core(target_core);
219
220 for snapshot in rx_snapshot {
221 let valid = verify_bounds_at_depth(&snapshot, VERIFICATION_DEPTH);
222 if !valid {
223 shutdown_flag.store(true, Ordering::Release);
224 break;
225 }
226 }
227 })
228}
229
230fn execute_snn_control_loop(
231 mut pool: LightweightSnnPool,
232 supervisor: &SupervisorState,
233 snapshot_period: u64,
234 target_core: usize,
235 max_steps: u64,
236 step_interval_ns: u64,
237) -> u64 {
238 bind_core(target_core);
239
240 let snapshot_period = snapshot_period.max(1);
241 let mut executed_steps = 0;
242
243 loop {
244 if supervisor.safe_shutdown_flag.load(Ordering::Acquire) {
245 break;
246 }
247
248 if max_steps != 0 && executed_steps >= max_steps {
249 break;
250 }
251
252 let control_output = pool.step();
253 executed_steps = executed_steps.saturating_add(1);
254
255 if pool.step_index.is_multiple_of(snapshot_period) {
256 let snapshot = pool.snapshot(pool.step_index, control_output);
257 match supervisor.tx_snapshot.try_send(snapshot) {
258 Ok(()) => {}
259 Err(TrySendError::Full(_)) => {}
260 Err(TrySendError::Disconnected(_)) => break,
261 }
262 }
263
264 if step_interval_ns != 0 {
265 thread::sleep(Duration::from_nanos(step_interval_ns));
266 }
267 }
268
269 executed_steps
270}
271
272fn run_supervisor_steps_with_flag(
273 n_neurons: usize,
274 seed: u64,
275 snapshot_period: u64,
276 step_interval_ns: u64,
277 core_snn: usize,
278 core_z3: usize,
279 max_steps: u64,
280 safe_shutdown_flag: Arc<AtomicBool>,
281) -> Result<u64, SupervisorExecutionError> {
282 if n_neurons == 0 {
283 return Err(SupervisorExecutionError::InvalidNeuronCount);
284 }
285
286 safe_shutdown_flag.store(false, Ordering::Release);
287
288 let (tx_snapshot, rx_snapshot) = bounded::<PetriNetSnapshot>(DEFAULT_SNAPSHOT_CAPACITY);
289 let z3_handle = spawn_z3_verification_worker(rx_snapshot, safe_shutdown_flag.clone(), core_z3);
290
291 let pool = LightweightSnnPool::new(n_neurons, seed);
292 let executed = {
293 let supervisor = SupervisorState {
294 safe_shutdown_flag: safe_shutdown_flag.clone(),
295 tx_snapshot,
296 };
297 execute_snn_control_loop(
298 pool,
299 &supervisor,
300 snapshot_period,
301 core_snn,
302 max_steps,
303 step_interval_ns,
304 )
305 };
306
307 let _ = z3_handle.join();
309
310 if safe_shutdown_flag.load(Ordering::Acquire) {
311 return Err(SupervisorExecutionError::SafetyViolation);
312 }
313
314 Ok(executed)
315}
316
317pub fn run_supervisor_steps(
318 n_neurons: usize,
319 seed: u64,
320 snapshot_period: u64,
321 step_interval_ns: u64,
322 core_snn: usize,
323 core_z3: usize,
324 max_steps: u64,
325) -> Result<u64, SupervisorExecutionError> {
326 run_supervisor_steps_with_flag(
327 n_neurons,
328 seed,
329 snapshot_period,
330 step_interval_ns,
331 core_snn,
332 core_z3,
333 max_steps,
334 Arc::new(AtomicBool::new(false)),
335 )
336}
337
338#[pyclass(
339 name = "PySpikingControllerPool",
340 module = "sc_neurocore_engine.sc_neurocore_engine"
341)]
342pub struct PySpikingControllerPool {
343 n_neurons: usize,
344 seed: u64,
345 snapshot_period: u64,
346 step_interval_ns: u64,
347 safe_shutdown_flag: Arc<AtomicBool>,
348}
349
350#[pymethods]
351impl PySpikingControllerPool {
352 #[new]
353 #[pyo3(signature = (n_neurons=64, seed=7, step_interval_ns=DEFAULT_STEP_INTERVAL_NS, snapshot_period=DEFAULT_SNAPSHOT_PERIOD))]
354 fn new(
355 n_neurons: usize,
356 seed: u64,
357 step_interval_ns: u64,
358 snapshot_period: u64,
359 ) -> PyResult<Self> {
360 if n_neurons == 0 {
361 return Err(PyRuntimeError::new_err("n_neurons must be > 0."));
362 }
363
364 Ok(Self {
365 n_neurons,
366 seed,
367 snapshot_period,
368 step_interval_ns,
369 safe_shutdown_flag: Arc::new(AtomicBool::new(false)),
370 })
371 }
372
373 #[pyo3(signature = (core_snn=1, core_z3=2, max_steps=0))]
379 fn start(&self, core_snn: usize, core_z3: usize, max_steps: usize) -> PyResult<usize> {
380 match run_supervisor_steps_with_flag(
381 self.n_neurons,
382 self.seed,
383 self.snapshot_period,
384 self.step_interval_ns,
385 core_snn,
386 core_z3,
387 max_steps as u64,
388 self.safe_shutdown_flag.clone(),
389 ) {
390 Ok(executed) => Ok(executed as usize),
391 Err(SupervisorExecutionError::SafetyViolation) => Err(PyRuntimeError::new_err(
392 "Hardware execution terminated: safety contract violation detected by Z3 worker.",
393 )),
394 Err(SupervisorExecutionError::InvalidNeuronCount) => Err(PyRuntimeError::new_err(
395 SupervisorExecutionError::InvalidNeuronCount.to_string(),
396 )),
397 }
398 }
399
400 fn is_safety_tripped(&self) -> bool {
401 self.safe_shutdown_flag.load(Ordering::Acquire)
402 }
403
404 fn force_shutdown(&self) {
405 self.safe_shutdown_flag.store(true, Ordering::Release);
406 }
407}