1use fusion_types::error::{FusionError, FusionResult};
15use ndarray::{Array1, Array2};
16use std::collections::HashSet;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct VmecFourierMode {
20 pub m: i32,
21 pub n: i32,
22 pub r_cos: f64,
23 pub r_sin: f64,
24 pub z_cos: f64,
25 pub z_sin: f64,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct VmecBoundaryState {
30 pub r_axis: f64,
31 pub z_axis: f64,
32 pub a_minor: f64,
33 pub kappa: f64,
34 pub triangularity: f64,
35 pub nfp: usize,
36 pub modes: Vec<VmecFourierMode>,
37}
38
39impl VmecBoundaryState {
40 pub fn validate(&self) -> FusionResult<()> {
41 if !self.r_axis.is_finite()
42 || !self.z_axis.is_finite()
43 || !self.a_minor.is_finite()
44 || !self.kappa.is_finite()
45 || !self.triangularity.is_finite()
46 {
47 return Err(FusionError::PhysicsViolation(
48 "VMEC boundary contains non-finite scalar".to_string(),
49 ));
50 }
51 if self.a_minor <= 0.0 {
52 return Err(FusionError::PhysicsViolation(
53 "VMEC boundary requires a_minor > 0".to_string(),
54 ));
55 }
56 if self.kappa <= 0.0 {
57 return Err(FusionError::PhysicsViolation(
58 "VMEC boundary requires kappa > 0".to_string(),
59 ));
60 }
61 if self.nfp < 1 {
62 return Err(FusionError::PhysicsViolation(
63 "VMEC boundary requires nfp >= 1".to_string(),
64 ));
65 }
66 let mut seen_modes: HashSet<(i32, i32)> = HashSet::with_capacity(self.modes.len());
67 for (idx, mode) in self.modes.iter().enumerate() {
68 if mode.m < 0 {
69 return Err(FusionError::PhysicsViolation(format!(
70 "VMEC boundary mode[{idx}] requires m >= 0, got {}",
71 mode.m
72 )));
73 }
74 if !mode.r_cos.is_finite()
75 || !mode.r_sin.is_finite()
76 || !mode.z_cos.is_finite()
77 || !mode.z_sin.is_finite()
78 {
79 return Err(FusionError::PhysicsViolation(format!(
80 "VMEC boundary mode[{idx}] contains non-finite coefficients"
81 )));
82 }
83 let key = (mode.m, mode.n);
84 if !seen_modes.insert(key) {
85 return Err(FusionError::PhysicsViolation(format!(
86 "VMEC boundary contains duplicate mode (m={}, n={})",
87 mode.m, mode.n
88 )));
89 }
90 }
91 Ok(())
92 }
93}
94
95pub fn export_vmec_like_text(state: &VmecBoundaryState) -> FusionResult<String> {
96 state.validate()?;
97 let mut out = String::new();
98 out.push_str("format=vmec_like_v1\n");
99 out.push_str(&format!("r_axis={:.16e}\n", state.r_axis));
100 out.push_str(&format!("z_axis={:.16e}\n", state.z_axis));
101 out.push_str(&format!("a_minor={:.16e}\n", state.a_minor));
102 out.push_str(&format!("kappa={:.16e}\n", state.kappa));
103 out.push_str(&format!("triangularity={:.16e}\n", state.triangularity));
104 out.push_str(&format!("nfp={}\n", state.nfp));
105 for mode in &state.modes {
106 out.push_str(&format!(
107 "mode,{},{},{:.16e},{:.16e},{:.16e},{:.16e}\n",
108 mode.m, mode.n, mode.r_cos, mode.r_sin, mode.z_cos, mode.z_sin
109 ));
110 }
111 Ok(out)
112}
113
114fn parse_float(key: &str, text: &str) -> FusionResult<f64> {
115 let val = text.parse::<f64>().map_err(|e| {
116 FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as float: {e}"))
117 })?;
118 if !val.is_finite() {
119 return Err(FusionError::PhysicsViolation(format!(
120 "VMEC key '{key}' must be finite, got {val}"
121 )));
122 }
123 Ok(val)
124}
125
126fn parse_int<T>(key: &str, text: &str) -> FusionResult<T>
127where
128 T: std::str::FromStr,
129 T::Err: std::fmt::Display,
130{
131 text.parse::<T>().map_err(|e| {
132 FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as integer: {e}"))
133 })
134}
135
136pub fn import_vmec_like_text(text: &str) -> FusionResult<VmecBoundaryState> {
137 let mut format_seen = false;
138 let mut r_axis: Option<f64> = None;
139 let mut z_axis: Option<f64> = None;
140 let mut a_minor: Option<f64> = None;
141 let mut kappa: Option<f64> = None;
142 let mut triangularity: Option<f64> = None;
143 let mut nfp: Option<usize> = None;
144 let mut modes: Vec<VmecFourierMode> = Vec::new();
145
146 for raw_line in text.lines() {
147 let line = raw_line.trim();
148 if line.is_empty() || line.starts_with('#') {
149 continue;
150 }
151 if let Some(format_name) = line.strip_prefix("format=") {
152 if format_seen {
153 return Err(FusionError::PhysicsViolation(
154 "Duplicate VMEC key: format".to_string(),
155 ));
156 }
157 if format_name.trim() != "vmec_like_v1" {
158 return Err(FusionError::PhysicsViolation(format!(
159 "Unsupported VMEC format: {}",
160 format_name.trim()
161 )));
162 }
163 format_seen = true;
164 continue;
165 }
166 if let Some(rest) = line.strip_prefix("mode,") {
167 let cols: Vec<&str> = rest.split(',').map(|v| v.trim()).collect();
168 if cols.len() != 6 {
169 return Err(FusionError::PhysicsViolation(
170 "VMEC mode line must contain exactly 6 columns".to_string(),
171 ));
172 }
173 modes.push(VmecFourierMode {
174 m: parse_int("mode.m", cols[0])?,
175 n: parse_int("mode.n", cols[1])?,
176 r_cos: parse_float("mode.r_cos", cols[2])?,
177 r_sin: parse_float("mode.r_sin", cols[3])?,
178 z_cos: parse_float("mode.z_cos", cols[4])?,
179 z_sin: parse_float("mode.z_sin", cols[5])?,
180 });
181 continue;
182 }
183 let (key, value) = line.split_once('=').ok_or_else(|| {
184 FusionError::PhysicsViolation(format!("Invalid VMEC line (missing '='): {line}"))
185 })?;
186 let key = key.trim();
187 let value = value.trim();
188 match key {
189 "r_axis" => {
190 if r_axis.is_some() {
191 return Err(FusionError::PhysicsViolation(
192 "Duplicate VMEC key: r_axis".to_string(),
193 ));
194 }
195 r_axis = Some(parse_float(key, value)?);
196 }
197 "z_axis" => {
198 if z_axis.is_some() {
199 return Err(FusionError::PhysicsViolation(
200 "Duplicate VMEC key: z_axis".to_string(),
201 ));
202 }
203 z_axis = Some(parse_float(key, value)?);
204 }
205 "a_minor" => {
206 if a_minor.is_some() {
207 return Err(FusionError::PhysicsViolation(
208 "Duplicate VMEC key: a_minor".to_string(),
209 ));
210 }
211 a_minor = Some(parse_float(key, value)?);
212 }
213 "kappa" => {
214 if kappa.is_some() {
215 return Err(FusionError::PhysicsViolation(
216 "Duplicate VMEC key: kappa".to_string(),
217 ));
218 }
219 kappa = Some(parse_float(key, value)?);
220 }
221 "triangularity" => {
222 if triangularity.is_some() {
223 return Err(FusionError::PhysicsViolation(
224 "Duplicate VMEC key: triangularity".to_string(),
225 ));
226 }
227 triangularity = Some(parse_float(key, value)?);
228 }
229 "nfp" => {
230 if nfp.is_some() {
231 return Err(FusionError::PhysicsViolation(
232 "Duplicate VMEC key: nfp".to_string(),
233 ));
234 }
235 nfp = Some(parse_int(key, value)?);
236 }
237 other => {
238 return Err(FusionError::PhysicsViolation(format!(
239 "Unknown VMEC key: {other}"
240 )));
241 }
242 }
243 }
244
245 let state = VmecBoundaryState {
246 r_axis: r_axis
247 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: r_axis".to_string()))?,
248 z_axis: z_axis
249 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: z_axis".to_string()))?,
250 a_minor: a_minor.ok_or_else(|| {
251 FusionError::PhysicsViolation("Missing VMEC key: a_minor".to_string())
252 })?,
253 kappa: kappa
254 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: kappa".to_string()))?,
255 triangularity: triangularity.ok_or_else(|| {
256 FusionError::PhysicsViolation("Missing VMEC key: triangularity".to_string())
257 })?,
258 nfp: nfp
259 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: nfp".to_string()))?,
260 modes,
261 };
262 state.validate()?;
263 Ok(state)
264}
265
266#[derive(Debug, Clone)]
279pub struct VmecSolverConfig {
280 pub m_pol: usize,
282 pub n_tor: usize,
284 pub ns: usize,
286 pub ntheta: usize,
288 pub nzeta: usize,
290 pub max_iter: usize,
292 pub tol: f64,
294 pub step_size: f64,
296}
297
298impl Default for VmecSolverConfig {
299 fn default() -> Self {
300 Self {
301 m_pol: 6,
302 n_tor: 0,
303 ns: 25,
304 ntheta: 32,
305 nzeta: 1,
306 max_iter: 500,
307 tol: 1e-8,
308 step_size: 5e-3,
309 }
310 }
311}
312
313impl VmecSolverConfig {
314 pub fn validate(&self) -> FusionResult<()> {
315 if self.ns < 3 {
316 return Err(FusionError::PhysicsViolation(
317 "VMEC solver requires ns >= 3".into(),
318 ));
319 }
320 if self.ntheta < 8 {
321 return Err(FusionError::PhysicsViolation(
322 "VMEC solver requires ntheta >= 8".into(),
323 ));
324 }
325 if self.max_iter == 0 {
326 return Err(FusionError::PhysicsViolation(
327 "VMEC solver requires max_iter >= 1".into(),
328 ));
329 }
330 if !self.tol.is_finite() || self.tol <= 0.0 {
331 return Err(FusionError::PhysicsViolation(
332 "VMEC solver tol must be finite and > 0".into(),
333 ));
334 }
335 if !self.step_size.is_finite() || self.step_size <= 0.0 {
336 return Err(FusionError::PhysicsViolation(
337 "VMEC solver step_size must be finite and > 0".into(),
338 ));
339 }
340 Ok(())
341 }
342}
343
344#[derive(Debug, Clone)]
346pub struct VmecEquilibrium {
347 pub rmnc: Array2<f64>,
349 pub zmns: Array2<f64>,
351 pub iota: Array1<f64>,
353 pub pressure: Array1<f64>,
355 pub phi_edge: f64,
357 pub volume: f64,
359 pub beta_avg: f64,
361 pub force_residual: f64,
363 pub iterations: usize,
365 pub converged: bool,
367 pub ns_grid: usize,
369 pub m_pol: usize,
370 pub n_tor: usize,
371 pub nfp: usize,
372}
373
374pub fn vmec_n_modes(m_pol: usize, n_tor: usize) -> usize {
376 if n_tor == 0 {
377 m_pol + 1
378 } else {
379 (m_pol + 1) * (2 * n_tor + 1)
380 }
381}
382
383pub fn vmec_mode_idx(m: usize, n: i32, m_pol: usize, n_tor: usize) -> Option<usize> {
385 if m > m_pol {
386 return None;
387 }
388 if n_tor == 0 {
389 if n != 0 {
390 return None;
391 }
392 Some(m)
393 } else {
394 let n_abs = n.unsigned_abs() as usize;
395 if n_abs > n_tor {
396 return None;
397 }
398 Some(m * (2 * n_tor + 1) + (n + n_tor as i32) as usize)
399 }
400}
401
402fn eval_surface_point(
404 rmnc: &[f64],
405 zmns: &[f64],
406 theta: f64,
407 zeta: f64,
408 m_pol: usize,
409 n_tor: usize,
410 nfp: usize,
411) -> (f64, f64) {
412 let mut r = 0.0;
413 let mut z = 0.0;
414 if n_tor == 0 {
415 for m in 0..=m_pol {
416 let angle = m as f64 * theta;
417 let (sin_a, cos_a) = angle.sin_cos();
418 r += rmnc[m] * cos_a;
419 z += zmns[m] * sin_a;
420 }
421 } else {
422 for m in 0..=m_pol {
423 for nn in -(n_tor as i32)..=(n_tor as i32) {
424 let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
425 let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
426 let (sin_a, cos_a) = angle.sin_cos();
427 r += rmnc[idx] * cos_a;
428 z += zmns[idx] * sin_a;
429 }
430 }
431 }
432 (r, z)
433}
434
435fn eval_surface_deriv_theta(
437 rmnc: &[f64],
438 zmns: &[f64],
439 theta: f64,
440 zeta: f64,
441 m_pol: usize,
442 n_tor: usize,
443 nfp: usize,
444) -> (f64, f64) {
445 let mut dr = 0.0;
446 let mut dz = 0.0;
447 if n_tor == 0 {
448 for m in 0..=m_pol {
449 let mf = m as f64;
450 let angle = mf * theta;
451 let (sin_a, cos_a) = angle.sin_cos();
452 dr -= mf * rmnc[m] * sin_a;
453 dz += mf * zmns[m] * cos_a;
454 }
455 } else {
456 for m in 0..=m_pol {
457 let mf = m as f64;
458 for nn in -(n_tor as i32)..=(n_tor as i32) {
459 let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
460 let angle = mf * theta - nn as f64 * nfp as f64 * zeta;
461 let (sin_a, cos_a) = angle.sin_cos();
462 dr -= mf * rmnc[idx] * sin_a;
463 dz += mf * zmns[idx] * cos_a;
464 }
465 }
466 }
467 (dr, dz)
468}
469
470pub fn vmec_fixed_boundary_solve(
479 boundary: &VmecBoundaryState,
480 config: &VmecSolverConfig,
481 pressure: &[f64],
482 iota: &[f64],
483 phi_edge: f64,
484) -> FusionResult<VmecEquilibrium> {
485 boundary.validate()?;
486 config.validate()?;
487
488 let ns = config.ns;
489 let m_pol = config.m_pol;
490 let n_tor = config.n_tor;
491 let nfp = boundary.nfp;
492 let nmodes = vmec_n_modes(m_pol, n_tor);
493 let ntheta = config.ntheta;
494 let nzeta = if n_tor == 0 { 1 } else { config.nzeta.max(4) };
495
496 if pressure.len() != ns || iota.len() != ns {
497 return Err(FusionError::PhysicsViolation(format!(
498 "pressure/iota length must match ns={ns}, got p={} iota={}",
499 pressure.len(),
500 iota.len()
501 )));
502 }
503 if pressure.iter().any(|v| !v.is_finite()) || iota.iter().any(|v| !v.is_finite()) {
504 return Err(FusionError::PhysicsViolation(
505 "pressure/iota profiles must be finite".into(),
506 ));
507 }
508 if !phi_edge.is_finite() || phi_edge <= 0.0 {
509 return Err(FusionError::PhysicsViolation(
510 "phi_edge must be finite and > 0".into(),
511 ));
512 }
513
514 let mut bnd_rmnc = vec![0.0; nmodes];
516 let mut bnd_zmns = vec![0.0; nmodes];
517
518 if let Some(idx) = vmec_mode_idx(0, 0, m_pol, n_tor) {
519 bnd_rmnc[idx] = boundary.r_axis;
520 }
521 if let Some(idx) = vmec_mode_idx(1, 0, m_pol, n_tor) {
522 bnd_rmnc[idx] = boundary.a_minor;
523 bnd_zmns[idx] = boundary.a_minor * boundary.kappa;
524 }
525 if m_pol >= 2 {
526 if let Some(idx) = vmec_mode_idx(2, 0, m_pol, n_tor) {
527 bnd_rmnc[idx] = -boundary.triangularity * boundary.a_minor * 0.5;
528 }
529 }
530 for mode in &boundary.modes {
531 let m = mode.m as usize;
532 if let Some(idx) = vmec_mode_idx(m, mode.n, m_pol, n_tor) {
533 bnd_rmnc[idx] += mode.r_cos;
534 bnd_zmns[idx] += mode.z_sin;
535 }
536 }
537
538 let mut rmnc = Array2::zeros((ns, nmodes));
540 let mut zmns = Array2::zeros((ns, nmodes));
541 let s_grid: Vec<f64> = (0..ns).map(|i| i as f64 / (ns - 1) as f64).collect();
542
543 for js in 0..ns {
544 let s = s_grid[js];
545 for k in 0..nmodes {
546 let axis_r = if k == vmec_mode_idx(0, 0, m_pol, n_tor).unwrap_or(usize::MAX) {
547 boundary.r_axis
548 } else {
549 0.0
550 };
551 rmnc[[js, k]] = axis_r * (1.0 - s) + bnd_rmnc[k] * s;
552 zmns[[js, k]] = bnd_zmns[k] * s;
553 }
554 }
555
556 let pi = std::f64::consts::PI;
557 let mu0 = 4.0e-7 * pi;
558 let ds = 1.0 / (ns - 1) as f64;
559 let dtheta = 2.0 * pi / ntheta as f64;
560 let dzeta = 2.0 * pi / (nzeta as f64 * nfp.max(1) as f64);
561 let n_angle_pts = (ntheta * nzeta) as f64;
562
563 let theta_arr: Vec<f64> = (0..ntheta)
565 .map(|i| 2.0 * pi * i as f64 / ntheta as f64)
566 .collect();
567 let zeta_arr: Vec<f64> = (0..nzeta)
568 .map(|i| 2.0 * pi * i as f64 / (nzeta as f64 * nfp.max(1) as f64))
569 .collect();
570
571 let mut global_force = f64::MAX;
573 let mut converged = false;
574 let mut iter_count = 0;
575 let mut total_volume = 0.0;
576 let mut total_beta_num = 0.0;
577 let mut total_b2_vol = 0.0;
578
579 for iteration in 0..config.max_iter {
580 let mut force_sq_sum = 0.0;
581 total_volume = 0.0;
582 total_beta_num = 0.0;
583 total_b2_vol = 0.0;
584
585 let mut grad_r: Array2<f64> = Array2::zeros((ns, nmodes));
586 let mut grad_z: Array2<f64> = Array2::zeros((ns, nmodes));
587
588 for js in 1..(ns - 1) {
590 let p = pressure[js];
591 let iota_s = iota[js];
592 let dp_ds =
593 (pressure[(js + 1).min(ns - 1)] - pressure[js.saturating_sub(1)]) / (2.0 * ds);
594
595 let rmnc_s: Vec<f64> = (0..nmodes).map(|k| rmnc[[js, k]]).collect();
596 let zmns_s: Vec<f64> = (0..nmodes).map(|k| zmns[[js, k]]).collect();
597 let rmnc_p: Vec<f64> = (0..nmodes).map(|k| rmnc[[js + 1, k]]).collect();
598 let zmns_p: Vec<f64> = (0..nmodes).map(|k| zmns[[js + 1, k]]).collect();
599 let rmnc_m: Vec<f64> = (0..nmodes).map(|k| rmnc[[js - 1, k]]).collect();
600 let zmns_m: Vec<f64> = (0..nmodes).map(|k| zmns[[js - 1, k]]).collect();
601
602 for &theta in theta_arr.iter().take(ntheta) {
603 for &zeta in zeta_arr.iter().take(nzeta) {
604 let (r_val, _z_val) =
605 eval_surface_point(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
606 let (r_p, _) =
607 eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
608 let (r_m, _) =
609 eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
610 let (_, z_p) =
611 eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
612 let (_, z_m) =
613 eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
614
615 let r_s = (r_p - r_m) / (2.0 * ds);
616 let z_s = (z_p - z_m) / (2.0 * ds);
617
618 let (r_theta, z_theta) =
619 eval_surface_deriv_theta(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
620
621 let jac = r_val * (r_s * z_theta - r_theta * z_s);
623 let jac_abs = jac.abs().max(1e-20);
624
625 let b_zeta = phi_edge / (2.0 * pi * jac_abs);
627 let b_theta = iota_s * b_zeta;
628
629 let g_tt = r_theta * r_theta + z_theta * z_theta;
631 let g_zz = r_val * r_val;
632 let b_sq = b_theta * b_theta * g_tt + b_zeta * b_zeta * g_zz;
633
634 let force_s = dp_ds + b_sq / (2.0 * mu0 * jac_abs.max(1e-10));
636 force_sq_sum += force_s * force_s;
637
638 let dvol = jac_abs * dtheta * dzeta * ds;
639 total_volume += dvol;
640 total_beta_num += p * dvol;
641 total_b2_vol += b_sq * dvol;
642
643 if n_tor == 0 {
645 for m in 0..=m_pol {
646 let angle = m as f64 * theta;
647 let (sin_a, cos_a) = angle.sin_cos();
648 grad_r[[js, m]] += force_s * cos_a * dtheta * dzeta;
649 grad_z[[js, m]] += force_s * sin_a * dtheta * dzeta;
650 }
651 } else {
652 for m in 0..=m_pol {
653 for nn in -(n_tor as i32)..=(n_tor as i32) {
654 let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
655 let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
656 let (sin_a, cos_a) = angle.sin_cos();
657 grad_r[[js, idx]] += force_s * cos_a * dtheta * dzeta;
658 grad_z[[js, idx]] += force_s * sin_a * dtheta * dzeta;
659 }
660 }
661 }
662 }
663 }
664 }
665
666 let n_interior = ((ns - 2) as f64 * n_angle_pts).max(1.0);
667 global_force = (force_sq_sum / n_interior).sqrt();
668
669 for js in 1..(ns - 1) {
671 for k in 0..nmodes {
672 rmnc[[js, k]] -= config.step_size * grad_r[[js, k]] / n_angle_pts.max(1.0);
673 zmns[[js, k]] -= config.step_size * grad_z[[js, k]] / n_angle_pts.max(1.0);
674 }
675 }
676
677 iter_count = iteration + 1;
678 if global_force < config.tol {
679 converged = true;
680 break;
681 }
682 if !global_force.is_finite() {
683 return Err(FusionError::PhysicsViolation(format!(
684 "VMEC solver diverged at iteration {}: force={}",
685 iteration, global_force
686 )));
687 }
688 }
689
690 let beta_avg = if total_b2_vol > 0.0 {
691 2.0 * mu0 * total_beta_num / total_b2_vol
692 } else {
693 0.0
694 };
695
696 Ok(VmecEquilibrium {
697 rmnc,
698 zmns,
699 iota: Array1::from_vec(iota.to_vec()),
700 pressure: Array1::from_vec(pressure.to_vec()),
701 phi_edge,
702 volume: total_volume.abs(),
703 beta_avg,
704 force_residual: global_force,
705 iterations: iter_count,
706 converged,
707 ns_grid: ns,
708 m_pol,
709 n_tor,
710 nfp,
711 })
712}
713
714pub fn vmec_eval_geometry(
716 eq: &VmecEquilibrium,
717 s_idx: usize,
718 theta: f64,
719 zeta: f64,
720) -> FusionResult<(f64, f64)> {
721 if s_idx >= eq.ns_grid {
722 return Err(FusionError::PhysicsViolation(format!(
723 "Surface index {} >= ns_grid {}",
724 s_idx, eq.ns_grid
725 )));
726 }
727 let nmodes = vmec_n_modes(eq.m_pol, eq.n_tor);
728 let rmnc_row: Vec<f64> = (0..nmodes).map(|k| eq.rmnc[[s_idx, k]]).collect();
729 let zmns_row: Vec<f64> = (0..nmodes).map(|k| eq.zmns[[s_idx, k]]).collect();
730 Ok(eval_surface_point(
731 &rmnc_row, &zmns_row, theta, zeta, eq.m_pol, eq.n_tor, eq.nfp,
732 ))
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 fn sample_state() -> VmecBoundaryState {
740 VmecBoundaryState {
741 r_axis: 6.2,
742 z_axis: 0.0,
743 a_minor: 2.0,
744 kappa: 1.7,
745 triangularity: 0.33,
746 nfp: 1,
747 modes: vec![
748 VmecFourierMode {
749 m: 1,
750 n: 1,
751 r_cos: 0.03,
752 r_sin: -0.01,
753 z_cos: 0.0,
754 z_sin: 0.02,
755 },
756 VmecFourierMode {
757 m: 2,
758 n: 1,
759 r_cos: 0.015,
760 r_sin: 0.0,
761 z_cos: 0.0,
762 z_sin: 0.008,
763 },
764 ],
765 }
766 }
767
768 #[test]
769 fn test_vmec_text_roundtrip() {
770 let state = sample_state();
771 let text = export_vmec_like_text(&state).expect("export must succeed");
772 let parsed = import_vmec_like_text(&text).expect("import must succeed");
773 assert_eq!(parsed.nfp, state.nfp);
774 assert_eq!(parsed.modes.len(), state.modes.len());
775 assert!((parsed.r_axis - state.r_axis).abs() < 1e-12);
776 assert!((parsed.kappa - state.kappa).abs() < 1e-12);
777 }
778
779 #[test]
780 fn test_vmec_missing_key_errors() {
781 let text = "format=vmec_like_v1\nr_axis=6.2\n";
782 let err = import_vmec_like_text(text).expect_err("missing keys should error");
783 match err {
784 FusionError::PhysicsViolation(msg) => assert!(msg.contains("Missing VMEC key")),
785 other => panic!("Unexpected error: {other:?}"),
786 }
787 }
788
789 #[test]
790 fn test_vmec_invalid_minor_radius_errors() {
791 let text = "\
792format=vmec_like_v1
793r_axis=6.2
794z_axis=0.0
795a_minor=0.0
796kappa=1.7
797triangularity=0.2
798nfp=1
799";
800 let err = import_vmec_like_text(text).expect_err("a_minor=0 must fail");
801 match err {
802 FusionError::PhysicsViolation(msg) => assert!(msg.contains("a_minor")),
803 other => panic!("Unexpected error: {other:?}"),
804 }
805 }
806
807 #[test]
808 fn test_vmec_rejects_invalid_modes_and_duplicates() {
809 let mut state = sample_state();
810 state.modes[0].r_cos = f64::NAN;
811 let err = export_vmec_like_text(&state).expect_err("non-finite mode coeff must fail");
812 match err {
813 FusionError::PhysicsViolation(msg) => assert!(msg.contains("mode")),
814 other => panic!("Unexpected error: {other:?}"),
815 }
816
817 let dup_text = "\
818format=vmec_like_v1
819r_axis=6.2
820r_axis=6.3
821z_axis=0.0
822a_minor=2.0
823kappa=1.7
824triangularity=0.2
825nfp=1
826";
827 let err = import_vmec_like_text(dup_text).expect_err("duplicate keys must fail");
828 match err {
829 FusionError::PhysicsViolation(msg) => assert!(msg.contains("Duplicate VMEC key")),
830 other => panic!("Unexpected error: {other:?}"),
831 }
832 }
833
834 #[test]
835 fn test_vmec_rejects_duplicate_modes_and_bad_format() {
836 let mut state = sample_state();
837 state.modes.push(VmecFourierMode {
838 m: state.modes[0].m,
839 n: state.modes[0].n,
840 r_cos: 0.001,
841 r_sin: 0.0,
842 z_cos: 0.0,
843 z_sin: 0.001,
844 });
845 let err = export_vmec_like_text(&state).expect_err("duplicate mode index must fail");
846 match err {
847 FusionError::PhysicsViolation(msg) => assert!(msg.contains("duplicate mode")),
848 other => panic!("Unexpected error: {other:?}"),
849 }
850
851 let bad_format = "\
852format=vmec_like_v2
853r_axis=6.2
854z_axis=0.0
855a_minor=2.0
856kappa=1.7
857triangularity=0.2
858nfp=1
859";
860 let err = import_vmec_like_text(bad_format).expect_err("unsupported format must fail");
861 match err {
862 FusionError::PhysicsViolation(msg) => assert!(msg.contains("Unsupported VMEC format")),
863 other => panic!("Unexpected error: {other:?}"),
864 }
865 }
866
867 #[test]
868 fn test_vmec_rejects_malformed_mode_and_duplicate_format() {
869 let malformed_mode = "\
870format=vmec_like_v1
871r_axis=6.2
872z_axis=0.0
873a_minor=2.0
874kappa=1.7
875triangularity=0.2
876nfp=1
877mode,1,1,0.1,0.0,0.2
878";
879 let err =
880 import_vmec_like_text(malformed_mode).expect_err("mode with wrong arity must fail");
881 match err {
882 FusionError::PhysicsViolation(msg) => assert!(msg.contains("exactly 6 columns")),
883 other => panic!("Unexpected error: {other:?}"),
884 }
885
886 let duplicate_format = "\
887format=vmec_like_v1
888format=vmec_like_v1
889r_axis=6.2
890z_axis=0.0
891a_minor=2.0
892kappa=1.7
893triangularity=0.2
894nfp=1
895";
896 let err =
897 import_vmec_like_text(duplicate_format).expect_err("duplicate format key must fail");
898 match err {
899 FusionError::PhysicsViolation(msg) => {
900 assert!(msg.contains("Duplicate VMEC key: format"))
901 }
902 other => panic!("Unexpected error: {other:?}"),
903 }
904 }
905
906 fn iter_like_boundary() -> VmecBoundaryState {
909 VmecBoundaryState {
910 r_axis: 6.2,
911 z_axis: 0.0,
912 a_minor: 2.0,
913 kappa: 1.7,
914 triangularity: 0.33,
915 nfp: 1,
916 modes: vec![],
917 }
918 }
919
920 #[test]
921 fn test_vmec_mode_indexing() {
922 assert_eq!(vmec_n_modes(6, 0), 7);
923 assert_eq!(vmec_n_modes(3, 2), 4 * 5);
924 assert_eq!(vmec_mode_idx(0, 0, 6, 0), Some(0));
925 assert_eq!(vmec_mode_idx(3, 0, 6, 0), Some(3));
926 assert_eq!(vmec_mode_idx(0, 1, 6, 0), None);
927 assert_eq!(vmec_mode_idx(7, 0, 6, 0), None);
928 assert_eq!(vmec_mode_idx(1, -1, 3, 2), Some(6));
929 }
930
931 #[test]
932 fn test_vmec_eval_surface_axis_is_circle() {
933 let m_pol = 4;
934 let n_tor = 0;
935 let nmodes = vmec_n_modes(m_pol, n_tor);
936 let mut rmnc = vec![0.0; nmodes];
937 let mut zmns = vec![0.0; nmodes];
938 rmnc[0] = 6.2; rmnc[1] = 2.0; zmns[1] = 2.0; let (r, z) = eval_surface_point(&rmnc, &zmns, 0.0, 0.0, m_pol, n_tor, 1);
943 assert!((r - 8.2).abs() < 1e-12, "outboard midplane R");
944 assert!(z.abs() < 1e-12, "midplane Z");
945
946 let (r2, z2) = eval_surface_point(
947 &rmnc,
948 &zmns,
949 std::f64::consts::FRAC_PI_2,
950 0.0,
951 m_pol,
952 n_tor,
953 1,
954 );
955 assert!((r2 - 6.2).abs() < 1e-12, "top R = R_axis");
956 assert!((z2 - 2.0).abs() < 1e-12, "top Z = a_minor");
957 }
958
959 #[test]
960 fn test_vmec_solver_runs_axisymmetric() {
961 let boundary = iter_like_boundary();
962 let config = VmecSolverConfig {
963 m_pol: 4,
964 n_tor: 0,
965 ns: 11,
966 ntheta: 16,
967 nzeta: 1,
968 max_iter: 50,
969 tol: 1e-6,
970 step_size: 1e-3,
971 };
972 let ns = config.ns;
973 let pressure: Vec<f64> = (0..ns)
974 .map(|i| {
975 let s = i as f64 / (ns - 1) as f64;
976 1e5 * (1.0 - s * s)
977 })
978 .collect();
979 let iota: Vec<f64> = (0..ns)
980 .map(|i| {
981 let s = i as f64 / (ns - 1) as f64;
982 0.3 + 0.7 * s * s
983 })
984 .collect();
985
986 let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0)
987 .expect("VMEC solve should succeed");
988
989 assert!(eq.iterations > 0);
990 assert!(eq.force_residual.is_finite());
991 assert!(eq.volume > 0.0, "Volume must be positive");
992 assert_eq!(eq.ns_grid, ns);
993 assert_eq!(eq.m_pol, 4);
994 assert_eq!(eq.nfp, 1);
995 }
996
997 #[test]
998 fn test_vmec_solver_boundary_preserved() {
999 let boundary = iter_like_boundary();
1000 let config = VmecSolverConfig {
1001 m_pol: 4,
1002 n_tor: 0,
1003 ns: 11,
1004 ntheta: 16,
1005 nzeta: 1,
1006 max_iter: 10,
1007 tol: 1e-12,
1008 step_size: 1e-3,
1009 };
1010 let ns = config.ns;
1011 let pressure: Vec<f64> = (0..ns).map(|_| 1e4).collect();
1012 let iota: Vec<f64> = (0..ns).map(|_| 0.5).collect();
1013
1014 let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0).unwrap();
1015
1016 let r10_bnd = eq.rmnc[[ns - 1, 1]];
1018 assert!(
1019 (r10_bnd - 2.0).abs() < 1e-10,
1020 "Boundary R_10 must be preserved: got {r10_bnd}"
1021 );
1022 }
1023
1024 #[test]
1025 fn test_vmec_solver_rejects_invalid_inputs() {
1026 let boundary = iter_like_boundary();
1027 let config = VmecSolverConfig::default();
1028 let ns = config.ns;
1029 let p: Vec<f64> = vec![0.0; ns];
1030 let iota: Vec<f64> = vec![0.5; ns];
1031
1032 assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, -1.0).is_err());
1034 assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, f64::NAN).is_err());
1035
1036 let short_p: Vec<f64> = vec![0.0; 3];
1038 assert!(vmec_fixed_boundary_solve(&boundary, &config, &short_p, &iota, 1.0).is_err());
1039
1040 let bad_cfg = VmecSolverConfig {
1042 ns: 2,
1043 ..Default::default()
1044 };
1045 assert!(bad_cfg.validate().is_err());
1046 }
1047
1048 #[test]
1049 fn test_vmec_eval_geometry_valid() {
1050 let boundary = iter_like_boundary();
1051 let config = VmecSolverConfig {
1052 m_pol: 4,
1053 n_tor: 0,
1054 ns: 7,
1055 ntheta: 16,
1056 nzeta: 1,
1057 max_iter: 5,
1058 tol: 1e-12,
1059 step_size: 1e-3,
1060 };
1061 let ns = config.ns;
1062 let p: Vec<f64> = (0..ns)
1063 .map(|i| 1e4 * (1.0 - (i as f64 / (ns - 1) as f64)))
1064 .collect();
1065 let iota: Vec<f64> = vec![0.5; ns];
1066
1067 let eq = vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, 1.0).unwrap();
1068 let (r, z) = vmec_eval_geometry(&eq, ns - 1, 0.0, 0.0).unwrap();
1069 assert!(r > 0.0 && r.is_finite());
1070 assert!(z.is_finite());
1071
1072 assert!(vmec_eval_geometry(&eq, ns + 5, 0.0, 0.0).is_err());
1074 }
1075
1076 #[test]
1077 fn test_vmec_deriv_theta_consistency() {
1078 let m_pol = 3;
1079 let n_tor = 0;
1080 let _nmodes = vmec_n_modes(m_pol, n_tor);
1081 let rmnc: Vec<f64> = vec![6.2, 2.0, 0.1, 0.05];
1082 let zmns: Vec<f64> = vec![0.0, 1.8, 0.08, 0.03];
1083 let theta = 1.0;
1084 let eps = 1e-7;
1085
1086 let (r_p, z_p) = eval_surface_point(&rmnc, &zmns, theta + eps, 0.0, m_pol, n_tor, 1);
1087 let (r_m, z_m) = eval_surface_point(&rmnc, &zmns, theta - eps, 0.0, m_pol, n_tor, 1);
1088 let dr_num = (r_p - r_m) / (2.0 * eps);
1089 let dz_num = (z_p - z_m) / (2.0 * eps);
1090
1091 let (dr_ana, dz_ana) = eval_surface_deriv_theta(&rmnc, &zmns, theta, 0.0, m_pol, n_tor, 1);
1092
1093 assert!(
1094 (dr_ana - dr_num).abs() < 1e-5,
1095 "dR/dθ mismatch: analytic={dr_ana}, numerical={dr_num}"
1096 );
1097 assert!(
1098 (dz_ana - dz_num).abs() < 1e-5,
1099 "dZ/dθ mismatch: analytic={dz_ana}, numerical={dz_num}"
1100 );
1101 }
1102}