Tutorial 37: JAX JIT Training¶
SC-NeuroCore supports JAX for JIT-compiled, GPU-accelerated SNN training. JAX's functional transformation model (grad, jit, vmap) enables efficient surrogate gradient computation on stochastic computing layers.
Prerequisites¶
pip install sc-neurocore[jax]
# or: pip install jax jaxlib
1. JAX Dense Layer¶
The JaxSCDenseLayer provides a JAX-compatible SC dense layer. Inputs of
shape (n_inputs,) are projected through the layer weights into neuron
currents; direct current vectors of shape (n_neurons,) remain supported for
low-level experiments.
import jax.numpy as jnp
from sc_neurocore.layers.jax_dense_layer import JaxSCDenseLayer
weights = jnp.asarray(
[
[1.0, 0.0, 0.2, 0.0],
[0.0, 0.8, 0.0, 0.2],
[0.5, 0.5, 0.0, 0.0],
],
dtype=jnp.float32,
)
layer = JaxSCDenseLayer(
n_inputs=4,
n_neurons=3,
bitstream_length=256,
weights=weights,
seed=42,
)
# Single input vector: shape (n_inputs,)
spikes = layer.step(jnp.asarray([0.9, 0.1, 0.5, 0.0], dtype=jnp.float32))
print(f"Step spikes: {spikes}")
# Time sequence: shape (T, n_inputs)
spike_train = layer.run(
jnp.asarray(
[
[0.9, 0.1, 0.5, 0.0],
[0.0, 0.8, 0.0, 0.2],
],
dtype=jnp.float32,
)
)
print(f"Spike train shape: {spike_train.shape}")
Constructor and runtime inputs fail closed: dimensions must be positive,
weights must have shape (n_neurons, n_inputs), seed must fit the JAX PRNG
range, neuron parameters must be known finite values, and runtime arrays must
be floating-point, finite, non-empty, and shape-compatible.
2. Surrogate Gradient Training¶
SC-NeuroCore exposes two explicit JAX surrogate paths:
custom_vjp— hard spikes forward, fast-sigmoid proxy backwardlegacy_stop_gradient— historical straight-through reset path
For new work, prefer custom_vjp:
import jax.numpy as jnp
from sc_neurocore.accel.jax_backend import jax_surrogate_gradient_step
weights = [jnp.asarray([[0.4, -0.1], [0.2, 0.5]], dtype=jnp.float32)]
x = jnp.asarray([[1.0, 0.2], [0.3, 1.1]], dtype=jnp.float32)
targets = jnp.asarray([[1.0, 0.0], [0.0, 1.0]], dtype=jnp.float32)
updated, loss_value = jax_surrogate_gradient_step(
weights,
x,
targets,
n_steps=5,
lr=1e-2,
surrogate_path="custom_vjp",
)
print(f"Loss after one step: {loss_value:.4f}")
The legacy route remains available when you need direct comparison:
updated_legacy, loss_legacy = jax_surrogate_gradient_step(
weights,
x,
targets,
n_steps=5,
lr=1e-2,
surrogate_path="legacy_stop_gradient",
)
3. Training Loop on Synthetic Data¶
import jax
import jax.numpy as jnp
import numpy as np
# Parameters
n_inputs, n_hidden, n_outputs = 10, 32, 3
n_steps = 50
lr = 0.01
# Initialize weights
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
w1 = jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
w2 = jax.random.normal(k2, (n_hidden, n_outputs)) * 0.1
def forward(w1, w2, x_seq):
"""Forward pass: n_steps of LIF dynamics."""
v1 = jnp.zeros(n_hidden)
v2 = jnp.zeros(n_outputs)
out_spikes = jnp.zeros(n_outputs)
for t in range(n_steps):
# Hidden layer
I1 = x_seq[t] @ w1
dv1 = (-v1 + I1) / 20.0
v1 = v1 + dv1
s1 = jax.nn.sigmoid(10.0 * (v1 - 1.0))
v1 = v1 * (1.0 - s1)
# Output layer
I2 = s1 @ w2
dv2 = (-v2 + I2) / 20.0
v2 = v2 + dv2
s2 = jax.nn.sigmoid(10.0 * (v2 - 1.0))
v2 = v2 * (1.0 - s2)
out_spikes = out_spikes + s2
return out_spikes / n_steps
def loss_fn(w1, w2, x_seq, target):
pred = forward(w1, w2, x_seq)
return jnp.mean((pred - target) ** 2)
# JIT-compile the gradient function
grad_fn = jax.jit(jax.grad(loss_fn, argnums=(0, 1)))
# Training
for epoch in range(100):
# Random input sequence and target
x_seq = jax.random.uniform(jax.random.PRNGKey(epoch), (n_steps, n_inputs))
target = jnp.array([1.0, 0.0, 0.0])
g1, g2 = grad_fn(w1, w2, x_seq, target)
w1 = w1 - lr * g1
w2 = w2 - lr * g2
if epoch % 20 == 0:
l = loss_fn(w1, w2, x_seq, target)
print(f"Epoch {epoch}: loss = {l:.4f}")
4. GPU Acceleration¶
JAX auto-detects GPU. Training runs on CUDA/ROCm without code changes:
# Check if GPU is available
print(jax.devices()) # [GpuDevice(id=0)] or [CpuDevice(id=0)]
# vmap for batch training
batched_loss = jax.vmap(loss_fn, in_axes=(None, None, 0, 0))
# Process entire batch in parallel on GPU
5. Export Trained Weights to SC¶
After training, convert JAX weights to SC-NeuroCore format:
import numpy as np
from sc_neurocore.layers.sc_dense_layer import SCDenseLayer
# Convert JAX arrays to NumPy
w1_np = np.array(w1)
# Normalize weights to [0, 1] for SC bitstream encoding
w_min, w_max = w1_np.min(), w1_np.max()
w_normalized = (w1_np - w_min) / (w_max - w_min)
# Create SC layer with trained weights
sc_layer = SCDenseLayer(n_inputs=n_inputs, n_neurons=n_hidden, length=512)
sc_layer.set_weights(w_normalized)
# Now run with stochastic bitstreams
output = sc_layer.forward([0.5] * n_inputs)
print(f"SC output: {output}")
6. Holonomic Adapters with JAX¶
The SCPN holonomic adapter ecosystem uses JAX for differentiable layer transforms:
from sc_neurocore.adapters.holonomic.l1_quantum import L1_QuantumAdapter
adapter = L1_QuantumAdapter()
# step_jax() uses JAX JIT under the hood
state = adapter.step_jax(dt=0.01)
JAX vs PyTorch vs NumPy¶
| Feature | JAX | PyTorch | NumPy |
|---|---|---|---|
| Autodiff | jax.grad (functional) |
loss.backward() (imperative) |
Manual |
| JIT | jax.jit (trace-based) |
torch.compile |
N/A |
| GPU | Automatic | .to("cuda") |
N/A (use CuPy) |
| Batch | jax.vmap |
DataLoader |
Manual |
| SC integration | JaxSCDenseLayer |
SpikingNet → to_sc_weights() |
VectorizedSCLayer |
| Best for | Research, custom surrogates | Production training | Inference, co-sim |
Runnable Demo¶
PYTHONPATH=src python examples/jax_training_demo.py
Further Reading¶
- Tutorial 03: Surrogate Gradient Training — PyTorch path
- Tutorial 20: Adapter Ecosystem — JAX holonomic adapters
- API: Acceleration — JAX/CuPy/MPI backend API