1use fusion_math::interp::gradient_2d;
16use fusion_types::error::{FusionError, FusionResult};
17use fusion_types::state::Grid2D;
18use ndarray::Array2;
19
20const R_SAFE_MIN: f64 = 1e-6;
22
23pub fn compute_b_field(
35 psi: &Array2<f64>,
36 grid: &Grid2D,
37) -> FusionResult<(Array2<f64>, Array2<f64>)> {
38 if grid.nz < 2 || grid.nr < 2 {
39 return Err(FusionError::ConfigError(format!(
40 "b-field grid requires nz,nr >= 2, got nz={} nr={}",
41 grid.nz, grid.nr
42 )));
43 }
44 if !grid.dr.is_finite()
45 || !grid.dz.is_finite()
46 || grid.dr.abs() <= f64::EPSILON
47 || grid.dz.abs() <= f64::EPSILON
48 {
49 return Err(FusionError::ConfigError(format!(
50 "b-field grid spacing must be finite and non-zero, got dr={} dz={}",
51 grid.dr, grid.dz
52 )));
53 }
54 if psi.nrows() != grid.nz || psi.ncols() != grid.nr {
55 return Err(FusionError::ConfigError(format!(
56 "b-field psi shape mismatch: expected ({}, {}), got ({}, {})",
57 grid.nz,
58 grid.nr,
59 psi.nrows(),
60 psi.ncols()
61 )));
62 }
63 if grid.rr.nrows() != grid.nz || grid.rr.ncols() != grid.nr {
64 return Err(FusionError::ConfigError(format!(
65 "b-field grid.rr shape mismatch: expected ({}, {}), got ({}, {})",
66 grid.nz,
67 grid.nr,
68 grid.rr.nrows(),
69 grid.rr.ncols()
70 )));
71 }
72 if grid.zz.nrows() != grid.nz || grid.zz.ncols() != grid.nr {
73 return Err(FusionError::ConfigError(format!(
74 "b-field grid.zz shape mismatch: expected ({}, {}), got ({}, {})",
75 grid.nz,
76 grid.nr,
77 grid.zz.nrows(),
78 grid.zz.ncols()
79 )));
80 }
81 if psi.iter().any(|v| !v.is_finite()) {
82 return Err(FusionError::ConfigError(
83 "b-field psi contains non-finite values".to_string(),
84 ));
85 }
86
87 let (dpsi_dz, dpsi_dr) = gradient_2d(psi, grid);
88
89 let nz = grid.nz;
90 let nr = grid.nr;
91 let mut b_r = Array2::zeros((nz, nr));
92 let mut b_z = Array2::zeros((nz, nr));
93
94 for iz in 0..nz {
95 for ir in 0..nr {
96 let r = grid.rr[[iz, ir]];
97 if !r.is_finite() || r <= 0.0 {
98 return Err(FusionError::ConfigError(format!(
99 "b-field radius must be finite and > 0 at ({iz}, {ir}), got {r}"
100 )));
101 }
102 let inv_r = 1.0 / r.max(R_SAFE_MIN);
103 let br = -inv_r * dpsi_dz[[iz, ir]];
104 let bz = inv_r * dpsi_dr[[iz, ir]];
105 if !br.is_finite() || !bz.is_finite() {
106 return Err(FusionError::ConfigError(format!(
107 "b-field output became non-finite at ({iz}, {ir})"
108 )));
109 }
110 b_r[[iz, ir]] = br;
111 b_z[[iz, ir]] = bz;
112 }
113 }
114
115 Ok((b_r, b_z))
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_b_field_uniform_psi() {
124 let grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
126 let psi = Array2::from_elem((16, 16), 1.0);
127 let (b_r, b_z) = compute_b_field(&psi, &grid).expect("valid b-field inputs");
128
129 for iz in 0..16 {
130 for ir in 0..16 {
131 assert!(
132 b_r[[iz, ir]].abs() < 1e-10,
133 "B_R should be zero for uniform Ψ"
134 );
135 assert!(
136 b_z[[iz, ir]].abs() < 1e-10,
137 "B_Z should be zero for uniform Ψ"
138 );
139 }
140 }
141 }
142
143 #[test]
144 fn test_b_field_no_nan() {
145 let grid = Grid2D::new(33, 33, 1.0, 9.0, -5.0, 5.0);
146 let psi = Array2::from_shape_fn((33, 33), |(iz, ir)| {
148 let r = grid.rr[[iz, ir]];
149 let z = grid.zz[[iz, ir]];
150 (-(((r - 5.0).powi(2) + z.powi(2)) / 4.0)).exp()
151 });
152 let (b_r, b_z) = compute_b_field(&psi, &grid).expect("valid b-field inputs");
153 assert!(!b_r.iter().any(|v| v.is_nan()), "B_R contains NaN");
154 assert!(!b_z.iter().any(|v| v.is_nan()), "B_Z contains NaN");
155 }
156
157 #[test]
158 fn test_b_field_rejects_invalid_runtime_inputs() {
159 let mut grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
160 let psi = Array2::from_elem((16, 16), 1.0);
161
162 grid.rr[[0, 0]] = 0.0;
163 let err = compute_b_field(&psi, &grid).expect_err("non-positive radius must fail");
164 assert!(matches!(err, FusionError::ConfigError(_)));
165
166 let bad_shape = Array2::from_elem((15, 16), 0.0);
167 let err = compute_b_field(&bad_shape, &Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0))
168 .expect_err("shape mismatch must fail");
169 assert!(matches!(err, FusionError::ConfigError(_)));
170 }
171}