fusion_core/
vmec_interface.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SCPN Fusion Core — VMEC Interface
7//! Lightweight VMEC-compatible boundary-state wrapper.
8//!
9//! This module intentionally does not implement a full 3D force-balance solve.
10//! Instead it provides a deterministic interoperability lane for exchanging
11//! reduced Fourier boundary states with external VMEC-class workflows.
12
13use fusion_types::error::{FusionError, FusionResult};
14use ndarray::{Array1, Array2};
15use std::collections::HashSet;
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct VmecFourierMode {
19    pub m: i32,
20    pub n: i32,
21    pub r_cos: f64,
22    pub r_sin: f64,
23    pub z_cos: f64,
24    pub z_sin: f64,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub struct VmecBoundaryState {
29    pub r_axis: f64,
30    pub z_axis: f64,
31    pub a_minor: f64,
32    pub kappa: f64,
33    pub triangularity: f64,
34    pub nfp: usize,
35    pub modes: Vec<VmecFourierMode>,
36}
37
38impl VmecBoundaryState {
39    pub fn validate(&self) -> FusionResult<()> {
40        if !self.r_axis.is_finite()
41            || !self.z_axis.is_finite()
42            || !self.a_minor.is_finite()
43            || !self.kappa.is_finite()
44            || !self.triangularity.is_finite()
45        {
46            return Err(FusionError::PhysicsViolation(
47                "VMEC boundary contains non-finite scalar".to_string(),
48            ));
49        }
50        if self.a_minor <= 0.0 {
51            return Err(FusionError::PhysicsViolation(
52                "VMEC boundary requires a_minor > 0".to_string(),
53            ));
54        }
55        if self.kappa <= 0.0 {
56            return Err(FusionError::PhysicsViolation(
57                "VMEC boundary requires kappa > 0".to_string(),
58            ));
59        }
60        if self.nfp < 1 {
61            return Err(FusionError::PhysicsViolation(
62                "VMEC boundary requires nfp >= 1".to_string(),
63            ));
64        }
65        let mut seen_modes: HashSet<(i32, i32)> = HashSet::with_capacity(self.modes.len());
66        for (idx, mode) in self.modes.iter().enumerate() {
67            if mode.m < 0 {
68                return Err(FusionError::PhysicsViolation(format!(
69                    "VMEC boundary mode[{idx}] requires m >= 0, got {}",
70                    mode.m
71                )));
72            }
73            if !mode.r_cos.is_finite()
74                || !mode.r_sin.is_finite()
75                || !mode.z_cos.is_finite()
76                || !mode.z_sin.is_finite()
77            {
78                return Err(FusionError::PhysicsViolation(format!(
79                    "VMEC boundary mode[{idx}] contains non-finite coefficients"
80                )));
81            }
82            let key = (mode.m, mode.n);
83            if !seen_modes.insert(key) {
84                return Err(FusionError::PhysicsViolation(format!(
85                    "VMEC boundary contains duplicate mode (m={}, n={})",
86                    mode.m, mode.n
87                )));
88            }
89        }
90        Ok(())
91    }
92}
93
94pub fn export_vmec_like_text(state: &VmecBoundaryState) -> FusionResult<String> {
95    state.validate()?;
96    let mut out = String::new();
97    out.push_str("format=vmec_like_v1\n");
98    out.push_str(&format!("r_axis={:.16e}\n", state.r_axis));
99    out.push_str(&format!("z_axis={:.16e}\n", state.z_axis));
100    out.push_str(&format!("a_minor={:.16e}\n", state.a_minor));
101    out.push_str(&format!("kappa={:.16e}\n", state.kappa));
102    out.push_str(&format!("triangularity={:.16e}\n", state.triangularity));
103    out.push_str(&format!("nfp={}\n", state.nfp));
104    for mode in &state.modes {
105        out.push_str(&format!(
106            "mode,{},{},{:.16e},{:.16e},{:.16e},{:.16e}\n",
107            mode.m, mode.n, mode.r_cos, mode.r_sin, mode.z_cos, mode.z_sin
108        ));
109    }
110    Ok(out)
111}
112
113fn parse_float(key: &str, text: &str) -> FusionResult<f64> {
114    let val = text.parse::<f64>().map_err(|e| {
115        FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as float: {e}"))
116    })?;
117    if !val.is_finite() {
118        return Err(FusionError::PhysicsViolation(format!(
119            "VMEC key '{key}' must be finite, got {val}"
120        )));
121    }
122    Ok(val)
123}
124
125fn parse_int<T>(key: &str, text: &str) -> FusionResult<T>
126where
127    T: std::str::FromStr,
128    T::Err: std::fmt::Display,
129{
130    text.parse::<T>().map_err(|e| {
131        FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as integer: {e}"))
132    })
133}
134
135pub fn import_vmec_like_text(text: &str) -> FusionResult<VmecBoundaryState> {
136    let mut format_seen = false;
137    let mut r_axis: Option<f64> = None;
138    let mut z_axis: Option<f64> = None;
139    let mut a_minor: Option<f64> = None;
140    let mut kappa: Option<f64> = None;
141    let mut triangularity: Option<f64> = None;
142    let mut nfp: Option<usize> = None;
143    let mut modes: Vec<VmecFourierMode> = Vec::new();
144
145    for raw_line in text.lines() {
146        let line = raw_line.trim();
147        if line.is_empty() || line.starts_with('#') {
148            continue;
149        }
150        if let Some(format_name) = line.strip_prefix("format=") {
151            if format_seen {
152                return Err(FusionError::PhysicsViolation(
153                    "Duplicate VMEC key: format".to_string(),
154                ));
155            }
156            if format_name.trim() != "vmec_like_v1" {
157                return Err(FusionError::PhysicsViolation(format!(
158                    "Unsupported VMEC format: {}",
159                    format_name.trim()
160                )));
161            }
162            format_seen = true;
163            continue;
164        }
165        if let Some(rest) = line.strip_prefix("mode,") {
166            let cols: Vec<&str> = rest.split(',').map(|v| v.trim()).collect();
167            if cols.len() != 6 {
168                return Err(FusionError::PhysicsViolation(
169                    "VMEC mode line must contain exactly 6 columns".to_string(),
170                ));
171            }
172            modes.push(VmecFourierMode {
173                m: parse_int("mode.m", cols[0])?,
174                n: parse_int("mode.n", cols[1])?,
175                r_cos: parse_float("mode.r_cos", cols[2])?,
176                r_sin: parse_float("mode.r_sin", cols[3])?,
177                z_cos: parse_float("mode.z_cos", cols[4])?,
178                z_sin: parse_float("mode.z_sin", cols[5])?,
179            });
180            continue;
181        }
182        let (key, value) = line.split_once('=').ok_or_else(|| {
183            FusionError::PhysicsViolation(format!("Invalid VMEC line (missing '='): {line}"))
184        })?;
185        let key = key.trim();
186        let value = value.trim();
187        match key {
188            "r_axis" => {
189                if r_axis.is_some() {
190                    return Err(FusionError::PhysicsViolation(
191                        "Duplicate VMEC key: r_axis".to_string(),
192                    ));
193                }
194                r_axis = Some(parse_float(key, value)?);
195            }
196            "z_axis" => {
197                if z_axis.is_some() {
198                    return Err(FusionError::PhysicsViolation(
199                        "Duplicate VMEC key: z_axis".to_string(),
200                    ));
201                }
202                z_axis = Some(parse_float(key, value)?);
203            }
204            "a_minor" => {
205                if a_minor.is_some() {
206                    return Err(FusionError::PhysicsViolation(
207                        "Duplicate VMEC key: a_minor".to_string(),
208                    ));
209                }
210                a_minor = Some(parse_float(key, value)?);
211            }
212            "kappa" => {
213                if kappa.is_some() {
214                    return Err(FusionError::PhysicsViolation(
215                        "Duplicate VMEC key: kappa".to_string(),
216                    ));
217                }
218                kappa = Some(parse_float(key, value)?);
219            }
220            "triangularity" => {
221                if triangularity.is_some() {
222                    return Err(FusionError::PhysicsViolation(
223                        "Duplicate VMEC key: triangularity".to_string(),
224                    ));
225                }
226                triangularity = Some(parse_float(key, value)?);
227            }
228            "nfp" => {
229                if nfp.is_some() {
230                    return Err(FusionError::PhysicsViolation(
231                        "Duplicate VMEC key: nfp".to_string(),
232                    ));
233                }
234                nfp = Some(parse_int(key, value)?);
235            }
236            other => {
237                return Err(FusionError::PhysicsViolation(format!(
238                    "Unknown VMEC key: {other}"
239                )));
240            }
241        }
242    }
243
244    let state = VmecBoundaryState {
245        r_axis: r_axis
246            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: r_axis".to_string()))?,
247        z_axis: z_axis
248            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: z_axis".to_string()))?,
249        a_minor: a_minor.ok_or_else(|| {
250            FusionError::PhysicsViolation("Missing VMEC key: a_minor".to_string())
251        })?,
252        kappa: kappa
253            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: kappa".to_string()))?,
254        triangularity: triangularity.ok_or_else(|| {
255            FusionError::PhysicsViolation("Missing VMEC key: triangularity".to_string())
256        })?,
257        nfp: nfp
258            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: nfp".to_string()))?,
259        modes,
260    };
261    state.validate()?;
262    Ok(state)
263}
264
265// ═══════════════════════════════════════════════════════════════════════
266// VMEC-like Fixed-Boundary 3D Equilibrium Solver
267// ═══════════════════════════════════════════════════════════════════════
268//
269// Implements a variational equilibrium solver following Hirshman & Whitson
270// (1983). Given boundary Fourier modes, pressure p(s), and rotational
271// transform iota(s), finds force-balanced interior flux surface shapes via
272// steepest descent on the MHD energy functional.
273//
274// Stellarator symmetry: R uses cos(mθ − nNζ), Z uses sin(mθ − nNζ).
275
276/// Solver configuration for the VMEC fixed-boundary equilibrium.
277#[derive(Debug, Clone)]
278pub struct VmecSolverConfig {
279    /// Maximum poloidal mode number.
280    pub m_pol: usize,
281    /// Maximum toroidal mode number (0 = axisymmetric).
282    pub n_tor: usize,
283    /// Number of flux surfaces (radial, including axis + boundary).
284    pub ns: usize,
285    /// Poloidal angle grid points.
286    pub ntheta: usize,
287    /// Toroidal angle grid points per field period.
288    pub nzeta: usize,
289    /// Maximum steepest-descent iterations.
290    pub max_iter: usize,
291    /// Force residual convergence tolerance.
292    pub tol: f64,
293    /// Steepest descent step size.
294    pub step_size: f64,
295}
296
297impl Default for VmecSolverConfig {
298    fn default() -> Self {
299        Self {
300            m_pol: 6,
301            n_tor: 0,
302            ns: 25,
303            ntheta: 32,
304            nzeta: 1,
305            max_iter: 500,
306            tol: 1e-8,
307            step_size: 5e-3,
308        }
309    }
310}
311
312impl VmecSolverConfig {
313    pub fn validate(&self) -> FusionResult<()> {
314        if self.ns < 3 {
315            return Err(FusionError::PhysicsViolation(
316                "VMEC solver requires ns >= 3".into(),
317            ));
318        }
319        if self.ntheta < 8 {
320            return Err(FusionError::PhysicsViolation(
321                "VMEC solver requires ntheta >= 8".into(),
322            ));
323        }
324        if self.max_iter == 0 {
325            return Err(FusionError::PhysicsViolation(
326                "VMEC solver requires max_iter >= 1".into(),
327            ));
328        }
329        if !self.tol.is_finite() || self.tol <= 0.0 {
330            return Err(FusionError::PhysicsViolation(
331                "VMEC solver tol must be finite and > 0".into(),
332            ));
333        }
334        if !self.step_size.is_finite() || self.step_size <= 0.0 {
335            return Err(FusionError::PhysicsViolation(
336                "VMEC solver step_size must be finite and > 0".into(),
337            ));
338        }
339        Ok(())
340    }
341}
342
343/// Solution from the VMEC fixed-boundary solver.
344#[derive(Debug, Clone)]
345pub struct VmecEquilibrium {
346    /// R cosine Fourier coefficients per surface [ns × n_modes].
347    pub rmnc: Array2<f64>,
348    /// Z sine Fourier coefficients per surface [ns × n_modes].
349    pub zmns: Array2<f64>,
350    /// Rotational transform profile iota(s) [ns].
351    pub iota: Array1<f64>,
352    /// Pressure profile [Pa] [ns].
353    pub pressure: Array1<f64>,
354    /// Total toroidal flux [Wb].
355    pub phi_edge: f64,
356    /// Plasma volume [m³].
357    pub volume: f64,
358    /// Volume-averaged beta.
359    pub beta_avg: f64,
360    /// Final force residual norm.
361    pub force_residual: f64,
362    /// Number of iterations.
363    pub iterations: usize,
364    /// Whether the solver converged.
365    pub converged: bool,
366    /// Grid parameters.
367    pub ns_grid: usize,
368    pub m_pol: usize,
369    pub n_tor: usize,
370    pub nfp: usize,
371}
372
373/// Number of Fourier modes for given (m_pol, n_tor).
374pub fn vmec_n_modes(m_pol: usize, n_tor: usize) -> usize {
375    if n_tor == 0 {
376        m_pol + 1
377    } else {
378        (m_pol + 1) * (2 * n_tor + 1)
379    }
380}
381
382/// Flat index for mode (m, n). Returns None if out of range.
383pub fn vmec_mode_idx(m: usize, n: i32, m_pol: usize, n_tor: usize) -> Option<usize> {
384    if m > m_pol {
385        return None;
386    }
387    if n_tor == 0 {
388        if n != 0 {
389            return None;
390        }
391        Some(m)
392    } else {
393        let n_abs = n.unsigned_abs() as usize;
394        if n_abs > n_tor {
395            return None;
396        }
397        Some(m * (2 * n_tor + 1) + (n + n_tor as i32) as usize)
398    }
399}
400
401/// Evaluate R(θ,ζ) and Z(θ,ζ) from one surface's Fourier coefficients.
402fn eval_surface_point(
403    rmnc: &[f64],
404    zmns: &[f64],
405    theta: f64,
406    zeta: f64,
407    m_pol: usize,
408    n_tor: usize,
409    nfp: usize,
410) -> (f64, f64) {
411    let mut r = 0.0;
412    let mut z = 0.0;
413    if n_tor == 0 {
414        for m in 0..=m_pol {
415            let angle = m as f64 * theta;
416            let (sin_a, cos_a) = angle.sin_cos();
417            r += rmnc[m] * cos_a;
418            z += zmns[m] * sin_a;
419        }
420    } else {
421        for m in 0..=m_pol {
422            for nn in -(n_tor as i32)..=(n_tor as i32) {
423                let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
424                let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
425                let (sin_a, cos_a) = angle.sin_cos();
426                r += rmnc[idx] * cos_a;
427                z += zmns[idx] * sin_a;
428            }
429        }
430    }
431    (r, z)
432}
433
434/// Analytic ∂R/∂θ and ∂Z/∂θ from Fourier coefficients.
435fn eval_surface_deriv_theta(
436    rmnc: &[f64],
437    zmns: &[f64],
438    theta: f64,
439    zeta: f64,
440    m_pol: usize,
441    n_tor: usize,
442    nfp: usize,
443) -> (f64, f64) {
444    let mut dr = 0.0;
445    let mut dz = 0.0;
446    if n_tor == 0 {
447        for m in 0..=m_pol {
448            let mf = m as f64;
449            let angle = mf * theta;
450            let (sin_a, cos_a) = angle.sin_cos();
451            dr -= mf * rmnc[m] * sin_a;
452            dz += mf * zmns[m] * cos_a;
453        }
454    } else {
455        for m in 0..=m_pol {
456            let mf = m as f64;
457            for nn in -(n_tor as i32)..=(n_tor as i32) {
458                let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
459                let angle = mf * theta - nn as f64 * nfp as f64 * zeta;
460                let (sin_a, cos_a) = angle.sin_cos();
461                dr -= mf * rmnc[idx] * sin_a;
462                dz += mf * zmns[idx] * cos_a;
463            }
464        }
465    }
466    (dr, dz)
467}
468
469/// Solve VMEC fixed-boundary 3D equilibrium.
470///
471/// Given boundary shape (from `VmecBoundaryState`), pressure and rotational
472/// transform profiles, finds force-balanced interior flux surface shapes.
473///
474/// The algorithm minimises the MHD energy functional W = ∫(B²/2μ₀ + p)dV
475/// via steepest descent on the interior Fourier coefficients R_mn^c(s),
476/// Z_mn^s(s) while holding the axis and boundary fixed.
477pub fn vmec_fixed_boundary_solve(
478    boundary: &VmecBoundaryState,
479    config: &VmecSolverConfig,
480    pressure: &[f64],
481    iota: &[f64],
482    phi_edge: f64,
483) -> FusionResult<VmecEquilibrium> {
484    boundary.validate()?;
485    config.validate()?;
486
487    let ns = config.ns;
488    let m_pol = config.m_pol;
489    let n_tor = config.n_tor;
490    let nfp = boundary.nfp;
491    let nmodes = vmec_n_modes(m_pol, n_tor);
492    let ntheta = config.ntheta;
493    let nzeta = if n_tor == 0 { 1 } else { config.nzeta.max(4) };
494
495    if pressure.len() != ns || iota.len() != ns {
496        return Err(FusionError::PhysicsViolation(format!(
497            "pressure/iota length must match ns={ns}, got p={} iota={}",
498            pressure.len(),
499            iota.len()
500        )));
501    }
502    if pressure.iter().any(|v| !v.is_finite()) || iota.iter().any(|v| !v.is_finite()) {
503        return Err(FusionError::PhysicsViolation(
504            "pressure/iota profiles must be finite".into(),
505        ));
506    }
507    if !phi_edge.is_finite() || phi_edge <= 0.0 {
508        return Err(FusionError::PhysicsViolation(
509            "phi_edge must be finite and > 0".into(),
510        ));
511    }
512
513    // Build boundary Fourier coefficients from VmecBoundaryState
514    let mut bnd_rmnc = vec![0.0; nmodes];
515    let mut bnd_zmns = vec![0.0; nmodes];
516
517    if let Some(idx) = vmec_mode_idx(0, 0, m_pol, n_tor) {
518        bnd_rmnc[idx] = boundary.r_axis;
519    }
520    if let Some(idx) = vmec_mode_idx(1, 0, m_pol, n_tor) {
521        bnd_rmnc[idx] = boundary.a_minor;
522        bnd_zmns[idx] = boundary.a_minor * boundary.kappa;
523    }
524    if m_pol >= 2 {
525        if let Some(idx) = vmec_mode_idx(2, 0, m_pol, n_tor) {
526            bnd_rmnc[idx] = -boundary.triangularity * boundary.a_minor * 0.5;
527        }
528    }
529    for mode in &boundary.modes {
530        let m = mode.m as usize;
531        if let Some(idx) = vmec_mode_idx(m, mode.n, m_pol, n_tor) {
532            bnd_rmnc[idx] += mode.r_cos;
533            bnd_zmns[idx] += mode.z_sin;
534        }
535    }
536
537    // Initialize: linear interpolation from axis to boundary
538    let mut rmnc = Array2::zeros((ns, nmodes));
539    let mut zmns = Array2::zeros((ns, nmodes));
540    let s_grid: Vec<f64> = (0..ns).map(|i| i as f64 / (ns - 1) as f64).collect();
541
542    for js in 0..ns {
543        let s = s_grid[js];
544        for k in 0..nmodes {
545            let axis_r = if k == vmec_mode_idx(0, 0, m_pol, n_tor).unwrap_or(usize::MAX) {
546                boundary.r_axis
547            } else {
548                0.0
549            };
550            rmnc[[js, k]] = axis_r * (1.0 - s) + bnd_rmnc[k] * s;
551            zmns[[js, k]] = bnd_zmns[k] * s;
552        }
553    }
554
555    let pi = std::f64::consts::PI;
556    let mu0 = 4.0e-7 * pi;
557    let ds = 1.0 / (ns - 1) as f64;
558    let dtheta = 2.0 * pi / ntheta as f64;
559    let dzeta = 2.0 * pi / (nzeta as f64 * nfp.max(1) as f64);
560    let n_angle_pts = (ntheta * nzeta) as f64;
561
562    // Precompute angles
563    let theta_arr: Vec<f64> = (0..ntheta)
564        .map(|i| 2.0 * pi * i as f64 / ntheta as f64)
565        .collect();
566    let zeta_arr: Vec<f64> = (0..nzeta)
567        .map(|i| 2.0 * pi * i as f64 / (nzeta as f64 * nfp.max(1) as f64))
568        .collect();
569
570    // Steepest descent iteration
571    let mut global_force = f64::MAX;
572    let mut converged = false;
573    let mut iter_count = 0;
574    let mut total_volume = 0.0;
575    let mut total_beta_num = 0.0;
576    let mut total_b2_vol = 0.0;
577
578    for iteration in 0..config.max_iter {
579        let mut force_sq_sum = 0.0;
580        total_volume = 0.0;
581        total_beta_num = 0.0;
582        total_b2_vol = 0.0;
583
584        let mut grad_r: Array2<f64> = Array2::zeros((ns, nmodes));
585        let mut grad_z: Array2<f64> = Array2::zeros((ns, nmodes));
586
587        // Evaluate force on each interior surface
588        for js in 1..(ns - 1) {
589            let p = pressure[js];
590            let iota_s = iota[js];
591            let dp_ds =
592                (pressure[(js + 1).min(ns - 1)] - pressure[js.saturating_sub(1)]) / (2.0 * ds);
593
594            let rmnc_s: Vec<f64> = (0..nmodes).map(|k| rmnc[[js, k]]).collect();
595            let zmns_s: Vec<f64> = (0..nmodes).map(|k| zmns[[js, k]]).collect();
596            let rmnc_p: Vec<f64> = (0..nmodes).map(|k| rmnc[[js + 1, k]]).collect();
597            let zmns_p: Vec<f64> = (0..nmodes).map(|k| zmns[[js + 1, k]]).collect();
598            let rmnc_m: Vec<f64> = (0..nmodes).map(|k| rmnc[[js - 1, k]]).collect();
599            let zmns_m: Vec<f64> = (0..nmodes).map(|k| zmns[[js - 1, k]]).collect();
600
601            for &theta in theta_arr.iter().take(ntheta) {
602                for &zeta in zeta_arr.iter().take(nzeta) {
603                    let (r_val, _z_val) =
604                        eval_surface_point(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
605                    let (r_p, _) =
606                        eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
607                    let (r_m, _) =
608                        eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
609                    let (_, z_p) =
610                        eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
611                    let (_, z_m) =
612                        eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
613
614                    let r_s = (r_p - r_m) / (2.0 * ds);
615                    let z_s = (z_p - z_m) / (2.0 * ds);
616
617                    let (r_theta, z_theta) =
618                        eval_surface_deriv_theta(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
619
620                    // Jacobian: √g = R · (R_s Z_θ − R_θ Z_s)
621                    let jac = r_val * (r_s * z_theta - r_theta * z_s);
622                    let jac_abs = jac.abs().max(1e-20);
623
624                    // B-field: B^ζ = Φ'/(2π√g), B^θ = ι·B^ζ
625                    let b_zeta = phi_edge / (2.0 * pi * jac_abs);
626                    let b_theta = iota_s * b_zeta;
627
628                    // |B|² via covariant metric
629                    let g_tt = r_theta * r_theta + z_theta * z_theta;
630                    let g_zz = r_val * r_val;
631                    let b_sq = b_theta * b_theta * g_tt + b_zeta * b_zeta * g_zz;
632
633                    // Force residual: F_s = dp/ds + d(B²/2μ₀)/ds
634                    let force_s = dp_ds + b_sq / (2.0 * mu0 * jac_abs.max(1e-10));
635                    force_sq_sum += force_s * force_s;
636
637                    let dvol = jac_abs * dtheta * dzeta * ds;
638                    total_volume += dvol;
639                    total_beta_num += p * dvol;
640                    total_b2_vol += b_sq * dvol;
641
642                    // Accumulate gradient for steepest descent
643                    if n_tor == 0 {
644                        for m in 0..=m_pol {
645                            let angle = m as f64 * theta;
646                            let (sin_a, cos_a) = angle.sin_cos();
647                            grad_r[[js, m]] += force_s * cos_a * dtheta * dzeta;
648                            grad_z[[js, m]] += force_s * sin_a * dtheta * dzeta;
649                        }
650                    } else {
651                        for m in 0..=m_pol {
652                            for nn in -(n_tor as i32)..=(n_tor as i32) {
653                                let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
654                                let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
655                                let (sin_a, cos_a) = angle.sin_cos();
656                                grad_r[[js, idx]] += force_s * cos_a * dtheta * dzeta;
657                                grad_z[[js, idx]] += force_s * sin_a * dtheta * dzeta;
658                            }
659                        }
660                    }
661                }
662            }
663        }
664
665        let n_interior = ((ns - 2) as f64 * n_angle_pts).max(1.0);
666        global_force = (force_sq_sum / n_interior).sqrt();
667
668        // Update interior surfaces
669        for js in 1..(ns - 1) {
670            for k in 0..nmodes {
671                rmnc[[js, k]] -= config.step_size * grad_r[[js, k]] / n_angle_pts.max(1.0);
672                zmns[[js, k]] -= config.step_size * grad_z[[js, k]] / n_angle_pts.max(1.0);
673            }
674        }
675
676        iter_count = iteration + 1;
677        if global_force < config.tol {
678            converged = true;
679            break;
680        }
681        if !global_force.is_finite() {
682            return Err(FusionError::PhysicsViolation(format!(
683                "VMEC solver diverged at iteration {}: force={}",
684                iteration, global_force
685            )));
686        }
687    }
688
689    let beta_avg = if total_b2_vol > 0.0 {
690        2.0 * mu0 * total_beta_num / total_b2_vol
691    } else {
692        0.0
693    };
694
695    Ok(VmecEquilibrium {
696        rmnc,
697        zmns,
698        iota: Array1::from_vec(iota.to_vec()),
699        pressure: Array1::from_vec(pressure.to_vec()),
700        phi_edge,
701        volume: total_volume.abs(),
702        beta_avg,
703        force_residual: global_force,
704        iterations: iter_count,
705        converged,
706        ns_grid: ns,
707        m_pol,
708        n_tor,
709        nfp,
710    })
711}
712
713/// Evaluate the equilibrium geometry at given (s, θ, ζ) coordinates.
714pub fn vmec_eval_geometry(
715    eq: &VmecEquilibrium,
716    s_idx: usize,
717    theta: f64,
718    zeta: f64,
719) -> FusionResult<(f64, f64)> {
720    if s_idx >= eq.ns_grid {
721        return Err(FusionError::PhysicsViolation(format!(
722            "Surface index {} >= ns_grid {}",
723            s_idx, eq.ns_grid
724        )));
725    }
726    let nmodes = vmec_n_modes(eq.m_pol, eq.n_tor);
727    let rmnc_row: Vec<f64> = (0..nmodes).map(|k| eq.rmnc[[s_idx, k]]).collect();
728    let zmns_row: Vec<f64> = (0..nmodes).map(|k| eq.zmns[[s_idx, k]]).collect();
729    Ok(eval_surface_point(
730        &rmnc_row, &zmns_row, theta, zeta, eq.m_pol, eq.n_tor, eq.nfp,
731    ))
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737
738    fn sample_state() -> VmecBoundaryState {
739        VmecBoundaryState {
740            r_axis: 6.2,
741            z_axis: 0.0,
742            a_minor: 2.0,
743            kappa: 1.7,
744            triangularity: 0.33,
745            nfp: 1,
746            modes: vec![
747                VmecFourierMode {
748                    m: 1,
749                    n: 1,
750                    r_cos: 0.03,
751                    r_sin: -0.01,
752                    z_cos: 0.0,
753                    z_sin: 0.02,
754                },
755                VmecFourierMode {
756                    m: 2,
757                    n: 1,
758                    r_cos: 0.015,
759                    r_sin: 0.0,
760                    z_cos: 0.0,
761                    z_sin: 0.008,
762                },
763            ],
764        }
765    }
766
767    #[test]
768    fn test_vmec_text_roundtrip() {
769        let state = sample_state();
770        let text = export_vmec_like_text(&state).expect("export must succeed");
771        let parsed = import_vmec_like_text(&text).expect("import must succeed");
772        assert_eq!(parsed.nfp, state.nfp);
773        assert_eq!(parsed.modes.len(), state.modes.len());
774        assert!((parsed.r_axis - state.r_axis).abs() < 1e-12);
775        assert!((parsed.kappa - state.kappa).abs() < 1e-12);
776    }
777
778    #[test]
779    fn test_vmec_missing_key_errors() {
780        let text = "format=vmec_like_v1\nr_axis=6.2\n";
781        let err = import_vmec_like_text(text).expect_err("missing keys should error");
782        match err {
783            FusionError::PhysicsViolation(msg) => assert!(msg.contains("Missing VMEC key")),
784            other => panic!("Unexpected error: {other:?}"),
785        }
786    }
787
788    #[test]
789    fn test_vmec_invalid_minor_radius_errors() {
790        let text = "\
791format=vmec_like_v1
792r_axis=6.2
793z_axis=0.0
794a_minor=0.0
795kappa=1.7
796triangularity=0.2
797nfp=1
798";
799        let err = import_vmec_like_text(text).expect_err("a_minor=0 must fail");
800        match err {
801            FusionError::PhysicsViolation(msg) => assert!(msg.contains("a_minor")),
802            other => panic!("Unexpected error: {other:?}"),
803        }
804    }
805
806    #[test]
807    fn test_vmec_rejects_invalid_modes_and_duplicates() {
808        let mut state = sample_state();
809        state.modes[0].r_cos = f64::NAN;
810        let err = export_vmec_like_text(&state).expect_err("non-finite mode coeff must fail");
811        match err {
812            FusionError::PhysicsViolation(msg) => assert!(msg.contains("mode")),
813            other => panic!("Unexpected error: {other:?}"),
814        }
815
816        let dup_text = "\
817format=vmec_like_v1
818r_axis=6.2
819r_axis=6.3
820z_axis=0.0
821a_minor=2.0
822kappa=1.7
823triangularity=0.2
824nfp=1
825";
826        let err = import_vmec_like_text(dup_text).expect_err("duplicate keys must fail");
827        match err {
828            FusionError::PhysicsViolation(msg) => assert!(msg.contains("Duplicate VMEC key")),
829            other => panic!("Unexpected error: {other:?}"),
830        }
831    }
832
833    #[test]
834    fn test_vmec_rejects_duplicate_modes_and_bad_format() {
835        let mut state = sample_state();
836        state.modes.push(VmecFourierMode {
837            m: state.modes[0].m,
838            n: state.modes[0].n,
839            r_cos: 0.001,
840            r_sin: 0.0,
841            z_cos: 0.0,
842            z_sin: 0.001,
843        });
844        let err = export_vmec_like_text(&state).expect_err("duplicate mode index must fail");
845        match err {
846            FusionError::PhysicsViolation(msg) => assert!(msg.contains("duplicate mode")),
847            other => panic!("Unexpected error: {other:?}"),
848        }
849
850        let bad_format = "\
851format=vmec_like_v2
852r_axis=6.2
853z_axis=0.0
854a_minor=2.0
855kappa=1.7
856triangularity=0.2
857nfp=1
858";
859        let err = import_vmec_like_text(bad_format).expect_err("unsupported format must fail");
860        match err {
861            FusionError::PhysicsViolation(msg) => assert!(msg.contains("Unsupported VMEC format")),
862            other => panic!("Unexpected error: {other:?}"),
863        }
864    }
865
866    #[test]
867    fn test_vmec_rejects_malformed_mode_and_duplicate_format() {
868        let malformed_mode = "\
869format=vmec_like_v1
870r_axis=6.2
871z_axis=0.0
872a_minor=2.0
873kappa=1.7
874triangularity=0.2
875nfp=1
876mode,1,1,0.1,0.0,0.2
877";
878        let err =
879            import_vmec_like_text(malformed_mode).expect_err("mode with wrong arity must fail");
880        match err {
881            FusionError::PhysicsViolation(msg) => assert!(msg.contains("exactly 6 columns")),
882            other => panic!("Unexpected error: {other:?}"),
883        }
884
885        let duplicate_format = "\
886format=vmec_like_v1
887format=vmec_like_v1
888r_axis=6.2
889z_axis=0.0
890a_minor=2.0
891kappa=1.7
892triangularity=0.2
893nfp=1
894";
895        let err =
896            import_vmec_like_text(duplicate_format).expect_err("duplicate format key must fail");
897        match err {
898            FusionError::PhysicsViolation(msg) => {
899                assert!(msg.contains("Duplicate VMEC key: format"))
900            }
901            other => panic!("Unexpected error: {other:?}"),
902        }
903    }
904
905    // === VMEC Solver Tests ===
906
907    fn iter_like_boundary() -> VmecBoundaryState {
908        VmecBoundaryState {
909            r_axis: 6.2,
910            z_axis: 0.0,
911            a_minor: 2.0,
912            kappa: 1.7,
913            triangularity: 0.33,
914            nfp: 1,
915            modes: vec![],
916        }
917    }
918
919    #[test]
920    fn test_vmec_mode_indexing() {
921        assert_eq!(vmec_n_modes(6, 0), 7);
922        assert_eq!(vmec_n_modes(3, 2), 4 * 5);
923        assert_eq!(vmec_mode_idx(0, 0, 6, 0), Some(0));
924        assert_eq!(vmec_mode_idx(3, 0, 6, 0), Some(3));
925        assert_eq!(vmec_mode_idx(0, 1, 6, 0), None);
926        assert_eq!(vmec_mode_idx(7, 0, 6, 0), None);
927        assert_eq!(vmec_mode_idx(1, -1, 3, 2), Some(6));
928    }
929
930    #[test]
931    fn test_vmec_eval_surface_axis_is_circle() {
932        let m_pol = 4;
933        let n_tor = 0;
934        let nmodes = vmec_n_modes(m_pol, n_tor);
935        let mut rmnc = vec![0.0; nmodes];
936        let mut zmns = vec![0.0; nmodes];
937        rmnc[0] = 6.2; // R_00 = major radius
938        rmnc[1] = 2.0; // R_10 = minor radius
939        zmns[1] = 2.0; // Z_10 = minor radius (circular cross-section)
940
941        let (r, z) = eval_surface_point(&rmnc, &zmns, 0.0, 0.0, m_pol, n_tor, 1);
942        assert!((r - 8.2).abs() < 1e-12, "outboard midplane R");
943        assert!(z.abs() < 1e-12, "midplane Z");
944
945        let (r2, z2) = eval_surface_point(
946            &rmnc,
947            &zmns,
948            std::f64::consts::FRAC_PI_2,
949            0.0,
950            m_pol,
951            n_tor,
952            1,
953        );
954        assert!((r2 - 6.2).abs() < 1e-12, "top R = R_axis");
955        assert!((z2 - 2.0).abs() < 1e-12, "top Z = a_minor");
956    }
957
958    #[test]
959    fn test_vmec_solver_runs_axisymmetric() {
960        let boundary = iter_like_boundary();
961        let config = VmecSolverConfig {
962            m_pol: 4,
963            n_tor: 0,
964            ns: 11,
965            ntheta: 16,
966            nzeta: 1,
967            max_iter: 50,
968            tol: 1e-6,
969            step_size: 1e-3,
970        };
971        let ns = config.ns;
972        let pressure: Vec<f64> = (0..ns)
973            .map(|i| {
974                let s = i as f64 / (ns - 1) as f64;
975                1e5 * (1.0 - s * s)
976            })
977            .collect();
978        let iota: Vec<f64> = (0..ns)
979            .map(|i| {
980                let s = i as f64 / (ns - 1) as f64;
981                0.3 + 0.7 * s * s
982            })
983            .collect();
984
985        let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0)
986            .expect("VMEC solve should succeed");
987
988        assert!(eq.iterations > 0);
989        assert!(eq.force_residual.is_finite());
990        assert!(eq.volume > 0.0, "Volume must be positive");
991        assert_eq!(eq.ns_grid, ns);
992        assert_eq!(eq.m_pol, 4);
993        assert_eq!(eq.nfp, 1);
994    }
995
996    #[test]
997    fn test_vmec_solver_boundary_preserved() {
998        let boundary = iter_like_boundary();
999        let config = VmecSolverConfig {
1000            m_pol: 4,
1001            n_tor: 0,
1002            ns: 11,
1003            ntheta: 16,
1004            nzeta: 1,
1005            max_iter: 10,
1006            tol: 1e-12,
1007            step_size: 1e-3,
1008        };
1009        let ns = config.ns;
1010        let pressure: Vec<f64> = (0..ns).map(|_| 1e4).collect();
1011        let iota: Vec<f64> = (0..ns).map(|_| 0.5).collect();
1012
1013        let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0).unwrap();
1014
1015        // Boundary (last surface) R_10 should be a_minor = 2.0
1016        let r10_bnd = eq.rmnc[[ns - 1, 1]];
1017        assert!(
1018            (r10_bnd - 2.0).abs() < 1e-10,
1019            "Boundary R_10 must be preserved: got {r10_bnd}"
1020        );
1021    }
1022
1023    #[test]
1024    fn test_vmec_solver_rejects_invalid_inputs() {
1025        let boundary = iter_like_boundary();
1026        let config = VmecSolverConfig::default();
1027        let ns = config.ns;
1028        let p: Vec<f64> = vec![0.0; ns];
1029        let iota: Vec<f64> = vec![0.5; ns];
1030
1031        // Invalid phi_edge
1032        assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, -1.0).is_err());
1033        assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, f64::NAN).is_err());
1034
1035        // Wrong profile length
1036        let short_p: Vec<f64> = vec![0.0; 3];
1037        assert!(vmec_fixed_boundary_solve(&boundary, &config, &short_p, &iota, 1.0).is_err());
1038
1039        // Invalid config
1040        let bad_cfg = VmecSolverConfig {
1041            ns: 2,
1042            ..Default::default()
1043        };
1044        assert!(bad_cfg.validate().is_err());
1045    }
1046
1047    #[test]
1048    fn test_vmec_eval_geometry_valid() {
1049        let boundary = iter_like_boundary();
1050        let config = VmecSolverConfig {
1051            m_pol: 4,
1052            n_tor: 0,
1053            ns: 7,
1054            ntheta: 16,
1055            nzeta: 1,
1056            max_iter: 5,
1057            tol: 1e-12,
1058            step_size: 1e-3,
1059        };
1060        let ns = config.ns;
1061        let p: Vec<f64> = (0..ns)
1062            .map(|i| 1e4 * (1.0 - (i as f64 / (ns - 1) as f64)))
1063            .collect();
1064        let iota: Vec<f64> = vec![0.5; ns];
1065
1066        let eq = vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, 1.0).unwrap();
1067        let (r, z) = vmec_eval_geometry(&eq, ns - 1, 0.0, 0.0).unwrap();
1068        assert!(r > 0.0 && r.is_finite());
1069        assert!(z.is_finite());
1070
1071        // Out of bounds
1072        assert!(vmec_eval_geometry(&eq, ns + 5, 0.0, 0.0).is_err());
1073    }
1074
1075    #[test]
1076    fn test_vmec_deriv_theta_consistency() {
1077        let m_pol = 3;
1078        let n_tor = 0;
1079        let _nmodes = vmec_n_modes(m_pol, n_tor);
1080        let rmnc: Vec<f64> = vec![6.2, 2.0, 0.1, 0.05];
1081        let zmns: Vec<f64> = vec![0.0, 1.8, 0.08, 0.03];
1082        let theta = 1.0;
1083        let eps = 1e-7;
1084
1085        let (r_p, z_p) = eval_surface_point(&rmnc, &zmns, theta + eps, 0.0, m_pol, n_tor, 1);
1086        let (r_m, z_m) = eval_surface_point(&rmnc, &zmns, theta - eps, 0.0, m_pol, n_tor, 1);
1087        let dr_num = (r_p - r_m) / (2.0 * eps);
1088        let dz_num = (z_p - z_m) / (2.0 * eps);
1089
1090        let (dr_ana, dz_ana) = eval_surface_deriv_theta(&rmnc, &zmns, theta, 0.0, m_pol, n_tor, 1);
1091
1092        assert!(
1093            (dr_ana - dr_num).abs() < 1e-5,
1094            "dR/dθ mismatch: analytic={dr_ana}, numerical={dr_num}"
1095        );
1096        assert!(
1097            (dz_ana - dz_num).abs() < 1e-5,
1098            "dZ/dθ mismatch: analytic={dz_ana}, numerical={dz_num}"
1099        );
1100    }
1101}