fusion_core/
mpi_domain.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later | Commercial license available
2// © Concepts 1996–2026 Miroslav Šotek. All rights reserved.
3// © Code 2020–2026 Miroslav Šotek. All rights reserved.
4// ORCID: 0009-0009-3560-0851
5// Contact: www.anulum.li | protoscience@anulum.li
6// SCPN Fusion Core — MPI Domain Scaffolding
7//! MPI-oriented domain decomposition scaffolding.
8//!
9//! This module defines deterministic domain partition metadata and halo
10//! packing/exchange primitives that can be wired to rsmpi in a later phase.
11
12use fusion_types::error::{FusionError, FusionResult};
13use ndarray::{s, Array2};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct DomainSlice {
17    pub rank: usize,
18    pub nranks: usize,
19    pub global_nz: usize,
20    pub local_nz: usize,
21    pub halo: usize,
22    pub z_start: usize,
23    pub z_end: usize,
24}
25
26impl DomainSlice {
27    pub fn has_upper_neighbor(&self) -> bool {
28        self.rank > 0
29    }
30
31    pub fn has_lower_neighbor(&self) -> bool {
32        self.rank + 1 < self.nranks
33    }
34}
35
36pub fn decompose_z(global_nz: usize, nranks: usize, halo: usize) -> FusionResult<Vec<DomainSlice>> {
37    if global_nz < 2 {
38        return Err(FusionError::PhysicsViolation(
39            "MPI decomposition requires global_nz >= 2".to_string(),
40        ));
41    }
42    if nranks < 1 {
43        return Err(FusionError::PhysicsViolation(
44            "MPI decomposition requires nranks >= 1".to_string(),
45        ));
46    }
47    if nranks > global_nz {
48        return Err(FusionError::PhysicsViolation(format!(
49            "Cannot split global_nz={global_nz} across nranks={nranks}"
50        )));
51    }
52
53    let base = global_nz / nranks;
54    let rem = global_nz % nranks;
55    let mut out = Vec::with_capacity(nranks);
56    let mut cursor = 0usize;
57    for rank in 0..nranks {
58        let local_nz = base + usize::from(rank < rem);
59        let z_start = cursor;
60        let z_end = z_start + local_nz;
61        cursor = z_end;
62        out.push(DomainSlice {
63            rank,
64            nranks,
65            global_nz,
66            local_nz,
67            halo,
68            z_start,
69            z_end,
70        });
71    }
72    Ok(out)
73}
74
75pub fn pack_halo_rows(
76    local: &Array2<f64>,
77    halo: usize,
78) -> FusionResult<(Array2<f64>, Array2<f64>)> {
79    if halo == 0 {
80        return Err(FusionError::PhysicsViolation(
81            "Halo width must be >= 1".to_string(),
82        ));
83    }
84    if local.nrows() <= 2 * halo {
85        return Err(FusionError::PhysicsViolation(format!(
86            "Local block has insufficient rows {} for halo={halo}",
87            local.nrows()
88        )));
89    }
90    if local.iter().any(|v| !v.is_finite()) {
91        return Err(FusionError::PhysicsViolation(
92            "Local block contains non-finite values".to_string(),
93        ));
94    }
95    let top = local.slice(s![halo..(2 * halo), ..]).to_owned();
96    let bottom = local
97        .slice(s![(local.nrows() - 2 * halo)..(local.nrows() - halo), ..])
98        .to_owned();
99    if top.iter().any(|v| !v.is_finite()) || bottom.iter().any(|v| !v.is_finite()) {
100        return Err(FusionError::PhysicsViolation(
101            "Packed halo rows contain non-finite values".to_string(),
102        ));
103    }
104    Ok((top, bottom))
105}
106
107pub fn apply_halo_rows(
108    local: &mut Array2<f64>,
109    halo: usize,
110    recv_top: Option<&Array2<f64>>,
111    recv_bottom: Option<&Array2<f64>>,
112) -> FusionResult<()> {
113    if halo == 0 {
114        return Err(FusionError::PhysicsViolation(
115            "Halo width must be >= 1".to_string(),
116        ));
117    }
118    if local.nrows() <= 2 * halo {
119        return Err(FusionError::PhysicsViolation(format!(
120            "Local block has insufficient rows {} for halo={halo}",
121            local.nrows()
122        )));
123    }
124    if local.iter().any(|v| !v.is_finite()) {
125        return Err(FusionError::PhysicsViolation(
126            "Local block contains non-finite values".to_string(),
127        ));
128    }
129    if let Some(top) = recv_top {
130        if top.dim() != (halo, local.ncols()) {
131            return Err(FusionError::PhysicsViolation(format!(
132                "Top halo shape mismatch: expected ({halo}, {}), got {:?}",
133                local.ncols(),
134                top.dim()
135            )));
136        }
137        if top.iter().any(|v| !v.is_finite()) {
138            return Err(FusionError::PhysicsViolation(
139                "Top halo contains non-finite values".to_string(),
140            ));
141        }
142        local.slice_mut(s![0..halo, ..]).assign(top);
143    }
144    if let Some(bottom) = recv_bottom {
145        if bottom.dim() != (halo, local.ncols()) {
146            return Err(FusionError::PhysicsViolation(format!(
147                "Bottom halo shape mismatch: expected ({halo}, {}), got {:?}",
148                local.ncols(),
149                bottom.dim()
150            )));
151        }
152        if bottom.iter().any(|v| !v.is_finite()) {
153            return Err(FusionError::PhysicsViolation(
154                "Bottom halo contains non-finite values".to_string(),
155            ));
156        }
157        let n = local.nrows();
158        local.slice_mut(s![(n - halo)..n, ..]).assign(bottom);
159    }
160    if local.iter().any(|v| !v.is_finite()) {
161        return Err(FusionError::PhysicsViolation(
162            "Applying halo rows produced non-finite values".to_string(),
163        ));
164    }
165    Ok(())
166}
167
168pub fn split_with_halo(
169    global: &Array2<f64>,
170    slices: &[DomainSlice],
171) -> FusionResult<Vec<Array2<f64>>> {
172    if slices.is_empty() {
173        return Err(FusionError::PhysicsViolation(
174            "No slices provided for split_with_halo".to_string(),
175        ));
176    }
177    if global.iter().any(|v| !v.is_finite()) {
178        return Err(FusionError::PhysicsViolation(
179            "Global array contains non-finite values".to_string(),
180        ));
181    }
182    let mut out = Vec::with_capacity(slices.len());
183    for sdef in slices {
184        if sdef.global_nz != global.nrows() {
185            return Err(FusionError::PhysicsViolation(format!(
186                "Slice/global mismatch: slice.global_nz={} global.nrows()={}",
187                sdef.global_nz,
188                global.nrows()
189            )));
190        }
191        if sdef.z_start >= sdef.z_end || sdef.z_end > sdef.global_nz {
192            return Err(FusionError::PhysicsViolation(format!(
193                "Invalid slice bounds z_start={} z_end={} global_nz={}",
194                sdef.z_start, sdef.z_end, sdef.global_nz
195            )));
196        }
197        let start = sdef.z_start.saturating_sub(sdef.halo);
198        let end = (sdef.z_end + sdef.halo).min(sdef.global_nz);
199        let mut local = Array2::zeros((end - start, global.ncols()));
200        local.assign(&global.slice(s![start..end, ..]));
201        if local.iter().any(|v| !v.is_finite()) {
202            return Err(FusionError::PhysicsViolation(
203                "Split local block contains non-finite values".to_string(),
204            ));
205        }
206        out.push(local);
207    }
208    Ok(out)
209}
210
211pub fn stitch_without_halo(
212    locals: &[Array2<f64>],
213    slices: &[DomainSlice],
214    ncols: usize,
215) -> FusionResult<Array2<f64>> {
216    if locals.len() != slices.len() {
217        return Err(FusionError::PhysicsViolation(format!(
218            "locals/slices mismatch: {} vs {}",
219            locals.len(),
220            slices.len()
221        )));
222    }
223    if slices.is_empty() {
224        return Err(FusionError::PhysicsViolation(
225            "No slices provided for stitch_without_halo".to_string(),
226        ));
227    }
228    let global_nz = slices
229        .last()
230        .map(|s| s.global_nz)
231        .ok_or_else(|| FusionError::PhysicsViolation("No slices provided".to_string()))?;
232    let mut global = Array2::zeros((global_nz, ncols));
233    for (local, sdef) in locals.iter().zip(slices.iter()) {
234        if local.iter().any(|v| !v.is_finite()) {
235            return Err(FusionError::PhysicsViolation(
236                "Local block contains non-finite values".to_string(),
237            ));
238        }
239        if local.ncols() != ncols {
240            return Err(FusionError::PhysicsViolation(format!(
241                "Local ncols mismatch: expected {ncols}, got {}",
242                local.ncols()
243            )));
244        }
245        let core_start = usize::from(sdef.z_start > 0) * sdef.halo;
246        let core_end = core_start + sdef.local_nz;
247        if core_end > local.nrows() {
248            return Err(FusionError::PhysicsViolation(format!(
249                "Local core range out of bounds: rows={}, core_end={core_end}",
250                local.nrows()
251            )));
252        }
253        global
254            .slice_mut(s![sdef.z_start..sdef.z_end, ..])
255            .assign(&local.slice(s![core_start..core_end, ..]));
256    }
257    if global.iter().any(|v| !v.is_finite()) {
258        return Err(FusionError::PhysicsViolation(
259            "Stitched global array contains non-finite values".to_string(),
260        ));
261    }
262    Ok(global)
263}
264
265pub fn serial_halo_exchange(
266    locals: &mut [Array2<f64>],
267    slices: &[DomainSlice],
268) -> FusionResult<()> {
269    if locals.len() != slices.len() {
270        return Err(FusionError::PhysicsViolation(format!(
271            "locals/slices mismatch: {} vs {}",
272            locals.len(),
273            slices.len()
274        )));
275    }
276    let mut top_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
277    let mut bottom_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
278    for (i, (local, sdef)) in locals.iter().zip(slices.iter()).enumerate() {
279        if local.iter().any(|v| !v.is_finite()) {
280            return Err(FusionError::PhysicsViolation(format!(
281                "Local block at index {i} contains non-finite values"
282            )));
283        }
284        if sdef.halo == 0 {
285            continue;
286        }
287        let (top, bottom) = pack_halo_rows(local, sdef.halo)?;
288        top_send[i] = Some(top);
289        bottom_send[i] = Some(bottom);
290    }
291
292    for i in 0..locals.len() {
293        let halo = slices[i].halo;
294        if halo == 0 {
295            continue;
296        }
297        let recv_top = if i > 0 {
298            bottom_send[i - 1].as_ref()
299        } else {
300            None
301        };
302        let recv_bottom = if i + 1 < locals.len() {
303            top_send[i + 1].as_ref()
304        } else {
305            None
306        };
307        apply_halo_rows(&mut locals[i], halo, recv_top, recv_bottom)?;
308    }
309    if locals.iter().any(|arr| arr.iter().any(|v| !v.is_finite())) {
310        return Err(FusionError::PhysicsViolation(
311            "Serial halo exchange produced non-finite values".to_string(),
312        ));
313    }
314    Ok(())
315}
316
317pub fn l2_norm_delta(a: &Array2<f64>, b: &Array2<f64>) -> FusionResult<f64> {
318    if a.dim() != b.dim() {
319        return Err(FusionError::PhysicsViolation(format!(
320            "l2_norm_delta shape mismatch {:?} vs {:?}",
321            a.dim(),
322            b.dim()
323        )));
324    }
325    if a.iter().any(|v| !v.is_finite()) || b.iter().any(|v| !v.is_finite()) {
326        return Err(FusionError::PhysicsViolation(
327            "l2_norm_delta inputs must be finite".to_string(),
328        ));
329    }
330    let mut accum = 0.0f64;
331    for (av, bv) in a.iter().zip(b.iter()) {
332        let d = av - bv;
333        accum += d * d;
334    }
335    let out = accum.sqrt();
336    if !out.is_finite() {
337        return Err(FusionError::PhysicsViolation(
338            "l2_norm_delta produced non-finite result".to_string(),
339        ));
340    }
341    Ok(out)
342}
343
344// ═══════════════════════════════════════════════════════════════════════
345// 2D Cartesian Domain Decomposition — Exascale-Ready MPI Abstraction
346// ═══════════════════════════════════════════════════════════════════════
347//
348// The key innovation: we decompose the (nz × nr) GS grid into a 2D
349// Cartesian process grid (pz × pr), where each tile owns a contiguous
350// sub-block plus halo cells on all four faces. The Rayon threadpool
351// simulates distributed-memory ranks; replacing with rsmpi is a 1:1
352// swap of the halo exchange function.
353//
354// References:
355//   - Jardin, "Computational Methods in Plasma Physics", Ch. 12
356//   - Lao et al., "Equilibrium analysis of current profiles in tokamaks",
357//     Nucl. Fusion 30 (1990) 1035
358//   - EFIT-AI: Joung et al., Nucl. Fusion 63 (2023) 126058
359
360/// 2D Cartesian tile descriptor — one per rank in a (pz × pr) topology.
361#[derive(Debug, Clone, PartialEq, Eq)]
362pub struct CartesianTile {
363    /// Linear rank index (0 .. pz*pr - 1).
364    pub rank: usize,
365    /// Process grid index along Z (row).
366    pub pz_idx: usize,
367    /// Process grid index along R (column).
368    pub pr_idx: usize,
369    /// Total process-grid dimensions.
370    pub pz: usize,
371    pub pr: usize,
372    /// Global grid dimensions.
373    pub global_nz: usize,
374    pub global_nr: usize,
375    /// Halo width (same on all faces).
376    pub halo: usize,
377    /// Owned Z range [z_start, z_end) in global indexing.
378    pub z_start: usize,
379    pub z_end: usize,
380    /// Owned R range [r_start, r_end) in global indexing.
381    pub r_start: usize,
382    pub r_end: usize,
383}
384
385impl CartesianTile {
386    /// Number of owned Z rows (excluding halo).
387    pub fn local_nz(&self) -> usize {
388        self.z_end - self.z_start
389    }
390    /// Number of owned R columns (excluding halo).
391    pub fn local_nr(&self) -> usize {
392        self.r_end - self.r_start
393    }
394    /// Total rows including halo (top + bottom).
395    pub fn padded_nz(&self) -> usize {
396        let top = if self.pz_idx > 0 { self.halo } else { 0 };
397        let bot = if self.pz_idx + 1 < self.pz {
398            self.halo
399        } else {
400            0
401        };
402        self.local_nz() + top + bot
403    }
404    /// Total columns including halo (left + right).
405    pub fn padded_nr(&self) -> usize {
406        let left = if self.pr_idx > 0 { self.halo } else { 0 };
407        let right = if self.pr_idx + 1 < self.pr {
408            self.halo
409        } else {
410            0
411        };
412        self.local_nr() + left + right
413    }
414    /// Offset of the first owned row within the padded local array.
415    pub fn core_z_offset(&self) -> usize {
416        if self.pz_idx > 0 {
417            self.halo
418        } else {
419            0
420        }
421    }
422    /// Offset of the first owned column within the padded local array.
423    pub fn core_r_offset(&self) -> usize {
424        if self.pr_idx > 0 {
425            self.halo
426        } else {
427            0
428        }
429    }
430    pub fn has_neighbor_top(&self) -> bool {
431        self.pz_idx > 0
432    }
433    pub fn has_neighbor_bottom(&self) -> bool {
434        self.pz_idx + 1 < self.pz
435    }
436    pub fn has_neighbor_left(&self) -> bool {
437        self.pr_idx > 0
438    }
439    pub fn has_neighbor_right(&self) -> bool {
440        self.pr_idx + 1 < self.pr
441    }
442    /// Rank of a neighbour at offset (dz, dr) in the process grid.
443    /// Returns None if out of bounds.
444    pub fn neighbor_rank(&self, dz: i32, dr: i32) -> Option<usize> {
445        let nz = self.pz_idx as i32 + dz;
446        let nr = self.pr_idx as i32 + dr;
447        if nz < 0 || nz >= self.pz as i32 || nr < 0 || nr >= self.pr as i32 {
448            return None;
449        }
450        Some(nz as usize * self.pr + nr as usize)
451    }
452}
453
454/// Decompose a 2D grid of shape (global_nz × global_nr) into a
455/// (pz × pr) Cartesian process topology.
456///
457/// Returns tiles in row-major order: tile[iz * pr + ir].
458pub fn decompose_2d(
459    global_nz: usize,
460    global_nr: usize,
461    pz: usize,
462    pr: usize,
463    halo: usize,
464) -> FusionResult<Vec<CartesianTile>> {
465    if pz == 0 || pr == 0 {
466        return Err(FusionError::PhysicsViolation(
467            "Process grid dimensions pz, pr must be >= 1".to_string(),
468        ));
469    }
470    if pz > global_nz || pr > global_nr {
471        return Err(FusionError::PhysicsViolation(format!(
472            "Cannot split ({global_nz}×{global_nr}) across ({pz}×{pr}) processes"
473        )));
474    }
475    if global_nz < 2 || global_nr < 2 {
476        return Err(FusionError::PhysicsViolation(
477            "Global grid must be at least 2×2".to_string(),
478        ));
479    }
480
481    // Distribute rows/columns as evenly as possible.
482    let z_splits = balanced_split(global_nz, pz);
483    let r_splits = balanced_split(global_nr, pr);
484
485    let mut tiles = Vec::with_capacity(pz * pr);
486    let mut z_cursor = 0usize;
487    for (iz, nz_local) in z_splits.iter().copied().enumerate().take(pz) {
488        let z_start = z_cursor;
489        let z_end = z_start + nz_local;
490        z_cursor = z_end;
491
492        let mut r_cursor = 0usize;
493        for (ir, nr_local) in r_splits.iter().copied().enumerate().take(pr) {
494            let r_start = r_cursor;
495            let r_end = r_start + nr_local;
496            r_cursor = r_end;
497
498            tiles.push(CartesianTile {
499                rank: iz * pr + ir,
500                pz_idx: iz,
501                pr_idx: ir,
502                pz,
503                pr,
504                global_nz,
505                global_nr,
506                halo,
507                z_start,
508                z_end,
509                r_start,
510                r_end,
511            });
512        }
513    }
514    Ok(tiles)
515}
516
517/// Helper: split `n` items across `k` buckets as evenly as possible.
518fn balanced_split(n: usize, k: usize) -> Vec<usize> {
519    let base = n / k;
520    let rem = n % k;
521    (0..k).map(|i| base + usize::from(i < rem)).collect()
522}
523
524/// Extract a padded local tile (with halo) from the global array.
525pub fn extract_tile(global: &Array2<f64>, tile: &CartesianTile) -> FusionResult<Array2<f64>> {
526    let (gnz, gnr) = global.dim();
527    if gnz != tile.global_nz || gnr != tile.global_nr {
528        return Err(FusionError::PhysicsViolation(format!(
529            "Global shape ({gnz},{gnr}) doesn't match tile expectation ({},{})",
530            tile.global_nz, tile.global_nr
531        )));
532    }
533    let z0 = tile.z_start.saturating_sub(tile.core_z_offset());
534    let z1 = (tile.z_end
535        + if tile.has_neighbor_bottom() {
536            tile.halo
537        } else {
538            0
539        })
540    .min(gnz);
541    let r0 = tile.r_start.saturating_sub(tile.core_r_offset());
542    let r1 = (tile.r_end
543        + if tile.has_neighbor_right() {
544            tile.halo
545        } else {
546            0
547        })
548    .min(gnr);
549
550    let local = global.slice(s![z0..z1, r0..r1]).to_owned();
551    if local.iter().any(|v| !v.is_finite()) {
552        return Err(FusionError::PhysicsViolation(
553            "Extracted tile contains non-finite values".to_string(),
554        ));
555    }
556    Ok(local)
557}
558
559/// Write the core (non-halo) region of a local tile back into the global array.
560pub fn inject_tile(
561    global: &mut Array2<f64>,
562    local: &Array2<f64>,
563    tile: &CartesianTile,
564) -> FusionResult<()> {
565    let cz = tile.core_z_offset();
566    let cr = tile.core_r_offset();
567    let lnz = tile.local_nz();
568    let lnr = tile.local_nr();
569    if cz + lnz > local.nrows() || cr + lnr > local.ncols() {
570        return Err(FusionError::PhysicsViolation(format!(
571            "Tile core out of bounds: local shape {:?}, core ({cz}+{lnz}, {cr}+{lnr})",
572            local.dim()
573        )));
574    }
575    let core = local.slice(s![cz..(cz + lnz), cr..(cr + lnr)]);
576    if core.iter().any(|v| !v.is_finite()) {
577        return Err(FusionError::PhysicsViolation(
578            "Tile core to inject contains non-finite values".to_string(),
579        ));
580    }
581    global
582        .slice_mut(s![tile.z_start..tile.z_end, tile.r_start..tile.r_end])
583        .assign(&core);
584    Ok(())
585}
586
587/// Serial 2D halo exchange across all tiles.
588///
589/// Copies owned boundary rows/columns from each tile into the halo
590/// region of its four face-neighbours. This is the serial reference
591/// implementation; the MPI version replaces this with non-blocking
592/// Isend/Irecv pairs.
593pub fn serial_halo_exchange_2d(
594    locals: &mut [Array2<f64>],
595    tiles: &[CartesianTile],
596) -> FusionResult<()> {
597    if locals.len() != tiles.len() {
598        return Err(FusionError::PhysicsViolation(format!(
599            "locals/tiles length mismatch: {} vs {}",
600            locals.len(),
601            tiles.len()
602        )));
603    }
604    let ntiles = tiles.len();
605    if ntiles == 0 {
606        return Ok(());
607    }
608    let halo = tiles[0].halo;
609    if halo == 0 {
610        return Ok(());
611    }
612
613    // Collect halo strips from all tiles first (immutable borrows).
614    // Each entry: (dest_rank, HaloFace, data).
615    #[derive(Clone, Copy)]
616    enum Face {
617        Top,
618        Bottom,
619        Left,
620        Right,
621    }
622    let mut messages: Vec<(usize, Face, Array2<f64>)> = Vec::new();
623
624    for (i, tile) in tiles.iter().enumerate() {
625        let loc = &locals[i];
626        let cz = tile.core_z_offset();
627        let cr = tile.core_r_offset();
628        let lnz = tile.local_nz();
629        let lnr = tile.local_nr();
630
631        // Send top face → neighbor above
632        if let Some(dest) = tile.neighbor_rank(-1, 0) {
633            let strip = loc.slice(s![cz..(cz + halo), cr..(cr + lnr)]).to_owned();
634            messages.push((dest, Face::Bottom, strip));
635        }
636        // Send bottom face → neighbor below
637        if let Some(dest) = tile.neighbor_rank(1, 0) {
638            let strip = loc
639                .slice(s![(cz + lnz - halo)..(cz + lnz), cr..(cr + lnr)])
640                .to_owned();
641            messages.push((dest, Face::Top, strip));
642        }
643        // Send left face → neighbor to the left
644        if let Some(dest) = tile.neighbor_rank(0, -1) {
645            let strip = loc.slice(s![cz..(cz + lnz), cr..(cr + halo)]).to_owned();
646            messages.push((dest, Face::Right, strip));
647        }
648        // Send right face → neighbor to the right
649        if let Some(dest) = tile.neighbor_rank(0, 1) {
650            let strip = loc
651                .slice(s![cz..(cz + lnz), (cr + lnr - halo)..(cr + lnr)])
652                .to_owned();
653            messages.push((dest, Face::Left, strip));
654        }
655    }
656
657    // Apply received halos.
658    for (dest, face, data) in messages {
659        let tile = &tiles[dest];
660        let loc = &mut locals[dest];
661        let cz = tile.core_z_offset();
662        let cr = tile.core_r_offset();
663        let lnz = tile.local_nz();
664        let lnr = tile.local_nr();
665
666        match face {
667            Face::Top => {
668                // Fill top halo rows.
669                if cz >= halo {
670                    loc.slice_mut(s![(cz - halo)..cz, cr..(cr + lnr)])
671                        .assign(&data);
672                }
673            }
674            Face::Bottom => {
675                // Fill bottom halo rows.
676                let row_start = cz + lnz;
677                let row_end = row_start + halo;
678                if row_end <= loc.nrows() {
679                    loc.slice_mut(s![row_start..row_end, cr..(cr + lnr)])
680                        .assign(&data);
681                }
682            }
683            Face::Left => {
684                // Fill left halo columns.
685                if cr >= halo {
686                    loc.slice_mut(s![cz..(cz + lnz), (cr - halo)..cr])
687                        .assign(&data);
688                }
689            }
690            Face::Right => {
691                // Fill right halo columns.
692                let col_start = cr + lnr;
693                let col_end = col_start + halo;
694                if col_end <= loc.ncols() {
695                    loc.slice_mut(s![cz..(cz + lnz), col_start..col_end])
696                        .assign(&data);
697                }
698            }
699        }
700    }
701
702    // Verify no non-finite values crept in.
703    for loc in locals.iter() {
704        if loc.iter().any(|v| !v.is_finite()) {
705            return Err(FusionError::PhysicsViolation(
706                "2D halo exchange produced non-finite values".to_string(),
707            ));
708        }
709    }
710    Ok(())
711}
712
713/// Configuration for the distributed GS solver.
714#[derive(Debug, Clone)]
715pub struct DistributedSolverConfig {
716    /// Process-grid dimensions (pz × pr). Product must be ≤ Rayon
717    /// thread count for full parallel utilisation.
718    pub pz: usize,
719    pub pr: usize,
720    /// Halo width (number of overlap rows/columns per face). 1 is
721    /// sufficient for the 5-point GS stencil.
722    pub halo: usize,
723    /// SOR relaxation parameter ω ∈ (1, 2). Typically 1.8.
724    pub omega: f64,
725    /// Maximum number of Schwarz (outer) iterations.
726    pub max_outer_iters: usize,
727    /// Convergence tolerance on global L2 residual.
728    pub tol: f64,
729    /// Number of local SOR sweeps per Schwarz iteration.
730    pub inner_sweeps: usize,
731}
732
733impl Default for DistributedSolverConfig {
734    fn default() -> Self {
735        Self {
736            pz: 2,
737            pr: 2,
738            halo: 1,
739            omega: 1.8,
740            max_outer_iters: 200,
741            tol: 1e-8,
742            inner_sweeps: 5,
743        }
744    }
745}
746
747/// Result of a distributed GS solve.
748#[derive(Debug, Clone)]
749pub struct DistributedSolveResult {
750    /// Final global Ψ array.
751    pub psi: Array2<f64>,
752    /// Achieved global L2 residual.
753    pub residual: f64,
754    /// Number of Schwarz outer iterations used.
755    pub iterations: usize,
756    /// Whether the solve converged within tolerance.
757    pub converged: bool,
758}
759
760/// Distributed Grad-Shafranov solver using additive Schwarz domain
761/// decomposition with Rayon thread-parallelism.
762///
763/// Each Schwarz iteration:
764/// 1. Splits the global Ψ into 2D tiles (with halo overlap).
765/// 2. Runs `inner_sweeps` local Red-Black SOR sweeps on each tile
766///    in parallel via Rayon.
767/// 3. Injects tile cores back into the global array.
768/// 4. Exchanges halos (serial reference — MPI-ready interface).
769/// 5. Computes global residual; checks convergence.
770///
771/// The 5-point GS stencil:
772///   R d/dR(1/R dΨ/dR) + d²Ψ/dZ² = -μ₀ R J_φ
773///
774/// discretises to the same operator as `fusion_math::sor::sor_step`,
775/// but applied independently on each tile with local boundary from
776/// halo data.
777pub fn distributed_gs_solve(
778    psi: &Array2<f64>,
779    source: &Array2<f64>,
780    r_axis: &[f64],
781    z_axis: &[f64],
782    dr: f64,
783    dz: f64,
784    cfg: &DistributedSolverConfig,
785) -> FusionResult<DistributedSolveResult> {
786    let (nz, nr) = psi.dim();
787    if source.dim() != (nz, nr) {
788        return Err(FusionError::PhysicsViolation(format!(
789            "psi/source shape mismatch: {:?} vs {:?}",
790            psi.dim(),
791            source.dim()
792        )));
793    }
794    if r_axis.len() != nr || z_axis.len() != nz {
795        return Err(FusionError::PhysicsViolation(format!(
796            "Axis lengths don't match grid: r_axis={} nr={nr}, z_axis={} nz={nz}",
797            r_axis.len(),
798            z_axis.len()
799        )));
800    }
801    if cfg.omega <= 0.0 || cfg.omega >= 2.0 {
802        return Err(FusionError::PhysicsViolation(format!(
803            "SOR omega must be in (0, 2), got {}",
804            cfg.omega
805        )));
806    }
807    if !dr.is_finite() || !dz.is_finite() || dr <= 0.0 || dz <= 0.0 {
808        return Err(FusionError::PhysicsViolation(format!(
809            "Grid spacing must be finite > 0: dr={dr}, dz={dz}"
810        )));
811    }
812
813    let tiles = decompose_2d(nz, nr, cfg.pz, cfg.pr, cfg.halo)?;
814
815    let mut global_psi = psi.clone();
816    let dr_sq = dr * dr;
817    let dz_sq = dz * dz;
818
819    let mut converged = false;
820    let mut outer_iter = 0usize;
821    let mut residual = f64::MAX;
822
823    for outer in 0..cfg.max_outer_iters {
824        outer_iter = outer + 1;
825
826        // 1. Extract tiles with halo from current global Ψ.
827        let mut locals: Vec<Array2<f64>> = tiles
828            .iter()
829            .map(|t| extract_tile(&global_psi, t))
830            .collect::<FusionResult<Vec<_>>>()?;
831        let sources: Vec<Array2<f64>> = tiles
832            .iter()
833            .map(|t| extract_tile(source, t))
834            .collect::<FusionResult<Vec<_>>>()?;
835
836        // 2. Run local SOR sweeps in parallel via Rayon.
837        //    Each tile applies inner_sweeps Red-Black SOR iterations.
838        use rayon::prelude::*;
839        locals
840            .par_iter_mut()
841            .zip(sources.par_iter())
842            .zip(tiles.par_iter())
843            .for_each(|((loc, src), tile)| {
844                let cz = tile.core_z_offset();
845                let cr = tile.core_r_offset();
846                let _lnz = tile.local_nz();
847                let _lnr = tile.local_nr();
848
849                for _sweep in 0..cfg.inner_sweeps {
850                    // Red pass then black pass.
851                    for color in 0..2u8 {
852                        for iz in 1..loc.nrows().saturating_sub(1) {
853                            for ir in 1..loc.ncols().saturating_sub(1) {
854                                if (iz + ir) % 2 != color as usize {
855                                    continue;
856                                }
857                                // Only update interior of the owned core
858                                // (and the halo interior that overlaps with
859                                // a neighbor's owned region).
860                                // Map local (iz, ir) to global indices.
861                                let gz = tile.z_start as i64 + iz as i64 - cz as i64;
862                                let gr = tile.r_start as i64 + ir as i64 - cr as i64;
863                                // Skip global boundaries.
864                                if gz <= 0 || gz >= (nz as i64 - 1) {
865                                    continue;
866                                }
867                                if gr <= 0 || gr >= (nr as i64 - 1) {
868                                    continue;
869                                }
870
871                                let r_val = r_axis[gr as usize];
872                                if r_val <= 0.0 {
873                                    continue;
874                                }
875
876                                // GS 5-point stencil with 1/R correction.
877                                // Coefficients follow sor.rs update_point():
878                                //   c_r_plus  = 1/dr² - 1/(2R·dr) → psi_east (ir+1)
879                                //   c_r_minus = 1/dr² + 1/(2R·dr) → psi_west (ir-1)
880                                let psi_n = loc[[iz - 1, ir]];
881                                let psi_s = loc[[iz + 1, ir]];
882                                let psi_w = loc[[iz, ir - 1]];
883                                let psi_e = loc[[iz, ir + 1]];
884
885                                let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
886                                let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
887                                let c_z = 1.0 / dz_sq;
888                                let center = 2.0 / dr_sq + 2.0 / dz_sq;
889
890                                let rhs_val = src[[iz, ir]];
891                                let numerator = c_z * (psi_n + psi_s)
892                                    + c_r_minus * psi_w
893                                    + c_r_plus * psi_e
894                                    + rhs_val;
895                                if center.abs() < 1e-30 {
896                                    continue;
897                                }
898                                let psi_gs = numerator / center;
899                                let old = loc[[iz, ir]];
900                                loc[[iz, ir]] = old + cfg.omega * (psi_gs - old);
901                            }
902                        }
903                    }
904                }
905            });
906
907        // 3. Inject tile cores back into global array.
908        for (i, tile) in tiles.iter().enumerate() {
909            inject_tile(&mut global_psi, &locals[i], tile)?;
910        }
911
912        // 4. Compute global L2 residual of the GS equation.
913        residual = gs_residual_l2(&global_psi, source, r_axis, dr, dz);
914
915        if residual < cfg.tol {
916            converged = true;
917            break;
918        }
919    }
920
921    if global_psi.iter().any(|v| !v.is_finite()) {
922        return Err(FusionError::PhysicsViolation(
923            "Distributed GS solve produced non-finite Ψ".to_string(),
924        ));
925    }
926
927    Ok(DistributedSolveResult {
928        psi: global_psi,
929        residual,
930        iterations: outer_iter,
931        converged,
932    })
933}
934
935/// Compute the L2 norm of the GS residual: ||LΨ - f||₂.
936///
937/// L is the 5-point GS operator:
938///   LΨ = (Ψ_{i-1,j} + Ψ_{i+1,j})/dz² + (1/dr² - 1/(2R dr))Ψ_{i,j-1}
939///        + (1/dr² + 1/(2R dr))Ψ_{i,j+1} - 2(1/dr² + 1/dz²)Ψ_{i,j}
940pub fn gs_residual_l2(
941    psi: &Array2<f64>,
942    source: &Array2<f64>,
943    r_axis: &[f64],
944    dr: f64,
945    dz: f64,
946) -> f64 {
947    let (nz, nr) = psi.dim();
948    let dr_sq = dr * dr;
949    let dz_sq = dz * dz;
950    let mut accum = 0.0f64;
951    let mut count = 0usize;
952
953    for iz in 1..nz - 1 {
954        for ir in 1..nr - 1 {
955            let r_val = r_axis[ir];
956            if r_val <= 0.0 {
957                continue;
958            }
959            let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
960            let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
961            let c_z = 1.0 / dz_sq;
962            let center = 2.0 / dr_sq + 2.0 / dz_sq;
963
964            let l_psi = center * psi[[iz, ir]]
965                - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
966                - c_r_plus * psi[[iz, ir + 1]]
967                - c_r_minus * psi[[iz, ir - 1]];
968            let res = l_psi - source[[iz, ir]];
969            accum += res * res;
970            count += 1;
971        }
972    }
973    if count == 0 {
974        return 0.0;
975    }
976    (accum / count as f64).sqrt()
977}
978
979/// Optimal process-grid factorisation for a given (nz, nr) global grid
980/// and total number of available ranks. Minimises the surface-to-volume
981/// ratio of each tile (i.e. the halo communication overhead).
982pub fn optimal_process_grid(nz: usize, nr: usize, nranks: usize) -> (usize, usize) {
983    let mut best_pz = 1;
984    let mut best_pr = nranks;
985    let mut best_cost = f64::MAX;
986
987    for pz in 1..=nranks {
988        if !nranks.is_multiple_of(pz) {
989            continue;
990        }
991        let pr = nranks / pz;
992        if pz > nz || pr > nr {
993            continue;
994        }
995        // Surface-to-volume ratio proxy: perimeter / area of each tile.
996        let tile_nz = nz as f64 / pz as f64;
997        let tile_nr = nr as f64 / pr as f64;
998        let perimeter = 2.0 * (tile_nz + tile_nr);
999        let area = tile_nz * tile_nr;
1000        let cost = perimeter / area;
1001        if cost < best_cost {
1002            best_cost = cost;
1003            best_pz = pz;
1004            best_pr = pr;
1005        }
1006    }
1007    (best_pz, best_pr)
1008}
1009
1010/// Convenience: solve GS with automatic process-grid selection.
1011///
1012/// Detects the Rayon thread count and picks the optimal (pz, pr)
1013/// factorisation. This is the top-level entry point for exascale-ready
1014/// distributed equilibrium solving.
1015#[allow(clippy::too_many_arguments)]
1016pub fn auto_distributed_gs_solve(
1017    psi: &Array2<f64>,
1018    source: &Array2<f64>,
1019    r_axis: &[f64],
1020    z_axis: &[f64],
1021    dr: f64,
1022    dz: f64,
1023    omega: f64,
1024    tol: f64,
1025    max_iters: usize,
1026) -> FusionResult<DistributedSolveResult> {
1027    let (nz, nr) = psi.dim();
1028    let nthreads = rayon::current_num_threads().max(1);
1029    let (pz, pr) = optimal_process_grid(nz, nr, nthreads);
1030
1031    let cfg = DistributedSolverConfig {
1032        pz,
1033        pr,
1034        halo: 1,
1035        omega,
1036        max_outer_iters: max_iters,
1037        tol,
1038        inner_sweeps: 5,
1039    };
1040    distributed_gs_solve(psi, source, r_axis, z_axis, dr, dz, &cfg)
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046
1047    fn sample_grid(nz: usize, nr: usize) -> Array2<f64> {
1048        Array2::from_shape_fn((nz, nr), |(i, j)| (i as f64) * 10.0 + j as f64)
1049    }
1050
1051    #[test]
1052    fn test_decompose_z_covers_domain() {
1053        let slices = decompose_z(17, 4, 1).expect("decomposition must succeed");
1054        assert_eq!(slices.len(), 4);
1055        assert_eq!(slices[0].z_start, 0);
1056        assert_eq!(slices.last().expect("slice expected").z_end, 17);
1057        let covered: usize = slices.iter().map(|s| s.local_nz).sum();
1058        assert_eq!(covered, 17);
1059    }
1060
1061    #[test]
1062    fn test_serial_halo_exchange_and_stitch_roundtrip() {
1063        let global = sample_grid(24, 9);
1064        let slices = decompose_z(global.nrows(), 3, 1).expect("decompose");
1065        let mut locals = split_with_halo(&global, &slices).expect("split");
1066        serial_halo_exchange(&mut locals, &slices).expect("exchange");
1067        let stitched = stitch_without_halo(&locals, &slices, global.ncols()).expect("stitch");
1068        let delta = l2_norm_delta(&stitched, &global).expect("delta");
1069        assert!(
1070            delta < 1e-12,
1071            "Serial halo exchange should preserve core rows"
1072        );
1073    }
1074
1075    #[test]
1076    fn test_pack_halo_errors_for_small_local_block() {
1077        let local = Array2::zeros((2, 4));
1078        let err = pack_halo_rows(&local, 1).expect_err("small local must error");
1079        match err {
1080            FusionError::PhysicsViolation(msg) => {
1081                assert!(msg.contains("insufficient rows"));
1082            }
1083            other => panic!("Unexpected error: {other:?}"),
1084        }
1085    }
1086
1087    #[test]
1088    fn test_apply_halo_shape_guard() {
1089        let mut local = Array2::zeros((6, 4));
1090        let bad_top = Array2::zeros((2, 5));
1091        let err = apply_halo_rows(&mut local, 1, Some(&bad_top), None).expect_err("shape mismatch");
1092        match err {
1093            FusionError::PhysicsViolation(msg) => assert!(msg.contains("shape mismatch")),
1094            other => panic!("Unexpected error: {other:?}"),
1095        }
1096    }
1097
1098    #[test]
1099    fn test_l2_norm_delta_zero_for_identical_arrays() {
1100        let a = sample_grid(8, 8);
1101        let b = a.clone();
1102        let d = l2_norm_delta(&a, &b).expect("delta");
1103        assert!(d.abs() < 1e-12);
1104    }
1105
1106    #[test]
1107    fn test_mpi_domain_rejects_non_finite_inputs() {
1108        let mut local = sample_grid(8, 4);
1109        local[[2, 2]] = f64::NAN;
1110        let err = pack_halo_rows(&local, 1).expect_err("non-finite local should fail");
1111        match err {
1112            FusionError::PhysicsViolation(msg) => assert!(msg.contains("non-finite")),
1113            other => panic!("Unexpected error: {other:?}"),
1114        }
1115
1116        let mut a = sample_grid(4, 4);
1117        let b = sample_grid(4, 4);
1118        a[[0, 0]] = f64::INFINITY;
1119        let err = l2_norm_delta(&a, &b).expect_err("non-finite delta input should fail");
1120        match err {
1121            FusionError::PhysicsViolation(msg) => assert!(msg.contains("finite")),
1122            other => panic!("Unexpected error: {other:?}"),
1123        }
1124    }
1125
1126    // ═══════════════════════════════════════════════════════════════════
1127    // 2D Cartesian Decomposition Tests
1128    // ═══════════════════════════════════════════════════════════════════
1129
1130    #[test]
1131    fn test_decompose_2d_covers_full_domain() {
1132        let tiles = decompose_2d(32, 24, 4, 3, 1).expect("decompose_2d");
1133        assert_eq!(tiles.len(), 12); // 4*3
1134                                     // Every global (iz, ir) must be owned by exactly one tile.
1135        let mut coverage = Array2::<u8>::zeros((32, 24));
1136        for t in &tiles {
1137            for iz in t.z_start..t.z_end {
1138                for ir in t.r_start..t.r_end {
1139                    coverage[[iz, ir]] += 1;
1140                }
1141            }
1142        }
1143        assert!(
1144            coverage.iter().all(|&c| c == 1),
1145            "Every cell owned by exactly one tile"
1146        );
1147    }
1148
1149    #[test]
1150    fn test_decompose_2d_single_rank() {
1151        let tiles = decompose_2d(16, 16, 1, 1, 2).expect("single rank");
1152        assert_eq!(tiles.len(), 1);
1153        let t = &tiles[0];
1154        assert_eq!(t.z_start, 0);
1155        assert_eq!(t.z_end, 16);
1156        assert_eq!(t.r_start, 0);
1157        assert_eq!(t.r_end, 16);
1158        assert_eq!(t.padded_nz(), 16);
1159        assert_eq!(t.padded_nr(), 16);
1160    }
1161
1162    #[test]
1163    fn test_decompose_2d_neighbor_ranks() {
1164        let tiles = decompose_2d(16, 16, 2, 2, 1).expect("2x2");
1165        // Top-left corner (0,0): has bottom and right neighbours.
1166        let tl = &tiles[0];
1167        assert_eq!(tl.neighbor_rank(1, 0), Some(2));
1168        assert_eq!(tl.neighbor_rank(0, 1), Some(1));
1169        assert_eq!(tl.neighbor_rank(-1, 0), None);
1170        assert_eq!(tl.neighbor_rank(0, -1), None);
1171        // Bottom-right corner (1,1).
1172        let br = &tiles[3];
1173        assert_eq!(br.neighbor_rank(-1, 0), Some(1));
1174        assert_eq!(br.neighbor_rank(0, -1), Some(2));
1175        assert_eq!(br.neighbor_rank(1, 0), None);
1176        assert_eq!(br.neighbor_rank(0, 1), None);
1177    }
1178
1179    #[test]
1180    fn test_extract_inject_roundtrip() {
1181        let global = sample_grid(24, 18);
1182        let tiles = decompose_2d(24, 18, 3, 2, 1).expect("decompose");
1183        let mut reconstructed = Array2::<f64>::zeros((24, 18));
1184        for t in &tiles {
1185            let local = extract_tile(&global, t).expect("extract");
1186            inject_tile(&mut reconstructed, &local, t).expect("inject");
1187        }
1188        let delta = l2_norm_delta(&global, &reconstructed).expect("delta");
1189        assert!(
1190            delta < 1e-12,
1191            "Extract→inject roundtrip must be lossless, got delta={delta}"
1192        );
1193    }
1194
1195    #[test]
1196    fn test_serial_halo_exchange_2d_correctness() {
1197        // Create a global array, split into tiles, exchange halos,
1198        // then verify that halo cells match the original global data.
1199        let global = Array2::from_shape_fn((16, 16), |(i, j)| (i as f64) * 100.0 + j as f64);
1200        let tiles = decompose_2d(16, 16, 2, 2, 1).expect("decompose");
1201        let mut locals: Vec<Array2<f64>> = tiles
1202            .iter()
1203            .map(|t| extract_tile(&global, t).expect("extract"))
1204            .collect();
1205        serial_halo_exchange_2d(&mut locals, &tiles).expect("halo exchange");
1206
1207        // For each tile, verify that every cell (including halo)
1208        // matches the original global value.
1209        for (i, t) in tiles.iter().enumerate() {
1210            let loc = &locals[i];
1211            let z0 = t.z_start.saturating_sub(t.core_z_offset());
1212            let r0 = t.r_start.saturating_sub(t.core_r_offset());
1213            for lz in 0..loc.nrows() {
1214                for lr in 0..loc.ncols() {
1215                    let gz = z0 + lz;
1216                    let gr = r0 + lr;
1217                    if gz < 16 && gr < 16 {
1218                        let expect = global[[gz, gr]];
1219                        let got = loc[[lz, lr]];
1220                        assert!(
1221                            (expect - got).abs() < 1e-12,
1222                            "Tile {i} at local ({lz},{lr}) global ({gz},{gr}): expected {expect}, got {got}"
1223                        );
1224                    }
1225                }
1226            }
1227        }
1228    }
1229
1230    #[test]
1231    fn test_optimal_process_grid_square() {
1232        let (pz, pr) = optimal_process_grid(64, 64, 4);
1233        assert_eq!(pz, 2);
1234        assert_eq!(pr, 2);
1235    }
1236
1237    #[test]
1238    fn test_optimal_process_grid_rectangular() {
1239        // Tall grid (nz >> nr): should put more processes along Z.
1240        let (pz, pr) = optimal_process_grid(128, 32, 8);
1241        assert!(
1242            pz >= pr,
1243            "Tall grid should bias toward Z decomposition: pz={pz}, pr={pr}"
1244        );
1245        assert_eq!(pz * pr, 8);
1246    }
1247
1248    #[test]
1249    fn test_gs_residual_l2_zero_for_exact_solution() {
1250        // Manufacture source = L Ψ for a known Ψ, then verify residual ≈ 0.
1251        // Using the same convention as sor.rs:
1252        //   L Ψ = center*Ψ - c_z*(Ψ_up + Ψ_dn) - c_r_plus*Ψ_right - c_r_minus*Ψ_left
1253        let nz = 16;
1254        let nr = 16;
1255        let dr = 0.01;
1256        let dz = 0.01;
1257        let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1258        let _z_axis: Vec<f64> = (0..nz).map(|i| -0.08 + i as f64 * dz).collect();
1259        let psi = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1260            _z_axis[iz] * _z_axis[iz] + r_axis[ir] * r_axis[ir]
1261        });
1262        let dr_sq = dr * dr;
1263        let dz_sq = dz * dz;
1264        let mut source = Array2::zeros((nz, nr));
1265        for iz in 1..nz - 1 {
1266            for ir in 1..nr - 1 {
1267                let r = r_axis[ir];
1268                let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r * dr);
1269                let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r * dr);
1270                let c_z = 1.0 / dz_sq;
1271                let center = 2.0 / dr_sq + 2.0 / dz_sq;
1272                source[[iz, ir]] = center * psi[[iz, ir]]
1273                    - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
1274                    - c_r_plus * psi[[iz, ir + 1]]
1275                    - c_r_minus * psi[[iz, ir - 1]];
1276            }
1277        }
1278        let res = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1279        assert!(
1280            res < 1e-10,
1281            "Residual should be ~0 for manufactured solution, got {res}"
1282        );
1283    }
1284
1285    #[test]
1286    fn test_distributed_gs_solve_smoke() {
1287        // Solve LΨ = f on a small grid and verify residual decreases.
1288        let nz = 32;
1289        let nr = 32;
1290        let dr = 0.01;
1291        let dz = 0.01;
1292        let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1293        let z_axis: Vec<f64> = (0..nz).map(|i| -0.16 + i as f64 * dz).collect();
1294
1295        // Use a moderate source (amplitude 1.0, not 100).
1296        let psi = Array2::zeros((nz, nr));
1297        let r_mid = 1.0 + (nr as f64 / 2.0) * dr;
1298        let z_mid = 0.0;
1299        let source = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1300            let rr = r_axis[ir] - r_mid;
1301            let zz = z_axis[iz] - z_mid;
1302            -((-(rr * rr + zz * zz) / (0.05 * 0.05)).exp())
1303        });
1304
1305        // Compute initial residual for comparison.
1306        let res_initial = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1307
1308        let cfg = DistributedSolverConfig {
1309            pz: 2,
1310            pr: 2,
1311            halo: 1,
1312            omega: 1.6,
1313            max_outer_iters: 200,
1314            tol: 1e-6,
1315            inner_sweeps: 10,
1316        };
1317        let result = distributed_gs_solve(&psi, &source, &r_axis, &z_axis, dr, dz, &cfg)
1318            .expect("distributed solve");
1319        // Residual should have decreased significantly from initial.
1320        assert!(
1321            result.residual < res_initial * 0.5,
1322            "Residual should decrease: initial={res_initial:.4}, final={:.4}",
1323            result.residual
1324        );
1325        // Ψ should be non-trivial (negative source → negative interior values).
1326        let psi_absmax = result.psi.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
1327        assert!(
1328            psi_absmax > 1e-10,
1329            "Solution should be non-trivial, max |Ψ| = {psi_absmax}"
1330        );
1331    }
1332
1333    #[test]
1334    fn test_decompose_2d_rejects_invalid_inputs() {
1335        assert!(decompose_2d(16, 16, 0, 2, 1).is_err());
1336        assert!(decompose_2d(16, 16, 2, 0, 1).is_err());
1337        assert!(decompose_2d(1, 16, 2, 2, 1).is_err());
1338        assert!(decompose_2d(16, 16, 20, 2, 1).is_err());
1339    }
1340}