Kuramoto JAX autodiff tier¶
The networked-Kuramoto forward integrators dispatch Rust → Julia → NumPy, and each carries a hand-written reverse-mode adjoint. This page documents a fourth kind of tier — the same RK4 solve expressed in JAX, where the gradient comes from automatic differentiation rather than a hand-derived scheme, and the whole solve runs on whatever accelerator JAX selected (a CUDA GPU when one is present). It is Slice 1 of the 7.7 backend programme (one facade, two tiers: a JAX autodiff tier and, later, a hand-tuned GPU-kernel tier).
The tier is opt-in: accel.jax_kuramoto.jax_kuramoto_rk4_trajectory and
jax_kuramoto_rk4_gradient are directly callable accelerated paths, re-exported through the
kuramoto facade, but they are not members of the default dispatch chain. The default
kuramoto_rk4_trajectory still serves the Rust tier, so every existing behaviour is unchanged.
Two claims, verified¶
- Forward faithfulness (reproducible). With 64-bit precision enabled, the JAX RK4 forward matches
the production Rust integrator to machine precision. The committed benchmark
(
docs/benchmarks/kuramoto_jax_tier.json) recordsparity_max_abs_diff = 8.66e-15for a 256- oscillator network over 200 steps on a CUDA GPU. Parity is asserted under a tolerance (not as bit-identity) because GPU reduction ordering need not equal NumPy's; here it lands at machine precision. - Gradient faithfulness (reproducible). The autodiff gradient —
∂L/∂θ₀,∂L/∂ω,∂L/∂Kfromjax.vjpof the forward solve — matches the hand-derivedkuramoto_rk4_vjpto machine precision (~1e-15) across networks of 6 to 64 oscillators (tests/test_jax_kuramoto.py). The autodiff tier therefore both verifies the hand-written adjoint and supplies the same gradient for objectives whose adjoint would be laborious to derive by hand.
Batched ensembles (vmap)¶
jax_kuramoto_rk4_ensemble and jax_kuramoto_rk4_ensemble_gradient solve — and differentiate — a
whole batch of B initial conditions in a single accelerator call, by jax.vmap over the batch
axis. This is a vectorisation of the entire solve, not just the inner force evaluation, and the
NumPy and Rust tiers cannot express it: they would loop over the ensemble one member at a time.
The reproducible guarantee is that batching changes nothing but the layout — each batched member is
bit-for-bit identical to its single-initial-condition jax_kuramoto_rk4_trajectory /
jax_kuramoto_rk4_gradient, and the batched gradient matches the per-member single gradient to
machine precision (tests/test_jax_kuramoto.py). The committed artefact records the ensemble parity
alongside advisory batched-versus-sequential timings (ensemble_forward_us versus
sequential_forward_us); the throughput factor is advisory and host/GPU-bound, the parity is the
reproducible quantity. This is the capability that makes the tier useful for Monte-Carlo basin studies
and machine-learning pipelines that evaluate many initial conditions or parameters at once.
Wall clock (host- and GPU-dependent, boundary-guarded — not a claim)¶
The benchmark also records advisory per-call timings. On the recorded host (an 11th Gen Intel Core
i5-11600K with a CUDA GPU) the JAX tier integrated the 256-oscillator, 200-step problem with a median
of about 29 ms against about 1082 ms for the production tier on the same host. These milliseconds are
excluded from any performance claim (production_claim_allowed: false): they depend on the host,
the governor, the GPU model and its clock, and the JIT warm-up. A clean absolute number needs a
quiesced, reserved host with a fixed GPU clock. The reproducible quantity is the parity, not the
milliseconds.
Requirements and precision¶
The tier requires the optional [jax] extra (pip install scpn-quantum-control[jax],
jax[cuda12] for a GPU). JAX is imported lazily behind a guard, so the pure-NumPy/Rust core stays
importable without it; calling the tier without JAX raises ImportError with the install hint — the
exception the dispatcher treats as "fall through", so a later slice may place this tier in a
size-aware accelerated chain. 64-bit precision (jax_enable_x64) is a global JAX flag the tier sets
lazily before tracing; it enables float64 support without forcing other JAX code off float32.
Reproduce¶
This writes docs/benchmarks/kuramoto_jax_tier.json with the parity, the advisory timings, the JAX
version and device, and the full host provenance. In continuous integration JAX is CPU-only
(jax[cpu]), where the same 64-bit parity holds within tolerance; the GPU path is exercised locally.
Related¶
- Kuramoto Tier Benchmark — the Rust / Julia / Python tier provenance for the default dispatch.
- Kuramoto Handbook — the facade the tier's forward and gradient are re-exported through.