fusion_core/
mpi_domain.rs

1// ─────────────────────────────────────────────────────────────────────
2// SCPN Fusion Core — MPI Domain Scaffolding
3// © 1998–2026 Miroslav Šotek. All rights reserved.
4// Contact: www.anulum.li | protoscience@anulum.li
5// ORCID: https://orcid.org/0009-0009-3560-0851
6// License: GNU AGPL v3 | Commercial licensing available
7// ─────────────────────────────────────────────────────────────────────
8//! MPI-oriented domain decomposition scaffolding.
9//!
10//! This module defines deterministic domain partition metadata and halo
11//! packing/exchange primitives that can be wired to rsmpi in a later phase.
12
13use fusion_types::error::{FusionError, FusionResult};
14use ndarray::{s, Array2};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct DomainSlice {
18    pub rank: usize,
19    pub nranks: usize,
20    pub global_nz: usize,
21    pub local_nz: usize,
22    pub halo: usize,
23    pub z_start: usize,
24    pub z_end: usize,
25}
26
27impl DomainSlice {
28    pub fn has_upper_neighbor(&self) -> bool {
29        self.rank > 0
30    }
31
32    pub fn has_lower_neighbor(&self) -> bool {
33        self.rank + 1 < self.nranks
34    }
35}
36
37pub fn decompose_z(global_nz: usize, nranks: usize, halo: usize) -> FusionResult<Vec<DomainSlice>> {
38    if global_nz < 2 {
39        return Err(FusionError::PhysicsViolation(
40            "MPI decomposition requires global_nz >= 2".to_string(),
41        ));
42    }
43    if nranks < 1 {
44        return Err(FusionError::PhysicsViolation(
45            "MPI decomposition requires nranks >= 1".to_string(),
46        ));
47    }
48    if nranks > global_nz {
49        return Err(FusionError::PhysicsViolation(format!(
50            "Cannot split global_nz={global_nz} across nranks={nranks}"
51        )));
52    }
53
54    let base = global_nz / nranks;
55    let rem = global_nz % nranks;
56    let mut out = Vec::with_capacity(nranks);
57    let mut cursor = 0usize;
58    for rank in 0..nranks {
59        let local_nz = base + usize::from(rank < rem);
60        let z_start = cursor;
61        let z_end = z_start + local_nz;
62        cursor = z_end;
63        out.push(DomainSlice {
64            rank,
65            nranks,
66            global_nz,
67            local_nz,
68            halo,
69            z_start,
70            z_end,
71        });
72    }
73    Ok(out)
74}
75
76pub fn pack_halo_rows(
77    local: &Array2<f64>,
78    halo: usize,
79) -> FusionResult<(Array2<f64>, Array2<f64>)> {
80    if halo == 0 {
81        return Err(FusionError::PhysicsViolation(
82            "Halo width must be >= 1".to_string(),
83        ));
84    }
85    if local.nrows() <= 2 * halo {
86        return Err(FusionError::PhysicsViolation(format!(
87            "Local block has insufficient rows {} for halo={halo}",
88            local.nrows()
89        )));
90    }
91    if local.iter().any(|v| !v.is_finite()) {
92        return Err(FusionError::PhysicsViolation(
93            "Local block contains non-finite values".to_string(),
94        ));
95    }
96    let top = local.slice(s![halo..(2 * halo), ..]).to_owned();
97    let bottom = local
98        .slice(s![(local.nrows() - 2 * halo)..(local.nrows() - halo), ..])
99        .to_owned();
100    if top.iter().any(|v| !v.is_finite()) || bottom.iter().any(|v| !v.is_finite()) {
101        return Err(FusionError::PhysicsViolation(
102            "Packed halo rows contain non-finite values".to_string(),
103        ));
104    }
105    Ok((top, bottom))
106}
107
108pub fn apply_halo_rows(
109    local: &mut Array2<f64>,
110    halo: usize,
111    recv_top: Option<&Array2<f64>>,
112    recv_bottom: Option<&Array2<f64>>,
113) -> FusionResult<()> {
114    if halo == 0 {
115        return Err(FusionError::PhysicsViolation(
116            "Halo width must be >= 1".to_string(),
117        ));
118    }
119    if local.nrows() <= 2 * halo {
120        return Err(FusionError::PhysicsViolation(format!(
121            "Local block has insufficient rows {} for halo={halo}",
122            local.nrows()
123        )));
124    }
125    if local.iter().any(|v| !v.is_finite()) {
126        return Err(FusionError::PhysicsViolation(
127            "Local block contains non-finite values".to_string(),
128        ));
129    }
130    if let Some(top) = recv_top {
131        if top.dim() != (halo, local.ncols()) {
132            return Err(FusionError::PhysicsViolation(format!(
133                "Top halo shape mismatch: expected ({halo}, {}), got {:?}",
134                local.ncols(),
135                top.dim()
136            )));
137        }
138        if top.iter().any(|v| !v.is_finite()) {
139            return Err(FusionError::PhysicsViolation(
140                "Top halo contains non-finite values".to_string(),
141            ));
142        }
143        local.slice_mut(s![0..halo, ..]).assign(top);
144    }
145    if let Some(bottom) = recv_bottom {
146        if bottom.dim() != (halo, local.ncols()) {
147            return Err(FusionError::PhysicsViolation(format!(
148                "Bottom halo shape mismatch: expected ({halo}, {}), got {:?}",
149                local.ncols(),
150                bottom.dim()
151            )));
152        }
153        if bottom.iter().any(|v| !v.is_finite()) {
154            return Err(FusionError::PhysicsViolation(
155                "Bottom halo contains non-finite values".to_string(),
156            ));
157        }
158        let n = local.nrows();
159        local.slice_mut(s![(n - halo)..n, ..]).assign(bottom);
160    }
161    if local.iter().any(|v| !v.is_finite()) {
162        return Err(FusionError::PhysicsViolation(
163            "Applying halo rows produced non-finite values".to_string(),
164        ));
165    }
166    Ok(())
167}
168
169pub fn split_with_halo(
170    global: &Array2<f64>,
171    slices: &[DomainSlice],
172) -> FusionResult<Vec<Array2<f64>>> {
173    if slices.is_empty() {
174        return Err(FusionError::PhysicsViolation(
175            "No slices provided for split_with_halo".to_string(),
176        ));
177    }
178    if global.iter().any(|v| !v.is_finite()) {
179        return Err(FusionError::PhysicsViolation(
180            "Global array contains non-finite values".to_string(),
181        ));
182    }
183    let mut out = Vec::with_capacity(slices.len());
184    for sdef in slices {
185        if sdef.global_nz != global.nrows() {
186            return Err(FusionError::PhysicsViolation(format!(
187                "Slice/global mismatch: slice.global_nz={} global.nrows()={}",
188                sdef.global_nz,
189                global.nrows()
190            )));
191        }
192        if sdef.z_start >= sdef.z_end || sdef.z_end > sdef.global_nz {
193            return Err(FusionError::PhysicsViolation(format!(
194                "Invalid slice bounds z_start={} z_end={} global_nz={}",
195                sdef.z_start, sdef.z_end, sdef.global_nz
196            )));
197        }
198        let start = sdef.z_start.saturating_sub(sdef.halo);
199        let end = (sdef.z_end + sdef.halo).min(sdef.global_nz);
200        let mut local = Array2::zeros((end - start, global.ncols()));
201        local.assign(&global.slice(s![start..end, ..]));
202        if local.iter().any(|v| !v.is_finite()) {
203            return Err(FusionError::PhysicsViolation(
204                "Split local block contains non-finite values".to_string(),
205            ));
206        }
207        out.push(local);
208    }
209    Ok(out)
210}
211
212pub fn stitch_without_halo(
213    locals: &[Array2<f64>],
214    slices: &[DomainSlice],
215    ncols: usize,
216) -> FusionResult<Array2<f64>> {
217    if locals.len() != slices.len() {
218        return Err(FusionError::PhysicsViolation(format!(
219            "locals/slices mismatch: {} vs {}",
220            locals.len(),
221            slices.len()
222        )));
223    }
224    if slices.is_empty() {
225        return Err(FusionError::PhysicsViolation(
226            "No slices provided for stitch_without_halo".to_string(),
227        ));
228    }
229    let global_nz = slices
230        .last()
231        .map(|s| s.global_nz)
232        .ok_or_else(|| FusionError::PhysicsViolation("No slices provided".to_string()))?;
233    let mut global = Array2::zeros((global_nz, ncols));
234    for (local, sdef) in locals.iter().zip(slices.iter()) {
235        if local.iter().any(|v| !v.is_finite()) {
236            return Err(FusionError::PhysicsViolation(
237                "Local block contains non-finite values".to_string(),
238            ));
239        }
240        if local.ncols() != ncols {
241            return Err(FusionError::PhysicsViolation(format!(
242                "Local ncols mismatch: expected {ncols}, got {}",
243                local.ncols()
244            )));
245        }
246        let core_start = usize::from(sdef.z_start > 0) * sdef.halo;
247        let core_end = core_start + sdef.local_nz;
248        if core_end > local.nrows() {
249            return Err(FusionError::PhysicsViolation(format!(
250                "Local core range out of bounds: rows={}, core_end={core_end}",
251                local.nrows()
252            )));
253        }
254        global
255            .slice_mut(s![sdef.z_start..sdef.z_end, ..])
256            .assign(&local.slice(s![core_start..core_end, ..]));
257    }
258    if global.iter().any(|v| !v.is_finite()) {
259        return Err(FusionError::PhysicsViolation(
260            "Stitched global array contains non-finite values".to_string(),
261        ));
262    }
263    Ok(global)
264}
265
266pub fn serial_halo_exchange(
267    locals: &mut [Array2<f64>],
268    slices: &[DomainSlice],
269) -> FusionResult<()> {
270    if locals.len() != slices.len() {
271        return Err(FusionError::PhysicsViolation(format!(
272            "locals/slices mismatch: {} vs {}",
273            locals.len(),
274            slices.len()
275        )));
276    }
277    let mut top_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
278    let mut bottom_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
279    for (i, (local, sdef)) in locals.iter().zip(slices.iter()).enumerate() {
280        if local.iter().any(|v| !v.is_finite()) {
281            return Err(FusionError::PhysicsViolation(format!(
282                "Local block at index {i} contains non-finite values"
283            )));
284        }
285        if sdef.halo == 0 {
286            continue;
287        }
288        let (top, bottom) = pack_halo_rows(local, sdef.halo)?;
289        top_send[i] = Some(top);
290        bottom_send[i] = Some(bottom);
291    }
292
293    for i in 0..locals.len() {
294        let halo = slices[i].halo;
295        if halo == 0 {
296            continue;
297        }
298        let recv_top = if i > 0 {
299            bottom_send[i - 1].as_ref()
300        } else {
301            None
302        };
303        let recv_bottom = if i + 1 < locals.len() {
304            top_send[i + 1].as_ref()
305        } else {
306            None
307        };
308        apply_halo_rows(&mut locals[i], halo, recv_top, recv_bottom)?;
309    }
310    if locals.iter().any(|arr| arr.iter().any(|v| !v.is_finite())) {
311        return Err(FusionError::PhysicsViolation(
312            "Serial halo exchange produced non-finite values".to_string(),
313        ));
314    }
315    Ok(())
316}
317
318pub fn l2_norm_delta(a: &Array2<f64>, b: &Array2<f64>) -> FusionResult<f64> {
319    if a.dim() != b.dim() {
320        return Err(FusionError::PhysicsViolation(format!(
321            "l2_norm_delta shape mismatch {:?} vs {:?}",
322            a.dim(),
323            b.dim()
324        )));
325    }
326    if a.iter().any(|v| !v.is_finite()) || b.iter().any(|v| !v.is_finite()) {
327        return Err(FusionError::PhysicsViolation(
328            "l2_norm_delta inputs must be finite".to_string(),
329        ));
330    }
331    let mut accum = 0.0f64;
332    for (av, bv) in a.iter().zip(b.iter()) {
333        let d = av - bv;
334        accum += d * d;
335    }
336    let out = accum.sqrt();
337    if !out.is_finite() {
338        return Err(FusionError::PhysicsViolation(
339            "l2_norm_delta produced non-finite result".to_string(),
340        ));
341    }
342    Ok(out)
343}
344
345// ═══════════════════════════════════════════════════════════════════════
346// 2D Cartesian Domain Decomposition — Exascale-Ready MPI Abstraction
347// ═══════════════════════════════════════════════════════════════════════
348//
349// The key innovation: we decompose the (nz × nr) GS grid into a 2D
350// Cartesian process grid (pz × pr), where each tile owns a contiguous
351// sub-block plus halo cells on all four faces. The Rayon threadpool
352// simulates distributed-memory ranks; replacing with rsmpi is a 1:1
353// swap of the halo exchange function.
354//
355// References:
356//   - Jardin, "Computational Methods in Plasma Physics", Ch. 12
357//   - Lao et al., "Equilibrium analysis of current profiles in tokamaks",
358//     Nucl. Fusion 30 (1990) 1035
359//   - EFIT-AI: Joung et al., Nucl. Fusion 63 (2023) 126058
360
361/// 2D Cartesian tile descriptor — one per rank in a (pz × pr) topology.
362#[derive(Debug, Clone, PartialEq, Eq)]
363pub struct CartesianTile {
364    /// Linear rank index (0 .. pz*pr - 1).
365    pub rank: usize,
366    /// Process grid index along Z (row).
367    pub pz_idx: usize,
368    /// Process grid index along R (column).
369    pub pr_idx: usize,
370    /// Total process-grid dimensions.
371    pub pz: usize,
372    pub pr: usize,
373    /// Global grid dimensions.
374    pub global_nz: usize,
375    pub global_nr: usize,
376    /// Halo width (same on all faces).
377    pub halo: usize,
378    /// Owned Z range [z_start, z_end) in global indexing.
379    pub z_start: usize,
380    pub z_end: usize,
381    /// Owned R range [r_start, r_end) in global indexing.
382    pub r_start: usize,
383    pub r_end: usize,
384}
385
386impl CartesianTile {
387    /// Number of owned Z rows (excluding halo).
388    pub fn local_nz(&self) -> usize {
389        self.z_end - self.z_start
390    }
391    /// Number of owned R columns (excluding halo).
392    pub fn local_nr(&self) -> usize {
393        self.r_end - self.r_start
394    }
395    /// Total rows including halo (top + bottom).
396    pub fn padded_nz(&self) -> usize {
397        let top = if self.pz_idx > 0 { self.halo } else { 0 };
398        let bot = if self.pz_idx + 1 < self.pz {
399            self.halo
400        } else {
401            0
402        };
403        self.local_nz() + top + bot
404    }
405    /// Total columns including halo (left + right).
406    pub fn padded_nr(&self) -> usize {
407        let left = if self.pr_idx > 0 { self.halo } else { 0 };
408        let right = if self.pr_idx + 1 < self.pr {
409            self.halo
410        } else {
411            0
412        };
413        self.local_nr() + left + right
414    }
415    /// Offset of the first owned row within the padded local array.
416    pub fn core_z_offset(&self) -> usize {
417        if self.pz_idx > 0 {
418            self.halo
419        } else {
420            0
421        }
422    }
423    /// Offset of the first owned column within the padded local array.
424    pub fn core_r_offset(&self) -> usize {
425        if self.pr_idx > 0 {
426            self.halo
427        } else {
428            0
429        }
430    }
431    pub fn has_neighbor_top(&self) -> bool {
432        self.pz_idx > 0
433    }
434    pub fn has_neighbor_bottom(&self) -> bool {
435        self.pz_idx + 1 < self.pz
436    }
437    pub fn has_neighbor_left(&self) -> bool {
438        self.pr_idx > 0
439    }
440    pub fn has_neighbor_right(&self) -> bool {
441        self.pr_idx + 1 < self.pr
442    }
443    /// Rank of a neighbour at offset (dz, dr) in the process grid.
444    /// Returns None if out of bounds.
445    pub fn neighbor_rank(&self, dz: i32, dr: i32) -> Option<usize> {
446        let nz = self.pz_idx as i32 + dz;
447        let nr = self.pr_idx as i32 + dr;
448        if nz < 0 || nz >= self.pz as i32 || nr < 0 || nr >= self.pr as i32 {
449            return None;
450        }
451        Some(nz as usize * self.pr + nr as usize)
452    }
453}
454
455/// Decompose a 2D grid of shape (global_nz × global_nr) into a
456/// (pz × pr) Cartesian process topology.
457///
458/// Returns tiles in row-major order: tile[iz * pr + ir].
459pub fn decompose_2d(
460    global_nz: usize,
461    global_nr: usize,
462    pz: usize,
463    pr: usize,
464    halo: usize,
465) -> FusionResult<Vec<CartesianTile>> {
466    if pz == 0 || pr == 0 {
467        return Err(FusionError::PhysicsViolation(
468            "Process grid dimensions pz, pr must be >= 1".to_string(),
469        ));
470    }
471    if pz > global_nz || pr > global_nr {
472        return Err(FusionError::PhysicsViolation(format!(
473            "Cannot split ({global_nz}×{global_nr}) across ({pz}×{pr}) processes"
474        )));
475    }
476    if global_nz < 2 || global_nr < 2 {
477        return Err(FusionError::PhysicsViolation(
478            "Global grid must be at least 2×2".to_string(),
479        ));
480    }
481
482    // Distribute rows/columns as evenly as possible.
483    let z_splits = balanced_split(global_nz, pz);
484    let r_splits = balanced_split(global_nr, pr);
485
486    let mut tiles = Vec::with_capacity(pz * pr);
487    let mut z_cursor = 0usize;
488    for (iz, nz_local) in z_splits.iter().copied().enumerate().take(pz) {
489        let z_start = z_cursor;
490        let z_end = z_start + nz_local;
491        z_cursor = z_end;
492
493        let mut r_cursor = 0usize;
494        for (ir, nr_local) in r_splits.iter().copied().enumerate().take(pr) {
495            let r_start = r_cursor;
496            let r_end = r_start + nr_local;
497            r_cursor = r_end;
498
499            tiles.push(CartesianTile {
500                rank: iz * pr + ir,
501                pz_idx: iz,
502                pr_idx: ir,
503                pz,
504                pr,
505                global_nz,
506                global_nr,
507                halo,
508                z_start,
509                z_end,
510                r_start,
511                r_end,
512            });
513        }
514    }
515    Ok(tiles)
516}
517
518/// Helper: split `n` items across `k` buckets as evenly as possible.
519fn balanced_split(n: usize, k: usize) -> Vec<usize> {
520    let base = n / k;
521    let rem = n % k;
522    (0..k).map(|i| base + usize::from(i < rem)).collect()
523}
524
525/// Extract a padded local tile (with halo) from the global array.
526pub fn extract_tile(global: &Array2<f64>, tile: &CartesianTile) -> FusionResult<Array2<f64>> {
527    let (gnz, gnr) = global.dim();
528    if gnz != tile.global_nz || gnr != tile.global_nr {
529        return Err(FusionError::PhysicsViolation(format!(
530            "Global shape ({gnz},{gnr}) doesn't match tile expectation ({},{})",
531            tile.global_nz, tile.global_nr
532        )));
533    }
534    let z0 = tile.z_start.saturating_sub(tile.core_z_offset());
535    let z1 = (tile.z_end
536        + if tile.has_neighbor_bottom() {
537            tile.halo
538        } else {
539            0
540        })
541    .min(gnz);
542    let r0 = tile.r_start.saturating_sub(tile.core_r_offset());
543    let r1 = (tile.r_end
544        + if tile.has_neighbor_right() {
545            tile.halo
546        } else {
547            0
548        })
549    .min(gnr);
550
551    let local = global.slice(s![z0..z1, r0..r1]).to_owned();
552    if local.iter().any(|v| !v.is_finite()) {
553        return Err(FusionError::PhysicsViolation(
554            "Extracted tile contains non-finite values".to_string(),
555        ));
556    }
557    Ok(local)
558}
559
560/// Write the core (non-halo) region of a local tile back into the global array.
561pub fn inject_tile(
562    global: &mut Array2<f64>,
563    local: &Array2<f64>,
564    tile: &CartesianTile,
565) -> FusionResult<()> {
566    let cz = tile.core_z_offset();
567    let cr = tile.core_r_offset();
568    let lnz = tile.local_nz();
569    let lnr = tile.local_nr();
570    if cz + lnz > local.nrows() || cr + lnr > local.ncols() {
571        return Err(FusionError::PhysicsViolation(format!(
572            "Tile core out of bounds: local shape {:?}, core ({cz}+{lnz}, {cr}+{lnr})",
573            local.dim()
574        )));
575    }
576    let core = local.slice(s![cz..(cz + lnz), cr..(cr + lnr)]);
577    if core.iter().any(|v| !v.is_finite()) {
578        return Err(FusionError::PhysicsViolation(
579            "Tile core to inject contains non-finite values".to_string(),
580        ));
581    }
582    global
583        .slice_mut(s![tile.z_start..tile.z_end, tile.r_start..tile.r_end])
584        .assign(&core);
585    Ok(())
586}
587
588/// Serial 2D halo exchange across all tiles.
589///
590/// Copies owned boundary rows/columns from each tile into the halo
591/// region of its four face-neighbours. This is the serial reference
592/// implementation; the MPI version replaces this with non-blocking
593/// Isend/Irecv pairs.
594pub fn serial_halo_exchange_2d(
595    locals: &mut [Array2<f64>],
596    tiles: &[CartesianTile],
597) -> FusionResult<()> {
598    if locals.len() != tiles.len() {
599        return Err(FusionError::PhysicsViolation(format!(
600            "locals/tiles length mismatch: {} vs {}",
601            locals.len(),
602            tiles.len()
603        )));
604    }
605    let ntiles = tiles.len();
606    if ntiles == 0 {
607        return Ok(());
608    }
609    let halo = tiles[0].halo;
610    if halo == 0 {
611        return Ok(());
612    }
613
614    // Collect halo strips from all tiles first (immutable borrows).
615    // Each entry: (dest_rank, HaloFace, data).
616    #[derive(Clone, Copy)]
617    enum Face {
618        Top,
619        Bottom,
620        Left,
621        Right,
622    }
623    let mut messages: Vec<(usize, Face, Array2<f64>)> = Vec::new();
624
625    for (i, tile) in tiles.iter().enumerate() {
626        let loc = &locals[i];
627        let cz = tile.core_z_offset();
628        let cr = tile.core_r_offset();
629        let lnz = tile.local_nz();
630        let lnr = tile.local_nr();
631
632        // Send top face → neighbor above
633        if let Some(dest) = tile.neighbor_rank(-1, 0) {
634            let strip = loc.slice(s![cz..(cz + halo), cr..(cr + lnr)]).to_owned();
635            messages.push((dest, Face::Bottom, strip));
636        }
637        // Send bottom face → neighbor below
638        if let Some(dest) = tile.neighbor_rank(1, 0) {
639            let strip = loc
640                .slice(s![(cz + lnz - halo)..(cz + lnz), cr..(cr + lnr)])
641                .to_owned();
642            messages.push((dest, Face::Top, strip));
643        }
644        // Send left face → neighbor to the left
645        if let Some(dest) = tile.neighbor_rank(0, -1) {
646            let strip = loc.slice(s![cz..(cz + lnz), cr..(cr + halo)]).to_owned();
647            messages.push((dest, Face::Right, strip));
648        }
649        // Send right face → neighbor to the right
650        if let Some(dest) = tile.neighbor_rank(0, 1) {
651            let strip = loc
652                .slice(s![cz..(cz + lnz), (cr + lnr - halo)..(cr + lnr)])
653                .to_owned();
654            messages.push((dest, Face::Left, strip));
655        }
656    }
657
658    // Apply received halos.
659    for (dest, face, data) in messages {
660        let tile = &tiles[dest];
661        let loc = &mut locals[dest];
662        let cz = tile.core_z_offset();
663        let cr = tile.core_r_offset();
664        let lnz = tile.local_nz();
665        let lnr = tile.local_nr();
666
667        match face {
668            Face::Top => {
669                // Fill top halo rows.
670                if cz >= halo {
671                    loc.slice_mut(s![(cz - halo)..cz, cr..(cr + lnr)])
672                        .assign(&data);
673                }
674            }
675            Face::Bottom => {
676                // Fill bottom halo rows.
677                let row_start = cz + lnz;
678                let row_end = row_start + halo;
679                if row_end <= loc.nrows() {
680                    loc.slice_mut(s![row_start..row_end, cr..(cr + lnr)])
681                        .assign(&data);
682                }
683            }
684            Face::Left => {
685                // Fill left halo columns.
686                if cr >= halo {
687                    loc.slice_mut(s![cz..(cz + lnz), (cr - halo)..cr])
688                        .assign(&data);
689                }
690            }
691            Face::Right => {
692                // Fill right halo columns.
693                let col_start = cr + lnr;
694                let col_end = col_start + halo;
695                if col_end <= loc.ncols() {
696                    loc.slice_mut(s![cz..(cz + lnz), col_start..col_end])
697                        .assign(&data);
698                }
699            }
700        }
701    }
702
703    // Verify no non-finite values crept in.
704    for loc in locals.iter() {
705        if loc.iter().any(|v| !v.is_finite()) {
706            return Err(FusionError::PhysicsViolation(
707                "2D halo exchange produced non-finite values".to_string(),
708            ));
709        }
710    }
711    Ok(())
712}
713
714/// Configuration for the distributed GS solver.
715#[derive(Debug, Clone)]
716pub struct DistributedSolverConfig {
717    /// Process-grid dimensions (pz × pr). Product must be ≤ Rayon
718    /// thread count for full parallel utilisation.
719    pub pz: usize,
720    pub pr: usize,
721    /// Halo width (number of overlap rows/columns per face). 1 is
722    /// sufficient for the 5-point GS stencil.
723    pub halo: usize,
724    /// SOR relaxation parameter ω ∈ (1, 2). Typically 1.8.
725    pub omega: f64,
726    /// Maximum number of Schwarz (outer) iterations.
727    pub max_outer_iters: usize,
728    /// Convergence tolerance on global L2 residual.
729    pub tol: f64,
730    /// Number of local SOR sweeps per Schwarz iteration.
731    pub inner_sweeps: usize,
732}
733
734impl Default for DistributedSolverConfig {
735    fn default() -> Self {
736        Self {
737            pz: 2,
738            pr: 2,
739            halo: 1,
740            omega: 1.8,
741            max_outer_iters: 200,
742            tol: 1e-8,
743            inner_sweeps: 5,
744        }
745    }
746}
747
748/// Result of a distributed GS solve.
749#[derive(Debug, Clone)]
750pub struct DistributedSolveResult {
751    /// Final global Ψ array.
752    pub psi: Array2<f64>,
753    /// Achieved global L2 residual.
754    pub residual: f64,
755    /// Number of Schwarz outer iterations used.
756    pub iterations: usize,
757    /// Whether the solve converged within tolerance.
758    pub converged: bool,
759}
760
761/// Distributed Grad-Shafranov solver using additive Schwarz domain
762/// decomposition with Rayon thread-parallelism.
763///
764/// Each Schwarz iteration:
765/// 1. Splits the global Ψ into 2D tiles (with halo overlap).
766/// 2. Runs `inner_sweeps` local Red-Black SOR sweeps on each tile
767///    in parallel via Rayon.
768/// 3. Injects tile cores back into the global array.
769/// 4. Exchanges halos (serial reference — MPI-ready interface).
770/// 5. Computes global residual; checks convergence.
771///
772/// The 5-point GS stencil:
773///   R d/dR(1/R dΨ/dR) + d²Ψ/dZ² = -μ₀ R J_φ
774///
775/// discretises to the same operator as `fusion_math::sor::sor_step`,
776/// but applied independently on each tile with local boundary from
777/// halo data.
778pub fn distributed_gs_solve(
779    psi: &Array2<f64>,
780    source: &Array2<f64>,
781    r_axis: &[f64],
782    z_axis: &[f64],
783    dr: f64,
784    dz: f64,
785    cfg: &DistributedSolverConfig,
786) -> FusionResult<DistributedSolveResult> {
787    let (nz, nr) = psi.dim();
788    if source.dim() != (nz, nr) {
789        return Err(FusionError::PhysicsViolation(format!(
790            "psi/source shape mismatch: {:?} vs {:?}",
791            psi.dim(),
792            source.dim()
793        )));
794    }
795    if r_axis.len() != nr || z_axis.len() != nz {
796        return Err(FusionError::PhysicsViolation(format!(
797            "Axis lengths don't match grid: r_axis={} nr={nr}, z_axis={} nz={nz}",
798            r_axis.len(),
799            z_axis.len()
800        )));
801    }
802    if cfg.omega <= 0.0 || cfg.omega >= 2.0 {
803        return Err(FusionError::PhysicsViolation(format!(
804            "SOR omega must be in (0, 2), got {}",
805            cfg.omega
806        )));
807    }
808    if !dr.is_finite() || !dz.is_finite() || dr <= 0.0 || dz <= 0.0 {
809        return Err(FusionError::PhysicsViolation(format!(
810            "Grid spacing must be finite > 0: dr={dr}, dz={dz}"
811        )));
812    }
813
814    let tiles = decompose_2d(nz, nr, cfg.pz, cfg.pr, cfg.halo)?;
815
816    let mut global_psi = psi.clone();
817    let dr_sq = dr * dr;
818    let dz_sq = dz * dz;
819
820    let mut converged = false;
821    let mut outer_iter = 0usize;
822    let mut residual = f64::MAX;
823
824    for outer in 0..cfg.max_outer_iters {
825        outer_iter = outer + 1;
826
827        // 1. Extract tiles with halo from current global Ψ.
828        let mut locals: Vec<Array2<f64>> = tiles
829            .iter()
830            .map(|t| extract_tile(&global_psi, t))
831            .collect::<FusionResult<Vec<_>>>()?;
832        let sources: Vec<Array2<f64>> = tiles
833            .iter()
834            .map(|t| extract_tile(source, t))
835            .collect::<FusionResult<Vec<_>>>()?;
836
837        // 2. Run local SOR sweeps in parallel via Rayon.
838        //    Each tile applies inner_sweeps Red-Black SOR iterations.
839        use rayon::prelude::*;
840        locals
841            .par_iter_mut()
842            .zip(sources.par_iter())
843            .zip(tiles.par_iter())
844            .for_each(|((loc, src), tile)| {
845                let cz = tile.core_z_offset();
846                let cr = tile.core_r_offset();
847                let _lnz = tile.local_nz();
848                let _lnr = tile.local_nr();
849
850                for _sweep in 0..cfg.inner_sweeps {
851                    // Red pass then black pass.
852                    for color in 0..2u8 {
853                        for iz in 1..loc.nrows().saturating_sub(1) {
854                            for ir in 1..loc.ncols().saturating_sub(1) {
855                                if (iz + ir) % 2 != color as usize {
856                                    continue;
857                                }
858                                // Only update interior of the owned core
859                                // (and the halo interior that overlaps with
860                                // a neighbor's owned region).
861                                // Map local (iz, ir) to global indices.
862                                let gz = tile.z_start as i64 + iz as i64 - cz as i64;
863                                let gr = tile.r_start as i64 + ir as i64 - cr as i64;
864                                // Skip global boundaries.
865                                if gz <= 0 || gz >= (nz as i64 - 1) {
866                                    continue;
867                                }
868                                if gr <= 0 || gr >= (nr as i64 - 1) {
869                                    continue;
870                                }
871
872                                let r_val = r_axis[gr as usize];
873                                if r_val <= 0.0 {
874                                    continue;
875                                }
876
877                                // GS 5-point stencil with 1/R correction.
878                                // Coefficients follow sor.rs update_point():
879                                //   c_r_plus  = 1/dr² - 1/(2R·dr) → psi_east (ir+1)
880                                //   c_r_minus = 1/dr² + 1/(2R·dr) → psi_west (ir-1)
881                                let psi_n = loc[[iz - 1, ir]];
882                                let psi_s = loc[[iz + 1, ir]];
883                                let psi_w = loc[[iz, ir - 1]];
884                                let psi_e = loc[[iz, ir + 1]];
885
886                                let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
887                                let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
888                                let c_z = 1.0 / dz_sq;
889                                let center = 2.0 / dr_sq + 2.0 / dz_sq;
890
891                                let rhs_val = src[[iz, ir]];
892                                let numerator = c_z * (psi_n + psi_s)
893                                    + c_r_minus * psi_w
894                                    + c_r_plus * psi_e
895                                    + rhs_val;
896                                if center.abs() < 1e-30 {
897                                    continue;
898                                }
899                                let psi_gs = numerator / center;
900                                let old = loc[[iz, ir]];
901                                loc[[iz, ir]] = old + cfg.omega * (psi_gs - old);
902                            }
903                        }
904                    }
905                }
906            });
907
908        // 3. Inject tile cores back into global array.
909        for (i, tile) in tiles.iter().enumerate() {
910            inject_tile(&mut global_psi, &locals[i], tile)?;
911        }
912
913        // 4. Compute global L2 residual of the GS equation.
914        residual = gs_residual_l2(&global_psi, source, r_axis, dr, dz);
915
916        if residual < cfg.tol {
917            converged = true;
918            break;
919        }
920    }
921
922    if global_psi.iter().any(|v| !v.is_finite()) {
923        return Err(FusionError::PhysicsViolation(
924            "Distributed GS solve produced non-finite Ψ".to_string(),
925        ));
926    }
927
928    Ok(DistributedSolveResult {
929        psi: global_psi,
930        residual,
931        iterations: outer_iter,
932        converged,
933    })
934}
935
936/// Compute the L2 norm of the GS residual: ||LΨ - f||₂.
937///
938/// L is the 5-point GS operator:
939///   LΨ = (Ψ_{i-1,j} + Ψ_{i+1,j})/dz² + (1/dr² - 1/(2R dr))Ψ_{i,j-1}
940///        + (1/dr² + 1/(2R dr))Ψ_{i,j+1} - 2(1/dr² + 1/dz²)Ψ_{i,j}
941pub fn gs_residual_l2(
942    psi: &Array2<f64>,
943    source: &Array2<f64>,
944    r_axis: &[f64],
945    dr: f64,
946    dz: f64,
947) -> f64 {
948    let (nz, nr) = psi.dim();
949    let dr_sq = dr * dr;
950    let dz_sq = dz * dz;
951    let mut accum = 0.0f64;
952    let mut count = 0usize;
953
954    for iz in 1..nz - 1 {
955        for ir in 1..nr - 1 {
956            let r_val = r_axis[ir];
957            if r_val <= 0.0 {
958                continue;
959            }
960            let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
961            let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
962            let c_z = 1.0 / dz_sq;
963            let center = 2.0 / dr_sq + 2.0 / dz_sq;
964
965            let l_psi = center * psi[[iz, ir]]
966                - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
967                - c_r_plus * psi[[iz, ir + 1]]
968                - c_r_minus * psi[[iz, ir - 1]];
969            let res = l_psi - source[[iz, ir]];
970            accum += res * res;
971            count += 1;
972        }
973    }
974    if count == 0 {
975        return 0.0;
976    }
977    (accum / count as f64).sqrt()
978}
979
980/// Optimal process-grid factorisation for a given (nz, nr) global grid
981/// and total number of available ranks. Minimises the surface-to-volume
982/// ratio of each tile (i.e. the halo communication overhead).
983pub fn optimal_process_grid(nz: usize, nr: usize, nranks: usize) -> (usize, usize) {
984    let mut best_pz = 1;
985    let mut best_pr = nranks;
986    let mut best_cost = f64::MAX;
987
988    for pz in 1..=nranks {
989        if !nranks.is_multiple_of(pz) {
990            continue;
991        }
992        let pr = nranks / pz;
993        if pz > nz || pr > nr {
994            continue;
995        }
996        // Surface-to-volume ratio proxy: perimeter / area of each tile.
997        let tile_nz = nz as f64 / pz as f64;
998        let tile_nr = nr as f64 / pr as f64;
999        let perimeter = 2.0 * (tile_nz + tile_nr);
1000        let area = tile_nz * tile_nr;
1001        let cost = perimeter / area;
1002        if cost < best_cost {
1003            best_cost = cost;
1004            best_pz = pz;
1005            best_pr = pr;
1006        }
1007    }
1008    (best_pz, best_pr)
1009}
1010
1011/// Convenience: solve GS with automatic process-grid selection.
1012///
1013/// Detects the Rayon thread count and picks the optimal (pz, pr)
1014/// factorisation. This is the top-level entry point for exascale-ready
1015/// distributed equilibrium solving.
1016#[allow(clippy::too_many_arguments)]
1017pub fn auto_distributed_gs_solve(
1018    psi: &Array2<f64>,
1019    source: &Array2<f64>,
1020    r_axis: &[f64],
1021    z_axis: &[f64],
1022    dr: f64,
1023    dz: f64,
1024    omega: f64,
1025    tol: f64,
1026    max_iters: usize,
1027) -> FusionResult<DistributedSolveResult> {
1028    let (nz, nr) = psi.dim();
1029    let nthreads = rayon::current_num_threads().max(1);
1030    let (pz, pr) = optimal_process_grid(nz, nr, nthreads);
1031
1032    let cfg = DistributedSolverConfig {
1033        pz,
1034        pr,
1035        halo: 1,
1036        omega,
1037        max_outer_iters: max_iters,
1038        tol,
1039        inner_sweeps: 5,
1040    };
1041    distributed_gs_solve(psi, source, r_axis, z_axis, dr, dz, &cfg)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047
1048    fn sample_grid(nz: usize, nr: usize) -> Array2<f64> {
1049        Array2::from_shape_fn((nz, nr), |(i, j)| (i as f64) * 10.0 + j as f64)
1050    }
1051
1052    #[test]
1053    fn test_decompose_z_covers_domain() {
1054        let slices = decompose_z(17, 4, 1).expect("decomposition must succeed");
1055        assert_eq!(slices.len(), 4);
1056        assert_eq!(slices[0].z_start, 0);
1057        assert_eq!(slices.last().expect("slice expected").z_end, 17);
1058        let covered: usize = slices.iter().map(|s| s.local_nz).sum();
1059        assert_eq!(covered, 17);
1060    }
1061
1062    #[test]
1063    fn test_serial_halo_exchange_and_stitch_roundtrip() {
1064        let global = sample_grid(24, 9);
1065        let slices = decompose_z(global.nrows(), 3, 1).expect("decompose");
1066        let mut locals = split_with_halo(&global, &slices).expect("split");
1067        serial_halo_exchange(&mut locals, &slices).expect("exchange");
1068        let stitched = stitch_without_halo(&locals, &slices, global.ncols()).expect("stitch");
1069        let delta = l2_norm_delta(&stitched, &global).expect("delta");
1070        assert!(
1071            delta < 1e-12,
1072            "Serial halo exchange should preserve core rows"
1073        );
1074    }
1075
1076    #[test]
1077    fn test_pack_halo_errors_for_small_local_block() {
1078        let local = Array2::zeros((2, 4));
1079        let err = pack_halo_rows(&local, 1).expect_err("small local must error");
1080        match err {
1081            FusionError::PhysicsViolation(msg) => {
1082                assert!(msg.contains("insufficient rows"));
1083            }
1084            other => panic!("Unexpected error: {other:?}"),
1085        }
1086    }
1087
1088    #[test]
1089    fn test_apply_halo_shape_guard() {
1090        let mut local = Array2::zeros((6, 4));
1091        let bad_top = Array2::zeros((2, 5));
1092        let err = apply_halo_rows(&mut local, 1, Some(&bad_top), None).expect_err("shape mismatch");
1093        match err {
1094            FusionError::PhysicsViolation(msg) => assert!(msg.contains("shape mismatch")),
1095            other => panic!("Unexpected error: {other:?}"),
1096        }
1097    }
1098
1099    #[test]
1100    fn test_l2_norm_delta_zero_for_identical_arrays() {
1101        let a = sample_grid(8, 8);
1102        let b = a.clone();
1103        let d = l2_norm_delta(&a, &b).expect("delta");
1104        assert!(d.abs() < 1e-12);
1105    }
1106
1107    #[test]
1108    fn test_mpi_domain_rejects_non_finite_inputs() {
1109        let mut local = sample_grid(8, 4);
1110        local[[2, 2]] = f64::NAN;
1111        let err = pack_halo_rows(&local, 1).expect_err("non-finite local should fail");
1112        match err {
1113            FusionError::PhysicsViolation(msg) => assert!(msg.contains("non-finite")),
1114            other => panic!("Unexpected error: {other:?}"),
1115        }
1116
1117        let mut a = sample_grid(4, 4);
1118        let b = sample_grid(4, 4);
1119        a[[0, 0]] = f64::INFINITY;
1120        let err = l2_norm_delta(&a, &b).expect_err("non-finite delta input should fail");
1121        match err {
1122            FusionError::PhysicsViolation(msg) => assert!(msg.contains("finite")),
1123            other => panic!("Unexpected error: {other:?}"),
1124        }
1125    }
1126
1127    // ═══════════════════════════════════════════════════════════════════
1128    // 2D Cartesian Decomposition Tests
1129    // ═══════════════════════════════════════════════════════════════════
1130
1131    #[test]
1132    fn test_decompose_2d_covers_full_domain() {
1133        let tiles = decompose_2d(32, 24, 4, 3, 1).expect("decompose_2d");
1134        assert_eq!(tiles.len(), 12); // 4*3
1135                                     // Every global (iz, ir) must be owned by exactly one tile.
1136        let mut coverage = Array2::<u8>::zeros((32, 24));
1137        for t in &tiles {
1138            for iz in t.z_start..t.z_end {
1139                for ir in t.r_start..t.r_end {
1140                    coverage[[iz, ir]] += 1;
1141                }
1142            }
1143        }
1144        assert!(
1145            coverage.iter().all(|&c| c == 1),
1146            "Every cell owned by exactly one tile"
1147        );
1148    }
1149
1150    #[test]
1151    fn test_decompose_2d_single_rank() {
1152        let tiles = decompose_2d(16, 16, 1, 1, 2).expect("single rank");
1153        assert_eq!(tiles.len(), 1);
1154        let t = &tiles[0];
1155        assert_eq!(t.z_start, 0);
1156        assert_eq!(t.z_end, 16);
1157        assert_eq!(t.r_start, 0);
1158        assert_eq!(t.r_end, 16);
1159        assert_eq!(t.padded_nz(), 16);
1160        assert_eq!(t.padded_nr(), 16);
1161    }
1162
1163    #[test]
1164    fn test_decompose_2d_neighbor_ranks() {
1165        let tiles = decompose_2d(16, 16, 2, 2, 1).expect("2x2");
1166        // Top-left corner (0,0): has bottom and right neighbours.
1167        let tl = &tiles[0];
1168        assert_eq!(tl.neighbor_rank(1, 0), Some(2));
1169        assert_eq!(tl.neighbor_rank(0, 1), Some(1));
1170        assert_eq!(tl.neighbor_rank(-1, 0), None);
1171        assert_eq!(tl.neighbor_rank(0, -1), None);
1172        // Bottom-right corner (1,1).
1173        let br = &tiles[3];
1174        assert_eq!(br.neighbor_rank(-1, 0), Some(1));
1175        assert_eq!(br.neighbor_rank(0, -1), Some(2));
1176        assert_eq!(br.neighbor_rank(1, 0), None);
1177        assert_eq!(br.neighbor_rank(0, 1), None);
1178    }
1179
1180    #[test]
1181    fn test_extract_inject_roundtrip() {
1182        let global = sample_grid(24, 18);
1183        let tiles = decompose_2d(24, 18, 3, 2, 1).expect("decompose");
1184        let mut reconstructed = Array2::<f64>::zeros((24, 18));
1185        for t in &tiles {
1186            let local = extract_tile(&global, t).expect("extract");
1187            inject_tile(&mut reconstructed, &local, t).expect("inject");
1188        }
1189        let delta = l2_norm_delta(&global, &reconstructed).expect("delta");
1190        assert!(
1191            delta < 1e-12,
1192            "Extract→inject roundtrip must be lossless, got delta={delta}"
1193        );
1194    }
1195
1196    #[test]
1197    fn test_serial_halo_exchange_2d_correctness() {
1198        // Create a global array, split into tiles, exchange halos,
1199        // then verify that halo cells match the original global data.
1200        let global = Array2::from_shape_fn((16, 16), |(i, j)| (i as f64) * 100.0 + j as f64);
1201        let tiles = decompose_2d(16, 16, 2, 2, 1).expect("decompose");
1202        let mut locals: Vec<Array2<f64>> = tiles
1203            .iter()
1204            .map(|t| extract_tile(&global, t).expect("extract"))
1205            .collect();
1206        serial_halo_exchange_2d(&mut locals, &tiles).expect("halo exchange");
1207
1208        // For each tile, verify that every cell (including halo)
1209        // matches the original global value.
1210        for (i, t) in tiles.iter().enumerate() {
1211            let loc = &locals[i];
1212            let z0 = t.z_start.saturating_sub(t.core_z_offset());
1213            let r0 = t.r_start.saturating_sub(t.core_r_offset());
1214            for lz in 0..loc.nrows() {
1215                for lr in 0..loc.ncols() {
1216                    let gz = z0 + lz;
1217                    let gr = r0 + lr;
1218                    if gz < 16 && gr < 16 {
1219                        let expect = global[[gz, gr]];
1220                        let got = loc[[lz, lr]];
1221                        assert!(
1222                            (expect - got).abs() < 1e-12,
1223                            "Tile {i} at local ({lz},{lr}) global ({gz},{gr}): expected {expect}, got {got}"
1224                        );
1225                    }
1226                }
1227            }
1228        }
1229    }
1230
1231    #[test]
1232    fn test_optimal_process_grid_square() {
1233        let (pz, pr) = optimal_process_grid(64, 64, 4);
1234        assert_eq!(pz, 2);
1235        assert_eq!(pr, 2);
1236    }
1237
1238    #[test]
1239    fn test_optimal_process_grid_rectangular() {
1240        // Tall grid (nz >> nr): should put more processes along Z.
1241        let (pz, pr) = optimal_process_grid(128, 32, 8);
1242        assert!(
1243            pz >= pr,
1244            "Tall grid should bias toward Z decomposition: pz={pz}, pr={pr}"
1245        );
1246        assert_eq!(pz * pr, 8);
1247    }
1248
1249    #[test]
1250    fn test_gs_residual_l2_zero_for_exact_solution() {
1251        // Manufacture source = L Ψ for a known Ψ, then verify residual ≈ 0.
1252        // Using the same convention as sor.rs:
1253        //   L Ψ = center*Ψ - c_z*(Ψ_up + Ψ_dn) - c_r_plus*Ψ_right - c_r_minus*Ψ_left
1254        let nz = 16;
1255        let nr = 16;
1256        let dr = 0.01;
1257        let dz = 0.01;
1258        let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1259        let _z_axis: Vec<f64> = (0..nz).map(|i| -0.08 + i as f64 * dz).collect();
1260        let psi = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1261            _z_axis[iz] * _z_axis[iz] + r_axis[ir] * r_axis[ir]
1262        });
1263        let dr_sq = dr * dr;
1264        let dz_sq = dz * dz;
1265        let mut source = Array2::zeros((nz, nr));
1266        for iz in 1..nz - 1 {
1267            for ir in 1..nr - 1 {
1268                let r = r_axis[ir];
1269                let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r * dr);
1270                let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r * dr);
1271                let c_z = 1.0 / dz_sq;
1272                let center = 2.0 / dr_sq + 2.0 / dz_sq;
1273                source[[iz, ir]] = center * psi[[iz, ir]]
1274                    - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
1275                    - c_r_plus * psi[[iz, ir + 1]]
1276                    - c_r_minus * psi[[iz, ir - 1]];
1277            }
1278        }
1279        let res = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1280        assert!(
1281            res < 1e-10,
1282            "Residual should be ~0 for manufactured solution, got {res}"
1283        );
1284    }
1285
1286    #[test]
1287    fn test_distributed_gs_solve_smoke() {
1288        // Solve LΨ = f on a small grid and verify residual decreases.
1289        let nz = 32;
1290        let nr = 32;
1291        let dr = 0.01;
1292        let dz = 0.01;
1293        let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1294        let z_axis: Vec<f64> = (0..nz).map(|i| -0.16 + i as f64 * dz).collect();
1295
1296        // Use a moderate source (amplitude 1.0, not 100).
1297        let psi = Array2::zeros((nz, nr));
1298        let r_mid = 1.0 + (nr as f64 / 2.0) * dr;
1299        let z_mid = 0.0;
1300        let source = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1301            let rr = r_axis[ir] - r_mid;
1302            let zz = z_axis[iz] - z_mid;
1303            -((-(rr * rr + zz * zz) / (0.05 * 0.05)).exp())
1304        });
1305
1306        // Compute initial residual for comparison.
1307        let res_initial = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1308
1309        let cfg = DistributedSolverConfig {
1310            pz: 2,
1311            pr: 2,
1312            halo: 1,
1313            omega: 1.6,
1314            max_outer_iters: 200,
1315            tol: 1e-6,
1316            inner_sweeps: 10,
1317        };
1318        let result = distributed_gs_solve(&psi, &source, &r_axis, &z_axis, dr, dz, &cfg)
1319            .expect("distributed solve");
1320        // Residual should have decreased significantly from initial.
1321        assert!(
1322            result.residual < res_initial * 0.5,
1323            "Residual should decrease: initial={res_initial:.4}, final={:.4}",
1324            result.residual
1325        );
1326        // Ψ should be non-trivial (negative source → negative interior values).
1327        let psi_absmax = result.psi.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
1328        assert!(
1329            psi_absmax > 1e-10,
1330            "Solution should be non-trivial, max |Ψ| = {psi_absmax}"
1331        );
1332    }
1333
1334    #[test]
1335    fn test_decompose_2d_rejects_invalid_inputs() {
1336        assert!(decompose_2d(16, 16, 0, 2, 1).is_err());
1337        assert!(decompose_2d(16, 16, 2, 0, 1).is_err());
1338        assert!(decompose_2d(1, 16, 2, 2, 1).is_err());
1339        assert!(decompose_2d(16, 16, 20, 2, 1).is_err());
1340    }
1341}