Skip to content

Tutorial 37: JAX JIT Training

SC-NeuroCore supports JAX for JIT-compiled, GPU-accelerated SNN training. JAX's functional transformation model (grad, jit, vmap) enables efficient surrogate gradient computation on stochastic computing layers.

Prerequisites

Bash
pip install sc-neurocore[jax]
# or: pip install jax jaxlib

1. JAX Dense Layer

The JaxSCDenseLayer provides a JAX-compatible SC dense layer. Inputs of shape (n_inputs,) are projected through the layer weights into neuron currents; direct current vectors of shape (n_neurons,) remain supported for low-level experiments.

Python
import jax.numpy as jnp
from sc_neurocore.layers.jax_dense_layer import JaxSCDenseLayer

weights = jnp.asarray(
    [
        [1.0, 0.0, 0.2, 0.0],
        [0.0, 0.8, 0.0, 0.2],
        [0.5, 0.5, 0.0, 0.0],
    ],
    dtype=jnp.float32,
)
layer = JaxSCDenseLayer(
    n_inputs=4,
    n_neurons=3,
    bitstream_length=256,
    weights=weights,
    seed=42,
)

# Single input vector: shape (n_inputs,)
spikes = layer.step(jnp.asarray([0.9, 0.1, 0.5, 0.0], dtype=jnp.float32))
print(f"Step spikes: {spikes}")

# Time sequence: shape (T, n_inputs)
spike_train = layer.run(
    jnp.asarray(
        [
            [0.9, 0.1, 0.5, 0.0],
            [0.0, 0.8, 0.0, 0.2],
        ],
        dtype=jnp.float32,
    )
)
print(f"Spike train shape: {spike_train.shape}")

Constructor and runtime inputs fail closed: dimensions must be positive, weights must have shape (n_neurons, n_inputs), seed must fit the JAX PRNG range, neuron parameters must be known finite values, and runtime arrays must be floating-point, finite, non-empty, and shape-compatible.

2. Surrogate Gradient Training

SC-NeuroCore exposes two explicit JAX surrogate paths:

  • custom_vjp — hard spikes forward, fast-sigmoid proxy backward
  • legacy_stop_gradient — historical straight-through reset path

For new work, prefer custom_vjp:

Python
import jax.numpy as jnp
from sc_neurocore.accel.jax_backend import jax_surrogate_gradient_step

weights = [jnp.asarray([[0.4, -0.1], [0.2, 0.5]], dtype=jnp.float32)]
x = jnp.asarray([[1.0, 0.2], [0.3, 1.1]], dtype=jnp.float32)
targets = jnp.asarray([[1.0, 0.0], [0.0, 1.0]], dtype=jnp.float32)

updated, loss_value = jax_surrogate_gradient_step(
    weights,
    x,
    targets,
    n_steps=5,
    lr=1e-2,
    surrogate_path="custom_vjp",
)
print(f"Loss after one step: {loss_value:.4f}")

The legacy route remains available when you need direct comparison:

Python
updated_legacy, loss_legacy = jax_surrogate_gradient_step(
    weights,
    x,
    targets,
    n_steps=5,
    lr=1e-2,
    surrogate_path="legacy_stop_gradient",
)

3. Training Loop on Synthetic Data

Python
import jax
import jax.numpy as jnp
import numpy as np

# Parameters
n_inputs, n_hidden, n_outputs = 10, 32, 3
n_steps = 50
lr = 0.01

# Initialize weights
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
w1 = jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
w2 = jax.random.normal(k2, (n_hidden, n_outputs)) * 0.1

def forward(w1, w2, x_seq):
    """Forward pass: n_steps of LIF dynamics."""
    v1 = jnp.zeros(n_hidden)
    v2 = jnp.zeros(n_outputs)
    out_spikes = jnp.zeros(n_outputs)

    for t in range(n_steps):
        # Hidden layer
        I1 = x_seq[t] @ w1
        dv1 = (-v1 + I1) / 20.0
        v1 = v1 + dv1
        s1 = jax.nn.sigmoid(10.0 * (v1 - 1.0))
        v1 = v1 * (1.0 - s1)

        # Output layer
        I2 = s1 @ w2
        dv2 = (-v2 + I2) / 20.0
        v2 = v2 + dv2
        s2 = jax.nn.sigmoid(10.0 * (v2 - 1.0))
        v2 = v2 * (1.0 - s2)
        out_spikes = out_spikes + s2

    return out_spikes / n_steps

def loss_fn(w1, w2, x_seq, target):
    pred = forward(w1, w2, x_seq)
    return jnp.mean((pred - target) ** 2)

# JIT-compile the gradient function
grad_fn = jax.jit(jax.grad(loss_fn, argnums=(0, 1)))

# Training
for epoch in range(100):
    # Random input sequence and target
    x_seq = jax.random.uniform(jax.random.PRNGKey(epoch), (n_steps, n_inputs))
    target = jnp.array([1.0, 0.0, 0.0])

    g1, g2 = grad_fn(w1, w2, x_seq, target)
    w1 = w1 - lr * g1
    w2 = w2 - lr * g2

    if epoch % 20 == 0:
        l = loss_fn(w1, w2, x_seq, target)
        print(f"Epoch {epoch}: loss = {l:.4f}")

4. GPU Acceleration

JAX auto-detects GPU. Training runs on CUDA/ROCm without code changes:

Python
# Check if GPU is available
print(jax.devices())  # [GpuDevice(id=0)] or [CpuDevice(id=0)]

# vmap for batch training
batched_loss = jax.vmap(loss_fn, in_axes=(None, None, 0, 0))
# Process entire batch in parallel on GPU

5. Export Trained Weights to SC

After training, convert JAX weights to SC-NeuroCore format:

Python
import numpy as np
from sc_neurocore.layers.sc_dense_layer import SCDenseLayer

# Convert JAX arrays to NumPy
w1_np = np.array(w1)

# Normalize weights to [0, 1] for SC bitstream encoding
w_min, w_max = w1_np.min(), w1_np.max()
w_normalized = (w1_np - w_min) / (w_max - w_min)

# Create SC layer with trained weights
sc_layer = SCDenseLayer(n_inputs=n_inputs, n_neurons=n_hidden, length=512)
sc_layer.set_weights(w_normalized)

# Now run with stochastic bitstreams
output = sc_layer.forward([0.5] * n_inputs)
print(f"SC output: {output}")

6. Holonomic Adapters with JAX

The SCPN holonomic adapter ecosystem uses JAX for differentiable layer transforms:

Python
from sc_neurocore.adapters.holonomic.l1_quantum import L1_QuantumAdapter

adapter = L1_QuantumAdapter()
# step_jax() uses JAX JIT under the hood
state = adapter.step_jax(dt=0.01)

JAX vs PyTorch vs NumPy

Feature JAX PyTorch NumPy
Autodiff jax.grad (functional) loss.backward() (imperative) Manual
JIT jax.jit (trace-based) torch.compile N/A
GPU Automatic .to("cuda") N/A (use CuPy)
Batch jax.vmap DataLoader Manual
SC integration JaxSCDenseLayer SpikingNetto_sc_weights() VectorizedSCLayer
Best for Research, custom surrogates Production training Inference, co-sim

Runnable Demo

Bash
PYTHONPATH=src python examples/jax_training_demo.py

Further Reading