Skip to content

Acceleration

Backend modules for high-performance SC operations.

Module Purpose
vector_ops Packed uint64 bitwise AND, popcount, pack/unpack
gpu_backend CuPy GPU dispatch (transparent NumPy fallback)
jax_backend JAX JIT-compiled LIF step for TPU/GPU scaling
jit_kernels Numba-accelerated inner loops
mpi_driver MPI-based distributed simulation

Rust Safety Mirrors

src/sc_neurocore/accel/rust/ is a nested Rust crate for safety and contract mirrors of higher-level Python modules. It is separate from the PyO3 engine: the mirror crate is tested directly with Cargo, while the Python modules keep their NumPy/Python path importable when optional engine submodules are absent.

Current documented mirrors:

Mirror Python surface Verification
safety/analysis.rs studio.analysis Rust unit tests plus tests/test_studio_analysis.py
safety/dna_mapper.rs bridges.dna_mapper Rust unit tests plus 139 DNA mapper tests
safety/l7_symbolic.rs scpn.layers.l7_symbolic Rust unit tests plus L7 and cross-layer contract tests
safety/predictive_model.rs world_model.predictive_model Rust unit tests plus 77 passed predictive-model tests, with 3 optional-path skips

Cargo command:

Bash
cargo test --manifest-path src/sc_neurocore/accel/rust/Cargo.toml --lib --no-default-features

Vector Operations

sc_neurocore.accel.vector_ops

pack_bitstream(bitstream)

Packs a uint8 bitstream (0s and 1s) into uint64 integers. This allows processing 64 time steps in parallel.

Parameters:

Name Type Description Default
bitstream ndarray[Any, Any]

Shape (N,) or (Batch, N) of uint8 {0,1}

required

Returns:

Name Type Description
packed ndarray[Any, Any]

Shape (ceil(N/64),) or (Batch, ceil(N/64)) of uint64

Source code in src/sc_neurocore/accel/vector_ops.py
Python
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def pack_bitstream(bitstream: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
    """
    Packs a uint8 bitstream (0s and 1s) into uint64 integers.
    This allows processing 64 time steps in parallel.

    Args:
        bitstream: Shape (N,) or (Batch, N) of uint8 {0,1}

    Returns:
        packed: Shape (ceil(N/64),) or (Batch, ceil(N/64)) of uint64
    """
    bitstream = np.asarray(bitstream, dtype=np.uint8)

    if bitstream.ndim == 1:
        # 1D case: single bitstream
        length = bitstream.size
        pad_len = (64 - (length % 64)) % 64
        if pad_len > 0:
            bitstream = np.append(bitstream, np.zeros(pad_len, dtype=np.uint8))

        chunks = bitstream.reshape(-1, 64)
        powers = 1 << np.arange(64, dtype=np.uint64)
        packed: np.ndarray[Any, Any] = (chunks * powers).sum(axis=1, dtype=np.uint64)
        return packed

    elif bitstream.ndim == 2:
        # 2D case: batch of bitstreams
        batch_size, length = bitstream.shape
        pad_len = (64 - (length % 64)) % 64

        if pad_len > 0:
            padding = np.zeros((batch_size, pad_len), dtype=np.uint8)
            bitstream = np.concatenate([bitstream, padding], axis=1)

        # Reshape to (batch, num_chunks, 64)
        num_chunks = bitstream.shape[1] // 64
        chunks: np.ndarray[Any, Any] = bitstream.reshape(batch_size, num_chunks, 64)  # type: ignore[no-redef]

        powers = 1 << np.arange(64, dtype=np.uint64)
        packed_2d: np.ndarray[Any, Any] = (chunks * powers).sum(axis=2, dtype=np.uint64)
        return packed_2d

    else:
        raise ValueError(f"Expected 1D or 2D array, got {bitstream.ndim}D")

unpack_bitstream(packed, original_length, original_shape=None)

Unpacks uint64 array back to uint8 bitstream.

Parameters:

Name Type Description Default
packed ndarray[Any, Any]

Packed uint64 array (1D or 2D)

required
original_length int

Total number of bits to extract

required
original_shape Optional[tuple[Any, ...]]

Optional tuple for reshaping output (batch, length)

None

Returns:

Type Description
ndarray[Any, Any]

Unpacked bitstream of shape (original_length,) or original_shape

Source code in src/sc_neurocore/accel/vector_ops.py
Python
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def unpack_bitstream(
    packed: np.ndarray[Any, Any],
    original_length: int,
    original_shape: Optional[tuple[Any, ...]] = None,
) -> np.ndarray[Any, Any]:
    """
    Unpacks uint64 array back to uint8 bitstream.

    Args:
        packed: Packed uint64 array (1D or 2D)
        original_length: Total number of bits to extract
        original_shape: Optional tuple for reshaping output (batch, length)

    Returns:
        Unpacked bitstream of shape (original_length,) or original_shape
    """
    if packed.ndim == 1:
        # 1D packed array
        bits = ((packed[:, None] & (1 << np.arange(64, dtype=np.uint64))) > 0).astype(np.uint8)
        unpacked = bits.flatten()
        result: np.ndarray[Any, Any] = unpacked[:original_length]
        return result

    elif packed.ndim == 2:
        # 2D packed array: (batch, num_chunks)
        batch_size, num_chunks = packed.shape
        # Extract bits: (batch, num_chunks, 64)
        bits = ((packed[:, :, None] & (1 << np.arange(64, dtype=np.uint64))) > 0).astype(np.uint8)
        # Reshape to (batch, num_chunks * 64)
        unpacked = bits.reshape(batch_size, -1)

        if original_shape is not None:
            result_2d: np.ndarray[Any, Any] = unpacked[:, : original_shape[1]]
            return result_2d
        else:
            # Assume original_length is per-batch
            per_batch_len = original_length // batch_size
            result_batch: np.ndarray[Any, Any] = unpacked[:, :per_batch_len]
            return result_batch

    else:
        raise ValueError(f"Expected 1D or 2D packed array, got {packed.ndim}D")

vec_and(a_packed, b_packed)

Bitwise AND on packed arrays. Simulates SC Multiplication.

Source code in src/sc_neurocore/accel/vector_ops.py
Python
105
106
107
108
109
110
def vec_and(a_packed: np.ndarray[Any, Any], b_packed: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
    """
    Bitwise AND on packed arrays. Simulates SC Multiplication.
    """
    result: np.ndarray[Any, Any] = np.bitwise_and(a_packed, b_packed)
    return result

vec_xnor(a_packed, b_packed)

Bitwise XNOR on packed arrays. SC bipolar multiplication: P(A XNOR B) = P(A)P(B) + (1-P(A))(1-P(B)).

Source code in src/sc_neurocore/accel/vector_ops.py
Python
113
114
115
116
117
118
def vec_xnor(
    a_packed: np.ndarray[Any, Any], b_packed: np.ndarray[Any, Any]
) -> np.ndarray[Any, Any]:
    """Bitwise XNOR on packed arrays. SC bipolar multiplication: P(A XNOR B) = P(A)*P(B) + (1-P(A))*(1-P(B))."""
    result: np.ndarray[Any, Any] = ~np.bitwise_xor(a_packed, b_packed)
    return result

vec_not(packed)

Bitwise NOT on packed arrays. SC complement: P(NOT A) = 1 - P(A).

Source code in src/sc_neurocore/accel/vector_ops.py
Python
121
122
123
124
def vec_not(packed: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
    """Bitwise NOT on packed arrays. SC complement: P(NOT A) = 1 - P(A)."""
    result: np.ndarray[Any, Any] = ~packed
    return result

vec_mux(select_packed, a_packed, b_packed)

Bitwise MUX on packed arrays. SC scaled addition: P(out) = P(sel)P(A) + (1-P(sel))P(B).

When sel is a Bernoulli(0.5) stream, this computes the average (A+B)/2.

Source code in src/sc_neurocore/accel/vector_ops.py
Python
127
128
129
130
131
132
133
134
135
136
137
def vec_mux(
    select_packed: np.ndarray[Any, Any],
    a_packed: np.ndarray[Any, Any],
    b_packed: np.ndarray[Any, Any],
) -> np.ndarray[Any, Any]:
    """Bitwise MUX on packed arrays. SC scaled addition: P(out) = P(sel)*P(A) + (1-P(sel))*P(B).

    When sel is a Bernoulli(0.5) stream, this computes the average (A+B)/2.
    """
    result: np.ndarray[Any, Any] = (select_packed & a_packed) | (~select_packed & b_packed)
    return result

vec_popcount(packed)

Count total set bits (1s) in the packed array. Used for integration/accumulation.

Source code in src/sc_neurocore/accel/vector_ops.py
Python
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def vec_popcount(packed: np.ndarray[Any, Any]) -> int:
    """
    Count total set bits (1s) in the packed array.
    Used for integration/accumulation.
    """
    # Using numpy's ability to cast to specialized types or simple lookup?
    # Actually, Python 3.10+ int.bit_count() is fast, but for numpy arrays:
    # We can use a trick or just loop if C-extension isn't available.
    # A generic parallel popcount on uint64 in pure numpy is tricky without looping or lookup tables.
    # However, we can map to python int and sum.

    # For speed in pure python/numpy env without heavy deps:
    # Use binary decomposition for vectorized popcount
    x = packed.copy()
    x -= (x >> 1) & 0x5555555555555555
    x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
    x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F
    x = (x * 0x0101010101010101) >> 56
    return int(np.sum(x))

GPU Backend

sc_neurocore.accel.gpu_backend

to_device(arr)

Move a NumPy array to the active backend (GPU copy or no-op).

Source code in src/sc_neurocore/accel/gpu_backend.py
Python
 93
 94
 95
 96
 97
 98
 99
100
def to_device(arr: np.ndarray[Any, Any]) -> xp.ndarray:  # type: ignore
    """Move a NumPy array to the active backend (GPU copy or no-op)."""
    if _gpu_enabled():  # pragma: no cover
        try:
            return cp.asarray(arr)
        except RuntimeError as exc:  # pragma: no cover
            _mark_gpu_runtime_broken(exc)
    return arr

to_host(arr)

Bring an array back to host RAM as a NumPy array.

Source code in src/sc_neurocore/accel/gpu_backend.py
Python
103
104
105
106
107
108
109
110
111
112
def to_host(arr: Any) -> np.ndarray[Any, Any]:
    """Bring an array back to host RAM as a NumPy array."""
    if HAS_CUPY and isinstance(arr, cp.ndarray):  # pragma: no cover
        try:
            result: np.ndarray[Any, Any] = arr.get()
            return result
        except RuntimeError as exc:  # pragma: no cover
            _mark_gpu_runtime_broken(exc)
    out: np.ndarray[Any, Any] = np.asarray(arr)
    return out

gpu_pack_bitstream(bits)

Pack uint8 {0,1} array into uint64 words.

Works on both CuPy and NumPy arrays.

Parameters:

Name Type Description Default
bits ndarray

Shape (N,) or (B, N) of uint8.

required

Returns:

Type Description
ndarray

Packed uint64 array, shape (ceil(N/64),) or (B, ceil(N/64)).

Source code in src/sc_neurocore/accel/gpu_backend.py
Python
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def gpu_pack_bitstream(bits: xp.ndarray) -> xp.ndarray:  # type: ignore
    """
    Pack uint8 {0,1} array into uint64 words.

    Works on both CuPy and NumPy arrays.

    Args:
        bits: Shape ``(N,)`` or ``(B, N)`` of uint8.

    Returns:
        Packed uint64 array, shape ``(ceil(N/64),)`` or ``(B, ceil(N/64))``.
    """
    if _gpu_enabled():  # pragma: no cover
        try:
            bits = cp.asarray(bits, dtype=cp.uint8)
            if bits.ndim == 1:
                length = bits.size
                pad = (64 - length % 64) % 64
                if pad:
                    bits = cp.concatenate([bits, cp.zeros(pad, dtype=cp.uint8)])
                chunks = bits.reshape(-1, 64)
                powers = cp.uint64(1) << cp.arange(64, dtype=cp.uint64)
                return (chunks.astype(cp.uint64) * powers).sum(axis=1)

            if bits.ndim == 2:
                batch, length = bits.shape
                pad = (64 - length % 64) % 64
                if pad:
                    bits = cp.concatenate(
                        [bits, cp.zeros((batch, pad), dtype=cp.uint8)],
                        axis=1,
                    )
                n_words = bits.shape[1] // 64
                chunks = bits.reshape(batch, n_words, 64)
                powers = cp.uint64(1) << cp.arange(64, dtype=cp.uint64)
                return (chunks.astype(cp.uint64) * powers).sum(axis=2)
        except RuntimeError as exc:  # pragma: no cover
            _mark_gpu_runtime_broken(exc)

    _warn_cpu_fallback()
    return _numpy_pack_bitstream(to_host(bits))

gpu_vec_and(a, b)

Bitwise AND on packed uint64 arrays (SC multiplication).

Source code in src/sc_neurocore/accel/gpu_backend.py
Python
203
204
205
206
207
208
209
210
211
def gpu_vec_and(a: xp.ndarray, b: xp.ndarray) -> xp.ndarray:  # type: ignore
    """Bitwise AND on packed uint64 arrays (SC multiplication)."""
    if _gpu_enabled():  # pragma: no cover
        try:
            return cp.bitwise_and(a, b)
        except RuntimeError as exc:  # pragma: no cover
            _mark_gpu_runtime_broken(exc)
    _warn_cpu_fallback()
    return np.bitwise_and(to_host(a), to_host(b))

gpu_popcount(packed)

Vectorised SWAR popcount on uint64 arrays — returns per-element counts.

On CuPy this runs as a fused GPU kernel; on NumPy it uses the same SWAR bit-trick as vector_ops.vec_popcount but returns an array instead of a scalar.

Source code in src/sc_neurocore/accel/gpu_backend.py
Python
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def gpu_popcount(packed: xp.ndarray) -> xp.ndarray:  # type: ignore
    """
    Vectorised SWAR popcount on uint64 arrays — returns per-element counts.

    On CuPy this runs as a fused GPU kernel; on NumPy it uses the same
    SWAR bit-trick as ``vector_ops.vec_popcount`` but returns an array
    instead of a scalar.
    """
    if _gpu_enabled():  # pragma: no cover
        try:
            x = cp.asarray(packed, dtype=cp.uint64).copy()
            m1 = cp.uint64(0x5555555555555555)
            m2 = cp.uint64(0x3333333333333333)
            m4 = cp.uint64(0x0F0F0F0F0F0F0F0F)
            h01 = cp.uint64(0x0101010101010101)

            x -= (x >> cp.uint64(1)) & m1
            x = (x & m2) + ((x >> cp.uint64(2)) & m2)
            x = (x + (x >> cp.uint64(4))) & m4
            return (x * h01) >> cp.uint64(56)
        except RuntimeError as exc:  # pragma: no cover
            _mark_gpu_runtime_broken(exc)

    _warn_cpu_fallback()
    return _numpy_popcount(to_host(packed))

gpu_vec_mac(packed_weights, packed_inputs)

GPU-accelerated multiply-accumulate for a dense SC layer.

Parameters:

Name Type Description Default
packed_weights ndarray

(n_neurons, n_inputs, n_words) uint64

required
packed_inputs ndarray

(n_inputs, n_words) uint64

required

Returns:

Type Description
ndarray

(n_neurons,) total bit counts (= SC dot products).

Source code in src/sc_neurocore/accel/gpu_backend.py
Python
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def gpu_vec_mac(
    packed_weights: xp.ndarray,  # type: ignore
    packed_inputs: xp.ndarray,  # type: ignore
) -> xp.ndarray:  # type: ignore
    """
    GPU-accelerated multiply-accumulate for a dense SC layer.

    Args:
        packed_weights: ``(n_neurons, n_inputs, n_words)`` uint64
        packed_inputs:  ``(n_inputs, n_words)`` uint64

    Returns:
        ``(n_neurons,)`` total bit counts (= SC dot products).
    """
    if _gpu_enabled():  # pragma: no cover
        try:
            products = cp.bitwise_and(packed_weights, packed_inputs[None, :, :])
            counts = gpu_popcount(products)
            return counts.sum(axis=(1, 2))
        except RuntimeError as exc:  # pragma: no cover
            _mark_gpu_runtime_broken(exc)

    _warn_cpu_fallback()
    weights_np = to_host(packed_weights)
    inputs_np = to_host(packed_inputs)
    products = np.bitwise_and(weights_np, inputs_np[None, :, :])
    counts = _numpy_popcount(products)
    return counts.sum(axis=(1, 2))

JAX Backend

sc_neurocore.accel.jax_backend

JAX backend for SC-NeuroCore.

Provides JAX-accelerated primitives for stochastic computing, unlocking automatic differentiation, JIT compilation (XLA), and native TPU/GPU scaling.

Usage::

Text Only
from sc_neurocore.accel.jax_backend import jnp, HAS_JAX, to_jax, to_host
from sc_neurocore.accel.jax_backend import jax_pack_bitstream, jax_vec_mac

if HAS_JAX:
    bits = jnp.array([1, 0, 1, 1], dtype=jnp.uint8)
    packed = jax_pack_bitstream(bits)

to_jax(arr)

Move a NumPy array to the JAX device.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
63
64
65
66
67
def to_jax(arr: Any) -> Any:
    """Move a NumPy array to the JAX device."""
    if HAS_JAX:
        return jnp.asarray(arr)
    return arr

to_host(arr)

Bring a JAX array back to host RAM as a NumPy array.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
70
71
72
73
74
def to_host(arr: Any) -> np.ndarray[Any, Any]:
    """Bring a JAX array back to host RAM as a NumPy array."""
    if HAS_JAX and isinstance(arr, jax.Array):
        return np.asarray(arr)
    return np.asarray(arr)

jax_pack_bitstream(bits)

Pack uint8 {0,1} array into uint64 words using JAX.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def jax_pack_bitstream(bits: Any) -> Any:
    """
    Pack uint8 {0,1} array into uint64 words using JAX.
    """
    if not HAS_JAX:
        from sc_neurocore.exceptions import SCDependencyError

        raise SCDependencyError("JAX is not available.")

    from sc_neurocore.exceptions import SCEncodingError

    host_bits = np.asarray(bits)
    if host_bits.dtype != np.uint8:
        raise SCEncodingError("Expected a uint8 binary bitstream containing only 0 and 1")
    if host_bits.ndim not in (1, 2):
        raise SCEncodingError(f"Expected 1-D or 2-D, got {host_bits.ndim}-D")
    if host_bits.size == 0 or 0 in host_bits.shape:
        raise SCEncodingError("Expected a non-empty uint8 binary bitstream")
    if not np.isin(host_bits, np.array([0, 1], dtype=np.uint8)).all():
        raise SCEncodingError("Expected a uint8 binary bitstream containing only 0 and 1")

    bits = jnp.asarray(host_bits, dtype=jnp.uint8)

    if bits.ndim == 1:
        return _jax_pack_1d(bits)
    if bits.ndim == 2:
        return _jax_pack_2d(bits)

    raise SCEncodingError(f"Expected 1-D or 2-D, got {bits.ndim}-D")

jax_vec_and(a, b)

Bitwise AND on matching non-empty uint64 packed arrays.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
389
390
391
392
393
394
395
396
def jax_vec_and(a: Any, b: Any) -> jax.Array:
    """Bitwise AND on matching non-empty uint64 packed arrays."""
    a_arr = _validate_uint64_array("a", a)
    b_arr = _validate_uint64_array("b", b)
    if a_arr.shape != b_arr.shape:
        raise ValueError(f"a shape {a_arr.shape} must match b shape {b_arr.shape}")
    result: jax.Array = _jax_vec_and_impl(a_arr, b_arr)
    return result

jax_popcount(packed)

Vectorised SWAR popcount on a non-empty uint64 array using JAX.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
415
416
417
418
419
420
421
def jax_popcount(packed: Any) -> jax.Array:
    """
    Vectorised SWAR popcount on a non-empty uint64 array using JAX.
    """
    packed_arr = _validate_uint64_array("packed", packed)
    result: jax.Array = _jax_popcount_impl(packed_arr)
    return result

jax_vec_mac(packed_weights, packed_inputs)

JAX-accelerated multiply-accumulate for packed uint64 dense SC layers.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def jax_vec_mac(packed_weights: Any, packed_inputs: Any) -> jax.Array:
    """
    JAX-accelerated multiply-accumulate for packed uint64 dense SC layers.
    """
    weight_arr = _validate_uint64_array("packed_weights", packed_weights)
    input_arr = _validate_uint64_array("packed_inputs", packed_inputs)
    if weight_arr.ndim != 3:
        raise ValueError(f"packed_weights must be 3-D, got {weight_arr.ndim}-D")
    if input_arr.ndim != 2:
        raise ValueError(f"packed_inputs must be 2-D, got {input_arr.ndim}-D")
    if weight_arr.shape[1] != input_arr.shape[0]:
        raise ValueError(
            f"packed_weights input dimension {weight_arr.shape[1]} does not match "
            f"packed_inputs input dimension {input_arr.shape[0]}"
        )
    if weight_arr.shape[2] != input_arr.shape[1]:
        raise ValueError(
            f"packed_weights word dimension {weight_arr.shape[2]} does not match "
            f"packed_inputs word dimension {input_arr.shape[1]}"
        )
    result: jax.Array = _jax_vec_mac_impl(weight_arr, input_arr)
    return result

jax_lif_step(v, I_t, v_rest, v_reset, v_threshold, alpha, resistance, noise)

Vectorized LIF step using JAX with fail-closed public input guards.

dv = (v_rest - v) * alpha + I_t * resistance + noise

Source code in src/sc_neurocore/accel/jax_backend.py
Python
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
def jax_lif_step(
    v: Any,
    I_t: Any,
    v_rest: float,
    v_reset: float,
    v_threshold: float,
    alpha: float,
    resistance: float,
    noise: Any,
) -> tuple[jax.Array, jax.Array]:
    """
    Vectorized LIF step using JAX with fail-closed public input guards.

    dv = (v_rest - v) * alpha + I_t * resistance + noise
    """
    _validate_finite_scalar("v_rest", v_rest)
    _validate_finite_scalar("v_reset", v_reset)
    _validate_finite_scalar("v_threshold", v_threshold)
    _validate_positive_finite_scalar("alpha", alpha)
    _validate_positive_finite_scalar("resistance", resistance)
    v_arr = _validate_lif_array("v", v)
    current_arr = _validate_lif_array("I_t", I_t, expected_shape=v_arr.shape)
    noise_arr = _validate_lif_array("noise", noise, expected_shape=v_arr.shape)
    result: tuple[jax.Array, jax.Array] = _jax_lif_step_impl(
        v_arr,
        current_arr,
        v_rest,
        v_reset,
        v_threshold,
        alpha,
        resistance,
        noise_arr,
    )
    return result

jax_forward_pass(weights, x, n_steps, v_rest=0.0, v_reset=0.0, v_threshold=1.0, alpha=0.9)

Multi-layer SNN forward pass with LIF neurons.

Returns (spike_trains_per_layer, final_membrane_potentials). Each layer: s = Heaviside(v - threshold), v = alpha * v * (1-s) + W @ s_prev

Source code in src/sc_neurocore/accel/jax_backend.py
Python
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def jax_forward_pass(
    weights: list[jax.Array],
    x: Any,
    n_steps: int,
    v_rest: float = 0.0,
    v_reset: float = 0.0,
    v_threshold: float = 1.0,
    alpha: float = 0.9,
) -> tuple[list[jax.Array], jax.Array]:
    """
    Multi-layer SNN forward pass with LIF neurons.

    Returns (spike_trains_per_layer, final_membrane_potentials).
    Each layer: s = Heaviside(v - threshold), v = alpha * v * (1-s) + W @ s_prev
    """
    weights, x = _validate_forward_inputs(
        weights=weights,
        x=x,
        n_steps=n_steps,
        v_rest=v_rest,
        v_reset=v_reset,
        v_threshold=v_threshold,
        alpha=alpha,
    )
    batch = x.shape[0]
    spikes = x
    all_spikes = []

    for W in weights:
        n_out = W.shape[0]
        v = jnp.full((batch, n_out), v_rest)
        layer_spikes = []

        for _t in range(n_steps):
            current = spikes @ W.T
            v = alpha * v * (1.0 - v_reset) + current
            s = (v >= v_threshold).astype(jnp.float32)
            v = jnp.where(s > 0.5, v_reset, v)
            layer_spikes.append(s)

        # Output spikes = mean firing rate over time
        spikes = jnp.stack(layer_spikes, axis=0).mean(axis=0)
        all_spikes.append(jnp.stack(layer_spikes, axis=0))

    return all_spikes, v

jax_surrogate_loss(weights, x, targets, n_steps=25, beta=10.0, threshold=1.0, surrogate_path='custom_vjp')

Cross-entropy loss for JAX SNN training with explicit surrogate paths.

Available paths: - custom_vjp: hard spikes forward, fast-sigmoid proxy backward via jax.custom_vjp - legacy_stop_gradient: historical straight-through reset path using jax.lax.stop_gradient

Source code in src/sc_neurocore/accel/jax_backend.py
Python
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
def jax_surrogate_loss(
    weights: list[jax.Array],
    x: jax.Array,
    targets: jax.Array,
    n_steps: int = 25,
    beta: float = 10.0,
    threshold: float = 1.0,
    surrogate_path: str = "custom_vjp",
) -> jax.Array:
    """
    Cross-entropy loss for JAX SNN training with explicit surrogate paths.

    Available paths:
    - ``custom_vjp``: hard spikes forward, fast-sigmoid proxy backward
      via ``jax.custom_vjp``
    - ``legacy_stop_gradient``: historical straight-through reset path
      using ``jax.lax.stop_gradient``
    """

    weights, x, targets = _validate_surrogate_inputs(
        weights=weights,
        x=x,
        targets=targets,
        n_steps=n_steps,
        beta=beta,
        threshold=threshold,
        surrogate_path=surrogate_path,
    )

    if surrogate_path == "custom_vjp":
        return _jax_loss_with_custom_vjp_surrogate(
            weights=weights,
            x=x,
            targets=targets,
            n_steps=n_steps,
            beta=beta,
            threshold=threshold,
        )
    return _jax_loss_with_legacy_stop_gradient_surrogate(
        weights=weights,
        x=x,
        targets=targets,
        n_steps=n_steps,
        beta=beta,
        threshold=threshold,
    )

jax_surrogate_gradient_step(weights, x, targets, n_steps=25, lr=0.001, beta=10.0, threshold=1.0, surrogate_path='custom_vjp')

One training step with surrogate gradient over an explicit JAX path.

custom_vjp is the modern path. legacy_stop_gradient keeps the historical training route available for side-by-side verification.

Source code in src/sc_neurocore/accel/jax_backend.py
Python
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def jax_surrogate_gradient_step(
    weights: list[jax.Array],
    x: jax.Array,
    targets: jax.Array,
    n_steps: int = 25,
    lr: float = 1e-3,
    beta: float = 10.0,
    threshold: float = 1.0,
    surrogate_path: str = "custom_vjp",
) -> tuple[list[jax.Array], float]:
    """
    One training step with surrogate gradient over an explicit JAX path.

    ``custom_vjp`` is the modern path. ``legacy_stop_gradient`` keeps the
    historical training route available for side-by-side verification.
    """

    _validate_positive_finite_scalar("lr", lr)
    weights, x, targets = _validate_surrogate_inputs(
        weights=weights,
        x=x,
        targets=targets,
        n_steps=n_steps,
        beta=beta,
        threshold=threshold,
        surrogate_path=surrogate_path,
    )

    def loss_fn(ws: list[jax.Array]) -> jax.Array:
        return jax_surrogate_loss(
            weights=ws,
            x=x,
            targets=targets,
            n_steps=n_steps,
            beta=beta,
            threshold=threshold,
            surrogate_path=surrogate_path,
        )

    loss_val, grads = jax.value_and_grad(loss_fn)(weights)
    updated = [w - lr * g for w, g in zip(weights, grads)]
    return updated, float(loss_val)

JIT Kernels

sc_neurocore.accel.jit_kernels

jit_pack_bits(bitstream, packed_arr)

Packs a uint8 bitstream into uint64 array. bitstream: (N,) uint8 {0, 1} packed_arr: (N//64,) uint64

Source code in src/sc_neurocore/accel/jit_kernels.py
Python
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@jit(nopython=True)  # type: ignore[untyped-decorator]
def jit_pack_bits(
    bitstream: np.ndarray[Any, Any], packed_arr: np.ndarray[Any, Any]
) -> None:  # pragma: no cover
    """
    Packs a uint8 bitstream into uint64 array.
    bitstream: (N,) uint8 {0, 1}
    packed_arr: (N//64,) uint64
    """
    n = bitstream.size
    n_packed = n // 64

    for i in range(n_packed):
        val = np.uint64(0)
        base = i * 64
        for j in range(64):
            if bitstream[base + j] > 0:
                val |= np.uint64(1) << np.uint64(j)
        packed_arr[i] = val

jit_vec_mac(packed_weights, packed_inputs, outputs)

Vectorized Multiply-Accumulate (MAC). Simulates: Output[i] = Sum(Weights[i] AND Inputs) weights: (n_neurons, n_inputs, n_words) inputs: (n_inputs, n_words) outputs: (n_neurons,)

Source code in src/sc_neurocore/accel/jit_kernels.py
Python
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@jit(nopython=True)  # type: ignore[untyped-decorator]
def jit_vec_mac(
    packed_weights: np.ndarray[Any, Any],
    packed_inputs: np.ndarray[Any, Any],
    outputs: np.ndarray[Any, Any],
) -> None:  # pragma: no cover
    """
    Vectorized Multiply-Accumulate (MAC).
    Simulates: Output[i] = Sum(Weights[i] AND Inputs)
    weights: (n_neurons, n_inputs, n_words)
    inputs: (n_inputs, n_words)
    outputs: (n_neurons,)
    """
    n_neurons = packed_weights.shape[0]
    n_inputs = packed_weights.shape[1]
    n_words = packed_weights.shape[2]

    for i in range(n_neurons):
        total_bits = 0
        for j in range(n_inputs):
            for k in range(n_words):
                # Bitwise AND = SC Multiplication
                res = packed_weights[i, j, k] & packed_inputs[j, k]

                # Popcount (Hamming Weight)
                # SWAR Algorithm for 64-bit popcount (Safe for Numba nopython mode)
                x = res
                x = x - ((x >> np.uint64(1)) & np.uint64(0x5555555555555555))
                x = (x & np.uint64(0x3333333333333333)) + (
                    (x >> np.uint64(2)) & np.uint64(0x3333333333333333)
                )
                x = (x + (x >> np.uint64(4))) & np.uint64(0x0F0F0F0F0F0F0F0F)
                x = (x * np.uint64(0x0101010101010101)) >> np.uint64(56)

                total_bits += x
        outputs[i] = total_bits

MPI Driver

sc_neurocore.accel.mpi_driver

MPIDriver

Distributed SC-NeuroCore Driver using MPI. Handles partitioning and synchronization of bitstreams across cluster nodes.

Source code in src/sc_neurocore/accel/mpi_driver.py
Python
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class MPIDriver:
    """
    Distributed SC-NeuroCore Driver using MPI.
    Handles partitioning and synchronization of bitstreams across cluster nodes.
    """

    def __init__(self) -> None:
        self.comm: Any | None
        if HAS_MPI:  # pragma: no cover
            self.comm = MPI.COMM_WORLD
            self.rank = self.comm.Get_rank()
            self.size = self.comm.Get_size()
        else:
            self.comm = None
            self.rank = 0
            self.size = 1

    def scatter_workload(self, global_inputs: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
        """
        Distributes a large input array across nodes.
        Splits along axis 0 (Batch or Neurons).
        """
        if not HAS_MPI or self.size == 1:
            return global_inputs
        comm = self.comm
        if comm is None:  # pragma: no cover
            return global_inputs

        # MPI multi-node path  # pragma: no cover
        total_len = len(global_inputs)  # pragma: no cover
        chunk_size = total_len // self.size  # pragma: no cover
        local_input = np.zeros(chunk_size, dtype=global_inputs.dtype)  # pragma: no cover
        comm.Scatter(global_inputs, local_input, root=0)  # pragma: no cover
        return local_input  # pragma: no cover

    def gather_results(self, local_results: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
        """
        Collects results from all nodes to Root.
        """
        if not HAS_MPI or self.size == 1:
            return local_results
        comm = self.comm
        if comm is None:  # pragma: no cover
            return local_results

        # MPI multi-node path  # pragma: no cover
        total_len = len(local_results) * self.size  # pragma: no cover
        global_results = None  # pragma: no cover
        if self.rank == 0:  # pragma: no cover
            global_results = np.zeros(total_len, dtype=local_results.dtype)  # pragma: no cover
        comm.Gather(local_results, global_results, root=0)  # pragma: no cover
        if global_results is None:
            return np.zeros(0)
        return global_results

    def barrier(self) -> None:
        """Synchronize all nodes."""
        if HAS_MPI and self.comm is not None:  # pragma: no cover
            self.comm.Barrier()

scatter_workload(global_inputs)

Distributes a large input array across nodes. Splits along axis 0 (Batch or Neurons).

Source code in src/sc_neurocore/accel/mpi_driver.py
Python
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def scatter_workload(self, global_inputs: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
    """
    Distributes a large input array across nodes.
    Splits along axis 0 (Batch or Neurons).
    """
    if not HAS_MPI or self.size == 1:
        return global_inputs
    comm = self.comm
    if comm is None:  # pragma: no cover
        return global_inputs

    # MPI multi-node path  # pragma: no cover
    total_len = len(global_inputs)  # pragma: no cover
    chunk_size = total_len // self.size  # pragma: no cover
    local_input = np.zeros(chunk_size, dtype=global_inputs.dtype)  # pragma: no cover
    comm.Scatter(global_inputs, local_input, root=0)  # pragma: no cover
    return local_input  # pragma: no cover

gather_results(local_results)

Collects results from all nodes to Root.

Source code in src/sc_neurocore/accel/mpi_driver.py
Python
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def gather_results(self, local_results: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
    """
    Collects results from all nodes to Root.
    """
    if not HAS_MPI or self.size == 1:
        return local_results
    comm = self.comm
    if comm is None:  # pragma: no cover
        return local_results

    # MPI multi-node path  # pragma: no cover
    total_len = len(local_results) * self.size  # pragma: no cover
    global_results = None  # pragma: no cover
    if self.rank == 0:  # pragma: no cover
        global_results = np.zeros(total_len, dtype=local_results.dtype)  # pragma: no cover
    comm.Gather(local_results, global_results, root=0)  # pragma: no cover
    if global_results is None:
        return np.zeros(0)
    return global_results

barrier()

Synchronize all nodes.

Source code in src/sc_neurocore/accel/mpi_driver.py
Python
77
78
79
80
def barrier(self) -> None:
    """Synchronize all nodes."""
    if HAS_MPI and self.comm is not None:  # pragma: no cover
        self.comm.Barrier()