fusion_core/
amr_kernel.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SCPN Fusion Core — AMR-Aware Kernel Solver
7//! AMR-assisted equilibrium solve wrapper.
8
9use fusion_math::amr::{estimate_error_field, AmrHierarchy};
10use fusion_math::sor::sor_solve;
11use fusion_types::error::{FusionError, FusionResult};
12use fusion_types::state::Grid2D;
13use ndarray::Array2;
14
15#[derive(Debug, Clone)]
16pub struct AmrKernelConfig {
17    pub max_levels: usize,
18    pub refinement_threshold: f64,
19    pub omega: f64,
20    pub coarse_iters: usize,
21    pub patch_iters: usize,
22    pub blend: f64,
23}
24
25impl Default for AmrKernelConfig {
26    fn default() -> Self {
27        Self {
28            max_levels: 2,
29            refinement_threshold: 0.1,
30            omega: 1.8,
31            coarse_iters: 400,
32            patch_iters: 300,
33            blend: 0.5,
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct AmrKernelSolver {
40    pub config: AmrKernelConfig,
41}
42
43impl AmrKernelSolver {
44    pub fn new(config: AmrKernelConfig) -> FusionResult<Self> {
45        if config.max_levels == 0 {
46            return Err(FusionError::ConfigError(
47                "AMR max_levels must be >= 1".to_string(),
48            ));
49        }
50        if config.refinement_threshold.is_nan() || config.refinement_threshold.is_sign_negative() {
51            return Err(FusionError::ConfigError(
52                "AMR refinement_threshold must be >= 0 (or +inf)".to_string(),
53            ));
54        }
55        if !config.blend.is_finite() || !(0.0..=1.0).contains(&config.blend) {
56            return Err(FusionError::ConfigError(
57                "AMR blend must be finite and in [0, 1]".to_string(),
58            ));
59        }
60        if !config.omega.is_finite() || config.omega <= 0.0 {
61            return Err(FusionError::ConfigError(
62                "AMR omega must be finite and > 0".to_string(),
63            ));
64        }
65        if config.coarse_iters == 0 || config.patch_iters == 0 {
66            return Err(FusionError::ConfigError(
67                "AMR iteration counts must be >= 1".to_string(),
68            ));
69        }
70        Ok(Self { config })
71    }
72
73    fn validate_source_inputs(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<()> {
74        if source.nrows() != base_grid.nz || source.ncols() != base_grid.nr {
75            return Err(FusionError::ConfigError(format!(
76                "AMR source shape mismatch: expected ({}, {}), got ({}, {})",
77                base_grid.nz,
78                base_grid.nr,
79                source.nrows(),
80                source.ncols()
81            )));
82        }
83        if source.iter().any(|v| !v.is_finite()) {
84            return Err(FusionError::ConfigError(
85                "AMR source must contain only finite values".to_string(),
86            ));
87        }
88        Ok(())
89    }
90
91    pub fn solve(&self, base_grid: &Grid2D, source: &Array2<f64>) -> FusionResult<Array2<f64>> {
92        self.solve_with_hierarchy(base_grid, source)
93            .map(|(psi, _)| psi)
94    }
95
96    pub fn solve_with_hierarchy(
97        &self,
98        base_grid: &Grid2D,
99        source: &Array2<f64>,
100    ) -> FusionResult<(Array2<f64>, AmrHierarchy)> {
101        self.validate_source_inputs(base_grid, source)?;
102        let mut psi = Array2::zeros((base_grid.nz, base_grid.nr));
103        sor_solve(
104            &mut psi,
105            source,
106            base_grid,
107            self.config.omega,
108            self.config.coarse_iters,
109        );
110        if psi.iter().any(|v| !v.is_finite()) {
111            return Err(FusionError::ConfigError(
112                "AMR coarse solve produced non-finite psi values".to_string(),
113            ));
114        }
115
116        let error = estimate_error_field(source, base_grid);
117        let mut hierarchy = AmrHierarchy::new(
118            base_grid.clone(),
119            self.config.max_levels,
120            self.config.refinement_threshold,
121        );
122        hierarchy.refine(&error);
123
124        if hierarchy.patches.is_empty() {
125            return Ok((psi, hierarchy));
126        }
127
128        fn refinement_scale(level: usize) -> usize {
129            let max_shift = usize::BITS as usize - 1;
130            let shift = level.min(max_shift) as u32;
131            1usize << shift
132        }
133
134        for patch in &mut hierarchy.patches {
135            let (iz_lo, _iz_hi, ir_lo, _ir_hi) = patch.bounds;
136            let scale = refinement_scale(patch.level);
137
138            for pz in 0..patch.grid.nz {
139                for pr in 0..patch.grid.nr {
140                    let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
141                    let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
142                    patch.psi[[pz, pr]] = psi[[base_iz, base_ir]];
143                }
144            }
145
146            let patch_source = Array2::from_shape_fn((patch.grid.nz, patch.grid.nr), |(pz, pr)| {
147                let base_iz = (iz_lo + pz / scale).min(base_grid.nz - 1);
148                let base_ir = (ir_lo + pr / scale).min(base_grid.nr - 1);
149                source[[base_iz, base_ir]]
150            });
151
152            sor_solve(
153                &mut patch.psi,
154                &patch_source,
155                &patch.grid,
156                self.config.omega,
157                self.config.patch_iters,
158            );
159            if patch.psi.iter().any(|v| !v.is_finite()) {
160                return Err(FusionError::ConfigError(
161                    "AMR patch solve produced non-finite psi values".to_string(),
162                ));
163            }
164        }
165
166        let blend = self.config.blend;
167        for patch in &hierarchy.patches {
168            let scale = refinement_scale(patch.level);
169            for iz in patch.bounds.0..=patch.bounds.1 {
170                for ir in patch.bounds.2..=patch.bounds.3 {
171                    let pz = (iz - patch.bounds.0) * scale;
172                    let pr = (ir - patch.bounds.2) * scale;
173                    if pz >= patch.grid.nz || pr >= patch.grid.nr {
174                        continue;
175                    }
176                    let patch_val = patch.psi[[pz, pr]];
177                    psi[[iz, ir]] = (1.0 - blend) * psi[[iz, ir]] + blend * patch_val;
178                }
179            }
180        }
181
182        if psi.iter().any(|v| !v.is_finite()) {
183            return Err(FusionError::ConfigError(
184                "AMR blended solve produced non-finite psi values".to_string(),
185            ));
186        }
187
188        Ok((psi, hierarchy))
189    }
190}
191
192impl Default for AmrKernelSolver {
193    fn default() -> Self {
194        Self::new(AmrKernelConfig::default()).expect("default AMR config must be valid")
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    fn gaussian_source(grid: &Grid2D) -> Array2<f64> {
203        let r0 = grid.r[grid.nr - 1] - 0.08 * (grid.r[grid.nr - 1] - grid.r[0]);
204        let z0 = 0.0;
205        Array2::from_shape_fn((grid.nz, grid.nr), |(iz, ir)| {
206            let dr = grid.rr[[iz, ir]] - r0;
207            let dz = grid.zz[[iz, ir]] - z0;
208            -(-((dr * dr) / 0.02 + (dz * dz) / 0.04)).exp()
209        })
210    }
211
212    fn rel_l2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
213        let num = (a - b).mapv(|v| v * v).sum().sqrt();
214        let den = b.mapv(|v| v * v).sum().sqrt().max(1e-12);
215        num / den
216    }
217
218    #[test]
219    fn test_amr_kernel_runs_with_refinement() {
220        let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
221        let source = gaussian_source(&grid);
222        let solver = AmrKernelSolver::default();
223        let (psi, hierarchy) = solver
224            .solve_with_hierarchy(&grid, &source)
225            .expect("valid AMR solve inputs");
226
227        assert!(psi.iter().all(|v| v.is_finite()));
228        assert!(
229            !hierarchy.patches.is_empty(),
230            "Expected at least one AMR patch for pedestal-weighted source"
231        );
232    }
233
234    #[test]
235    fn test_amr_kernel_no_refinement_matches_coarse() {
236        let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
237        let source = gaussian_source(&grid);
238
239        let coarse_cfg = AmrKernelConfig {
240            refinement_threshold: f64::INFINITY,
241            ..Default::default()
242        };
243        let solver = AmrKernelSolver::new(coarse_cfg.clone()).expect("valid AMR config");
244        let (amr_off, hierarchy) = solver
245            .solve_with_hierarchy(&grid, &source)
246            .expect("valid AMR solve inputs");
247        assert!(hierarchy.patches.is_empty());
248
249        let mut coarse = Array2::zeros((grid.nz, grid.nr));
250        sor_solve(
251            &mut coarse,
252            &source,
253            &grid,
254            coarse_cfg.omega,
255            coarse_cfg.coarse_iters,
256        );
257        let err = rel_l2(&amr_off, &coarse);
258        assert!(err < 1e-12, "No-refinement path should match coarse solve");
259    }
260
261    #[test]
262    fn test_amr_kernel_multilevel_hierarchy_when_enabled() {
263        let grid = Grid2D::new(33, 33, 1.0, 2.0, -1.0, 1.0);
264        let source = gaussian_source(&grid);
265        let solver = AmrKernelSolver::new(AmrKernelConfig {
266            max_levels: 3,
267            refinement_threshold: 0.05,
268            ..Default::default()
269        })
270        .expect("valid AMR config");
271
272        let (_psi, hierarchy) = solver
273            .solve_with_hierarchy(&grid, &source)
274            .expect("valid AMR solve inputs");
275        assert!(
276            hierarchy.patches.len() >= 2,
277            "Expected multi-level AMR hierarchy for max_levels=3"
278        );
279        assert_eq!(
280            hierarchy.patches.iter().map(|p| p.level).max().unwrap_or(0),
281            2
282        );
283    }
284
285    #[test]
286    fn test_amr_kernel_rejects_invalid_constructor_config() {
287        for bad_blend in [f64::NAN, -0.1, 1.1] {
288            let err = AmrKernelSolver::new(AmrKernelConfig {
289                blend: bad_blend,
290                ..Default::default()
291            })
292            .expect_err("invalid blend must error");
293            match err {
294                FusionError::ConfigError(msg) => {
295                    assert!(msg.contains("blend"));
296                }
297                other => panic!("Unexpected error: {other:?}"),
298            }
299        }
300        let err = AmrKernelSolver::new(AmrKernelConfig {
301            omega: 0.0,
302            ..Default::default()
303        })
304        .expect_err("invalid omega must error");
305        match err {
306            FusionError::ConfigError(msg) => {
307                assert!(msg.contains("omega"));
308            }
309            other => panic!("Unexpected error: {other:?}"),
310        }
311
312        let err = AmrKernelSolver::new(AmrKernelConfig {
313            max_levels: 0,
314            ..Default::default()
315        })
316        .expect_err("invalid max_levels must error");
317        match err {
318            FusionError::ConfigError(msg) => {
319                assert!(msg.contains("max_levels"));
320            }
321            other => panic!("Unexpected error: {other:?}"),
322        }
323
324        for bad_threshold in [f64::NAN, -0.01, f64::NEG_INFINITY] {
325            let err = AmrKernelSolver::new(AmrKernelConfig {
326                refinement_threshold: bad_threshold,
327                ..Default::default()
328            })
329            .expect_err("invalid refinement_threshold must error");
330            match err {
331                FusionError::ConfigError(msg) => {
332                    assert!(msg.contains("refinement_threshold"));
333                }
334                other => panic!("Unexpected error: {other:?}"),
335            }
336        }
337    }
338
339    #[test]
340    fn test_amr_kernel_rejects_invalid_source_shape_or_values() {
341        let grid = Grid2D::new(17, 17, 1.0, 2.0, -1.0, 1.0);
342        let solver = AmrKernelSolver::default();
343
344        let bad_shape = Array2::zeros((16, 17));
345        assert!(solver.solve_with_hierarchy(&grid, &bad_shape).is_err());
346
347        let mut bad_values = Array2::zeros((17, 17));
348        bad_values[[0, 0]] = f64::NAN;
349        assert!(solver.solve_with_hierarchy(&grid, &bad_values).is_err());
350    }
351}