1use fusion_types::error::{FusionError, FusionResult};
14use ndarray::{s, Array2};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct DomainSlice {
18 pub rank: usize,
19 pub nranks: usize,
20 pub global_nz: usize,
21 pub local_nz: usize,
22 pub halo: usize,
23 pub z_start: usize,
24 pub z_end: usize,
25}
26
27impl DomainSlice {
28 pub fn has_upper_neighbor(&self) -> bool {
29 self.rank > 0
30 }
31
32 pub fn has_lower_neighbor(&self) -> bool {
33 self.rank + 1 < self.nranks
34 }
35}
36
37pub fn decompose_z(global_nz: usize, nranks: usize, halo: usize) -> FusionResult<Vec<DomainSlice>> {
38 if global_nz < 2 {
39 return Err(FusionError::PhysicsViolation(
40 "MPI decomposition requires global_nz >= 2".to_string(),
41 ));
42 }
43 if nranks < 1 {
44 return Err(FusionError::PhysicsViolation(
45 "MPI decomposition requires nranks >= 1".to_string(),
46 ));
47 }
48 if nranks > global_nz {
49 return Err(FusionError::PhysicsViolation(format!(
50 "Cannot split global_nz={global_nz} across nranks={nranks}"
51 )));
52 }
53
54 let base = global_nz / nranks;
55 let rem = global_nz % nranks;
56 let mut out = Vec::with_capacity(nranks);
57 let mut cursor = 0usize;
58 for rank in 0..nranks {
59 let local_nz = base + usize::from(rank < rem);
60 let z_start = cursor;
61 let z_end = z_start + local_nz;
62 cursor = z_end;
63 out.push(DomainSlice {
64 rank,
65 nranks,
66 global_nz,
67 local_nz,
68 halo,
69 z_start,
70 z_end,
71 });
72 }
73 Ok(out)
74}
75
76pub fn pack_halo_rows(
77 local: &Array2<f64>,
78 halo: usize,
79) -> FusionResult<(Array2<f64>, Array2<f64>)> {
80 if halo == 0 {
81 return Err(FusionError::PhysicsViolation(
82 "Halo width must be >= 1".to_string(),
83 ));
84 }
85 if local.nrows() <= 2 * halo {
86 return Err(FusionError::PhysicsViolation(format!(
87 "Local block has insufficient rows {} for halo={halo}",
88 local.nrows()
89 )));
90 }
91 if local.iter().any(|v| !v.is_finite()) {
92 return Err(FusionError::PhysicsViolation(
93 "Local block contains non-finite values".to_string(),
94 ));
95 }
96 let top = local.slice(s![halo..(2 * halo), ..]).to_owned();
97 let bottom = local
98 .slice(s![(local.nrows() - 2 * halo)..(local.nrows() - halo), ..])
99 .to_owned();
100 if top.iter().any(|v| !v.is_finite()) || bottom.iter().any(|v| !v.is_finite()) {
101 return Err(FusionError::PhysicsViolation(
102 "Packed halo rows contain non-finite values".to_string(),
103 ));
104 }
105 Ok((top, bottom))
106}
107
108pub fn apply_halo_rows(
109 local: &mut Array2<f64>,
110 halo: usize,
111 recv_top: Option<&Array2<f64>>,
112 recv_bottom: Option<&Array2<f64>>,
113) -> FusionResult<()> {
114 if halo == 0 {
115 return Err(FusionError::PhysicsViolation(
116 "Halo width must be >= 1".to_string(),
117 ));
118 }
119 if local.nrows() <= 2 * halo {
120 return Err(FusionError::PhysicsViolation(format!(
121 "Local block has insufficient rows {} for halo={halo}",
122 local.nrows()
123 )));
124 }
125 if local.iter().any(|v| !v.is_finite()) {
126 return Err(FusionError::PhysicsViolation(
127 "Local block contains non-finite values".to_string(),
128 ));
129 }
130 if let Some(top) = recv_top {
131 if top.dim() != (halo, local.ncols()) {
132 return Err(FusionError::PhysicsViolation(format!(
133 "Top halo shape mismatch: expected ({halo}, {}), got {:?}",
134 local.ncols(),
135 top.dim()
136 )));
137 }
138 if top.iter().any(|v| !v.is_finite()) {
139 return Err(FusionError::PhysicsViolation(
140 "Top halo contains non-finite values".to_string(),
141 ));
142 }
143 local.slice_mut(s![0..halo, ..]).assign(top);
144 }
145 if let Some(bottom) = recv_bottom {
146 if bottom.dim() != (halo, local.ncols()) {
147 return Err(FusionError::PhysicsViolation(format!(
148 "Bottom halo shape mismatch: expected ({halo}, {}), got {:?}",
149 local.ncols(),
150 bottom.dim()
151 )));
152 }
153 if bottom.iter().any(|v| !v.is_finite()) {
154 return Err(FusionError::PhysicsViolation(
155 "Bottom halo contains non-finite values".to_string(),
156 ));
157 }
158 let n = local.nrows();
159 local.slice_mut(s![(n - halo)..n, ..]).assign(bottom);
160 }
161 if local.iter().any(|v| !v.is_finite()) {
162 return Err(FusionError::PhysicsViolation(
163 "Applying halo rows produced non-finite values".to_string(),
164 ));
165 }
166 Ok(())
167}
168
169pub fn split_with_halo(
170 global: &Array2<f64>,
171 slices: &[DomainSlice],
172) -> FusionResult<Vec<Array2<f64>>> {
173 if slices.is_empty() {
174 return Err(FusionError::PhysicsViolation(
175 "No slices provided for split_with_halo".to_string(),
176 ));
177 }
178 if global.iter().any(|v| !v.is_finite()) {
179 return Err(FusionError::PhysicsViolation(
180 "Global array contains non-finite values".to_string(),
181 ));
182 }
183 let mut out = Vec::with_capacity(slices.len());
184 for sdef in slices {
185 if sdef.global_nz != global.nrows() {
186 return Err(FusionError::PhysicsViolation(format!(
187 "Slice/global mismatch: slice.global_nz={} global.nrows()={}",
188 sdef.global_nz,
189 global.nrows()
190 )));
191 }
192 if sdef.z_start >= sdef.z_end || sdef.z_end > sdef.global_nz {
193 return Err(FusionError::PhysicsViolation(format!(
194 "Invalid slice bounds z_start={} z_end={} global_nz={}",
195 sdef.z_start, sdef.z_end, sdef.global_nz
196 )));
197 }
198 let start = sdef.z_start.saturating_sub(sdef.halo);
199 let end = (sdef.z_end + sdef.halo).min(sdef.global_nz);
200 let mut local = Array2::zeros((end - start, global.ncols()));
201 local.assign(&global.slice(s![start..end, ..]));
202 if local.iter().any(|v| !v.is_finite()) {
203 return Err(FusionError::PhysicsViolation(
204 "Split local block contains non-finite values".to_string(),
205 ));
206 }
207 out.push(local);
208 }
209 Ok(out)
210}
211
212pub fn stitch_without_halo(
213 locals: &[Array2<f64>],
214 slices: &[DomainSlice],
215 ncols: usize,
216) -> FusionResult<Array2<f64>> {
217 if locals.len() != slices.len() {
218 return Err(FusionError::PhysicsViolation(format!(
219 "locals/slices mismatch: {} vs {}",
220 locals.len(),
221 slices.len()
222 )));
223 }
224 if slices.is_empty() {
225 return Err(FusionError::PhysicsViolation(
226 "No slices provided for stitch_without_halo".to_string(),
227 ));
228 }
229 let global_nz = slices
230 .last()
231 .map(|s| s.global_nz)
232 .ok_or_else(|| FusionError::PhysicsViolation("No slices provided".to_string()))?;
233 let mut global = Array2::zeros((global_nz, ncols));
234 for (local, sdef) in locals.iter().zip(slices.iter()) {
235 if local.iter().any(|v| !v.is_finite()) {
236 return Err(FusionError::PhysicsViolation(
237 "Local block contains non-finite values".to_string(),
238 ));
239 }
240 if local.ncols() != ncols {
241 return Err(FusionError::PhysicsViolation(format!(
242 "Local ncols mismatch: expected {ncols}, got {}",
243 local.ncols()
244 )));
245 }
246 let core_start = usize::from(sdef.z_start > 0) * sdef.halo;
247 let core_end = core_start + sdef.local_nz;
248 if core_end > local.nrows() {
249 return Err(FusionError::PhysicsViolation(format!(
250 "Local core range out of bounds: rows={}, core_end={core_end}",
251 local.nrows()
252 )));
253 }
254 global
255 .slice_mut(s![sdef.z_start..sdef.z_end, ..])
256 .assign(&local.slice(s![core_start..core_end, ..]));
257 }
258 if global.iter().any(|v| !v.is_finite()) {
259 return Err(FusionError::PhysicsViolation(
260 "Stitched global array contains non-finite values".to_string(),
261 ));
262 }
263 Ok(global)
264}
265
266pub fn serial_halo_exchange(
267 locals: &mut [Array2<f64>],
268 slices: &[DomainSlice],
269) -> FusionResult<()> {
270 if locals.len() != slices.len() {
271 return Err(FusionError::PhysicsViolation(format!(
272 "locals/slices mismatch: {} vs {}",
273 locals.len(),
274 slices.len()
275 )));
276 }
277 let mut top_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
278 let mut bottom_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
279 for (i, (local, sdef)) in locals.iter().zip(slices.iter()).enumerate() {
280 if local.iter().any(|v| !v.is_finite()) {
281 return Err(FusionError::PhysicsViolation(format!(
282 "Local block at index {i} contains non-finite values"
283 )));
284 }
285 if sdef.halo == 0 {
286 continue;
287 }
288 let (top, bottom) = pack_halo_rows(local, sdef.halo)?;
289 top_send[i] = Some(top);
290 bottom_send[i] = Some(bottom);
291 }
292
293 for i in 0..locals.len() {
294 let halo = slices[i].halo;
295 if halo == 0 {
296 continue;
297 }
298 let recv_top = if i > 0 {
299 bottom_send[i - 1].as_ref()
300 } else {
301 None
302 };
303 let recv_bottom = if i + 1 < locals.len() {
304 top_send[i + 1].as_ref()
305 } else {
306 None
307 };
308 apply_halo_rows(&mut locals[i], halo, recv_top, recv_bottom)?;
309 }
310 if locals.iter().any(|arr| arr.iter().any(|v| !v.is_finite())) {
311 return Err(FusionError::PhysicsViolation(
312 "Serial halo exchange produced non-finite values".to_string(),
313 ));
314 }
315 Ok(())
316}
317
318pub fn l2_norm_delta(a: &Array2<f64>, b: &Array2<f64>) -> FusionResult<f64> {
319 if a.dim() != b.dim() {
320 return Err(FusionError::PhysicsViolation(format!(
321 "l2_norm_delta shape mismatch {:?} vs {:?}",
322 a.dim(),
323 b.dim()
324 )));
325 }
326 if a.iter().any(|v| !v.is_finite()) || b.iter().any(|v| !v.is_finite()) {
327 return Err(FusionError::PhysicsViolation(
328 "l2_norm_delta inputs must be finite".to_string(),
329 ));
330 }
331 let mut accum = 0.0f64;
332 for (av, bv) in a.iter().zip(b.iter()) {
333 let d = av - bv;
334 accum += d * d;
335 }
336 let out = accum.sqrt();
337 if !out.is_finite() {
338 return Err(FusionError::PhysicsViolation(
339 "l2_norm_delta produced non-finite result".to_string(),
340 ));
341 }
342 Ok(out)
343}
344
345#[derive(Debug, Clone, PartialEq, Eq)]
363pub struct CartesianTile {
364 pub rank: usize,
366 pub pz_idx: usize,
368 pub pr_idx: usize,
370 pub pz: usize,
372 pub pr: usize,
373 pub global_nz: usize,
375 pub global_nr: usize,
376 pub halo: usize,
378 pub z_start: usize,
380 pub z_end: usize,
381 pub r_start: usize,
383 pub r_end: usize,
384}
385
386impl CartesianTile {
387 pub fn local_nz(&self) -> usize {
389 self.z_end - self.z_start
390 }
391 pub fn local_nr(&self) -> usize {
393 self.r_end - self.r_start
394 }
395 pub fn padded_nz(&self) -> usize {
397 let top = if self.pz_idx > 0 { self.halo } else { 0 };
398 let bot = if self.pz_idx + 1 < self.pz {
399 self.halo
400 } else {
401 0
402 };
403 self.local_nz() + top + bot
404 }
405 pub fn padded_nr(&self) -> usize {
407 let left = if self.pr_idx > 0 { self.halo } else { 0 };
408 let right = if self.pr_idx + 1 < self.pr {
409 self.halo
410 } else {
411 0
412 };
413 self.local_nr() + left + right
414 }
415 pub fn core_z_offset(&self) -> usize {
417 if self.pz_idx > 0 {
418 self.halo
419 } else {
420 0
421 }
422 }
423 pub fn core_r_offset(&self) -> usize {
425 if self.pr_idx > 0 {
426 self.halo
427 } else {
428 0
429 }
430 }
431 pub fn has_neighbor_top(&self) -> bool {
432 self.pz_idx > 0
433 }
434 pub fn has_neighbor_bottom(&self) -> bool {
435 self.pz_idx + 1 < self.pz
436 }
437 pub fn has_neighbor_left(&self) -> bool {
438 self.pr_idx > 0
439 }
440 pub fn has_neighbor_right(&self) -> bool {
441 self.pr_idx + 1 < self.pr
442 }
443 pub fn neighbor_rank(&self, dz: i32, dr: i32) -> Option<usize> {
446 let nz = self.pz_idx as i32 + dz;
447 let nr = self.pr_idx as i32 + dr;
448 if nz < 0 || nz >= self.pz as i32 || nr < 0 || nr >= self.pr as i32 {
449 return None;
450 }
451 Some(nz as usize * self.pr + nr as usize)
452 }
453}
454
455pub fn decompose_2d(
460 global_nz: usize,
461 global_nr: usize,
462 pz: usize,
463 pr: usize,
464 halo: usize,
465) -> FusionResult<Vec<CartesianTile>> {
466 if pz == 0 || pr == 0 {
467 return Err(FusionError::PhysicsViolation(
468 "Process grid dimensions pz, pr must be >= 1".to_string(),
469 ));
470 }
471 if pz > global_nz || pr > global_nr {
472 return Err(FusionError::PhysicsViolation(format!(
473 "Cannot split ({global_nz}×{global_nr}) across ({pz}×{pr}) processes"
474 )));
475 }
476 if global_nz < 2 || global_nr < 2 {
477 return Err(FusionError::PhysicsViolation(
478 "Global grid must be at least 2×2".to_string(),
479 ));
480 }
481
482 let z_splits = balanced_split(global_nz, pz);
484 let r_splits = balanced_split(global_nr, pr);
485
486 let mut tiles = Vec::with_capacity(pz * pr);
487 let mut z_cursor = 0usize;
488 for (iz, nz_local) in z_splits.iter().copied().enumerate().take(pz) {
489 let z_start = z_cursor;
490 let z_end = z_start + nz_local;
491 z_cursor = z_end;
492
493 let mut r_cursor = 0usize;
494 for (ir, nr_local) in r_splits.iter().copied().enumerate().take(pr) {
495 let r_start = r_cursor;
496 let r_end = r_start + nr_local;
497 r_cursor = r_end;
498
499 tiles.push(CartesianTile {
500 rank: iz * pr + ir,
501 pz_idx: iz,
502 pr_idx: ir,
503 pz,
504 pr,
505 global_nz,
506 global_nr,
507 halo,
508 z_start,
509 z_end,
510 r_start,
511 r_end,
512 });
513 }
514 }
515 Ok(tiles)
516}
517
518fn balanced_split(n: usize, k: usize) -> Vec<usize> {
520 let base = n / k;
521 let rem = n % k;
522 (0..k).map(|i| base + usize::from(i < rem)).collect()
523}
524
525pub fn extract_tile(global: &Array2<f64>, tile: &CartesianTile) -> FusionResult<Array2<f64>> {
527 let (gnz, gnr) = global.dim();
528 if gnz != tile.global_nz || gnr != tile.global_nr {
529 return Err(FusionError::PhysicsViolation(format!(
530 "Global shape ({gnz},{gnr}) doesn't match tile expectation ({},{})",
531 tile.global_nz, tile.global_nr
532 )));
533 }
534 let z0 = tile.z_start.saturating_sub(tile.core_z_offset());
535 let z1 = (tile.z_end
536 + if tile.has_neighbor_bottom() {
537 tile.halo
538 } else {
539 0
540 })
541 .min(gnz);
542 let r0 = tile.r_start.saturating_sub(tile.core_r_offset());
543 let r1 = (tile.r_end
544 + if tile.has_neighbor_right() {
545 tile.halo
546 } else {
547 0
548 })
549 .min(gnr);
550
551 let local = global.slice(s![z0..z1, r0..r1]).to_owned();
552 if local.iter().any(|v| !v.is_finite()) {
553 return Err(FusionError::PhysicsViolation(
554 "Extracted tile contains non-finite values".to_string(),
555 ));
556 }
557 Ok(local)
558}
559
560pub fn inject_tile(
562 global: &mut Array2<f64>,
563 local: &Array2<f64>,
564 tile: &CartesianTile,
565) -> FusionResult<()> {
566 let cz = tile.core_z_offset();
567 let cr = tile.core_r_offset();
568 let lnz = tile.local_nz();
569 let lnr = tile.local_nr();
570 if cz + lnz > local.nrows() || cr + lnr > local.ncols() {
571 return Err(FusionError::PhysicsViolation(format!(
572 "Tile core out of bounds: local shape {:?}, core ({cz}+{lnz}, {cr}+{lnr})",
573 local.dim()
574 )));
575 }
576 let core = local.slice(s![cz..(cz + lnz), cr..(cr + lnr)]);
577 if core.iter().any(|v| !v.is_finite()) {
578 return Err(FusionError::PhysicsViolation(
579 "Tile core to inject contains non-finite values".to_string(),
580 ));
581 }
582 global
583 .slice_mut(s![tile.z_start..tile.z_end, tile.r_start..tile.r_end])
584 .assign(&core);
585 Ok(())
586}
587
588pub fn serial_halo_exchange_2d(
595 locals: &mut [Array2<f64>],
596 tiles: &[CartesianTile],
597) -> FusionResult<()> {
598 if locals.len() != tiles.len() {
599 return Err(FusionError::PhysicsViolation(format!(
600 "locals/tiles length mismatch: {} vs {}",
601 locals.len(),
602 tiles.len()
603 )));
604 }
605 let ntiles = tiles.len();
606 if ntiles == 0 {
607 return Ok(());
608 }
609 let halo = tiles[0].halo;
610 if halo == 0 {
611 return Ok(());
612 }
613
614 #[derive(Clone, Copy)]
617 enum Face {
618 Top,
619 Bottom,
620 Left,
621 Right,
622 }
623 let mut messages: Vec<(usize, Face, Array2<f64>)> = Vec::new();
624
625 for (i, tile) in tiles.iter().enumerate() {
626 let loc = &locals[i];
627 let cz = tile.core_z_offset();
628 let cr = tile.core_r_offset();
629 let lnz = tile.local_nz();
630 let lnr = tile.local_nr();
631
632 if let Some(dest) = tile.neighbor_rank(-1, 0) {
634 let strip = loc.slice(s![cz..(cz + halo), cr..(cr + lnr)]).to_owned();
635 messages.push((dest, Face::Bottom, strip));
636 }
637 if let Some(dest) = tile.neighbor_rank(1, 0) {
639 let strip = loc
640 .slice(s![(cz + lnz - halo)..(cz + lnz), cr..(cr + lnr)])
641 .to_owned();
642 messages.push((dest, Face::Top, strip));
643 }
644 if let Some(dest) = tile.neighbor_rank(0, -1) {
646 let strip = loc.slice(s![cz..(cz + lnz), cr..(cr + halo)]).to_owned();
647 messages.push((dest, Face::Right, strip));
648 }
649 if let Some(dest) = tile.neighbor_rank(0, 1) {
651 let strip = loc
652 .slice(s![cz..(cz + lnz), (cr + lnr - halo)..(cr + lnr)])
653 .to_owned();
654 messages.push((dest, Face::Left, strip));
655 }
656 }
657
658 for (dest, face, data) in messages {
660 let tile = &tiles[dest];
661 let loc = &mut locals[dest];
662 let cz = tile.core_z_offset();
663 let cr = tile.core_r_offset();
664 let lnz = tile.local_nz();
665 let lnr = tile.local_nr();
666
667 match face {
668 Face::Top => {
669 if cz >= halo {
671 loc.slice_mut(s![(cz - halo)..cz, cr..(cr + lnr)])
672 .assign(&data);
673 }
674 }
675 Face::Bottom => {
676 let row_start = cz + lnz;
678 let row_end = row_start + halo;
679 if row_end <= loc.nrows() {
680 loc.slice_mut(s![row_start..row_end, cr..(cr + lnr)])
681 .assign(&data);
682 }
683 }
684 Face::Left => {
685 if cr >= halo {
687 loc.slice_mut(s![cz..(cz + lnz), (cr - halo)..cr])
688 .assign(&data);
689 }
690 }
691 Face::Right => {
692 let col_start = cr + lnr;
694 let col_end = col_start + halo;
695 if col_end <= loc.ncols() {
696 loc.slice_mut(s![cz..(cz + lnz), col_start..col_end])
697 .assign(&data);
698 }
699 }
700 }
701 }
702
703 for loc in locals.iter() {
705 if loc.iter().any(|v| !v.is_finite()) {
706 return Err(FusionError::PhysicsViolation(
707 "2D halo exchange produced non-finite values".to_string(),
708 ));
709 }
710 }
711 Ok(())
712}
713
714#[derive(Debug, Clone)]
716pub struct DistributedSolverConfig {
717 pub pz: usize,
720 pub pr: usize,
721 pub halo: usize,
724 pub omega: f64,
726 pub max_outer_iters: usize,
728 pub tol: f64,
730 pub inner_sweeps: usize,
732}
733
734impl Default for DistributedSolverConfig {
735 fn default() -> Self {
736 Self {
737 pz: 2,
738 pr: 2,
739 halo: 1,
740 omega: 1.8,
741 max_outer_iters: 200,
742 tol: 1e-8,
743 inner_sweeps: 5,
744 }
745 }
746}
747
748#[derive(Debug, Clone)]
750pub struct DistributedSolveResult {
751 pub psi: Array2<f64>,
753 pub residual: f64,
755 pub iterations: usize,
757 pub converged: bool,
759}
760
761pub fn distributed_gs_solve(
779 psi: &Array2<f64>,
780 source: &Array2<f64>,
781 r_axis: &[f64],
782 z_axis: &[f64],
783 dr: f64,
784 dz: f64,
785 cfg: &DistributedSolverConfig,
786) -> FusionResult<DistributedSolveResult> {
787 let (nz, nr) = psi.dim();
788 if source.dim() != (nz, nr) {
789 return Err(FusionError::PhysicsViolation(format!(
790 "psi/source shape mismatch: {:?} vs {:?}",
791 psi.dim(),
792 source.dim()
793 )));
794 }
795 if r_axis.len() != nr || z_axis.len() != nz {
796 return Err(FusionError::PhysicsViolation(format!(
797 "Axis lengths don't match grid: r_axis={} nr={nr}, z_axis={} nz={nz}",
798 r_axis.len(),
799 z_axis.len()
800 )));
801 }
802 if cfg.omega <= 0.0 || cfg.omega >= 2.0 {
803 return Err(FusionError::PhysicsViolation(format!(
804 "SOR omega must be in (0, 2), got {}",
805 cfg.omega
806 )));
807 }
808 if !dr.is_finite() || !dz.is_finite() || dr <= 0.0 || dz <= 0.0 {
809 return Err(FusionError::PhysicsViolation(format!(
810 "Grid spacing must be finite > 0: dr={dr}, dz={dz}"
811 )));
812 }
813
814 let tiles = decompose_2d(nz, nr, cfg.pz, cfg.pr, cfg.halo)?;
815
816 let mut global_psi = psi.clone();
817 let dr_sq = dr * dr;
818 let dz_sq = dz * dz;
819
820 let mut converged = false;
821 let mut outer_iter = 0usize;
822 let mut residual = f64::MAX;
823
824 for outer in 0..cfg.max_outer_iters {
825 outer_iter = outer + 1;
826
827 let mut locals: Vec<Array2<f64>> = tiles
829 .iter()
830 .map(|t| extract_tile(&global_psi, t))
831 .collect::<FusionResult<Vec<_>>>()?;
832 let sources: Vec<Array2<f64>> = tiles
833 .iter()
834 .map(|t| extract_tile(source, t))
835 .collect::<FusionResult<Vec<_>>>()?;
836
837 use rayon::prelude::*;
840 locals
841 .par_iter_mut()
842 .zip(sources.par_iter())
843 .zip(tiles.par_iter())
844 .for_each(|((loc, src), tile)| {
845 let cz = tile.core_z_offset();
846 let cr = tile.core_r_offset();
847 let _lnz = tile.local_nz();
848 let _lnr = tile.local_nr();
849
850 for _sweep in 0..cfg.inner_sweeps {
851 for color in 0..2u8 {
853 for iz in 1..loc.nrows().saturating_sub(1) {
854 for ir in 1..loc.ncols().saturating_sub(1) {
855 if (iz + ir) % 2 != color as usize {
856 continue;
857 }
858 let gz = tile.z_start as i64 + iz as i64 - cz as i64;
863 let gr = tile.r_start as i64 + ir as i64 - cr as i64;
864 if gz <= 0 || gz >= (nz as i64 - 1) {
866 continue;
867 }
868 if gr <= 0 || gr >= (nr as i64 - 1) {
869 continue;
870 }
871
872 let r_val = r_axis[gr as usize];
873 if r_val <= 0.0 {
874 continue;
875 }
876
877 let psi_n = loc[[iz - 1, ir]];
882 let psi_s = loc[[iz + 1, ir]];
883 let psi_w = loc[[iz, ir - 1]];
884 let psi_e = loc[[iz, ir + 1]];
885
886 let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
887 let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
888 let c_z = 1.0 / dz_sq;
889 let center = 2.0 / dr_sq + 2.0 / dz_sq;
890
891 let rhs_val = src[[iz, ir]];
892 let numerator = c_z * (psi_n + psi_s)
893 + c_r_minus * psi_w
894 + c_r_plus * psi_e
895 + rhs_val;
896 if center.abs() < 1e-30 {
897 continue;
898 }
899 let psi_gs = numerator / center;
900 let old = loc[[iz, ir]];
901 loc[[iz, ir]] = old + cfg.omega * (psi_gs - old);
902 }
903 }
904 }
905 }
906 });
907
908 for (i, tile) in tiles.iter().enumerate() {
910 inject_tile(&mut global_psi, &locals[i], tile)?;
911 }
912
913 residual = gs_residual_l2(&global_psi, source, r_axis, dr, dz);
915
916 if residual < cfg.tol {
917 converged = true;
918 break;
919 }
920 }
921
922 if global_psi.iter().any(|v| !v.is_finite()) {
923 return Err(FusionError::PhysicsViolation(
924 "Distributed GS solve produced non-finite Ψ".to_string(),
925 ));
926 }
927
928 Ok(DistributedSolveResult {
929 psi: global_psi,
930 residual,
931 iterations: outer_iter,
932 converged,
933 })
934}
935
936pub fn gs_residual_l2(
942 psi: &Array2<f64>,
943 source: &Array2<f64>,
944 r_axis: &[f64],
945 dr: f64,
946 dz: f64,
947) -> f64 {
948 let (nz, nr) = psi.dim();
949 let dr_sq = dr * dr;
950 let dz_sq = dz * dz;
951 let mut accum = 0.0f64;
952 let mut count = 0usize;
953
954 for iz in 1..nz - 1 {
955 for ir in 1..nr - 1 {
956 let r_val = r_axis[ir];
957 if r_val <= 0.0 {
958 continue;
959 }
960 let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
961 let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
962 let c_z = 1.0 / dz_sq;
963 let center = 2.0 / dr_sq + 2.0 / dz_sq;
964
965 let l_psi = center * psi[[iz, ir]]
966 - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
967 - c_r_plus * psi[[iz, ir + 1]]
968 - c_r_minus * psi[[iz, ir - 1]];
969 let res = l_psi - source[[iz, ir]];
970 accum += res * res;
971 count += 1;
972 }
973 }
974 if count == 0 {
975 return 0.0;
976 }
977 (accum / count as f64).sqrt()
978}
979
980pub fn optimal_process_grid(nz: usize, nr: usize, nranks: usize) -> (usize, usize) {
984 let mut best_pz = 1;
985 let mut best_pr = nranks;
986 let mut best_cost = f64::MAX;
987
988 for pz in 1..=nranks {
989 if !nranks.is_multiple_of(pz) {
990 continue;
991 }
992 let pr = nranks / pz;
993 if pz > nz || pr > nr {
994 continue;
995 }
996 let tile_nz = nz as f64 / pz as f64;
998 let tile_nr = nr as f64 / pr as f64;
999 let perimeter = 2.0 * (tile_nz + tile_nr);
1000 let area = tile_nz * tile_nr;
1001 let cost = perimeter / area;
1002 if cost < best_cost {
1003 best_cost = cost;
1004 best_pz = pz;
1005 best_pr = pr;
1006 }
1007 }
1008 (best_pz, best_pr)
1009}
1010
1011#[allow(clippy::too_many_arguments)]
1017pub fn auto_distributed_gs_solve(
1018 psi: &Array2<f64>,
1019 source: &Array2<f64>,
1020 r_axis: &[f64],
1021 z_axis: &[f64],
1022 dr: f64,
1023 dz: f64,
1024 omega: f64,
1025 tol: f64,
1026 max_iters: usize,
1027) -> FusionResult<DistributedSolveResult> {
1028 let (nz, nr) = psi.dim();
1029 let nthreads = rayon::current_num_threads().max(1);
1030 let (pz, pr) = optimal_process_grid(nz, nr, nthreads);
1031
1032 let cfg = DistributedSolverConfig {
1033 pz,
1034 pr,
1035 halo: 1,
1036 omega,
1037 max_outer_iters: max_iters,
1038 tol,
1039 inner_sweeps: 5,
1040 };
1041 distributed_gs_solve(psi, source, r_axis, z_axis, dr, dz, &cfg)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::*;
1047
1048 fn sample_grid(nz: usize, nr: usize) -> Array2<f64> {
1049 Array2::from_shape_fn((nz, nr), |(i, j)| (i as f64) * 10.0 + j as f64)
1050 }
1051
1052 #[test]
1053 fn test_decompose_z_covers_domain() {
1054 let slices = decompose_z(17, 4, 1).expect("decomposition must succeed");
1055 assert_eq!(slices.len(), 4);
1056 assert_eq!(slices[0].z_start, 0);
1057 assert_eq!(slices.last().expect("slice expected").z_end, 17);
1058 let covered: usize = slices.iter().map(|s| s.local_nz).sum();
1059 assert_eq!(covered, 17);
1060 }
1061
1062 #[test]
1063 fn test_serial_halo_exchange_and_stitch_roundtrip() {
1064 let global = sample_grid(24, 9);
1065 let slices = decompose_z(global.nrows(), 3, 1).expect("decompose");
1066 let mut locals = split_with_halo(&global, &slices).expect("split");
1067 serial_halo_exchange(&mut locals, &slices).expect("exchange");
1068 let stitched = stitch_without_halo(&locals, &slices, global.ncols()).expect("stitch");
1069 let delta = l2_norm_delta(&stitched, &global).expect("delta");
1070 assert!(
1071 delta < 1e-12,
1072 "Serial halo exchange should preserve core rows"
1073 );
1074 }
1075
1076 #[test]
1077 fn test_pack_halo_errors_for_small_local_block() {
1078 let local = Array2::zeros((2, 4));
1079 let err = pack_halo_rows(&local, 1).expect_err("small local must error");
1080 match err {
1081 FusionError::PhysicsViolation(msg) => {
1082 assert!(msg.contains("insufficient rows"));
1083 }
1084 other => panic!("Unexpected error: {other:?}"),
1085 }
1086 }
1087
1088 #[test]
1089 fn test_apply_halo_shape_guard() {
1090 let mut local = Array2::zeros((6, 4));
1091 let bad_top = Array2::zeros((2, 5));
1092 let err = apply_halo_rows(&mut local, 1, Some(&bad_top), None).expect_err("shape mismatch");
1093 match err {
1094 FusionError::PhysicsViolation(msg) => assert!(msg.contains("shape mismatch")),
1095 other => panic!("Unexpected error: {other:?}"),
1096 }
1097 }
1098
1099 #[test]
1100 fn test_l2_norm_delta_zero_for_identical_arrays() {
1101 let a = sample_grid(8, 8);
1102 let b = a.clone();
1103 let d = l2_norm_delta(&a, &b).expect("delta");
1104 assert!(d.abs() < 1e-12);
1105 }
1106
1107 #[test]
1108 fn test_mpi_domain_rejects_non_finite_inputs() {
1109 let mut local = sample_grid(8, 4);
1110 local[[2, 2]] = f64::NAN;
1111 let err = pack_halo_rows(&local, 1).expect_err("non-finite local should fail");
1112 match err {
1113 FusionError::PhysicsViolation(msg) => assert!(msg.contains("non-finite")),
1114 other => panic!("Unexpected error: {other:?}"),
1115 }
1116
1117 let mut a = sample_grid(4, 4);
1118 let b = sample_grid(4, 4);
1119 a[[0, 0]] = f64::INFINITY;
1120 let err = l2_norm_delta(&a, &b).expect_err("non-finite delta input should fail");
1121 match err {
1122 FusionError::PhysicsViolation(msg) => assert!(msg.contains("finite")),
1123 other => panic!("Unexpected error: {other:?}"),
1124 }
1125 }
1126
1127 #[test]
1132 fn test_decompose_2d_covers_full_domain() {
1133 let tiles = decompose_2d(32, 24, 4, 3, 1).expect("decompose_2d");
1134 assert_eq!(tiles.len(), 12); let mut coverage = Array2::<u8>::zeros((32, 24));
1137 for t in &tiles {
1138 for iz in t.z_start..t.z_end {
1139 for ir in t.r_start..t.r_end {
1140 coverage[[iz, ir]] += 1;
1141 }
1142 }
1143 }
1144 assert!(
1145 coverage.iter().all(|&c| c == 1),
1146 "Every cell owned by exactly one tile"
1147 );
1148 }
1149
1150 #[test]
1151 fn test_decompose_2d_single_rank() {
1152 let tiles = decompose_2d(16, 16, 1, 1, 2).expect("single rank");
1153 assert_eq!(tiles.len(), 1);
1154 let t = &tiles[0];
1155 assert_eq!(t.z_start, 0);
1156 assert_eq!(t.z_end, 16);
1157 assert_eq!(t.r_start, 0);
1158 assert_eq!(t.r_end, 16);
1159 assert_eq!(t.padded_nz(), 16);
1160 assert_eq!(t.padded_nr(), 16);
1161 }
1162
1163 #[test]
1164 fn test_decompose_2d_neighbor_ranks() {
1165 let tiles = decompose_2d(16, 16, 2, 2, 1).expect("2x2");
1166 let tl = &tiles[0];
1168 assert_eq!(tl.neighbor_rank(1, 0), Some(2));
1169 assert_eq!(tl.neighbor_rank(0, 1), Some(1));
1170 assert_eq!(tl.neighbor_rank(-1, 0), None);
1171 assert_eq!(tl.neighbor_rank(0, -1), None);
1172 let br = &tiles[3];
1174 assert_eq!(br.neighbor_rank(-1, 0), Some(1));
1175 assert_eq!(br.neighbor_rank(0, -1), Some(2));
1176 assert_eq!(br.neighbor_rank(1, 0), None);
1177 assert_eq!(br.neighbor_rank(0, 1), None);
1178 }
1179
1180 #[test]
1181 fn test_extract_inject_roundtrip() {
1182 let global = sample_grid(24, 18);
1183 let tiles = decompose_2d(24, 18, 3, 2, 1).expect("decompose");
1184 let mut reconstructed = Array2::<f64>::zeros((24, 18));
1185 for t in &tiles {
1186 let local = extract_tile(&global, t).expect("extract");
1187 inject_tile(&mut reconstructed, &local, t).expect("inject");
1188 }
1189 let delta = l2_norm_delta(&global, &reconstructed).expect("delta");
1190 assert!(
1191 delta < 1e-12,
1192 "Extract→inject roundtrip must be lossless, got delta={delta}"
1193 );
1194 }
1195
1196 #[test]
1197 fn test_serial_halo_exchange_2d_correctness() {
1198 let global = Array2::from_shape_fn((16, 16), |(i, j)| (i as f64) * 100.0 + j as f64);
1201 let tiles = decompose_2d(16, 16, 2, 2, 1).expect("decompose");
1202 let mut locals: Vec<Array2<f64>> = tiles
1203 .iter()
1204 .map(|t| extract_tile(&global, t).expect("extract"))
1205 .collect();
1206 serial_halo_exchange_2d(&mut locals, &tiles).expect("halo exchange");
1207
1208 for (i, t) in tiles.iter().enumerate() {
1211 let loc = &locals[i];
1212 let z0 = t.z_start.saturating_sub(t.core_z_offset());
1213 let r0 = t.r_start.saturating_sub(t.core_r_offset());
1214 for lz in 0..loc.nrows() {
1215 for lr in 0..loc.ncols() {
1216 let gz = z0 + lz;
1217 let gr = r0 + lr;
1218 if gz < 16 && gr < 16 {
1219 let expect = global[[gz, gr]];
1220 let got = loc[[lz, lr]];
1221 assert!(
1222 (expect - got).abs() < 1e-12,
1223 "Tile {i} at local ({lz},{lr}) global ({gz},{gr}): expected {expect}, got {got}"
1224 );
1225 }
1226 }
1227 }
1228 }
1229 }
1230
1231 #[test]
1232 fn test_optimal_process_grid_square() {
1233 let (pz, pr) = optimal_process_grid(64, 64, 4);
1234 assert_eq!(pz, 2);
1235 assert_eq!(pr, 2);
1236 }
1237
1238 #[test]
1239 fn test_optimal_process_grid_rectangular() {
1240 let (pz, pr) = optimal_process_grid(128, 32, 8);
1242 assert!(
1243 pz >= pr,
1244 "Tall grid should bias toward Z decomposition: pz={pz}, pr={pr}"
1245 );
1246 assert_eq!(pz * pr, 8);
1247 }
1248
1249 #[test]
1250 fn test_gs_residual_l2_zero_for_exact_solution() {
1251 let nz = 16;
1255 let nr = 16;
1256 let dr = 0.01;
1257 let dz = 0.01;
1258 let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1259 let _z_axis: Vec<f64> = (0..nz).map(|i| -0.08 + i as f64 * dz).collect();
1260 let psi = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1261 _z_axis[iz] * _z_axis[iz] + r_axis[ir] * r_axis[ir]
1262 });
1263 let dr_sq = dr * dr;
1264 let dz_sq = dz * dz;
1265 let mut source = Array2::zeros((nz, nr));
1266 for iz in 1..nz - 1 {
1267 for ir in 1..nr - 1 {
1268 let r = r_axis[ir];
1269 let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r * dr);
1270 let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r * dr);
1271 let c_z = 1.0 / dz_sq;
1272 let center = 2.0 / dr_sq + 2.0 / dz_sq;
1273 source[[iz, ir]] = center * psi[[iz, ir]]
1274 - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
1275 - c_r_plus * psi[[iz, ir + 1]]
1276 - c_r_minus * psi[[iz, ir - 1]];
1277 }
1278 }
1279 let res = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1280 assert!(
1281 res < 1e-10,
1282 "Residual should be ~0 for manufactured solution, got {res}"
1283 );
1284 }
1285
1286 #[test]
1287 fn test_distributed_gs_solve_smoke() {
1288 let nz = 32;
1290 let nr = 32;
1291 let dr = 0.01;
1292 let dz = 0.01;
1293 let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1294 let z_axis: Vec<f64> = (0..nz).map(|i| -0.16 + i as f64 * dz).collect();
1295
1296 let psi = Array2::zeros((nz, nr));
1298 let r_mid = 1.0 + (nr as f64 / 2.0) * dr;
1299 let z_mid = 0.0;
1300 let source = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1301 let rr = r_axis[ir] - r_mid;
1302 let zz = z_axis[iz] - z_mid;
1303 -((-(rr * rr + zz * zz) / (0.05 * 0.05)).exp())
1304 });
1305
1306 let res_initial = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1308
1309 let cfg = DistributedSolverConfig {
1310 pz: 2,
1311 pr: 2,
1312 halo: 1,
1313 omega: 1.6,
1314 max_outer_iters: 200,
1315 tol: 1e-6,
1316 inner_sweeps: 10,
1317 };
1318 let result = distributed_gs_solve(&psi, &source, &r_axis, &z_axis, dr, dz, &cfg)
1319 .expect("distributed solve");
1320 assert!(
1322 result.residual < res_initial * 0.5,
1323 "Residual should decrease: initial={res_initial:.4}, final={:.4}",
1324 result.residual
1325 );
1326 let psi_absmax = result.psi.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
1328 assert!(
1329 psi_absmax > 1e-10,
1330 "Solution should be non-trivial, max |Ψ| = {psi_absmax}"
1331 );
1332 }
1333
1334 #[test]
1335 fn test_decompose_2d_rejects_invalid_inputs() {
1336 assert!(decompose_2d(16, 16, 0, 2, 1).is_err());
1337 assert!(decompose_2d(16, 16, 2, 0, 1).is_err());
1338 assert!(decompose_2d(1, 16, 2, 2, 1).is_err());
1339 assert!(decompose_2d(16, 16, 20, 2, 1).is_err());
1340 }
1341}