fusion_core/
vmec_interface.rs

1// ─────────────────────────────────────────────────────────────────────
2// SCPN Fusion Core — VMEC Interface
3// © 1998–2026 Miroslav Šotek. All rights reserved.
4// Contact: www.anulum.li | protoscience@anulum.li
5// ORCID: https://orcid.org/0009-0009-3560-0851
6// License: GNU AGPL v3 | Commercial licensing available
7// ─────────────────────────────────────────────────────────────────────
8//! Lightweight VMEC-compatible boundary-state wrapper.
9//!
10//! This module intentionally does not implement a full 3D force-balance solve.
11//! Instead it provides a deterministic interoperability lane for exchanging
12//! reduced Fourier boundary states with external VMEC-class workflows.
13
14use fusion_types::error::{FusionError, FusionResult};
15use ndarray::{Array1, Array2};
16use std::collections::HashSet;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct VmecFourierMode {
20    pub m: i32,
21    pub n: i32,
22    pub r_cos: f64,
23    pub r_sin: f64,
24    pub z_cos: f64,
25    pub z_sin: f64,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct VmecBoundaryState {
30    pub r_axis: f64,
31    pub z_axis: f64,
32    pub a_minor: f64,
33    pub kappa: f64,
34    pub triangularity: f64,
35    pub nfp: usize,
36    pub modes: Vec<VmecFourierMode>,
37}
38
39impl VmecBoundaryState {
40    pub fn validate(&self) -> FusionResult<()> {
41        if !self.r_axis.is_finite()
42            || !self.z_axis.is_finite()
43            || !self.a_minor.is_finite()
44            || !self.kappa.is_finite()
45            || !self.triangularity.is_finite()
46        {
47            return Err(FusionError::PhysicsViolation(
48                "VMEC boundary contains non-finite scalar".to_string(),
49            ));
50        }
51        if self.a_minor <= 0.0 {
52            return Err(FusionError::PhysicsViolation(
53                "VMEC boundary requires a_minor > 0".to_string(),
54            ));
55        }
56        if self.kappa <= 0.0 {
57            return Err(FusionError::PhysicsViolation(
58                "VMEC boundary requires kappa > 0".to_string(),
59            ));
60        }
61        if self.nfp < 1 {
62            return Err(FusionError::PhysicsViolation(
63                "VMEC boundary requires nfp >= 1".to_string(),
64            ));
65        }
66        let mut seen_modes: HashSet<(i32, i32)> = HashSet::with_capacity(self.modes.len());
67        for (idx, mode) in self.modes.iter().enumerate() {
68            if mode.m < 0 {
69                return Err(FusionError::PhysicsViolation(format!(
70                    "VMEC boundary mode[{idx}] requires m >= 0, got {}",
71                    mode.m
72                )));
73            }
74            if !mode.r_cos.is_finite()
75                || !mode.r_sin.is_finite()
76                || !mode.z_cos.is_finite()
77                || !mode.z_sin.is_finite()
78            {
79                return Err(FusionError::PhysicsViolation(format!(
80                    "VMEC boundary mode[{idx}] contains non-finite coefficients"
81                )));
82            }
83            let key = (mode.m, mode.n);
84            if !seen_modes.insert(key) {
85                return Err(FusionError::PhysicsViolation(format!(
86                    "VMEC boundary contains duplicate mode (m={}, n={})",
87                    mode.m, mode.n
88                )));
89            }
90        }
91        Ok(())
92    }
93}
94
95pub fn export_vmec_like_text(state: &VmecBoundaryState) -> FusionResult<String> {
96    state.validate()?;
97    let mut out = String::new();
98    out.push_str("format=vmec_like_v1\n");
99    out.push_str(&format!("r_axis={:.16e}\n", state.r_axis));
100    out.push_str(&format!("z_axis={:.16e}\n", state.z_axis));
101    out.push_str(&format!("a_minor={:.16e}\n", state.a_minor));
102    out.push_str(&format!("kappa={:.16e}\n", state.kappa));
103    out.push_str(&format!("triangularity={:.16e}\n", state.triangularity));
104    out.push_str(&format!("nfp={}\n", state.nfp));
105    for mode in &state.modes {
106        out.push_str(&format!(
107            "mode,{},{},{:.16e},{:.16e},{:.16e},{:.16e}\n",
108            mode.m, mode.n, mode.r_cos, mode.r_sin, mode.z_cos, mode.z_sin
109        ));
110    }
111    Ok(out)
112}
113
114fn parse_float(key: &str, text: &str) -> FusionResult<f64> {
115    let val = text.parse::<f64>().map_err(|e| {
116        FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as float: {e}"))
117    })?;
118    if !val.is_finite() {
119        return Err(FusionError::PhysicsViolation(format!(
120            "VMEC key '{key}' must be finite, got {val}"
121        )));
122    }
123    Ok(val)
124}
125
126fn parse_int<T>(key: &str, text: &str) -> FusionResult<T>
127where
128    T: std::str::FromStr,
129    T::Err: std::fmt::Display,
130{
131    text.parse::<T>().map_err(|e| {
132        FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as integer: {e}"))
133    })
134}
135
136pub fn import_vmec_like_text(text: &str) -> FusionResult<VmecBoundaryState> {
137    let mut format_seen = false;
138    let mut r_axis: Option<f64> = None;
139    let mut z_axis: Option<f64> = None;
140    let mut a_minor: Option<f64> = None;
141    let mut kappa: Option<f64> = None;
142    let mut triangularity: Option<f64> = None;
143    let mut nfp: Option<usize> = None;
144    let mut modes: Vec<VmecFourierMode> = Vec::new();
145
146    for raw_line in text.lines() {
147        let line = raw_line.trim();
148        if line.is_empty() || line.starts_with('#') {
149            continue;
150        }
151        if let Some(format_name) = line.strip_prefix("format=") {
152            if format_seen {
153                return Err(FusionError::PhysicsViolation(
154                    "Duplicate VMEC key: format".to_string(),
155                ));
156            }
157            if format_name.trim() != "vmec_like_v1" {
158                return Err(FusionError::PhysicsViolation(format!(
159                    "Unsupported VMEC format: {}",
160                    format_name.trim()
161                )));
162            }
163            format_seen = true;
164            continue;
165        }
166        if let Some(rest) = line.strip_prefix("mode,") {
167            let cols: Vec<&str> = rest.split(',').map(|v| v.trim()).collect();
168            if cols.len() != 6 {
169                return Err(FusionError::PhysicsViolation(
170                    "VMEC mode line must contain exactly 6 columns".to_string(),
171                ));
172            }
173            modes.push(VmecFourierMode {
174                m: parse_int("mode.m", cols[0])?,
175                n: parse_int("mode.n", cols[1])?,
176                r_cos: parse_float("mode.r_cos", cols[2])?,
177                r_sin: parse_float("mode.r_sin", cols[3])?,
178                z_cos: parse_float("mode.z_cos", cols[4])?,
179                z_sin: parse_float("mode.z_sin", cols[5])?,
180            });
181            continue;
182        }
183        let (key, value) = line.split_once('=').ok_or_else(|| {
184            FusionError::PhysicsViolation(format!("Invalid VMEC line (missing '='): {line}"))
185        })?;
186        let key = key.trim();
187        let value = value.trim();
188        match key {
189            "r_axis" => {
190                if r_axis.is_some() {
191                    return Err(FusionError::PhysicsViolation(
192                        "Duplicate VMEC key: r_axis".to_string(),
193                    ));
194                }
195                r_axis = Some(parse_float(key, value)?);
196            }
197            "z_axis" => {
198                if z_axis.is_some() {
199                    return Err(FusionError::PhysicsViolation(
200                        "Duplicate VMEC key: z_axis".to_string(),
201                    ));
202                }
203                z_axis = Some(parse_float(key, value)?);
204            }
205            "a_minor" => {
206                if a_minor.is_some() {
207                    return Err(FusionError::PhysicsViolation(
208                        "Duplicate VMEC key: a_minor".to_string(),
209                    ));
210                }
211                a_minor = Some(parse_float(key, value)?);
212            }
213            "kappa" => {
214                if kappa.is_some() {
215                    return Err(FusionError::PhysicsViolation(
216                        "Duplicate VMEC key: kappa".to_string(),
217                    ));
218                }
219                kappa = Some(parse_float(key, value)?);
220            }
221            "triangularity" => {
222                if triangularity.is_some() {
223                    return Err(FusionError::PhysicsViolation(
224                        "Duplicate VMEC key: triangularity".to_string(),
225                    ));
226                }
227                triangularity = Some(parse_float(key, value)?);
228            }
229            "nfp" => {
230                if nfp.is_some() {
231                    return Err(FusionError::PhysicsViolation(
232                        "Duplicate VMEC key: nfp".to_string(),
233                    ));
234                }
235                nfp = Some(parse_int(key, value)?);
236            }
237            other => {
238                return Err(FusionError::PhysicsViolation(format!(
239                    "Unknown VMEC key: {other}"
240                )));
241            }
242        }
243    }
244
245    let state = VmecBoundaryState {
246        r_axis: r_axis
247            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: r_axis".to_string()))?,
248        z_axis: z_axis
249            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: z_axis".to_string()))?,
250        a_minor: a_minor.ok_or_else(|| {
251            FusionError::PhysicsViolation("Missing VMEC key: a_minor".to_string())
252        })?,
253        kappa: kappa
254            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: kappa".to_string()))?,
255        triangularity: triangularity.ok_or_else(|| {
256            FusionError::PhysicsViolation("Missing VMEC key: triangularity".to_string())
257        })?,
258        nfp: nfp
259            .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: nfp".to_string()))?,
260        modes,
261    };
262    state.validate()?;
263    Ok(state)
264}
265
266// ═══════════════════════════════════════════════════════════════════════
267// VMEC-like Fixed-Boundary 3D Equilibrium Solver
268// ═══════════════════════════════════════════════════════════════════════
269//
270// Implements a variational equilibrium solver following Hirshman & Whitson
271// (1983). Given boundary Fourier modes, pressure p(s), and rotational
272// transform iota(s), finds force-balanced interior flux surface shapes via
273// steepest descent on the MHD energy functional.
274//
275// Stellarator symmetry: R uses cos(mθ − nNζ), Z uses sin(mθ − nNζ).
276
277/// Solver configuration for the VMEC fixed-boundary equilibrium.
278#[derive(Debug, Clone)]
279pub struct VmecSolverConfig {
280    /// Maximum poloidal mode number.
281    pub m_pol: usize,
282    /// Maximum toroidal mode number (0 = axisymmetric).
283    pub n_tor: usize,
284    /// Number of flux surfaces (radial, including axis + boundary).
285    pub ns: usize,
286    /// Poloidal angle grid points.
287    pub ntheta: usize,
288    /// Toroidal angle grid points per field period.
289    pub nzeta: usize,
290    /// Maximum steepest-descent iterations.
291    pub max_iter: usize,
292    /// Force residual convergence tolerance.
293    pub tol: f64,
294    /// Steepest descent step size.
295    pub step_size: f64,
296}
297
298impl Default for VmecSolverConfig {
299    fn default() -> Self {
300        Self {
301            m_pol: 6,
302            n_tor: 0,
303            ns: 25,
304            ntheta: 32,
305            nzeta: 1,
306            max_iter: 500,
307            tol: 1e-8,
308            step_size: 5e-3,
309        }
310    }
311}
312
313impl VmecSolverConfig {
314    pub fn validate(&self) -> FusionResult<()> {
315        if self.ns < 3 {
316            return Err(FusionError::PhysicsViolation(
317                "VMEC solver requires ns >= 3".into(),
318            ));
319        }
320        if self.ntheta < 8 {
321            return Err(FusionError::PhysicsViolation(
322                "VMEC solver requires ntheta >= 8".into(),
323            ));
324        }
325        if self.max_iter == 0 {
326            return Err(FusionError::PhysicsViolation(
327                "VMEC solver requires max_iter >= 1".into(),
328            ));
329        }
330        if !self.tol.is_finite() || self.tol <= 0.0 {
331            return Err(FusionError::PhysicsViolation(
332                "VMEC solver tol must be finite and > 0".into(),
333            ));
334        }
335        if !self.step_size.is_finite() || self.step_size <= 0.0 {
336            return Err(FusionError::PhysicsViolation(
337                "VMEC solver step_size must be finite and > 0".into(),
338            ));
339        }
340        Ok(())
341    }
342}
343
344/// Solution from the VMEC fixed-boundary solver.
345#[derive(Debug, Clone)]
346pub struct VmecEquilibrium {
347    /// R cosine Fourier coefficients per surface [ns × n_modes].
348    pub rmnc: Array2<f64>,
349    /// Z sine Fourier coefficients per surface [ns × n_modes].
350    pub zmns: Array2<f64>,
351    /// Rotational transform profile iota(s) [ns].
352    pub iota: Array1<f64>,
353    /// Pressure profile [Pa] [ns].
354    pub pressure: Array1<f64>,
355    /// Total toroidal flux [Wb].
356    pub phi_edge: f64,
357    /// Plasma volume [m³].
358    pub volume: f64,
359    /// Volume-averaged beta.
360    pub beta_avg: f64,
361    /// Final force residual norm.
362    pub force_residual: f64,
363    /// Number of iterations.
364    pub iterations: usize,
365    /// Whether the solver converged.
366    pub converged: bool,
367    /// Grid parameters.
368    pub ns_grid: usize,
369    pub m_pol: usize,
370    pub n_tor: usize,
371    pub nfp: usize,
372}
373
374/// Number of Fourier modes for given (m_pol, n_tor).
375pub fn vmec_n_modes(m_pol: usize, n_tor: usize) -> usize {
376    if n_tor == 0 {
377        m_pol + 1
378    } else {
379        (m_pol + 1) * (2 * n_tor + 1)
380    }
381}
382
383/// Flat index for mode (m, n). Returns None if out of range.
384pub fn vmec_mode_idx(m: usize, n: i32, m_pol: usize, n_tor: usize) -> Option<usize> {
385    if m > m_pol {
386        return None;
387    }
388    if n_tor == 0 {
389        if n != 0 {
390            return None;
391        }
392        Some(m)
393    } else {
394        let n_abs = n.unsigned_abs() as usize;
395        if n_abs > n_tor {
396            return None;
397        }
398        Some(m * (2 * n_tor + 1) + (n + n_tor as i32) as usize)
399    }
400}
401
402/// Evaluate R(θ,ζ) and Z(θ,ζ) from one surface's Fourier coefficients.
403fn eval_surface_point(
404    rmnc: &[f64],
405    zmns: &[f64],
406    theta: f64,
407    zeta: f64,
408    m_pol: usize,
409    n_tor: usize,
410    nfp: usize,
411) -> (f64, f64) {
412    let mut r = 0.0;
413    let mut z = 0.0;
414    if n_tor == 0 {
415        for m in 0..=m_pol {
416            let angle = m as f64 * theta;
417            let (sin_a, cos_a) = angle.sin_cos();
418            r += rmnc[m] * cos_a;
419            z += zmns[m] * sin_a;
420        }
421    } else {
422        for m in 0..=m_pol {
423            for nn in -(n_tor as i32)..=(n_tor as i32) {
424                let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
425                let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
426                let (sin_a, cos_a) = angle.sin_cos();
427                r += rmnc[idx] * cos_a;
428                z += zmns[idx] * sin_a;
429            }
430        }
431    }
432    (r, z)
433}
434
435/// Analytic ∂R/∂θ and ∂Z/∂θ from Fourier coefficients.
436fn eval_surface_deriv_theta(
437    rmnc: &[f64],
438    zmns: &[f64],
439    theta: f64,
440    zeta: f64,
441    m_pol: usize,
442    n_tor: usize,
443    nfp: usize,
444) -> (f64, f64) {
445    let mut dr = 0.0;
446    let mut dz = 0.0;
447    if n_tor == 0 {
448        for m in 0..=m_pol {
449            let mf = m as f64;
450            let angle = mf * theta;
451            let (sin_a, cos_a) = angle.sin_cos();
452            dr -= mf * rmnc[m] * sin_a;
453            dz += mf * zmns[m] * cos_a;
454        }
455    } else {
456        for m in 0..=m_pol {
457            let mf = m as f64;
458            for nn in -(n_tor as i32)..=(n_tor as i32) {
459                let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
460                let angle = mf * theta - nn as f64 * nfp as f64 * zeta;
461                let (sin_a, cos_a) = angle.sin_cos();
462                dr -= mf * rmnc[idx] * sin_a;
463                dz += mf * zmns[idx] * cos_a;
464            }
465        }
466    }
467    (dr, dz)
468}
469
470/// Solve VMEC fixed-boundary 3D equilibrium.
471///
472/// Given boundary shape (from `VmecBoundaryState`), pressure and rotational
473/// transform profiles, finds force-balanced interior flux surface shapes.
474///
475/// The algorithm minimises the MHD energy functional W = ∫(B²/2μ₀ + p)dV
476/// via steepest descent on the interior Fourier coefficients R_mn^c(s),
477/// Z_mn^s(s) while holding the axis and boundary fixed.
478pub fn vmec_fixed_boundary_solve(
479    boundary: &VmecBoundaryState,
480    config: &VmecSolverConfig,
481    pressure: &[f64],
482    iota: &[f64],
483    phi_edge: f64,
484) -> FusionResult<VmecEquilibrium> {
485    boundary.validate()?;
486    config.validate()?;
487
488    let ns = config.ns;
489    let m_pol = config.m_pol;
490    let n_tor = config.n_tor;
491    let nfp = boundary.nfp;
492    let nmodes = vmec_n_modes(m_pol, n_tor);
493    let ntheta = config.ntheta;
494    let nzeta = if n_tor == 0 { 1 } else { config.nzeta.max(4) };
495
496    if pressure.len() != ns || iota.len() != ns {
497        return Err(FusionError::PhysicsViolation(format!(
498            "pressure/iota length must match ns={ns}, got p={} iota={}",
499            pressure.len(),
500            iota.len()
501        )));
502    }
503    if pressure.iter().any(|v| !v.is_finite()) || iota.iter().any(|v| !v.is_finite()) {
504        return Err(FusionError::PhysicsViolation(
505            "pressure/iota profiles must be finite".into(),
506        ));
507    }
508    if !phi_edge.is_finite() || phi_edge <= 0.0 {
509        return Err(FusionError::PhysicsViolation(
510            "phi_edge must be finite and > 0".into(),
511        ));
512    }
513
514    // Build boundary Fourier coefficients from VmecBoundaryState
515    let mut bnd_rmnc = vec![0.0; nmodes];
516    let mut bnd_zmns = vec![0.0; nmodes];
517
518    if let Some(idx) = vmec_mode_idx(0, 0, m_pol, n_tor) {
519        bnd_rmnc[idx] = boundary.r_axis;
520    }
521    if let Some(idx) = vmec_mode_idx(1, 0, m_pol, n_tor) {
522        bnd_rmnc[idx] = boundary.a_minor;
523        bnd_zmns[idx] = boundary.a_minor * boundary.kappa;
524    }
525    if m_pol >= 2 {
526        if let Some(idx) = vmec_mode_idx(2, 0, m_pol, n_tor) {
527            bnd_rmnc[idx] = -boundary.triangularity * boundary.a_minor * 0.5;
528        }
529    }
530    for mode in &boundary.modes {
531        let m = mode.m as usize;
532        if let Some(idx) = vmec_mode_idx(m, mode.n, m_pol, n_tor) {
533            bnd_rmnc[idx] += mode.r_cos;
534            bnd_zmns[idx] += mode.z_sin;
535        }
536    }
537
538    // Initialize: linear interpolation from axis to boundary
539    let mut rmnc = Array2::zeros((ns, nmodes));
540    let mut zmns = Array2::zeros((ns, nmodes));
541    let s_grid: Vec<f64> = (0..ns).map(|i| i as f64 / (ns - 1) as f64).collect();
542
543    for js in 0..ns {
544        let s = s_grid[js];
545        for k in 0..nmodes {
546            let axis_r = if k == vmec_mode_idx(0, 0, m_pol, n_tor).unwrap_or(usize::MAX) {
547                boundary.r_axis
548            } else {
549                0.0
550            };
551            rmnc[[js, k]] = axis_r * (1.0 - s) + bnd_rmnc[k] * s;
552            zmns[[js, k]] = bnd_zmns[k] * s;
553        }
554    }
555
556    let pi = std::f64::consts::PI;
557    let mu0 = 4.0e-7 * pi;
558    let ds = 1.0 / (ns - 1) as f64;
559    let dtheta = 2.0 * pi / ntheta as f64;
560    let dzeta = 2.0 * pi / (nzeta as f64 * nfp.max(1) as f64);
561    let n_angle_pts = (ntheta * nzeta) as f64;
562
563    // Precompute angles
564    let theta_arr: Vec<f64> = (0..ntheta)
565        .map(|i| 2.0 * pi * i as f64 / ntheta as f64)
566        .collect();
567    let zeta_arr: Vec<f64> = (0..nzeta)
568        .map(|i| 2.0 * pi * i as f64 / (nzeta as f64 * nfp.max(1) as f64))
569        .collect();
570
571    // Steepest descent iteration
572    let mut global_force = f64::MAX;
573    let mut converged = false;
574    let mut iter_count = 0;
575    let mut total_volume = 0.0;
576    let mut total_beta_num = 0.0;
577    let mut total_b2_vol = 0.0;
578
579    for iteration in 0..config.max_iter {
580        let mut force_sq_sum = 0.0;
581        total_volume = 0.0;
582        total_beta_num = 0.0;
583        total_b2_vol = 0.0;
584
585        let mut grad_r: Array2<f64> = Array2::zeros((ns, nmodes));
586        let mut grad_z: Array2<f64> = Array2::zeros((ns, nmodes));
587
588        // Evaluate force on each interior surface
589        for js in 1..(ns - 1) {
590            let p = pressure[js];
591            let iota_s = iota[js];
592            let dp_ds =
593                (pressure[(js + 1).min(ns - 1)] - pressure[js.saturating_sub(1)]) / (2.0 * ds);
594
595            let rmnc_s: Vec<f64> = (0..nmodes).map(|k| rmnc[[js, k]]).collect();
596            let zmns_s: Vec<f64> = (0..nmodes).map(|k| zmns[[js, k]]).collect();
597            let rmnc_p: Vec<f64> = (0..nmodes).map(|k| rmnc[[js + 1, k]]).collect();
598            let zmns_p: Vec<f64> = (0..nmodes).map(|k| zmns[[js + 1, k]]).collect();
599            let rmnc_m: Vec<f64> = (0..nmodes).map(|k| rmnc[[js - 1, k]]).collect();
600            let zmns_m: Vec<f64> = (0..nmodes).map(|k| zmns[[js - 1, k]]).collect();
601
602            for &theta in theta_arr.iter().take(ntheta) {
603                for &zeta in zeta_arr.iter().take(nzeta) {
604                    let (r_val, _z_val) =
605                        eval_surface_point(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
606                    let (r_p, _) =
607                        eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
608                    let (r_m, _) =
609                        eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
610                    let (_, z_p) =
611                        eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
612                    let (_, z_m) =
613                        eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
614
615                    let r_s = (r_p - r_m) / (2.0 * ds);
616                    let z_s = (z_p - z_m) / (2.0 * ds);
617
618                    let (r_theta, z_theta) =
619                        eval_surface_deriv_theta(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
620
621                    // Jacobian: √g = R · (R_s Z_θ − R_θ Z_s)
622                    let jac = r_val * (r_s * z_theta - r_theta * z_s);
623                    let jac_abs = jac.abs().max(1e-20);
624
625                    // B-field: B^ζ = Φ'/(2π√g), B^θ = ι·B^ζ
626                    let b_zeta = phi_edge / (2.0 * pi * jac_abs);
627                    let b_theta = iota_s * b_zeta;
628
629                    // |B|² via covariant metric
630                    let g_tt = r_theta * r_theta + z_theta * z_theta;
631                    let g_zz = r_val * r_val;
632                    let b_sq = b_theta * b_theta * g_tt + b_zeta * b_zeta * g_zz;
633
634                    // Force residual: F_s = dp/ds + d(B²/2μ₀)/ds
635                    let force_s = dp_ds + b_sq / (2.0 * mu0 * jac_abs.max(1e-10));
636                    force_sq_sum += force_s * force_s;
637
638                    let dvol = jac_abs * dtheta * dzeta * ds;
639                    total_volume += dvol;
640                    total_beta_num += p * dvol;
641                    total_b2_vol += b_sq * dvol;
642
643                    // Accumulate gradient for steepest descent
644                    if n_tor == 0 {
645                        for m in 0..=m_pol {
646                            let angle = m as f64 * theta;
647                            let (sin_a, cos_a) = angle.sin_cos();
648                            grad_r[[js, m]] += force_s * cos_a * dtheta * dzeta;
649                            grad_z[[js, m]] += force_s * sin_a * dtheta * dzeta;
650                        }
651                    } else {
652                        for m in 0..=m_pol {
653                            for nn in -(n_tor as i32)..=(n_tor as i32) {
654                                let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
655                                let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
656                                let (sin_a, cos_a) = angle.sin_cos();
657                                grad_r[[js, idx]] += force_s * cos_a * dtheta * dzeta;
658                                grad_z[[js, idx]] += force_s * sin_a * dtheta * dzeta;
659                            }
660                        }
661                    }
662                }
663            }
664        }
665
666        let n_interior = ((ns - 2) as f64 * n_angle_pts).max(1.0);
667        global_force = (force_sq_sum / n_interior).sqrt();
668
669        // Update interior surfaces
670        for js in 1..(ns - 1) {
671            for k in 0..nmodes {
672                rmnc[[js, k]] -= config.step_size * grad_r[[js, k]] / n_angle_pts.max(1.0);
673                zmns[[js, k]] -= config.step_size * grad_z[[js, k]] / n_angle_pts.max(1.0);
674            }
675        }
676
677        iter_count = iteration + 1;
678        if global_force < config.tol {
679            converged = true;
680            break;
681        }
682        if !global_force.is_finite() {
683            return Err(FusionError::PhysicsViolation(format!(
684                "VMEC solver diverged at iteration {}: force={}",
685                iteration, global_force
686            )));
687        }
688    }
689
690    let beta_avg = if total_b2_vol > 0.0 {
691        2.0 * mu0 * total_beta_num / total_b2_vol
692    } else {
693        0.0
694    };
695
696    Ok(VmecEquilibrium {
697        rmnc,
698        zmns,
699        iota: Array1::from_vec(iota.to_vec()),
700        pressure: Array1::from_vec(pressure.to_vec()),
701        phi_edge,
702        volume: total_volume.abs(),
703        beta_avg,
704        force_residual: global_force,
705        iterations: iter_count,
706        converged,
707        ns_grid: ns,
708        m_pol,
709        n_tor,
710        nfp,
711    })
712}
713
714/// Evaluate the equilibrium geometry at given (s, θ, ζ) coordinates.
715pub fn vmec_eval_geometry(
716    eq: &VmecEquilibrium,
717    s_idx: usize,
718    theta: f64,
719    zeta: f64,
720) -> FusionResult<(f64, f64)> {
721    if s_idx >= eq.ns_grid {
722        return Err(FusionError::PhysicsViolation(format!(
723            "Surface index {} >= ns_grid {}",
724            s_idx, eq.ns_grid
725        )));
726    }
727    let nmodes = vmec_n_modes(eq.m_pol, eq.n_tor);
728    let rmnc_row: Vec<f64> = (0..nmodes).map(|k| eq.rmnc[[s_idx, k]]).collect();
729    let zmns_row: Vec<f64> = (0..nmodes).map(|k| eq.zmns[[s_idx, k]]).collect();
730    Ok(eval_surface_point(
731        &rmnc_row, &zmns_row, theta, zeta, eq.m_pol, eq.n_tor, eq.nfp,
732    ))
733}
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738
739    fn sample_state() -> VmecBoundaryState {
740        VmecBoundaryState {
741            r_axis: 6.2,
742            z_axis: 0.0,
743            a_minor: 2.0,
744            kappa: 1.7,
745            triangularity: 0.33,
746            nfp: 1,
747            modes: vec![
748                VmecFourierMode {
749                    m: 1,
750                    n: 1,
751                    r_cos: 0.03,
752                    r_sin: -0.01,
753                    z_cos: 0.0,
754                    z_sin: 0.02,
755                },
756                VmecFourierMode {
757                    m: 2,
758                    n: 1,
759                    r_cos: 0.015,
760                    r_sin: 0.0,
761                    z_cos: 0.0,
762                    z_sin: 0.008,
763                },
764            ],
765        }
766    }
767
768    #[test]
769    fn test_vmec_text_roundtrip() {
770        let state = sample_state();
771        let text = export_vmec_like_text(&state).expect("export must succeed");
772        let parsed = import_vmec_like_text(&text).expect("import must succeed");
773        assert_eq!(parsed.nfp, state.nfp);
774        assert_eq!(parsed.modes.len(), state.modes.len());
775        assert!((parsed.r_axis - state.r_axis).abs() < 1e-12);
776        assert!((parsed.kappa - state.kappa).abs() < 1e-12);
777    }
778
779    #[test]
780    fn test_vmec_missing_key_errors() {
781        let text = "format=vmec_like_v1\nr_axis=6.2\n";
782        let err = import_vmec_like_text(text).expect_err("missing keys should error");
783        match err {
784            FusionError::PhysicsViolation(msg) => assert!(msg.contains("Missing VMEC key")),
785            other => panic!("Unexpected error: {other:?}"),
786        }
787    }
788
789    #[test]
790    fn test_vmec_invalid_minor_radius_errors() {
791        let text = "\
792format=vmec_like_v1
793r_axis=6.2
794z_axis=0.0
795a_minor=0.0
796kappa=1.7
797triangularity=0.2
798nfp=1
799";
800        let err = import_vmec_like_text(text).expect_err("a_minor=0 must fail");
801        match err {
802            FusionError::PhysicsViolation(msg) => assert!(msg.contains("a_minor")),
803            other => panic!("Unexpected error: {other:?}"),
804        }
805    }
806
807    #[test]
808    fn test_vmec_rejects_invalid_modes_and_duplicates() {
809        let mut state = sample_state();
810        state.modes[0].r_cos = f64::NAN;
811        let err = export_vmec_like_text(&state).expect_err("non-finite mode coeff must fail");
812        match err {
813            FusionError::PhysicsViolation(msg) => assert!(msg.contains("mode")),
814            other => panic!("Unexpected error: {other:?}"),
815        }
816
817        let dup_text = "\
818format=vmec_like_v1
819r_axis=6.2
820r_axis=6.3
821z_axis=0.0
822a_minor=2.0
823kappa=1.7
824triangularity=0.2
825nfp=1
826";
827        let err = import_vmec_like_text(dup_text).expect_err("duplicate keys must fail");
828        match err {
829            FusionError::PhysicsViolation(msg) => assert!(msg.contains("Duplicate VMEC key")),
830            other => panic!("Unexpected error: {other:?}"),
831        }
832    }
833
834    #[test]
835    fn test_vmec_rejects_duplicate_modes_and_bad_format() {
836        let mut state = sample_state();
837        state.modes.push(VmecFourierMode {
838            m: state.modes[0].m,
839            n: state.modes[0].n,
840            r_cos: 0.001,
841            r_sin: 0.0,
842            z_cos: 0.0,
843            z_sin: 0.001,
844        });
845        let err = export_vmec_like_text(&state).expect_err("duplicate mode index must fail");
846        match err {
847            FusionError::PhysicsViolation(msg) => assert!(msg.contains("duplicate mode")),
848            other => panic!("Unexpected error: {other:?}"),
849        }
850
851        let bad_format = "\
852format=vmec_like_v2
853r_axis=6.2
854z_axis=0.0
855a_minor=2.0
856kappa=1.7
857triangularity=0.2
858nfp=1
859";
860        let err = import_vmec_like_text(bad_format).expect_err("unsupported format must fail");
861        match err {
862            FusionError::PhysicsViolation(msg) => assert!(msg.contains("Unsupported VMEC format")),
863            other => panic!("Unexpected error: {other:?}"),
864        }
865    }
866
867    #[test]
868    fn test_vmec_rejects_malformed_mode_and_duplicate_format() {
869        let malformed_mode = "\
870format=vmec_like_v1
871r_axis=6.2
872z_axis=0.0
873a_minor=2.0
874kappa=1.7
875triangularity=0.2
876nfp=1
877mode,1,1,0.1,0.0,0.2
878";
879        let err =
880            import_vmec_like_text(malformed_mode).expect_err("mode with wrong arity must fail");
881        match err {
882            FusionError::PhysicsViolation(msg) => assert!(msg.contains("exactly 6 columns")),
883            other => panic!("Unexpected error: {other:?}"),
884        }
885
886        let duplicate_format = "\
887format=vmec_like_v1
888format=vmec_like_v1
889r_axis=6.2
890z_axis=0.0
891a_minor=2.0
892kappa=1.7
893triangularity=0.2
894nfp=1
895";
896        let err =
897            import_vmec_like_text(duplicate_format).expect_err("duplicate format key must fail");
898        match err {
899            FusionError::PhysicsViolation(msg) => {
900                assert!(msg.contains("Duplicate VMEC key: format"))
901            }
902            other => panic!("Unexpected error: {other:?}"),
903        }
904    }
905
906    // === VMEC Solver Tests ===
907
908    fn iter_like_boundary() -> VmecBoundaryState {
909        VmecBoundaryState {
910            r_axis: 6.2,
911            z_axis: 0.0,
912            a_minor: 2.0,
913            kappa: 1.7,
914            triangularity: 0.33,
915            nfp: 1,
916            modes: vec![],
917        }
918    }
919
920    #[test]
921    fn test_vmec_mode_indexing() {
922        assert_eq!(vmec_n_modes(6, 0), 7);
923        assert_eq!(vmec_n_modes(3, 2), 4 * 5);
924        assert_eq!(vmec_mode_idx(0, 0, 6, 0), Some(0));
925        assert_eq!(vmec_mode_idx(3, 0, 6, 0), Some(3));
926        assert_eq!(vmec_mode_idx(0, 1, 6, 0), None);
927        assert_eq!(vmec_mode_idx(7, 0, 6, 0), None);
928        assert_eq!(vmec_mode_idx(1, -1, 3, 2), Some(6));
929    }
930
931    #[test]
932    fn test_vmec_eval_surface_axis_is_circle() {
933        let m_pol = 4;
934        let n_tor = 0;
935        let nmodes = vmec_n_modes(m_pol, n_tor);
936        let mut rmnc = vec![0.0; nmodes];
937        let mut zmns = vec![0.0; nmodes];
938        rmnc[0] = 6.2; // R_00 = major radius
939        rmnc[1] = 2.0; // R_10 = minor radius
940        zmns[1] = 2.0; // Z_10 = minor radius (circular cross-section)
941
942        let (r, z) = eval_surface_point(&rmnc, &zmns, 0.0, 0.0, m_pol, n_tor, 1);
943        assert!((r - 8.2).abs() < 1e-12, "outboard midplane R");
944        assert!(z.abs() < 1e-12, "midplane Z");
945
946        let (r2, z2) = eval_surface_point(
947            &rmnc,
948            &zmns,
949            std::f64::consts::FRAC_PI_2,
950            0.0,
951            m_pol,
952            n_tor,
953            1,
954        );
955        assert!((r2 - 6.2).abs() < 1e-12, "top R = R_axis");
956        assert!((z2 - 2.0).abs() < 1e-12, "top Z = a_minor");
957    }
958
959    #[test]
960    fn test_vmec_solver_runs_axisymmetric() {
961        let boundary = iter_like_boundary();
962        let config = VmecSolverConfig {
963            m_pol: 4,
964            n_tor: 0,
965            ns: 11,
966            ntheta: 16,
967            nzeta: 1,
968            max_iter: 50,
969            tol: 1e-6,
970            step_size: 1e-3,
971        };
972        let ns = config.ns;
973        let pressure: Vec<f64> = (0..ns)
974            .map(|i| {
975                let s = i as f64 / (ns - 1) as f64;
976                1e5 * (1.0 - s * s)
977            })
978            .collect();
979        let iota: Vec<f64> = (0..ns)
980            .map(|i| {
981                let s = i as f64 / (ns - 1) as f64;
982                0.3 + 0.7 * s * s
983            })
984            .collect();
985
986        let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0)
987            .expect("VMEC solve should succeed");
988
989        assert!(eq.iterations > 0);
990        assert!(eq.force_residual.is_finite());
991        assert!(eq.volume > 0.0, "Volume must be positive");
992        assert_eq!(eq.ns_grid, ns);
993        assert_eq!(eq.m_pol, 4);
994        assert_eq!(eq.nfp, 1);
995    }
996
997    #[test]
998    fn test_vmec_solver_boundary_preserved() {
999        let boundary = iter_like_boundary();
1000        let config = VmecSolverConfig {
1001            m_pol: 4,
1002            n_tor: 0,
1003            ns: 11,
1004            ntheta: 16,
1005            nzeta: 1,
1006            max_iter: 10,
1007            tol: 1e-12,
1008            step_size: 1e-3,
1009        };
1010        let ns = config.ns;
1011        let pressure: Vec<f64> = (0..ns).map(|_| 1e4).collect();
1012        let iota: Vec<f64> = (0..ns).map(|_| 0.5).collect();
1013
1014        let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0).unwrap();
1015
1016        // Boundary (last surface) R_10 should be a_minor = 2.0
1017        let r10_bnd = eq.rmnc[[ns - 1, 1]];
1018        assert!(
1019            (r10_bnd - 2.0).abs() < 1e-10,
1020            "Boundary R_10 must be preserved: got {r10_bnd}"
1021        );
1022    }
1023
1024    #[test]
1025    fn test_vmec_solver_rejects_invalid_inputs() {
1026        let boundary = iter_like_boundary();
1027        let config = VmecSolverConfig::default();
1028        let ns = config.ns;
1029        let p: Vec<f64> = vec![0.0; ns];
1030        let iota: Vec<f64> = vec![0.5; ns];
1031
1032        // Invalid phi_edge
1033        assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, -1.0).is_err());
1034        assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, f64::NAN).is_err());
1035
1036        // Wrong profile length
1037        let short_p: Vec<f64> = vec![0.0; 3];
1038        assert!(vmec_fixed_boundary_solve(&boundary, &config, &short_p, &iota, 1.0).is_err());
1039
1040        // Invalid config
1041        let bad_cfg = VmecSolverConfig {
1042            ns: 2,
1043            ..Default::default()
1044        };
1045        assert!(bad_cfg.validate().is_err());
1046    }
1047
1048    #[test]
1049    fn test_vmec_eval_geometry_valid() {
1050        let boundary = iter_like_boundary();
1051        let config = VmecSolverConfig {
1052            m_pol: 4,
1053            n_tor: 0,
1054            ns: 7,
1055            ntheta: 16,
1056            nzeta: 1,
1057            max_iter: 5,
1058            tol: 1e-12,
1059            step_size: 1e-3,
1060        };
1061        let ns = config.ns;
1062        let p: Vec<f64> = (0..ns)
1063            .map(|i| 1e4 * (1.0 - (i as f64 / (ns - 1) as f64)))
1064            .collect();
1065        let iota: Vec<f64> = vec![0.5; ns];
1066
1067        let eq = vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, 1.0).unwrap();
1068        let (r, z) = vmec_eval_geometry(&eq, ns - 1, 0.0, 0.0).unwrap();
1069        assert!(r > 0.0 && r.is_finite());
1070        assert!(z.is_finite());
1071
1072        // Out of bounds
1073        assert!(vmec_eval_geometry(&eq, ns + 5, 0.0, 0.0).is_err());
1074    }
1075
1076    #[test]
1077    fn test_vmec_deriv_theta_consistency() {
1078        let m_pol = 3;
1079        let n_tor = 0;
1080        let _nmodes = vmec_n_modes(m_pol, n_tor);
1081        let rmnc: Vec<f64> = vec![6.2, 2.0, 0.1, 0.05];
1082        let zmns: Vec<f64> = vec![0.0, 1.8, 0.08, 0.03];
1083        let theta = 1.0;
1084        let eps = 1e-7;
1085
1086        let (r_p, z_p) = eval_surface_point(&rmnc, &zmns, theta + eps, 0.0, m_pol, n_tor, 1);
1087        let (r_m, z_m) = eval_surface_point(&rmnc, &zmns, theta - eps, 0.0, m_pol, n_tor, 1);
1088        let dr_num = (r_p - r_m) / (2.0 * eps);
1089        let dz_num = (z_p - z_m) / (2.0 * eps);
1090
1091        let (dr_ana, dz_ana) = eval_surface_deriv_theta(&rmnc, &zmns, theta, 0.0, m_pol, n_tor, 1);
1092
1093        assert!(
1094            (dr_ana - dr_num).abs() < 1e-5,
1095            "dR/dθ mismatch: analytic={dr_ana}, numerical={dr_num}"
1096        );
1097        assert!(
1098            (dz_ana - dz_num).abs() < 1e-5,
1099            "dZ/dθ mismatch: analytic={dz_ana}, numerical={dz_num}"
1100        );
1101    }
1102}