fusion_core/
amr_kernel.rs

1// ─────────────────────────────────────────────────────────────────────
2// SCPN Fusion Core — AMR-Aware Kernel Solver
3// © 1998–2026 Miroslav Šotek. All rights reserved.
4// Contact: www.anulum.li | protoscience@anulum.li
5// ORCID: https://orcid.org/0009-0009-3560-0851
6// License: GNU AGPL v3 | Commercial licensing available
7// ─────────────────────────────────────────────────────────────────────
8//! AMR-assisted equilibrium solve wrapper.
9
10use fusion_math::amr::{estimate_error_field, AmrHierarchy};
11use fusion_math::sor::sor_solve;
12use fusion_types::error::{FusionError, FusionResult};
13use fusion_types::state::Grid2D;
14use ndarray::Array2;
15
16#[derive(Debug, Clone)]
17pub struct AmrKernelConfig {
18    pub max_levels: usize,
19    pub refinement_threshold: f64,
20    pub omega: f64,
21    pub coarse_iters: usize,
22    pub patch_iters: usize,
23    pub blend: f64,
24}
25
26impl Default for AmrKernelConfig {
27    fn default() -> Self {
28        Self {
29            max_levels: 2,
30            refinement_threshold: 0.1,
31            omega: 1.8,
32            coarse_iters: 400,
33            patch_iters: 300,
34            blend: 0.5,
35        }
36    }
37}
38
39#[derive(Debug, Clone)]
40pub struct AmrKernelSolver {
41    pub config: AmrKernelConfig,
42}
43
44impl AmrKernelSolver {
45    pub fn new(config: AmrKernelConfig) -> FusionResult<Self> {
46        if config.max_levels == 0 {
47            return Err(FusionError::ConfigError(
48                "AMR max_levels must be >= 1".to_string(),
49            ));
50        }
51        if config.refinement_threshold.is_nan() || config.refinement_threshold.is_sign_negative() {
52            return Err(FusionError::ConfigError(
53                "AMR refinement_threshold must be >= 0 (or +inf)".to_string(),
54            ));
55        }
56        if !config.blend.is_finite() || !(0.0..=1.0).contains(&config.blend) {
57            return Err(FusionError::ConfigError(
58                "AMR blend must be finite and in [0, 1]".to_string(),
59            ));
60        }
61        if !config.omega.is_finite() || config.omega <= 0.0 {
62            return Err(FusionError::ConfigError(
63                "AMR omega must be finite and > 0".to_string(),
64            ));
65        }
66        if config.coarse_iters == 0 || config.patch_iters == 0 {
67            return Err(FusionError::ConfigError(
68                "AMR iteration counts must be >= 1".to_string(),
69            ));
70        }
71        Ok(Self { config })
72    }
73
74    fn validate_source_inputs(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<()> {
75        if source.nrows() != base_grid.nz || source.ncols() != base_grid.nr {
76            return Err(FusionError::ConfigError(format!(
77                "AMR source shape mismatch: expected ({}, {}), got ({}, {})",
78                base_grid.nz,
79                base_grid.nr,
80                source.nrows(),
81                source.ncols()
82            )));
83        }
84        if source.iter().any(|v| !v.is_finite()) {
85            return Err(FusionError::ConfigError(
86                "AMR source must contain only finite values".to_string(),
87            ));
88        }
89        Ok(())
90    }
91
92    pub fn solve(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<Array2<f64>> {
93        self.solve_with_hierarchy(base_grid, source)
94            .map(|(psi, _)| psi)
95    }
96
97    pub fn solve_with_hierarchy(
98        &self,
99        base_grid: &Grid2D,
100        source: &Array2<f64>,
101    ) -> FusionResult<(Array2<f64>, AmrHierarchy)> {
102        self.validate_source_inputs(base_grid, source)?;
103        let mut psi = Array2::zeros((base_grid.nz, base_grid.nr));
104        sor_solve(
105            &mut psi,
106            source,
107            base_grid,
108            self.config.omega,
109            self.config.coarse_iters,
110        );
111        if psi.iter().any(|v| !v.is_finite()) {
112            return Err(FusionError::ConfigError(
113                "AMR coarse solve produced non-finite psi values".to_string(),
114            ));
115        }
116
117        let error = estimate_error_field(source, base_grid);
118        let mut hierarchy = AmrHierarchy::new(
119            base_grid.clone(),
120            self.config.max_levels,
121            self.config.refinement_threshold,
122        );
123        hierarchy.refine(&error);
124
125        if hierarchy.patches.is_empty() {
126            return Ok((psi, hierarchy));
127        }
128
129        fn refinement_scale(level: usize) -> usize {
130            let max_shift = usize::BITS as usize - 1;
131            let shift = level.min(max_shift) as u32;
132            1usize << shift
133        }
134
135        for patch in &mut hierarchy.patches {
136            let (iz_lo, _iz_hi, ir_lo, _ir_hi) = patch.bounds;
137            let scale = refinement_scale(patch.level);
138
139            for pz in 0..patch.grid.nz {
140                for pr in 0..patch.grid.nr {
141                    let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
142                    let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
143                    patch.psi[[pz, pr]] = psi[[base_iz, base_ir]];
144                }
145            }
146
147            let patch_source = Array2::from_shape_fn((patch.grid.nz, patch.grid.nr), |(pz, pr)| {
148                let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
149                let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
150                source[[base_iz, base_ir]]
151            });
152
153            sor_solve(
154                &mut patch.psi,
155                &patch_source,
156                &patch.grid,
157                self.config.omega,
158                self.config.patch_iters,
159            );
160            if patch.psi.iter().any(|v| !v.is_finite()) {
161                return Err(FusionError::ConfigError(
162                    "AMR patch solve produced non-finite psi values".to_string(),
163                ));
164            }
165        }
166
167        let blend = self.config.blend;
168        for patch in &hierarchy.patches {
169            let scale = refinement_scale(patch.level);
170            for iz in patch.bounds.0..=patch.bounds.1 {
171                for ir in patch.bounds.2..=patch.bounds.3 {
172                    let pz = (iz - patch.bounds.0) * scale;
173                    let pr = (ir - patch.bounds.2) * scale;
174                    if pz >= patch.grid.nz || pr >= patch.grid.nr {
175                        continue;
176                    }
177                    let patch_val = patch.psi[[pz, pr]];
178                    psi[[iz, ir]] = (1.0 - blend) * psi[[iz, ir]] + blend * patch_val;
179                }
180            }
181        }
182
183        if psi.iter().any(|v| !v.is_finite()) {
184            return Err(FusionError::ConfigError(
185                "AMR blended solve produced non-finite psi values".to_string(),
186            ));
187        }
188
189        Ok((psi, hierarchy))
190    }
191}
192
193impl Default for AmrKernelSolver {
194    fn default() -> Self {
195        Self::new(AmrKernelConfig::default()).expect("default AMR config must be valid")
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    fn gaussian_source(grid: &Grid2D) -> Array2<f64> {
204        let r0 = grid.r[grid.nr - 1] - 0.08 * (grid.r[grid.nr - 1] - grid.r[0]);
205        let z0 = 0.0;
206        Array2::from_shape_fn((grid.nz, grid.nr), |(iz, ir)| {
207            let dr = grid.rr[[iz, ir]] - r0;
208            let dz = grid.zz[[iz, ir]] - z0;
209            -(-((dr * dr) / 0.02 + (dz * dz) / 0.04)).exp()
210        })
211    }
212
213    fn rel_l2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
214        let num = (a - b).mapv(|v| v * v).sum().sqrt();
215        let den = b.mapv(|v| v * v).sum().sqrt().max(1e-12);
216        num / den
217    }
218
219    #[test]
220    fn test_amr_kernel_runs_with_refinement() {
221        let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
222        let source = gaussian_source(&grid);
223        let solver = AmrKernelSolver::default();
224        let (psi, hierarchy) = solver
225            .solve_with_hierarchy(&grid, &source)
226            .expect("valid AMR solve inputs");
227
228        assert!(psi.iter().all(|v| v.is_finite()));
229        assert!(
230            !hierarchy.patches.is_empty(),
231            "Expected at least one AMR patch for pedestal-weighted source"
232        );
233    }
234
235    #[test]
236    fn test_amr_kernel_no_refinement_matches_coarse() {
237        let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
238        let source = gaussian_source(&grid);
239
240        let coarse_cfg = AmrKernelConfig {
241            refinement_threshold: f64::INFINITY,
242            ..Default::default()
243        };
244        let solver = AmrKernelSolver::new(coarse_cfg.clone()).expect("valid AMR config");
245        let (amr_off, hierarchy) = solver
246            .solve_with_hierarchy(&grid, &source)
247            .expect("valid AMR solve inputs");
248        assert!(hierarchy.patches.is_empty());
249
250        let mut coarse = Array2::zeros((grid.nz, grid.nr));
251        sor_solve(
252            &mut coarse,
253            &source,
254            &grid,
255            coarse_cfg.omega,
256            coarse_cfg.coarse_iters,
257        );
258        let err = rel_l2(&amr_off, &coarse);
259        assert!(err < 1e-12, "No-refinement path should match coarse solve");
260    }
261
262    #[test]
263    fn test_amr_kernel_multilevel_hierarchy_when_enabled() {
264        let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
265        let source = gaussian_source(&grid);
266        let solver = AmrKernelSolver::new(AmrKernelConfig {
267            max_levels: 3,
268            refinement_threshold: 0.05,
269            ..Default::default()
270        })
271        .expect("valid AMR config");
272
273        let (_psi, hierarchy) = solver
274            .solve_with_hierarchy(&grid, &source)
275            .expect("valid AMR solve inputs");
276        assert!(
277            hierarchy.patches.len() >= 2,
278            "Expected multi-level AMR hierarchy for max_levels=3"
279        );
280        assert_eq!(
281            hierarchy.patches.iter().map(|p| p.level).max().unwrap_or(0),
282            2
283        );
284    }
285
286    #[test]
287    fn test_amr_kernel_rejects_invalid_constructor_config() {
288        for bad_blend in [f64::NAN, -0.1, 1.1] {
289            let err = AmrKernelSolver::new(AmrKernelConfig {
290                blend: bad_blend,
291                ..Default::default()
292            })
293            .expect_err("invalid blend must error");
294            match err {
295                FusionError::ConfigError(msg) => {
296                    assert!(msg.contains("blend"));
297                }
298                other => panic!("Unexpected error: {other:?}"),
299            }
300        }
301        let err = AmrKernelSolver::new(AmrKernelConfig {
302            omega: 0.0,
303            ..Default::default()
304        })
305        .expect_err("invalid omega must error");
306        match err {
307            FusionError::ConfigError(msg) => {
308                assert!(msg.contains("omega"));
309            }
310            other => panic!("Unexpected error: {other:?}"),
311        }
312
313        let err = AmrKernelSolver::new(AmrKernelConfig {
314            max_levels: 0,
315            ..Default::default()
316        })
317        .expect_err("invalid max_levels must error");
318        match err {
319            FusionError::ConfigError(msg) => {
320                assert!(msg.contains("max_levels"));
321            }
322            other => panic!("Unexpected error: {other:?}"),
323        }
324
325        for bad_threshold in [f64::NAN, -0.01, f64::NEG_INFINITY] {
326            let err = AmrKernelSolver::new(AmrKernelConfig {
327                refinement_threshold: bad_threshold,
328                ..Default::default()
329            })
330            .expect_err("invalid refinement_threshold must error");
331            match err {
332                FusionError::ConfigError(msg) => {
333                    assert!(msg.contains("refinement_threshold"));
334                }
335                other => panic!("Unexpected error: {other:?}"),
336            }
337        }
338    }
339
340    #[test]
341    fn test_amr_kernel_rejects_invalid_source_shape_or_values() {
342        let grid = Grid2D::new(17, 17, 1.0, 2.0, -1.0, 1.0);
343        let solver = AmrKernelSolver::default();
344
345        let bad_shape = Array2::zeros((16, 17));
346        assert!(solver.solve_with_hierarchy(&grid, &bad_shape).is_err());
347
348        let mut bad_values = Array2::zeros((17, 17));
349        bad_values[[0, 0]] = f64::NAN;
350        assert!(solver.solve_with_hierarchy(&grid, &bad_values).is_err());
351    }
352}