1use fusion_math::amr::{estimate_error_field, AmrHierarchy};
11use fusion_math::sor::sor_solve;
12use fusion_types::error::{FusionError, FusionResult};
13use fusion_types::state::Grid2D;
14use ndarray::Array2;
15
16#[derive(Debug, Clone)]
17pub struct AmrKernelConfig {
18 pub max_levels: usize,
19 pub refinement_threshold: f64,
20 pub omega: f64,
21 pub coarse_iters: usize,
22 pub patch_iters: usize,
23 pub blend: f64,
24}
25
26impl Default for AmrKernelConfig {
27 fn default() -> Self {
28 Self {
29 max_levels: 2,
30 refinement_threshold: 0.1,
31 omega: 1.8,
32 coarse_iters: 400,
33 patch_iters: 300,
34 blend: 0.5,
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
40pub struct AmrKernelSolver {
41 pub config: AmrKernelConfig,
42}
43
44impl AmrKernelSolver {
45 pub fn new(config: AmrKernelConfig) -> FusionResult<Self> {
46 if config.max_levels == 0 {
47 return Err(FusionError::ConfigError(
48 "AMR max_levels must be >= 1".to_string(),
49 ));
50 }
51 if config.refinement_threshold.is_nan() || config.refinement_threshold.is_sign_negative() {
52 return Err(FusionError::ConfigError(
53 "AMR refinement_threshold must be >= 0 (or +inf)".to_string(),
54 ));
55 }
56 if !config.blend.is_finite() || !(0.0..=1.0).contains(&config.blend) {
57 return Err(FusionError::ConfigError(
58 "AMR blend must be finite and in [0, 1]".to_string(),
59 ));
60 }
61 if !config.omega.is_finite() || config.omega <= 0.0 {
62 return Err(FusionError::ConfigError(
63 "AMR omega must be finite and > 0".to_string(),
64 ));
65 }
66 if config.coarse_iters == 0 || config.patch_iters == 0 {
67 return Err(FusionError::ConfigError(
68 "AMR iteration counts must be >= 1".to_string(),
69 ));
70 }
71 Ok(Self { config })
72 }
73
74 fn validate_source_inputs(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<()> {
75 if source.nrows() != base_grid.nz || source.ncols() != base_grid.nr {
76 return Err(FusionError::ConfigError(format!(
77 "AMR source shape mismatch: expected ({}, {}), got ({}, {})",
78 base_grid.nz,
79 base_grid.nr,
80 source.nrows(),
81 source.ncols()
82 )));
83 }
84 if source.iter().any(|v| !v.is_finite()) {
85 return Err(FusionError::ConfigError(
86 "AMR source must contain only finite values".to_string(),
87 ));
88 }
89 Ok(())
90 }
91
92 pub fn solve(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<Array2<f64>> {
93 self.solve_with_hierarchy(base_grid, source)
94 .map(|(psi, _)| psi)
95 }
96
97 pub fn solve_with_hierarchy(
98 &self,
99 base_grid: &Grid2D,
100 source: &Array2<f64>,
101 ) -> FusionResult<(Array2<f64>, AmrHierarchy)> {
102 self.validate_source_inputs(base_grid, source)?;
103 let mut psi = Array2::zeros((base_grid.nz, base_grid.nr));
104 sor_solve(
105 &mut psi,
106 source,
107 base_grid,
108 self.config.omega,
109 self.config.coarse_iters,
110 );
111 if psi.iter().any(|v| !v.is_finite()) {
112 return Err(FusionError::ConfigError(
113 "AMR coarse solve produced non-finite psi values".to_string(),
114 ));
115 }
116
117 let error = estimate_error_field(source, base_grid);
118 let mut hierarchy = AmrHierarchy::new(
119 base_grid.clone(),
120 self.config.max_levels,
121 self.config.refinement_threshold,
122 );
123 hierarchy.refine(&error);
124
125 if hierarchy.patches.is_empty() {
126 return Ok((psi, hierarchy));
127 }
128
129 fn refinement_scale(level: usize) -> usize {
130 let max_shift = usize::BITS as usize - 1;
131 let shift = level.min(max_shift) as u32;
132 1usize << shift
133 }
134
135 for patch in &mut hierarchy.patches {
136 let (iz_lo, _iz_hi, ir_lo, _ir_hi) = patch.bounds;
137 let scale = refinement_scale(patch.level);
138
139 for pz in 0..patch.grid.nz {
140 for pr in 0..patch.grid.nr {
141 let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
142 let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
143 patch.psi[[pz, pr]] = psi[[base_iz, base_ir]];
144 }
145 }
146
147 let patch_source = Array2::from_shape_fn((patch.grid.nz, patch.grid.nr), |(pz, pr)| {
148 let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
149 let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
150 source[[base_iz, base_ir]]
151 });
152
153 sor_solve(
154 &mut patch.psi,
155 &patch_source,
156 &patch.grid,
157 self.config.omega,
158 self.config.patch_iters,
159 );
160 if patch.psi.iter().any(|v| !v.is_finite()) {
161 return Err(FusionError::ConfigError(
162 "AMR patch solve produced non-finite psi values".to_string(),
163 ));
164 }
165 }
166
167 let blend = self.config.blend;
168 for patch in &hierarchy.patches {
169 let scale = refinement_scale(patch.level);
170 for iz in patch.bounds.0..=patch.bounds.1 {
171 for ir in patch.bounds.2..=patch.bounds.3 {
172 let pz = (iz - patch.bounds.0) * scale;
173 let pr = (ir - patch.bounds.2) * scale;
174 if pz >= patch.grid.nz || pr >= patch.grid.nr {
175 continue;
176 }
177 let patch_val = patch.psi[[pz, pr]];
178 psi[[iz, ir]] = (1.0 - blend) * psi[[iz, ir]] + blend * patch_val;
179 }
180 }
181 }
182
183 if psi.iter().any(|v| !v.is_finite()) {
184 return Err(FusionError::ConfigError(
185 "AMR blended solve produced non-finite psi values".to_string(),
186 ));
187 }
188
189 Ok((psi, hierarchy))
190 }
191}
192
193impl Default for AmrKernelSolver {
194 fn default() -> Self {
195 Self::new(AmrKernelConfig::default()).expect("default AMR config must be valid")
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 fn gaussian_source(grid: &Grid2D) -> Array2<f64> {
204 let r0 = grid.r[grid.nr - 1] - 0.08 * (grid.r[grid.nr - 1] - grid.r[0]);
205 let z0 = 0.0;
206 Array2::from_shape_fn((grid.nz, grid.nr), |(iz, ir)| {
207 let dr = grid.rr[[iz, ir]] - r0;
208 let dz = grid.zz[[iz, ir]] - z0;
209 -(-((dr * dr) / 0.02 + (dz * dz) / 0.04)).exp()
210 })
211 }
212
213 fn rel_l2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
214 let num = (a - b).mapv(|v| v * v).sum().sqrt();
215 let den = b.mapv(|v| v * v).sum().sqrt().max(1e-12);
216 num / den
217 }
218
219 #[test]
220 fn test_amr_kernel_runs_with_refinement() {
221 let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
222 let source = gaussian_source(&grid);
223 let solver = AmrKernelSolver::default();
224 let (psi, hierarchy) = solver
225 .solve_with_hierarchy(&grid, &source)
226 .expect("valid AMR solve inputs");
227
228 assert!(psi.iter().all(|v| v.is_finite()));
229 assert!(
230 !hierarchy.patches.is_empty(),
231 "Expected at least one AMR patch for pedestal-weighted source"
232 );
233 }
234
235 #[test]
236 fn test_amr_kernel_no_refinement_matches_coarse() {
237 let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
238 let source = gaussian_source(&grid);
239
240 let coarse_cfg = AmrKernelConfig {
241 refinement_threshold: f64::INFINITY,
242 ..Default::default()
243 };
244 let solver = AmrKernelSolver::new(coarse_cfg.clone()).expect("valid AMR config");
245 let (amr_off, hierarchy) = solver
246 .solve_with_hierarchy(&grid, &source)
247 .expect("valid AMR solve inputs");
248 assert!(hierarchy.patches.is_empty());
249
250 let mut coarse = Array2::zeros((grid.nz, grid.nr));
251 sor_solve(
252 &mut coarse,
253 &source,
254 &grid,
255 coarse_cfg.omega,
256 coarse_cfg.coarse_iters,
257 );
258 let err = rel_l2(&amr_off, &coarse);
259 assert!(err < 1e-12, "No-refinement path should match coarse solve");
260 }
261
262 #[test]
263 fn test_amr_kernel_multilevel_hierarchy_when_enabled() {
264 let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
265 let source = gaussian_source(&grid);
266 let solver = AmrKernelSolver::new(AmrKernelConfig {
267 max_levels: 3,
268 refinement_threshold: 0.05,
269 ..Default::default()
270 })
271 .expect("valid AMR config");
272
273 let (_psi, hierarchy) = solver
274 .solve_with_hierarchy(&grid, &source)
275 .expect("valid AMR solve inputs");
276 assert!(
277 hierarchy.patches.len() >= 2,
278 "Expected multi-level AMR hierarchy for max_levels=3"
279 );
280 assert_eq!(
281 hierarchy.patches.iter().map(|p| p.level).max().unwrap_or(0),
282 2
283 );
284 }
285
286 #[test]
287 fn test_amr_kernel_rejects_invalid_constructor_config() {
288 for bad_blend in [f64::NAN, -0.1, 1.1] {
289 let err = AmrKernelSolver::new(AmrKernelConfig {
290 blend: bad_blend,
291 ..Default::default()
292 })
293 .expect_err("invalid blend must error");
294 match err {
295 FusionError::ConfigError(msg) => {
296 assert!(msg.contains("blend"));
297 }
298 other => panic!("Unexpected error: {other:?}"),
299 }
300 }
301 let err = AmrKernelSolver::new(AmrKernelConfig {
302 omega: 0.0,
303 ..Default::default()
304 })
305 .expect_err("invalid omega must error");
306 match err {
307 FusionError::ConfigError(msg) => {
308 assert!(msg.contains("omega"));
309 }
310 other => panic!("Unexpected error: {other:?}"),
311 }
312
313 let err = AmrKernelSolver::new(AmrKernelConfig {
314 max_levels: 0,
315 ..Default::default()
316 })
317 .expect_err("invalid max_levels must error");
318 match err {
319 FusionError::ConfigError(msg) => {
320 assert!(msg.contains("max_levels"));
321 }
322 other => panic!("Unexpected error: {other:?}"),
323 }
324
325 for bad_threshold in [f64::NAN, -0.01, f64::NEG_INFINITY] {
326 let err = AmrKernelSolver::new(AmrKernelConfig {
327 refinement_threshold: bad_threshold,
328 ..Default::default()
329 })
330 .expect_err("invalid refinement_threshold must error");
331 match err {
332 FusionError::ConfigError(msg) => {
333 assert!(msg.contains("refinement_threshold"));
334 }
335 other => panic!("Unexpected error: {other:?}"),
336 }
337 }
338 }
339
340 #[test]
341 fn test_amr_kernel_rejects_invalid_source_shape_or_values() {
342 let grid = Grid2D::new(17, 17, 1.0, 2.0, -1.0, 1.0);
343 let solver = AmrKernelSolver::default();
344
345 let bad_shape = Array2::zeros((16, 17));
346 assert!(solver.solve_with_hierarchy(&grid, &bad_shape).is_err());
347
348 let mut bad_values = Array2::zeros((17, 17));
349 bad_values[[0, 0]] = f64::NAN;
350 assert!(solver.solve_with_hierarchy(&grid, &bad_values).is_err());
351 }
352}