Skip to content

Tutorial 37: JAX JIT Training

SC-NeuroCore supports JAX for JIT-compiled, GPU-accelerated SNN training. JAX's functional transformation model (grad, jit, vmap) enables efficient surrogate gradient computation on stochastic computing layers.

Prerequisites

pip install sc-neurocore[jax]
# or: pip install jax jaxlib

1. JAX Dense Layer

The JaxSCDenseLayer provides a JAX-compatible SC dense layer:

import jax
import jax.numpy as jnp
from sc_neurocore.layers.jax_dense_layer import JaxSCDenseLayer

layer = JaxSCDenseLayer(n_inputs=8, n_neurons=4, bitstream_length=256, seed=42)

# Step pass — JIT-compiled
inputs = jnp.array([0.3, 0.5, 0.7, 0.2, 0.8, 0.1, 0.6, 0.4])
output = layer.step(inputs)
print(f"Output: {output}")

2. Surrogate Gradient Training

JAX's grad computes gradients through the non-differentiable spike function using a straight-through estimator:

import jax
import jax.numpy as jnp
from sc_neurocore.accel.jax_backend import to_jax

def lif_step(v, current, threshold=1.0, tau=20.0, dt=1.0):
    """Single LIF step with straight-through surrogate gradient."""
    dv = (-v + current) * dt / tau
    v_new = v + dv
    # Spike: hard threshold forward, surrogate backward
    spike = (v_new >= threshold).astype(jnp.float32)
    # Straight-through: gradient of spike ≈ gradient of sigmoid
    spike_surrogate = jax.nn.sigmoid(10.0 * (v_new - threshold))
    spike = spike - jax.lax.stop_gradient(spike - spike_surrogate)
    v_reset = v_new * (1.0 - spike)
    return v_reset, spike

# Differentiable through the spike function
grad_fn = jax.grad(lambda v, I: lif_step(v, I)[1].sum())
gradient = grad_fn(jnp.float32(0.8), jnp.float32(1.5))
print(f"Gradient of spikes w.r.t. voltage: {gradient:.4f}")

3. Training Loop on Synthetic Data

import jax
import jax.numpy as jnp
import numpy as np

# Parameters
n_inputs, n_hidden, n_outputs = 10, 32, 3
n_steps = 50
lr = 0.01

# Initialize weights
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
w1 = jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
w2 = jax.random.normal(k2, (n_hidden, n_outputs)) * 0.1

def forward(w1, w2, x_seq):
    """Forward pass: n_steps of LIF dynamics."""
    v1 = jnp.zeros(n_hidden)
    v2 = jnp.zeros(n_outputs)
    out_spikes = jnp.zeros(n_outputs)

    for t in range(n_steps):
        # Hidden layer
        I1 = x_seq[t] @ w1
        dv1 = (-v1 + I1) / 20.0
        v1 = v1 + dv1
        s1 = jax.nn.sigmoid(10.0 * (v1 - 1.0))
        v1 = v1 * (1.0 - s1)

        # Output layer
        I2 = s1 @ w2
        dv2 = (-v2 + I2) / 20.0
        v2 = v2 + dv2
        s2 = jax.nn.sigmoid(10.0 * (v2 - 1.0))
        v2 = v2 * (1.0 - s2)
        out_spikes = out_spikes + s2

    return out_spikes / n_steps

def loss_fn(w1, w2, x_seq, target):
    pred = forward(w1, w2, x_seq)
    return jnp.mean((pred - target) ** 2)

# JIT-compile the gradient function
grad_fn = jax.jit(jax.grad(loss_fn, argnums=(0, 1)))

# Training
for epoch in range(100):
    # Random input sequence and target
    x_seq = jax.random.uniform(jax.random.PRNGKey(epoch), (n_steps, n_inputs))
    target = jnp.array([1.0, 0.0, 0.0])

    g1, g2 = grad_fn(w1, w2, x_seq, target)
    w1 = w1 - lr * g1
    w2 = w2 - lr * g2

    if epoch % 20 == 0:
        l = loss_fn(w1, w2, x_seq, target)
        print(f"Epoch {epoch}: loss = {l:.4f}")

4. GPU Acceleration

JAX auto-detects GPU. Training runs on CUDA/ROCm without code changes:

# Check if GPU is available
print(jax.devices())  # [GpuDevice(id=0)] or [CpuDevice(id=0)]

# vmap for batch training
batched_loss = jax.vmap(loss_fn, in_axes=(None, None, 0, 0))
# Process entire batch in parallel on GPU

5. Export Trained Weights to SC

After training, convert JAX weights to SC-NeuroCore format:

import numpy as np
from sc_neurocore.layers.sc_dense_layer import SCDenseLayer

# Convert JAX arrays to NumPy
w1_np = np.array(w1)

# Normalize weights to [0, 1] for SC bitstream encoding
w_min, w_max = w1_np.min(), w1_np.max()
w_normalized = (w1_np - w_min) / (w_max - w_min)

# Create SC layer with trained weights
sc_layer = SCDenseLayer(n_inputs=n_inputs, n_neurons=n_hidden, length=512)
sc_layer.set_weights(w_normalized)

# Now run with stochastic bitstreams
output = sc_layer.forward([0.5] * n_inputs)
print(f"SC output: {output}")

6. Holonomic Adapters with JAX

The SCPN holonomic adapter ecosystem uses JAX for differentiable layer transforms:

from sc_neurocore.adapters.holonomic.l1_quantum import L1_QuantumAdapter

adapter = L1_QuantumAdapter()
# step_jax() uses JAX JIT under the hood
state = adapter.step_jax(dt=0.01)

JAX vs PyTorch vs NumPy

Feature JAX PyTorch NumPy
Autodiff jax.grad (functional) loss.backward() (imperative) Manual
JIT jax.jit (trace-based) torch.compile N/A
GPU Automatic .to("cuda") N/A (use CuPy)
Batch jax.vmap DataLoader Manual
SC integration JaxSCDenseLayer SpikingNetto_sc_weights() VectorizedSCLayer
Best for Research, custom surrogates Production training Inference, co-sim

Runnable Demo

PYTHONPATH=src python examples/jax_training_demo.py

Further Reading