1use fusion_math::interp::gradient_2d;
15use fusion_types::error::{FusionError, FusionResult};
16use fusion_types::state::Grid2D;
17use ndarray::Array2;
18
19const R_SAFE_MIN: f64 = 1e-6;
21
22pub fn compute_b_field(
34 psi: &Array2<f64>,
35 grid: &Grid2D,
36) -> FusionResult<(Array2<f64>, Array2<f64>)> {
37 if grid.nz < 2 || grid.nr < 2 {
38 return Err(FusionError::ConfigError(format!(
39 "b-field grid requires nz,nr >= 2, got nz={} nr={}",
40 grid.nz, grid.nr
41 )));
42 }
43 if !grid.dr.is_finite()
44 || !grid.dz.is_finite()
45 || grid.dr.abs() <= f64::EPSILON
46 || grid.dz.abs() <= f64::EPSILON
47 {
48 return Err(FusionError::ConfigError(format!(
49 "b-field grid spacing must be finite and non-zero, got dr={} dz={}",
50 grid.dr, grid.dz
51 )));
52 }
53 if psi.nrows() != grid.nz || psi.ncols() != grid.nr {
54 return Err(FusionError::ConfigError(format!(
55 "b-field psi shape mismatch: expected ({}, {}), got ({}, {})",
56 grid.nz,
57 grid.nr,
58 psi.nrows(),
59 psi.ncols()
60 )));
61 }
62 if grid.rr.nrows() != grid.nz || grid.rr.ncols() != grid.nr {
63 return Err(FusionError::ConfigError(format!(
64 "b-field grid.rr shape mismatch: expected ({}, {}), got ({}, {})",
65 grid.nz,
66 grid.nr,
67 grid.rr.nrows(),
68 grid.rr.ncols()
69 )));
70 }
71 if grid.zz.nrows() != grid.nz || grid.zz.ncols() != grid.nr {
72 return Err(FusionError::ConfigError(format!(
73 "b-field grid.zz shape mismatch: expected ({}, {}), got ({}, {})",
74 grid.nz,
75 grid.nr,
76 grid.zz.nrows(),
77 grid.zz.ncols()
78 )));
79 }
80 if psi.iter().any(|v| !v.is_finite()) {
81 return Err(FusionError::ConfigError(
82 "b-field psi contains non-finite values".to_string(),
83 ));
84 }
85
86 let (dpsi_dz, dpsi_dr) = gradient_2d(psi, grid);
87
88 let nz = grid.nz;
89 let nr = grid.nr;
90 let mut b_r = Array2::zeros((nz, nr));
91 let mut b_z = Array2::zeros((nz, nr));
92
93 for iz in 0..nz {
94 for ir in 0..nr {
95 let r = grid.rr[[iz, ir]];
96 if !r.is_finite() || r <= 0.0 {
97 return Err(FusionError::ConfigError(format!(
98 "b-field radius must be finite and > 0 at ({iz}, {ir}), got {r}"
99 )));
100 }
101 let inv_r = 1.0 / r.max(R_SAFE_MIN);
102 let br = -inv_r * dpsi_dz[[iz, ir]];
103 let bz = inv_r * dpsi_dr[[iz, ir]];
104 if !br.is_finite() || !bz.is_finite() {
105 return Err(FusionError::ConfigError(format!(
106 "b-field output became non-finite at ({iz}, {ir})"
107 )));
108 }
109 b_r[[iz, ir]] = br;
110 b_z[[iz, ir]] = bz;
111 }
112 }
113
114 Ok((b_r, b_z))
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_b_field_uniform_psi() {
123 let grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
125 let psi = Array2::from_elem((16, 16), 1.0);
126 let (b_r, b_z) = compute_b_field(&psi, &grid).expect("valid b-field inputs");
127
128 for iz in 0..16 {
129 for ir in 0..16 {
130 assert!(
131 b_r[[iz, ir]].abs() < 1e-10,
132 "B_R should be zero for uniform Ψ"
133 );
134 assert!(
135 b_z[[iz, ir]].abs() < 1e-10,
136 "B_Z should be zero for uniform Ψ"
137 );
138 }
139 }
140 }
141
142 #[test]
143 fn test_b_field_no_nan() {
144 let grid = Grid2D::new(33, 33, 1.0, 9.0, -5.0, 5.0);
145 let psi = Array2::from_shape_fn((33, 33), |(iz, ir)| {
147 let r = grid.rr[[iz, ir]];
148 let z = grid.zz[[iz, ir]];
149 (-(((r - 5.0).powi(2) + z.powi(2)) / 4.0)).exp()
150 });
151 let (b_r, b_z) = compute_b_field(&psi, &grid).expect("valid b-field inputs");
152 assert!(!b_r.iter().any(|v| v.is_nan()), "B_R contains NaN");
153 assert!(!b_z.iter().any(|v| v.is_nan()), "B_Z contains NaN");
154 }
155
156 #[test]
157 fn test_b_field_rejects_invalid_runtime_inputs() {
158 let mut grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
159 let psi = Array2::from_elem((16, 16), 1.0);
160
161 grid.rr[[0, 0]] = 0.0;
162 let err = compute_b_field(&psi, &grid).expect_err("non-positive radius must fail");
163 assert!(matches!(err, FusionError::ConfigError(_)));
164
165 let bad_shape = Array2::from_elem((15, 16), 0.0);
166 let err = compute_b_field(&bad_shape, &Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0))
167 .expect_err("shape mismatch must fail");
168 assert!(matches!(err, FusionError::ConfigError(_)));
169 }
170}