1use fusion_types::error::{FusionError, FusionResult};
14use ndarray::{Array1, Array2};
15use std::collections::HashSet;
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct VmecFourierMode {
19 pub m: i32,
20 pub n: i32,
21 pub r_cos: f64,
22 pub r_sin: f64,
23 pub z_cos: f64,
24 pub z_sin: f64,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub struct VmecBoundaryState {
29 pub r_axis: f64,
30 pub z_axis: f64,
31 pub a_minor: f64,
32 pub kappa: f64,
33 pub triangularity: f64,
34 pub nfp: usize,
35 pub modes: Vec<VmecFourierMode>,
36}
37
38impl VmecBoundaryState {
39 pub fn validate(&self) -> FusionResult<()> {
40 if !self.r_axis.is_finite()
41 || !self.z_axis.is_finite()
42 || !self.a_minor.is_finite()
43 || !self.kappa.is_finite()
44 || !self.triangularity.is_finite()
45 {
46 return Err(FusionError::PhysicsViolation(
47 "VMEC boundary contains non-finite scalar".to_string(),
48 ));
49 }
50 if self.a_minor <= 0.0 {
51 return Err(FusionError::PhysicsViolation(
52 "VMEC boundary requires a_minor > 0".to_string(),
53 ));
54 }
55 if self.kappa <= 0.0 {
56 return Err(FusionError::PhysicsViolation(
57 "VMEC boundary requires kappa > 0".to_string(),
58 ));
59 }
60 if self.nfp < 1 {
61 return Err(FusionError::PhysicsViolation(
62 "VMEC boundary requires nfp >= 1".to_string(),
63 ));
64 }
65 let mut seen_modes: HashSet<(i32, i32)> = HashSet::with_capacity(self.modes.len());
66 for (idx, mode) in self.modes.iter().enumerate() {
67 if mode.m < 0 {
68 return Err(FusionError::PhysicsViolation(format!(
69 "VMEC boundary mode[{idx}] requires m >= 0, got {}",
70 mode.m
71 )));
72 }
73 if !mode.r_cos.is_finite()
74 || !mode.r_sin.is_finite()
75 || !mode.z_cos.is_finite()
76 || !mode.z_sin.is_finite()
77 {
78 return Err(FusionError::PhysicsViolation(format!(
79 "VMEC boundary mode[{idx}] contains non-finite coefficients"
80 )));
81 }
82 let key = (mode.m, mode.n);
83 if !seen_modes.insert(key) {
84 return Err(FusionError::PhysicsViolation(format!(
85 "VMEC boundary contains duplicate mode (m={}, n={})",
86 mode.m, mode.n
87 )));
88 }
89 }
90 Ok(())
91 }
92}
93
94pub fn export_vmec_like_text(state: &VmecBoundaryState) -> FusionResult<String> {
95 state.validate()?;
96 let mut out = String::new();
97 out.push_str("format=vmec_like_v1\n");
98 out.push_str(&format!("r_axis={:.16e}\n", state.r_axis));
99 out.push_str(&format!("z_axis={:.16e}\n", state.z_axis));
100 out.push_str(&format!("a_minor={:.16e}\n", state.a_minor));
101 out.push_str(&format!("kappa={:.16e}\n", state.kappa));
102 out.push_str(&format!("triangularity={:.16e}\n", state.triangularity));
103 out.push_str(&format!("nfp={}\n", state.nfp));
104 for mode in &state.modes {
105 out.push_str(&format!(
106 "mode,{},{},{:.16e},{:.16e},{:.16e},{:.16e}\n",
107 mode.m, mode.n, mode.r_cos, mode.r_sin, mode.z_cos, mode.z_sin
108 ));
109 }
110 Ok(out)
111}
112
113fn parse_float(key: &str, text: &str) -> FusionResult<f64> {
114 let val = text.parse::<f64>().map_err(|e| {
115 FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as float: {e}"))
116 })?;
117 if !val.is_finite() {
118 return Err(FusionError::PhysicsViolation(format!(
119 "VMEC key '{key}' must be finite, got {val}"
120 )));
121 }
122 Ok(val)
123}
124
125fn parse_int<T>(key: &str, text: &str) -> FusionResult<T>
126where
127 T: std::str::FromStr,
128 T::Err: std::fmt::Display,
129{
130 text.parse::<T>().map_err(|e| {
131 FusionError::PhysicsViolation(format!("Failed to parse VMEC key '{key}' as integer: {e}"))
132 })
133}
134
135pub fn import_vmec_like_text(text: &str) -> FusionResult<VmecBoundaryState> {
136 let mut format_seen = false;
137 let mut r_axis: Option<f64> = None;
138 let mut z_axis: Option<f64> = None;
139 let mut a_minor: Option<f64> = None;
140 let mut kappa: Option<f64> = None;
141 let mut triangularity: Option<f64> = None;
142 let mut nfp: Option<usize> = None;
143 let mut modes: Vec<VmecFourierMode> = Vec::new();
144
145 for raw_line in text.lines() {
146 let line = raw_line.trim();
147 if line.is_empty() || line.starts_with('#') {
148 continue;
149 }
150 if let Some(format_name) = line.strip_prefix("format=") {
151 if format_seen {
152 return Err(FusionError::PhysicsViolation(
153 "Duplicate VMEC key: format".to_string(),
154 ));
155 }
156 if format_name.trim() != "vmec_like_v1" {
157 return Err(FusionError::PhysicsViolation(format!(
158 "Unsupported VMEC format: {}",
159 format_name.trim()
160 )));
161 }
162 format_seen = true;
163 continue;
164 }
165 if let Some(rest) = line.strip_prefix("mode,") {
166 let cols: Vec<&str> = rest.split(',').map(|v| v.trim()).collect();
167 if cols.len() != 6 {
168 return Err(FusionError::PhysicsViolation(
169 "VMEC mode line must contain exactly 6 columns".to_string(),
170 ));
171 }
172 modes.push(VmecFourierMode {
173 m: parse_int("mode.m", cols[0])?,
174 n: parse_int("mode.n", cols[1])?,
175 r_cos: parse_float("mode.r_cos", cols[2])?,
176 r_sin: parse_float("mode.r_sin", cols[3])?,
177 z_cos: parse_float("mode.z_cos", cols[4])?,
178 z_sin: parse_float("mode.z_sin", cols[5])?,
179 });
180 continue;
181 }
182 let (key, value) = line.split_once('=').ok_or_else(|| {
183 FusionError::PhysicsViolation(format!("Invalid VMEC line (missing '='): {line}"))
184 })?;
185 let key = key.trim();
186 let value = value.trim();
187 match key {
188 "r_axis" => {
189 if r_axis.is_some() {
190 return Err(FusionError::PhysicsViolation(
191 "Duplicate VMEC key: r_axis".to_string(),
192 ));
193 }
194 r_axis = Some(parse_float(key, value)?);
195 }
196 "z_axis" => {
197 if z_axis.is_some() {
198 return Err(FusionError::PhysicsViolation(
199 "Duplicate VMEC key: z_axis".to_string(),
200 ));
201 }
202 z_axis = Some(parse_float(key, value)?);
203 }
204 "a_minor" => {
205 if a_minor.is_some() {
206 return Err(FusionError::PhysicsViolation(
207 "Duplicate VMEC key: a_minor".to_string(),
208 ));
209 }
210 a_minor = Some(parse_float(key, value)?);
211 }
212 "kappa" => {
213 if kappa.is_some() {
214 return Err(FusionError::PhysicsViolation(
215 "Duplicate VMEC key: kappa".to_string(),
216 ));
217 }
218 kappa = Some(parse_float(key, value)?);
219 }
220 "triangularity" => {
221 if triangularity.is_some() {
222 return Err(FusionError::PhysicsViolation(
223 "Duplicate VMEC key: triangularity".to_string(),
224 ));
225 }
226 triangularity = Some(parse_float(key, value)?);
227 }
228 "nfp" => {
229 if nfp.is_some() {
230 return Err(FusionError::PhysicsViolation(
231 "Duplicate VMEC key: nfp".to_string(),
232 ));
233 }
234 nfp = Some(parse_int(key, value)?);
235 }
236 other => {
237 return Err(FusionError::PhysicsViolation(format!(
238 "Unknown VMEC key: {other}"
239 )));
240 }
241 }
242 }
243
244 let state = VmecBoundaryState {
245 r_axis: r_axis
246 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: r_axis".to_string()))?,
247 z_axis: z_axis
248 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: z_axis".to_string()))?,
249 a_minor: a_minor.ok_or_else(|| {
250 FusionError::PhysicsViolation("Missing VMEC key: a_minor".to_string())
251 })?,
252 kappa: kappa
253 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: kappa".to_string()))?,
254 triangularity: triangularity.ok_or_else(|| {
255 FusionError::PhysicsViolation("Missing VMEC key: triangularity".to_string())
256 })?,
257 nfp: nfp
258 .ok_or_else(|| FusionError::PhysicsViolation("Missing VMEC key: nfp".to_string()))?,
259 modes,
260 };
261 state.validate()?;
262 Ok(state)
263}
264
265#[derive(Debug, Clone)]
278pub struct VmecSolverConfig {
279 pub m_pol: usize,
281 pub n_tor: usize,
283 pub ns: usize,
285 pub ntheta: usize,
287 pub nzeta: usize,
289 pub max_iter: usize,
291 pub tol: f64,
293 pub step_size: f64,
295}
296
297impl Default for VmecSolverConfig {
298 fn default() -> Self {
299 Self {
300 m_pol: 6,
301 n_tor: 0,
302 ns: 25,
303 ntheta: 32,
304 nzeta: 1,
305 max_iter: 500,
306 tol: 1e-8,
307 step_size: 5e-3,
308 }
309 }
310}
311
312impl VmecSolverConfig {
313 pub fn validate(&self) -> FusionResult<()> {
314 if self.ns < 3 {
315 return Err(FusionError::PhysicsViolation(
316 "VMEC solver requires ns >= 3".into(),
317 ));
318 }
319 if self.ntheta < 8 {
320 return Err(FusionError::PhysicsViolation(
321 "VMEC solver requires ntheta >= 8".into(),
322 ));
323 }
324 if self.max_iter == 0 {
325 return Err(FusionError::PhysicsViolation(
326 "VMEC solver requires max_iter >= 1".into(),
327 ));
328 }
329 if !self.tol.is_finite() || self.tol <= 0.0 {
330 return Err(FusionError::PhysicsViolation(
331 "VMEC solver tol must be finite and > 0".into(),
332 ));
333 }
334 if !self.step_size.is_finite() || self.step_size <= 0.0 {
335 return Err(FusionError::PhysicsViolation(
336 "VMEC solver step_size must be finite and > 0".into(),
337 ));
338 }
339 Ok(())
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct VmecEquilibrium {
346 pub rmnc: Array2<f64>,
348 pub zmns: Array2<f64>,
350 pub iota: Array1<f64>,
352 pub pressure: Array1<f64>,
354 pub phi_edge: f64,
356 pub volume: f64,
358 pub beta_avg: f64,
360 pub force_residual: f64,
362 pub iterations: usize,
364 pub converged: bool,
366 pub ns_grid: usize,
368 pub m_pol: usize,
369 pub n_tor: usize,
370 pub nfp: usize,
371}
372
373pub fn vmec_n_modes(m_pol: usize, n_tor: usize) -> usize {
375 if n_tor == 0 {
376 m_pol + 1
377 } else {
378 (m_pol + 1) * (2 * n_tor + 1)
379 }
380}
381
382pub fn vmec_mode_idx(m: usize, n: i32, m_pol: usize, n_tor: usize) -> Option<usize> {
384 if m > m_pol {
385 return None;
386 }
387 if n_tor == 0 {
388 if n != 0 {
389 return None;
390 }
391 Some(m)
392 } else {
393 let n_abs = n.unsigned_abs() as usize;
394 if n_abs > n_tor {
395 return None;
396 }
397 Some(m * (2 * n_tor + 1) + (n + n_tor as i32) as usize)
398 }
399}
400
401fn eval_surface_point(
403 rmnc: &[f64],
404 zmns: &[f64],
405 theta: f64,
406 zeta: f64,
407 m_pol: usize,
408 n_tor: usize,
409 nfp: usize,
410) -> (f64, f64) {
411 let mut r = 0.0;
412 let mut z = 0.0;
413 if n_tor == 0 {
414 for m in 0..=m_pol {
415 let angle = m as f64 * theta;
416 let (sin_a, cos_a) = angle.sin_cos();
417 r += rmnc[m] * cos_a;
418 z += zmns[m] * sin_a;
419 }
420 } else {
421 for m in 0..=m_pol {
422 for nn in -(n_tor as i32)..=(n_tor as i32) {
423 let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
424 let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
425 let (sin_a, cos_a) = angle.sin_cos();
426 r += rmnc[idx] * cos_a;
427 z += zmns[idx] * sin_a;
428 }
429 }
430 }
431 (r, z)
432}
433
434fn eval_surface_deriv_theta(
436 rmnc: &[f64],
437 zmns: &[f64],
438 theta: f64,
439 zeta: f64,
440 m_pol: usize,
441 n_tor: usize,
442 nfp: usize,
443) -> (f64, f64) {
444 let mut dr = 0.0;
445 let mut dz = 0.0;
446 if n_tor == 0 {
447 for m in 0..=m_pol {
448 let mf = m as f64;
449 let angle = mf * theta;
450 let (sin_a, cos_a) = angle.sin_cos();
451 dr -= mf * rmnc[m] * sin_a;
452 dz += mf * zmns[m] * cos_a;
453 }
454 } else {
455 for m in 0..=m_pol {
456 let mf = m as f64;
457 for nn in -(n_tor as i32)..=(n_tor as i32) {
458 let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
459 let angle = mf * theta - nn as f64 * nfp as f64 * zeta;
460 let (sin_a, cos_a) = angle.sin_cos();
461 dr -= mf * rmnc[idx] * sin_a;
462 dz += mf * zmns[idx] * cos_a;
463 }
464 }
465 }
466 (dr, dz)
467}
468
469pub fn vmec_fixed_boundary_solve(
478 boundary: &VmecBoundaryState,
479 config: &VmecSolverConfig,
480 pressure: &[f64],
481 iota: &[f64],
482 phi_edge: f64,
483) -> FusionResult<VmecEquilibrium> {
484 boundary.validate()?;
485 config.validate()?;
486
487 let ns = config.ns;
488 let m_pol = config.m_pol;
489 let n_tor = config.n_tor;
490 let nfp = boundary.nfp;
491 let nmodes = vmec_n_modes(m_pol, n_tor);
492 let ntheta = config.ntheta;
493 let nzeta = if n_tor == 0 { 1 } else { config.nzeta.max(4) };
494
495 if pressure.len() != ns || iota.len() != ns {
496 return Err(FusionError::PhysicsViolation(format!(
497 "pressure/iota length must match ns={ns}, got p={} iota={}",
498 pressure.len(),
499 iota.len()
500 )));
501 }
502 if pressure.iter().any(|v| !v.is_finite()) || iota.iter().any(|v| !v.is_finite()) {
503 return Err(FusionError::PhysicsViolation(
504 "pressure/iota profiles must be finite".into(),
505 ));
506 }
507 if !phi_edge.is_finite() || phi_edge <= 0.0 {
508 return Err(FusionError::PhysicsViolation(
509 "phi_edge must be finite and > 0".into(),
510 ));
511 }
512
513 let mut bnd_rmnc = vec![0.0; nmodes];
515 let mut bnd_zmns = vec![0.0; nmodes];
516
517 if let Some(idx) = vmec_mode_idx(0, 0, m_pol, n_tor) {
518 bnd_rmnc[idx] = boundary.r_axis;
519 }
520 if let Some(idx) = vmec_mode_idx(1, 0, m_pol, n_tor) {
521 bnd_rmnc[idx] = boundary.a_minor;
522 bnd_zmns[idx] = boundary.a_minor * boundary.kappa;
523 }
524 if m_pol >= 2 {
525 if let Some(idx) = vmec_mode_idx(2, 0, m_pol, n_tor) {
526 bnd_rmnc[idx] = -boundary.triangularity * boundary.a_minor * 0.5;
527 }
528 }
529 for mode in &boundary.modes {
530 let m = mode.m as usize;
531 if let Some(idx) = vmec_mode_idx(m, mode.n, m_pol, n_tor) {
532 bnd_rmnc[idx] += mode.r_cos;
533 bnd_zmns[idx] += mode.z_sin;
534 }
535 }
536
537 let mut rmnc = Array2::zeros((ns, nmodes));
539 let mut zmns = Array2::zeros((ns, nmodes));
540 let s_grid: Vec<f64> = (0..ns).map(|i| i as f64 / (ns - 1) as f64).collect();
541
542 for js in 0..ns {
543 let s = s_grid[js];
544 for k in 0..nmodes {
545 let axis_r = if k == vmec_mode_idx(0, 0, m_pol, n_tor).unwrap_or(usize::MAX) {
546 boundary.r_axis
547 } else {
548 0.0
549 };
550 rmnc[[js, k]] = axis_r * (1.0 - s) + bnd_rmnc[k] * s;
551 zmns[[js, k]] = bnd_zmns[k] * s;
552 }
553 }
554
555 let pi = std::f64::consts::PI;
556 let mu0 = 4.0e-7 * pi;
557 let ds = 1.0 / (ns - 1) as f64;
558 let dtheta = 2.0 * pi / ntheta as f64;
559 let dzeta = 2.0 * pi / (nzeta as f64 * nfp.max(1) as f64);
560 let n_angle_pts = (ntheta * nzeta) as f64;
561
562 let theta_arr: Vec<f64> = (0..ntheta)
564 .map(|i| 2.0 * pi * i as f64 / ntheta as f64)
565 .collect();
566 let zeta_arr: Vec<f64> = (0..nzeta)
567 .map(|i| 2.0 * pi * i as f64 / (nzeta as f64 * nfp.max(1) as f64))
568 .collect();
569
570 let mut global_force = f64::MAX;
572 let mut converged = false;
573 let mut iter_count = 0;
574 let mut total_volume = 0.0;
575 let mut total_beta_num = 0.0;
576 let mut total_b2_vol = 0.0;
577
578 for iteration in 0..config.max_iter {
579 let mut force_sq_sum = 0.0;
580 total_volume = 0.0;
581 total_beta_num = 0.0;
582 total_b2_vol = 0.0;
583
584 let mut grad_r: Array2<f64> = Array2::zeros((ns, nmodes));
585 let mut grad_z: Array2<f64> = Array2::zeros((ns, nmodes));
586
587 for js in 1..(ns - 1) {
589 let p = pressure[js];
590 let iota_s = iota[js];
591 let dp_ds =
592 (pressure[(js + 1).min(ns - 1)] - pressure[js.saturating_sub(1)]) / (2.0 * ds);
593
594 let rmnc_s: Vec<f64> = (0..nmodes).map(|k| rmnc[[js, k]]).collect();
595 let zmns_s: Vec<f64> = (0..nmodes).map(|k| zmns[[js, k]]).collect();
596 let rmnc_p: Vec<f64> = (0..nmodes).map(|k| rmnc[[js + 1, k]]).collect();
597 let zmns_p: Vec<f64> = (0..nmodes).map(|k| zmns[[js + 1, k]]).collect();
598 let rmnc_m: Vec<f64> = (0..nmodes).map(|k| rmnc[[js - 1, k]]).collect();
599 let zmns_m: Vec<f64> = (0..nmodes).map(|k| zmns[[js - 1, k]]).collect();
600
601 for &theta in theta_arr.iter().take(ntheta) {
602 for &zeta in zeta_arr.iter().take(nzeta) {
603 let (r_val, _z_val) =
604 eval_surface_point(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
605 let (r_p, _) =
606 eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
607 let (r_m, _) =
608 eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
609 let (_, z_p) =
610 eval_surface_point(&rmnc_p, &zmns_p, theta, zeta, m_pol, n_tor, nfp);
611 let (_, z_m) =
612 eval_surface_point(&rmnc_m, &zmns_m, theta, zeta, m_pol, n_tor, nfp);
613
614 let r_s = (r_p - r_m) / (2.0 * ds);
615 let z_s = (z_p - z_m) / (2.0 * ds);
616
617 let (r_theta, z_theta) =
618 eval_surface_deriv_theta(&rmnc_s, &zmns_s, theta, zeta, m_pol, n_tor, nfp);
619
620 let jac = r_val * (r_s * z_theta - r_theta * z_s);
622 let jac_abs = jac.abs().max(1e-20);
623
624 let b_zeta = phi_edge / (2.0 * pi * jac_abs);
626 let b_theta = iota_s * b_zeta;
627
628 let g_tt = r_theta * r_theta + z_theta * z_theta;
630 let g_zz = r_val * r_val;
631 let b_sq = b_theta * b_theta * g_tt + b_zeta * b_zeta * g_zz;
632
633 let force_s = dp_ds + b_sq / (2.0 * mu0 * jac_abs.max(1e-10));
635 force_sq_sum += force_s * force_s;
636
637 let dvol = jac_abs * dtheta * dzeta * ds;
638 total_volume += dvol;
639 total_beta_num += p * dvol;
640 total_b2_vol += b_sq * dvol;
641
642 if n_tor == 0 {
644 for m in 0..=m_pol {
645 let angle = m as f64 * theta;
646 let (sin_a, cos_a) = angle.sin_cos();
647 grad_r[[js, m]] += force_s * cos_a * dtheta * dzeta;
648 grad_z[[js, m]] += force_s * sin_a * dtheta * dzeta;
649 }
650 } else {
651 for m in 0..=m_pol {
652 for nn in -(n_tor as i32)..=(n_tor as i32) {
653 let idx = m * (2 * n_tor + 1) + (nn + n_tor as i32) as usize;
654 let angle = m as f64 * theta - nn as f64 * nfp as f64 * zeta;
655 let (sin_a, cos_a) = angle.sin_cos();
656 grad_r[[js, idx]] += force_s * cos_a * dtheta * dzeta;
657 grad_z[[js, idx]] += force_s * sin_a * dtheta * dzeta;
658 }
659 }
660 }
661 }
662 }
663 }
664
665 let n_interior = ((ns - 2) as f64 * n_angle_pts).max(1.0);
666 global_force = (force_sq_sum / n_interior).sqrt();
667
668 for js in 1..(ns - 1) {
670 for k in 0..nmodes {
671 rmnc[[js, k]] -= config.step_size * grad_r[[js, k]] / n_angle_pts.max(1.0);
672 zmns[[js, k]] -= config.step_size * grad_z[[js, k]] / n_angle_pts.max(1.0);
673 }
674 }
675
676 iter_count = iteration + 1;
677 if global_force < config.tol {
678 converged = true;
679 break;
680 }
681 if !global_force.is_finite() {
682 return Err(FusionError::PhysicsViolation(format!(
683 "VMEC solver diverged at iteration {}: force={}",
684 iteration, global_force
685 )));
686 }
687 }
688
689 let beta_avg = if total_b2_vol > 0.0 {
690 2.0 * mu0 * total_beta_num / total_b2_vol
691 } else {
692 0.0
693 };
694
695 Ok(VmecEquilibrium {
696 rmnc,
697 zmns,
698 iota: Array1::from_vec(iota.to_vec()),
699 pressure: Array1::from_vec(pressure.to_vec()),
700 phi_edge,
701 volume: total_volume.abs(),
702 beta_avg,
703 force_residual: global_force,
704 iterations: iter_count,
705 converged,
706 ns_grid: ns,
707 m_pol,
708 n_tor,
709 nfp,
710 })
711}
712
713pub fn vmec_eval_geometry(
715 eq: &VmecEquilibrium,
716 s_idx: usize,
717 theta: f64,
718 zeta: f64,
719) -> FusionResult<(f64, f64)> {
720 if s_idx >= eq.ns_grid {
721 return Err(FusionError::PhysicsViolation(format!(
722 "Surface index {} >= ns_grid {}",
723 s_idx, eq.ns_grid
724 )));
725 }
726 let nmodes = vmec_n_modes(eq.m_pol, eq.n_tor);
727 let rmnc_row: Vec<f64> = (0..nmodes).map(|k| eq.rmnc[[s_idx, k]]).collect();
728 let zmns_row: Vec<f64> = (0..nmodes).map(|k| eq.zmns[[s_idx, k]]).collect();
729 Ok(eval_surface_point(
730 &rmnc_row, &zmns_row, theta, zeta, eq.m_pol, eq.n_tor, eq.nfp,
731 ))
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 fn sample_state() -> VmecBoundaryState {
739 VmecBoundaryState {
740 r_axis: 6.2,
741 z_axis: 0.0,
742 a_minor: 2.0,
743 kappa: 1.7,
744 triangularity: 0.33,
745 nfp: 1,
746 modes: vec![
747 VmecFourierMode {
748 m: 1,
749 n: 1,
750 r_cos: 0.03,
751 r_sin: -0.01,
752 z_cos: 0.0,
753 z_sin: 0.02,
754 },
755 VmecFourierMode {
756 m: 2,
757 n: 1,
758 r_cos: 0.015,
759 r_sin: 0.0,
760 z_cos: 0.0,
761 z_sin: 0.008,
762 },
763 ],
764 }
765 }
766
767 #[test]
768 fn test_vmec_text_roundtrip() {
769 let state = sample_state();
770 let text = export_vmec_like_text(&state).expect("export must succeed");
771 let parsed = import_vmec_like_text(&text).expect("import must succeed");
772 assert_eq!(parsed.nfp, state.nfp);
773 assert_eq!(parsed.modes.len(), state.modes.len());
774 assert!((parsed.r_axis - state.r_axis).abs() < 1e-12);
775 assert!((parsed.kappa - state.kappa).abs() < 1e-12);
776 }
777
778 #[test]
779 fn test_vmec_missing_key_errors() {
780 let text = "format=vmec_like_v1\nr_axis=6.2\n";
781 let err = import_vmec_like_text(text).expect_err("missing keys should error");
782 match err {
783 FusionError::PhysicsViolation(msg) => assert!(msg.contains("Missing VMEC key")),
784 other => panic!("Unexpected error: {other:?}"),
785 }
786 }
787
788 #[test]
789 fn test_vmec_invalid_minor_radius_errors() {
790 let text = "\
791format=vmec_like_v1
792r_axis=6.2
793z_axis=0.0
794a_minor=0.0
795kappa=1.7
796triangularity=0.2
797nfp=1
798";
799 let err = import_vmec_like_text(text).expect_err("a_minor=0 must fail");
800 match err {
801 FusionError::PhysicsViolation(msg) => assert!(msg.contains("a_minor")),
802 other => panic!("Unexpected error: {other:?}"),
803 }
804 }
805
806 #[test]
807 fn test_vmec_rejects_invalid_modes_and_duplicates() {
808 let mut state = sample_state();
809 state.modes[0].r_cos = f64::NAN;
810 let err = export_vmec_like_text(&state).expect_err("non-finite mode coeff must fail");
811 match err {
812 FusionError::PhysicsViolation(msg) => assert!(msg.contains("mode")),
813 other => panic!("Unexpected error: {other:?}"),
814 }
815
816 let dup_text = "\
817format=vmec_like_v1
818r_axis=6.2
819r_axis=6.3
820z_axis=0.0
821a_minor=2.0
822kappa=1.7
823triangularity=0.2
824nfp=1
825";
826 let err = import_vmec_like_text(dup_text).expect_err("duplicate keys must fail");
827 match err {
828 FusionError::PhysicsViolation(msg) => assert!(msg.contains("Duplicate VMEC key")),
829 other => panic!("Unexpected error: {other:?}"),
830 }
831 }
832
833 #[test]
834 fn test_vmec_rejects_duplicate_modes_and_bad_format() {
835 let mut state = sample_state();
836 state.modes.push(VmecFourierMode {
837 m: state.modes[0].m,
838 n: state.modes[0].n,
839 r_cos: 0.001,
840 r_sin: 0.0,
841 z_cos: 0.0,
842 z_sin: 0.001,
843 });
844 let err = export_vmec_like_text(&state).expect_err("duplicate mode index must fail");
845 match err {
846 FusionError::PhysicsViolation(msg) => assert!(msg.contains("duplicate mode")),
847 other => panic!("Unexpected error: {other:?}"),
848 }
849
850 let bad_format = "\
851format=vmec_like_v2
852r_axis=6.2
853z_axis=0.0
854a_minor=2.0
855kappa=1.7
856triangularity=0.2
857nfp=1
858";
859 let err = import_vmec_like_text(bad_format).expect_err("unsupported format must fail");
860 match err {
861 FusionError::PhysicsViolation(msg) => assert!(msg.contains("Unsupported VMEC format")),
862 other => panic!("Unexpected error: {other:?}"),
863 }
864 }
865
866 #[test]
867 fn test_vmec_rejects_malformed_mode_and_duplicate_format() {
868 let malformed_mode = "\
869format=vmec_like_v1
870r_axis=6.2
871z_axis=0.0
872a_minor=2.0
873kappa=1.7
874triangularity=0.2
875nfp=1
876mode,1,1,0.1,0.0,0.2
877";
878 let err =
879 import_vmec_like_text(malformed_mode).expect_err("mode with wrong arity must fail");
880 match err {
881 FusionError::PhysicsViolation(msg) => assert!(msg.contains("exactly 6 columns")),
882 other => panic!("Unexpected error: {other:?}"),
883 }
884
885 let duplicate_format = "\
886format=vmec_like_v1
887format=vmec_like_v1
888r_axis=6.2
889z_axis=0.0
890a_minor=2.0
891kappa=1.7
892triangularity=0.2
893nfp=1
894";
895 let err =
896 import_vmec_like_text(duplicate_format).expect_err("duplicate format key must fail");
897 match err {
898 FusionError::PhysicsViolation(msg) => {
899 assert!(msg.contains("Duplicate VMEC key: format"))
900 }
901 other => panic!("Unexpected error: {other:?}"),
902 }
903 }
904
905 fn iter_like_boundary() -> VmecBoundaryState {
908 VmecBoundaryState {
909 r_axis: 6.2,
910 z_axis: 0.0,
911 a_minor: 2.0,
912 kappa: 1.7,
913 triangularity: 0.33,
914 nfp: 1,
915 modes: vec![],
916 }
917 }
918
919 #[test]
920 fn test_vmec_mode_indexing() {
921 assert_eq!(vmec_n_modes(6, 0), 7);
922 assert_eq!(vmec_n_modes(3, 2), 4 * 5);
923 assert_eq!(vmec_mode_idx(0, 0, 6, 0), Some(0));
924 assert_eq!(vmec_mode_idx(3, 0, 6, 0), Some(3));
925 assert_eq!(vmec_mode_idx(0, 1, 6, 0), None);
926 assert_eq!(vmec_mode_idx(7, 0, 6, 0), None);
927 assert_eq!(vmec_mode_idx(1, -1, 3, 2), Some(6));
928 }
929
930 #[test]
931 fn test_vmec_eval_surface_axis_is_circle() {
932 let m_pol = 4;
933 let n_tor = 0;
934 let nmodes = vmec_n_modes(m_pol, n_tor);
935 let mut rmnc = vec![0.0; nmodes];
936 let mut zmns = vec![0.0; nmodes];
937 rmnc[0] = 6.2; rmnc[1] = 2.0; zmns[1] = 2.0; let (r, z) = eval_surface_point(&rmnc, &zmns, 0.0, 0.0, m_pol, n_tor, 1);
942 assert!((r - 8.2).abs() < 1e-12, "outboard midplane R");
943 assert!(z.abs() < 1e-12, "midplane Z");
944
945 let (r2, z2) = eval_surface_point(
946 &rmnc,
947 &zmns,
948 std::f64::consts::FRAC_PI_2,
949 0.0,
950 m_pol,
951 n_tor,
952 1,
953 );
954 assert!((r2 - 6.2).abs() < 1e-12, "top R = R_axis");
955 assert!((z2 - 2.0).abs() < 1e-12, "top Z = a_minor");
956 }
957
958 #[test]
959 fn test_vmec_solver_runs_axisymmetric() {
960 let boundary = iter_like_boundary();
961 let config = VmecSolverConfig {
962 m_pol: 4,
963 n_tor: 0,
964 ns: 11,
965 ntheta: 16,
966 nzeta: 1,
967 max_iter: 50,
968 tol: 1e-6,
969 step_size: 1e-3,
970 };
971 let ns = config.ns;
972 let pressure: Vec<f64> = (0..ns)
973 .map(|i| {
974 let s = i as f64 / (ns - 1) as f64;
975 1e5 * (1.0 - s * s)
976 })
977 .collect();
978 let iota: Vec<f64> = (0..ns)
979 .map(|i| {
980 let s = i as f64 / (ns - 1) as f64;
981 0.3 + 0.7 * s * s
982 })
983 .collect();
984
985 let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0)
986 .expect("VMEC solve should succeed");
987
988 assert!(eq.iterations > 0);
989 assert!(eq.force_residual.is_finite());
990 assert!(eq.volume > 0.0, "Volume must be positive");
991 assert_eq!(eq.ns_grid, ns);
992 assert_eq!(eq.m_pol, 4);
993 assert_eq!(eq.nfp, 1);
994 }
995
996 #[test]
997 fn test_vmec_solver_boundary_preserved() {
998 let boundary = iter_like_boundary();
999 let config = VmecSolverConfig {
1000 m_pol: 4,
1001 n_tor: 0,
1002 ns: 11,
1003 ntheta: 16,
1004 nzeta: 1,
1005 max_iter: 10,
1006 tol: 1e-12,
1007 step_size: 1e-3,
1008 };
1009 let ns = config.ns;
1010 let pressure: Vec<f64> = (0..ns).map(|_| 1e4).collect();
1011 let iota: Vec<f64> = (0..ns).map(|_| 0.5).collect();
1012
1013 let eq = vmec_fixed_boundary_solve(&boundary, &config, &pressure, &iota, 1.0).unwrap();
1014
1015 let r10_bnd = eq.rmnc[[ns - 1, 1]];
1017 assert!(
1018 (r10_bnd - 2.0).abs() < 1e-10,
1019 "Boundary R_10 must be preserved: got {r10_bnd}"
1020 );
1021 }
1022
1023 #[test]
1024 fn test_vmec_solver_rejects_invalid_inputs() {
1025 let boundary = iter_like_boundary();
1026 let config = VmecSolverConfig::default();
1027 let ns = config.ns;
1028 let p: Vec<f64> = vec![0.0; ns];
1029 let iota: Vec<f64> = vec![0.5; ns];
1030
1031 assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, -1.0).is_err());
1033 assert!(vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, f64::NAN).is_err());
1034
1035 let short_p: Vec<f64> = vec![0.0; 3];
1037 assert!(vmec_fixed_boundary_solve(&boundary, &config, &short_p, &iota, 1.0).is_err());
1038
1039 let bad_cfg = VmecSolverConfig {
1041 ns: 2,
1042 ..Default::default()
1043 };
1044 assert!(bad_cfg.validate().is_err());
1045 }
1046
1047 #[test]
1048 fn test_vmec_eval_geometry_valid() {
1049 let boundary = iter_like_boundary();
1050 let config = VmecSolverConfig {
1051 m_pol: 4,
1052 n_tor: 0,
1053 ns: 7,
1054 ntheta: 16,
1055 nzeta: 1,
1056 max_iter: 5,
1057 tol: 1e-12,
1058 step_size: 1e-3,
1059 };
1060 let ns = config.ns;
1061 let p: Vec<f64> = (0..ns)
1062 .map(|i| 1e4 * (1.0 - (i as f64 / (ns - 1) as f64)))
1063 .collect();
1064 let iota: Vec<f64> = vec![0.5; ns];
1065
1066 let eq = vmec_fixed_boundary_solve(&boundary, &config, &p, &iota, 1.0).unwrap();
1067 let (r, z) = vmec_eval_geometry(&eq, ns - 1, 0.0, 0.0).unwrap();
1068 assert!(r > 0.0 && r.is_finite());
1069 assert!(z.is_finite());
1070
1071 assert!(vmec_eval_geometry(&eq, ns + 5, 0.0, 0.0).is_err());
1073 }
1074
1075 #[test]
1076 fn test_vmec_deriv_theta_consistency() {
1077 let m_pol = 3;
1078 let n_tor = 0;
1079 let _nmodes = vmec_n_modes(m_pol, n_tor);
1080 let rmnc: Vec<f64> = vec![6.2, 2.0, 0.1, 0.05];
1081 let zmns: Vec<f64> = vec![0.0, 1.8, 0.08, 0.03];
1082 let theta = 1.0;
1083 let eps = 1e-7;
1084
1085 let (r_p, z_p) = eval_surface_point(&rmnc, &zmns, theta + eps, 0.0, m_pol, n_tor, 1);
1086 let (r_m, z_m) = eval_surface_point(&rmnc, &zmns, theta - eps, 0.0, m_pol, n_tor, 1);
1087 let dr_num = (r_p - r_m) / (2.0 * eps);
1088 let dz_num = (z_p - z_m) / (2.0 * eps);
1089
1090 let (dr_ana, dz_ana) = eval_surface_deriv_theta(&rmnc, &zmns, theta, 0.0, m_pol, n_tor, 1);
1091
1092 assert!(
1093 (dr_ana - dr_num).abs() < 1e-5,
1094 "dR/dθ mismatch: analytic={dr_ana}, numerical={dr_num}"
1095 );
1096 assert!(
1097 (dz_ana - dz_num).abs() < 1e-5,
1098 "dZ/dθ mismatch: analytic={dz_ana}, numerical={dz_num}"
1099 );
1100 }
1101}