SPDX-License-Identifier: AGPL-3.0-or-later¶
Commercial license available¶
© Concepts 1996–2026 Miroslav Šotek. All rights reserved.¶
© Code 2020–2026 Miroslav Šotek. All rights reserved.¶
ORCID: 0009-0009-3560-0851¶
Contact: www.anulum.li | protoscience@anulum.li¶
scpn-quantum-control — Neural Quantum States Documentation¶
Neural Quantum States¶
Two modules for variational ground state search using neural network wavefunctions:
- RBM ansatz (
phase/nqs_ansatz.py) — Restricted Boltzmann Machine with numpy, exact mode for \(n \leq 12\) - JAX-accelerated NQS (
phase/jax_nqs.py) — same RBM with JAX auto-differentiation, ~100× faster gradients
Caveat: These are pedagogical/research implementations for the Kuramoto-XY system. For production NQS at scale, use NetKet (Vicentini et al., 2022).
Theory¶
Restricted Boltzmann Machine Wavefunction¶
The RBM ansatz parameterises the wavefunction as:
where \(\sigma \in \{+1, -1\}^n\) is a spin configuration, \(a_i\) are visible biases, \(b_j\) are hidden biases, and \(W_{ji}\) are weights connecting \(n_\text{visible}\) spins to \(n_\text{hidden}\) hidden units.
Total parameters: \(n + n_h + n \cdot n_h\) (typically \(n_h = 2n\)).
Variational Monte Carlo (VMC)¶
The energy expectation is:
For small systems (\(n \leq 12\)), we evaluate the sum over all \(2^n\) configurations exactly. The gradient is computed via the parameter-shift rule (numpy) or automatic differentiation (JAX).
Why RBM?¶
Carleo & Troyer (Science, 2017) showed that RBMs can represent ground states of 1D and 2D spin models with accuracy competitive with DMRG. The universal approximation theorem guarantees that sufficiently wide RBMs can represent any quantum state.
In practice, convergence depends on the problem. For the XY model with moderate coupling, RBMs with \(n_h = 2n\) typically converge within 100–500 iterations.
Part 1: RBM Ansatz (NumPy)¶
scpn_quantum_control.phase.nqs_ansatz
API Reference¶
RBMWavefunction Class¶
from scpn_quantum_control.phase.nqs_ansatz import RBMWavefunction
rbm = RBMWavefunction(
n_visible: int, # number of spins
n_hidden: int | None = None, # hidden units (default: 2 * n_visible)
seed: int | None = None, # RNG seed
)
Methods:
| Method | Signature | Returns | Description |
|---|---|---|---|
log_psi(sigma) |
ndarray → complex |
\(\log\psi(\sigma)\) | Log-amplitude for configuration \(\sigma \in \{+1,-1\}^n\) |
psi(sigma) |
ndarray → complex |
\(\psi(\sigma)\) | Amplitude (exponentiated) |
all_amplitudes() |
() → ndarray |
Shape \((2^n,)\) | All amplitudes (exact, for \(n \leq 12\)) |
n_params() |
() → int |
Total parameter count | \(n + n_h + n \cdot n_h\) |
vmc_ground_state¶
from scpn_quantum_control.phase.nqs_ansatz import vmc_ground_state
result = vmc_ground_state(
K: np.ndarray, # (n, n) coupling matrix
omega: np.ndarray, # (n,) frequencies
n_hidden: int | None = None, # hidden units (default: 2n)
learning_rate: float = 0.01, # gradient descent step
n_iterations: int = 200, # optimisation steps
n_samples: None = None, # sampling unsupported in this path
seed: int | None = None,
max_dense_gib: float | None = None,
) -> dict
max_dense_gib gates the exact dense Hamiltonian and full-configuration
statevector workspace before allocation.
Returns:
{
"energy": float, # final variational energy
"energy_history": list[float], # energy at each iteration
"wavefunction": RBMWavefunction, # trained RBM
"n_params": int, # total parameters
"sampling_mode": "exact_enumeration",
"n_samples_used": 2**n,
"gradient_method": "central_finite_difference",
}
Note: For \(n > 12\), raises ValueError (exact summation is \(O(2^n)\)).
Passing n_samples raises ValueError; this NumPy path is exact
enumeration with central finite-difference gradients, not sampled VMC.
Use NetKet for larger systems or MCMC sampling.
Example¶
import numpy as np
from scpn_quantum_control.phase.nqs_ansatz import vmc_ground_state
n = 6
K = 0.45 * np.exp(-0.3 * np.abs(np.subtract.outer(range(n), range(n))))
np.fill_diagonal(K, 0.0)
omega = np.linspace(0.8, 1.2, n)
result = vmc_ground_state(K, omega, n_iterations=300, seed=42)
print(f"VMC energy: {result['energy']:.6f}")
print(f"Parameters: {result['n_params']}")
# Compare with exact diagonalisation
from scpn_quantum_control.bridge.knm_hamiltonian import knm_to_dense_matrix
H = knm_to_dense_matrix(K, omega)
E_exact = np.linalg.eigvalsh(H)[0]
print(f"Exact energy: {E_exact:.6f}")
print(f"Relative error: {abs(result['energy'] - E_exact) / abs(E_exact):.2e}")
Part 2: JAX-Accelerated NQS¶
scpn_quantum_control.phase.jax_nqs
Why JAX?¶
The numpy VMC uses finite-difference gradients: \(2 \times n_\text{params}\)
function evaluations per iteration. JAX replaces this with automatic
differentiation via jax.grad — a single forward + backward pass.
Measured speedup: ~100× for gradient computation at \(n=8\).
API Reference¶
from scpn_quantum_control.phase.jax_nqs import (
is_jax_available,
jax_rbm_energy,
jax_vmc_ground_state,
)
jax_rbm_energy¶
energy = jax_rbm_energy(
params: dict[str, Any], # {'a': visible, 'b': hidden, 'W': weights}
H: jax.Array, # dense Hamiltonian
n: int, # number of qubits
) -> jax.Array
Differentiable energy expectation. Use with jax.grad for gradients.
jax_vmc_ground_state¶
result = jax_vmc_ground_state(
K: np.ndarray,
omega: np.ndarray,
n_hidden: int | None = None,
learning_rate: float = 0.01,
n_iterations: int = 200,
seed: int = 42,
max_dense_gib: float | None = None,
) -> dict
The JAX path uses the same exact-enumeration Hilbert-space boundary as the
NumPy RBM path; max_dense_gib gates the dense Hamiltonian and configuration
workspace before transferring arrays to JAX.
Returns:
{
"energy": float,
"energy_history": list[float],
"params": dict, # JAX parameter dict
"n_params": int,
}
Example¶
import numpy as np
from scpn_quantum_control.phase.jax_nqs import (
jax_vmc_ground_state, is_jax_available
)
if not is_jax_available():
raise ImportError("Install JAX: pip install jax jaxlib")
n = 8
K = 0.45 * np.exp(-0.3 * np.abs(np.subtract.outer(range(n), range(n))))
np.fill_diagonal(K, 0.0)
omega = np.linspace(0.8, 1.2, n)
result = jax_vmc_ground_state(K, omega, n_iterations=300, seed=42)
print(f"JAX VMC energy: {result['energy']:.6f}")
print(f"Parameters: {result['n_params']}")
Tutorial: Comparing NQS Methods¶
import numpy as np
from scpn_quantum_control.bridge.knm_hamiltonian import knm_to_dense_matrix
n = 6
K = 0.45 * np.exp(-0.3 * np.abs(np.subtract.outer(range(n), range(n))))
np.fill_diagonal(K, 0.0)
omega = np.linspace(0.8, 1.2, n)
# Exact reference
H = knm_to_dense_matrix(K, omega)
E_exact = np.linalg.eigvalsh(H)[0]
# NumPy RBM
from scpn_quantum_control.phase.nqs_ansatz import vmc_ground_state
result_np = vmc_ground_state(K, omega, n_iterations=300, seed=42)
# JAX RBM (if available)
from scpn_quantum_control.phase.jax_nqs import jax_vmc_ground_state, is_jax_available
if is_jax_available():
result_jax = jax_vmc_ground_state(K, omega, n_iterations=300, seed=42)
jax_energy = result_jax['energy']
else:
jax_energy = float('nan')
print(f"Exact: {E_exact:.6f}")
print(f"NumPy VMC: {result_np['energy']:.6f} (error: {abs(result_np['energy'] - E_exact):.2e})")
print(f"JAX VMC: {jax_energy:.6f}")
Comparison¶
| Feature | This module (numpy) | This module (JAX) | NetKet | PennyLane |
|---|---|---|---|---|
| RBM ansatz | Yes | Yes | Yes | No |
| Other ansätze | No | No | Many (RNN, GCN, etc.) | VQE circuits |
| Gradient method | Finite difference | Auto-diff (JAX) | Auto-diff (JAX) | Auto-diff |
| MCMC sampling | No (\(n \leq 12\)) | No (\(n \leq 12\)) | Yes | N/A |
| Max system size | 12 | 12 | 1000+ | Circuit-limited |
| GPU | No | Yes (JAX) | Yes (JAX) | Yes |
| Hamiltonian | XY (built-in) | XY (built-in) | Any | Any |
References¶
- Carleo, G. & Troyer, M. "Solving the quantum many-body problem with artificial neural networks." Science 355, 602 (2017).
- Vicentini, F. et al. "NetKet 3: Machine learning toolbox for many-body quantum systems." SoftwareX 17, 100933 (2022).
- Bradbury, J. et al. "JAX: composable transformations of Python+NumPy programs." (2018). http://github.com/google/jax
See Also¶
- Variational Methods — parameter-shift gradient rule
- GPU Batch VQE — parallel parameter scanning
- Tensor Networks — MPS/DMRG alternative