1use fusion_math::amr::{estimate_error_field, AmrHierarchy};
10use fusion_math::sor::sor_solve;
11use fusion_types::error::{FusionError, FusionResult};
12use fusion_types::state::Grid2D;
13use ndarray::Array2;
14
15#[derive(Debug, Clone)]
16pub struct AmrKernelConfig {
17 pub max_levels: usize,
18 pub refinement_threshold: f64,
19 pub omega: f64,
20 pub coarse_iters: usize,
21 pub patch_iters: usize,
22 pub blend: f64,
23}
24
25impl Default for AmrKernelConfig {
26 fn default() -> Self {
27 Self {
28 max_levels: 2,
29 refinement_threshold: 0.1,
30 omega: 1.8,
31 coarse_iters: 400,
32 patch_iters: 300,
33 blend: 0.5,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct AmrKernelSolver {
40 pub config: AmrKernelConfig,
41}
42
43impl AmrKernelSolver {
44 pub fn new(config: AmrKernelConfig) -> FusionResult<Self> {
45 if config.max_levels == 0 {
46 return Err(FusionError::ConfigError(
47 "AMR max_levels must be >= 1".to_string(),
48 ));
49 }
50 if config.refinement_threshold.is_nan() || config.refinement_threshold.is_sign_negative() {
51 return Err(FusionError::ConfigError(
52 "AMR refinement_threshold must be >= 0 (or +inf)".to_string(),
53 ));
54 }
55 if !config.blend.is_finite() || !(0.0..=1.0).contains(&config.blend) {
56 return Err(FusionError::ConfigError(
57 "AMR blend must be finite and in [0, 1]".to_string(),
58 ));
59 }
60 if !config.omega.is_finite() || config.omega <= 0.0 {
61 return Err(FusionError::ConfigError(
62 "AMR omega must be finite and > 0".to_string(),
63 ));
64 }
65 if config.coarse_iters == 0 || config.patch_iters == 0 {
66 return Err(FusionError::ConfigError(
67 "AMR iteration counts must be >= 1".to_string(),
68 ));
69 }
70 Ok(Self { config })
71 }
72
73 fn validate_source_inputs(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<()> {
74 if source.nrows() != base_grid.nz || source.ncols() != base_grid.nr {
75 return Err(FusionError::ConfigError(format!(
76 "AMR source shape mismatch: expected ({}, {}), got ({}, {})",
77 base_grid.nz,
78 base_grid.nr,
79 source.nrows(),
80 source.ncols()
81 )));
82 }
83 if source.iter().any(|v| !v.is_finite()) {
84 return Err(FusionError::ConfigError(
85 "AMR source must contain only finite values".to_string(),
86 ));
87 }
88 Ok(())
89 }
90
91 pub fn solve(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<Array2<f64>> {
92 self.solve_with_hierarchy(base_grid, source)
93 .map(|(psi, _)| psi)
94 }
95
96 pub fn solve_with_hierarchy(
97 &self,
98 base_grid: &Grid2D,
99 source: &Array2<f64>,
100 ) -> FusionResult<(Array2<f64>, AmrHierarchy)> {
101 self.validate_source_inputs(base_grid, source)?;
102 let mut psi = Array2::zeros((base_grid.nz, base_grid.nr));
103 sor_solve(
104 &mut psi,
105 source,
106 base_grid,
107 self.config.omega,
108 self.config.coarse_iters,
109 );
110 if psi.iter().any(|v| !v.is_finite()) {
111 return Err(FusionError::ConfigError(
112 "AMR coarse solve produced non-finite psi values".to_string(),
113 ));
114 }
115
116 let error = estimate_error_field(source, base_grid);
117 let mut hierarchy = AmrHierarchy::new(
118 base_grid.clone(),
119 self.config.max_levels,
120 self.config.refinement_threshold,
121 );
122 hierarchy.refine(&error);
123
124 if hierarchy.patches.is_empty() {
125 return Ok((psi, hierarchy));
126 }
127
128 fn refinement_scale(level: usize) -> usize {
129 let max_shift = usize::BITS as usize - 1;
130 let shift = level.min(max_shift) as u32;
131 1usize << shift
132 }
133
134 for patch in &mut hierarchy.patches {
135 let (iz_lo, _iz_hi, ir_lo, _ir_hi) = patch.bounds;
136 let scale = refinement_scale(patch.level);
137
138 for pz in 0..patch.grid.nz {
139 for pr in 0..patch.grid.nr {
140 let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
141 let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
142 patch.psi[[pz, pr]] = psi[[base_iz, base_ir]];
143 }
144 }
145
146 let patch_source = Array2::from_shape_fn((patch.grid.nz, patch.grid.nr), |(pz, pr)| {
147 let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
148 let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
149 source[[base_iz, base_ir]]
150 });
151
152 sor_solve(
153 &mut patch.psi,
154 &patch_source,
155 &patch.grid,
156 self.config.omega,
157 self.config.patch_iters,
158 );
159 if patch.psi.iter().any(|v| !v.is_finite()) {
160 return Err(FusionError::ConfigError(
161 "AMR patch solve produced non-finite psi values".to_string(),
162 ));
163 }
164 }
165
166 let blend = self.config.blend;
167 for patch in &hierarchy.patches {
168 let scale = refinement_scale(patch.level);
169 for iz in patch.bounds.0..=patch.bounds.1 {
170 for ir in patch.bounds.2..=patch.bounds.3 {
171 let pz = (iz - patch.bounds.0) * scale;
172 let pr = (ir - patch.bounds.2) * scale;
173 if pz >= patch.grid.nz || pr >= patch.grid.nr {
174 continue;
175 }
176 let patch_val = patch.psi[[pz, pr]];
177 psi[[iz, ir]] = (1.0 - blend) * psi[[iz, ir]] + blend * patch_val;
178 }
179 }
180 }
181
182 if psi.iter().any(|v| !v.is_finite()) {
183 return Err(FusionError::ConfigError(
184 "AMR blended solve produced non-finite psi values".to_string(),
185 ));
186 }
187
188 Ok((psi, hierarchy))
189 }
190}
191
192impl Default for AmrKernelSolver {
193 fn default() -> Self {
194 Self::new(AmrKernelConfig::default()).expect("default AMR config must be valid")
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 fn gaussian_source(grid: &Grid2D) -> Array2<f64> {
203 let r0 = grid.r[grid.nr - 1] - 0.08 * (grid.r[grid.nr - 1] - grid.r[0]);
204 let z0 = 0.0;
205 Array2::from_shape_fn((grid.nz, grid.nr), |(iz, ir)| {
206 let dr = grid.rr[[iz, ir]] - r0;
207 let dz = grid.zz[[iz, ir]] - z0;
208 -(-((dr * dr) / 0.02 + (dz * dz) / 0.04)).exp()
209 })
210 }
211
212 fn rel_l2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
213 let num = (a - b).mapv(|v| v * v).sum().sqrt();
214 let den = b.mapv(|v| v * v).sum().sqrt().max(1e-12);
215 num / den
216 }
217
218 #[test]
219 fn test_amr_kernel_runs_with_refinement() {
220 let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
221 let source = gaussian_source(&grid);
222 let solver = AmrKernelSolver::default();
223 let (psi, hierarchy) = solver
224 .solve_with_hierarchy(&grid, &source)
225 .expect("valid AMR solve inputs");
226
227 assert!(psi.iter().all(|v| v.is_finite()));
228 assert!(
229 !hierarchy.patches.is_empty(),
230 "Expected at least one AMR patch for pedestal-weighted source"
231 );
232 }
233
234 #[test]
235 fn test_amr_kernel_no_refinement_matches_coarse() {
236 let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
237 let source = gaussian_source(&grid);
238
239 let coarse_cfg = AmrKernelConfig {
240 refinement_threshold: f64::INFINITY,
241 ..Default::default()
242 };
243 let solver = AmrKernelSolver::new(coarse_cfg.clone()).expect("valid AMR config");
244 let (amr_off, hierarchy) = solver
245 .solve_with_hierarchy(&grid, &source)
246 .expect("valid AMR solve inputs");
247 assert!(hierarchy.patches.is_empty());
248
249 let mut coarse = Array2::zeros((grid.nz, grid.nr));
250 sor_solve(
251 &mut coarse,
252 &source,
253 &grid,
254 coarse_cfg.omega,
255 coarse_cfg.coarse_iters,
256 );
257 let err = rel_l2(&amr_off, &coarse);
258 assert!(err < 1e-12, "No-refinement path should match coarse solve");
259 }
260
261 #[test]
262 fn test_amr_kernel_multilevel_hierarchy_when_enabled() {
263 let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
264 let source = gaussian_source(&grid);
265 let solver = AmrKernelSolver::new(AmrKernelConfig {
266 max_levels: 3,
267 refinement_threshold: 0.05,
268 ..Default::default()
269 })
270 .expect("valid AMR config");
271
272 let (_psi, hierarchy) = solver
273 .solve_with_hierarchy(&grid, &source)
274 .expect("valid AMR solve inputs");
275 assert!(
276 hierarchy.patches.len() >= 2,
277 "Expected multi-level AMR hierarchy for max_levels=3"
278 );
279 assert_eq!(
280 hierarchy.patches.iter().map(|p| p.level).max().unwrap_or(0),
281 2
282 );
283 }
284
285 #[test]
286 fn test_amr_kernel_rejects_invalid_constructor_config() {
287 for bad_blend in [f64::NAN, -0.1, 1.1] {
288 let err = AmrKernelSolver::new(AmrKernelConfig {
289 blend: bad_blend,
290 ..Default::default()
291 })
292 .expect_err("invalid blend must error");
293 match err {
294 FusionError::ConfigError(msg) => {
295 assert!(msg.contains("blend"));
296 }
297 other => panic!("Unexpected error: {other:?}"),
298 }
299 }
300 let err = AmrKernelSolver::new(AmrKernelConfig {
301 omega: 0.0,
302 ..Default::default()
303 })
304 .expect_err("invalid omega must error");
305 match err {
306 FusionError::ConfigError(msg) => {
307 assert!(msg.contains("omega"));
308 }
309 other => panic!("Unexpected error: {other:?}"),
310 }
311
312 let err = AmrKernelSolver::new(AmrKernelConfig {
313 max_levels: 0,
314 ..Default::default()
315 })
316 .expect_err("invalid max_levels must error");
317 match err {
318 FusionError::ConfigError(msg) => {
319 assert!(msg.contains("max_levels"));
320 }
321 other => panic!("Unexpected error: {other:?}"),
322 }
323
324 for bad_threshold in [f64::NAN, -0.01, f64::NEG_INFINITY] {
325 let err = AmrKernelSolver::new(AmrKernelConfig {
326 refinement_threshold: bad_threshold,
327 ..Default::default()
328 })
329 .expect_err("invalid refinement_threshold must error");
330 match err {
331 FusionError::ConfigError(msg) => {
332 assert!(msg.contains("refinement_threshold"));
333 }
334 other => panic!("Unexpected error: {other:?}"),
335 }
336 }
337 }
338
339 #[test]
340 fn test_amr_kernel_rejects_invalid_source_shape_or_values() {
341 let grid = Grid2D::new(17, 17, 1.0, 2.0, -1.0, 1.0);
342 let solver = AmrKernelSolver::default();
343
344 let bad_shape = Array2::zeros((16, 17));
345 assert!(solver.solve_with_hierarchy(&grid, &bad_shape).is_err());
346
347 let mut bad_values = Array2::zeros((17, 17));
348 bad_values[[0, 0]] = f64::NAN;
349 assert!(solver.solve_with_hierarchy(&grid, &bad_values).is_err());
350 }
351}