1use fusion_types::error::{FusionError, FusionResult};
13use ndarray::{s, Array2};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct DomainSlice {
17 pub rank: usize,
18 pub nranks: usize,
19 pub global_nz: usize,
20 pub local_nz: usize,
21 pub halo: usize,
22 pub z_start: usize,
23 pub z_end: usize,
24}
25
26impl DomainSlice {
27 pub fn has_upper_neighbor(&self) -> bool {
28 self.rank > 0
29 }
30
31 pub fn has_lower_neighbor(&self) -> bool {
32 self.rank + 1 < self.nranks
33 }
34}
35
36pub fn decompose_z(global_nz: usize, nranks: usize, halo: usize) -> FusionResult<Vec<DomainSlice>> {
37 if global_nz < 2 {
38 return Err(FusionError::PhysicsViolation(
39 "MPI decomposition requires global_nz >= 2".to_string(),
40 ));
41 }
42 if nranks < 1 {
43 return Err(FusionError::PhysicsViolation(
44 "MPI decomposition requires nranks >= 1".to_string(),
45 ));
46 }
47 if nranks > global_nz {
48 return Err(FusionError::PhysicsViolation(format!(
49 "Cannot split global_nz={global_nz} across nranks={nranks}"
50 )));
51 }
52
53 let base = global_nz / nranks;
54 let rem = global_nz % nranks;
55 let mut out = Vec::with_capacity(nranks);
56 let mut cursor = 0usize;
57 for rank in 0..nranks {
58 let local_nz = base + usize::from(rank < rem);
59 let z_start = cursor;
60 let z_end = z_start + local_nz;
61 cursor = z_end;
62 out.push(DomainSlice {
63 rank,
64 nranks,
65 global_nz,
66 local_nz,
67 halo,
68 z_start,
69 z_end,
70 });
71 }
72 Ok(out)
73}
74
75pub fn pack_halo_rows(
76 local: &Array2<f64>,
77 halo: usize,
78) -> FusionResult<(Array2<f64>, Array2<f64>)> {
79 if halo == 0 {
80 return Err(FusionError::PhysicsViolation(
81 "Halo width must be >= 1".to_string(),
82 ));
83 }
84 if local.nrows() <= 2 * halo {
85 return Err(FusionError::PhysicsViolation(format!(
86 "Local block has insufficient rows {} for halo={halo}",
87 local.nrows()
88 )));
89 }
90 if local.iter().any(|v| !v.is_finite()) {
91 return Err(FusionError::PhysicsViolation(
92 "Local block contains non-finite values".to_string(),
93 ));
94 }
95 let top = local.slice(s![halo..(2 * halo), ..]).to_owned();
96 let bottom = local
97 .slice(s![(local.nrows() - 2 * halo)..(local.nrows() - halo), ..])
98 .to_owned();
99 if top.iter().any(|v| !v.is_finite()) || bottom.iter().any(|v| !v.is_finite()) {
100 return Err(FusionError::PhysicsViolation(
101 "Packed halo rows contain non-finite values".to_string(),
102 ));
103 }
104 Ok((top, bottom))
105}
106
107pub fn apply_halo_rows(
108 local: &mut Array2<f64>,
109 halo: usize,
110 recv_top: Option<&Array2<f64>>,
111 recv_bottom: Option<&Array2<f64>>,
112) -> FusionResult<()> {
113 if halo == 0 {
114 return Err(FusionError::PhysicsViolation(
115 "Halo width must be >= 1".to_string(),
116 ));
117 }
118 if local.nrows() <= 2 * halo {
119 return Err(FusionError::PhysicsViolation(format!(
120 "Local block has insufficient rows {} for halo={halo}",
121 local.nrows()
122 )));
123 }
124 if local.iter().any(|v| !v.is_finite()) {
125 return Err(FusionError::PhysicsViolation(
126 "Local block contains non-finite values".to_string(),
127 ));
128 }
129 if let Some(top) = recv_top {
130 if top.dim() != (halo, local.ncols()) {
131 return Err(FusionError::PhysicsViolation(format!(
132 "Top halo shape mismatch: expected ({halo}, {}), got {:?}",
133 local.ncols(),
134 top.dim()
135 )));
136 }
137 if top.iter().any(|v| !v.is_finite()) {
138 return Err(FusionError::PhysicsViolation(
139 "Top halo contains non-finite values".to_string(),
140 ));
141 }
142 local.slice_mut(s![0..halo, ..]).assign(top);
143 }
144 if let Some(bottom) = recv_bottom {
145 if bottom.dim() != (halo, local.ncols()) {
146 return Err(FusionError::PhysicsViolation(format!(
147 "Bottom halo shape mismatch: expected ({halo}, {}), got {:?}",
148 local.ncols(),
149 bottom.dim()
150 )));
151 }
152 if bottom.iter().any(|v| !v.is_finite()) {
153 return Err(FusionError::PhysicsViolation(
154 "Bottom halo contains non-finite values".to_string(),
155 ));
156 }
157 let n = local.nrows();
158 local.slice_mut(s![(n - halo)..n, ..]).assign(bottom);
159 }
160 if local.iter().any(|v| !v.is_finite()) {
161 return Err(FusionError::PhysicsViolation(
162 "Applying halo rows produced non-finite values".to_string(),
163 ));
164 }
165 Ok(())
166}
167
168pub fn split_with_halo(
169 global: &Array2<f64>,
170 slices: &[DomainSlice],
171) -> FusionResult<Vec<Array2<f64>>> {
172 if slices.is_empty() {
173 return Err(FusionError::PhysicsViolation(
174 "No slices provided for split_with_halo".to_string(),
175 ));
176 }
177 if global.iter().any(|v| !v.is_finite()) {
178 return Err(FusionError::PhysicsViolation(
179 "Global array contains non-finite values".to_string(),
180 ));
181 }
182 let mut out = Vec::with_capacity(slices.len());
183 for sdef in slices {
184 if sdef.global_nz != global.nrows() {
185 return Err(FusionError::PhysicsViolation(format!(
186 "Slice/global mismatch: slice.global_nz={} global.nrows()={}",
187 sdef.global_nz,
188 global.nrows()
189 )));
190 }
191 if sdef.z_start >= sdef.z_end || sdef.z_end > sdef.global_nz {
192 return Err(FusionError::PhysicsViolation(format!(
193 "Invalid slice bounds z_start={} z_end={} global_nz={}",
194 sdef.z_start, sdef.z_end, sdef.global_nz
195 )));
196 }
197 let start = sdef.z_start.saturating_sub(sdef.halo);
198 let end = (sdef.z_end + sdef.halo).min(sdef.global_nz);
199 let mut local = Array2::zeros((end - start, global.ncols()));
200 local.assign(&global.slice(s![start..end, ..]));
201 if local.iter().any(|v| !v.is_finite()) {
202 return Err(FusionError::PhysicsViolation(
203 "Split local block contains non-finite values".to_string(),
204 ));
205 }
206 out.push(local);
207 }
208 Ok(out)
209}
210
211pub fn stitch_without_halo(
212 locals: &[Array2<f64>],
213 slices: &[DomainSlice],
214 ncols: usize,
215) -> FusionResult<Array2<f64>> {
216 if locals.len() != slices.len() {
217 return Err(FusionError::PhysicsViolation(format!(
218 "locals/slices mismatch: {} vs {}",
219 locals.len(),
220 slices.len()
221 )));
222 }
223 if slices.is_empty() {
224 return Err(FusionError::PhysicsViolation(
225 "No slices provided for stitch_without_halo".to_string(),
226 ));
227 }
228 let global_nz = slices
229 .last()
230 .map(|s| s.global_nz)
231 .ok_or_else(|| FusionError::PhysicsViolation("No slices provided".to_string()))?;
232 let mut global = Array2::zeros((global_nz, ncols));
233 for (local, sdef) in locals.iter().zip(slices.iter()) {
234 if local.iter().any(|v| !v.is_finite()) {
235 return Err(FusionError::PhysicsViolation(
236 "Local block contains non-finite values".to_string(),
237 ));
238 }
239 if local.ncols() != ncols {
240 return Err(FusionError::PhysicsViolation(format!(
241 "Local ncols mismatch: expected {ncols}, got {}",
242 local.ncols()
243 )));
244 }
245 let core_start = usize::from(sdef.z_start > 0) * sdef.halo;
246 let core_end = core_start + sdef.local_nz;
247 if core_end > local.nrows() {
248 return Err(FusionError::PhysicsViolation(format!(
249 "Local core range out of bounds: rows={}, core_end={core_end}",
250 local.nrows()
251 )));
252 }
253 global
254 .slice_mut(s![sdef.z_start..sdef.z_end, ..])
255 .assign(&local.slice(s![core_start..core_end, ..]));
256 }
257 if global.iter().any(|v| !v.is_finite()) {
258 return Err(FusionError::PhysicsViolation(
259 "Stitched global array contains non-finite values".to_string(),
260 ));
261 }
262 Ok(global)
263}
264
265pub fn serial_halo_exchange(
266 locals: &mut [Array2<f64>],
267 slices: &[DomainSlice],
268) -> FusionResult<()> {
269 if locals.len() != slices.len() {
270 return Err(FusionError::PhysicsViolation(format!(
271 "locals/slices mismatch: {} vs {}",
272 locals.len(),
273 slices.len()
274 )));
275 }
276 let mut top_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
277 let mut bottom_send: Vec<Option<Array2<f64>>> = vec![None; locals.len()];
278 for (i, (local, sdef)) in locals.iter().zip(slices.iter()).enumerate() {
279 if local.iter().any(|v| !v.is_finite()) {
280 return Err(FusionError::PhysicsViolation(format!(
281 "Local block at index {i} contains non-finite values"
282 )));
283 }
284 if sdef.halo == 0 {
285 continue;
286 }
287 let (top, bottom) = pack_halo_rows(local, sdef.halo)?;
288 top_send[i] = Some(top);
289 bottom_send[i] = Some(bottom);
290 }
291
292 for i in 0..locals.len() {
293 let halo = slices[i].halo;
294 if halo == 0 {
295 continue;
296 }
297 let recv_top = if i > 0 {
298 bottom_send[i - 1].as_ref()
299 } else {
300 None
301 };
302 let recv_bottom = if i + 1 < locals.len() {
303 top_send[i + 1].as_ref()
304 } else {
305 None
306 };
307 apply_halo_rows(&mut locals[i], halo, recv_top, recv_bottom)?;
308 }
309 if locals.iter().any(|arr| arr.iter().any(|v| !v.is_finite())) {
310 return Err(FusionError::PhysicsViolation(
311 "Serial halo exchange produced non-finite values".to_string(),
312 ));
313 }
314 Ok(())
315}
316
317pub fn l2_norm_delta(a: &Array2<f64>, b: &Array2<f64>) -> FusionResult<f64> {
318 if a.dim() != b.dim() {
319 return Err(FusionError::PhysicsViolation(format!(
320 "l2_norm_delta shape mismatch {:?} vs {:?}",
321 a.dim(),
322 b.dim()
323 )));
324 }
325 if a.iter().any(|v| !v.is_finite()) || b.iter().any(|v| !v.is_finite()) {
326 return Err(FusionError::PhysicsViolation(
327 "l2_norm_delta inputs must be finite".to_string(),
328 ));
329 }
330 let mut accum = 0.0f64;
331 for (av, bv) in a.iter().zip(b.iter()) {
332 let d = av - bv;
333 accum += d * d;
334 }
335 let out = accum.sqrt();
336 if !out.is_finite() {
337 return Err(FusionError::PhysicsViolation(
338 "l2_norm_delta produced non-finite result".to_string(),
339 ));
340 }
341 Ok(out)
342}
343
344#[derive(Debug, Clone, PartialEq, Eq)]
362pub struct CartesianTile {
363 pub rank: usize,
365 pub pz_idx: usize,
367 pub pr_idx: usize,
369 pub pz: usize,
371 pub pr: usize,
372 pub global_nz: usize,
374 pub global_nr: usize,
375 pub halo: usize,
377 pub z_start: usize,
379 pub z_end: usize,
380 pub r_start: usize,
382 pub r_end: usize,
383}
384
385impl CartesianTile {
386 pub fn local_nz(&self) -> usize {
388 self.z_end - self.z_start
389 }
390 pub fn local_nr(&self) -> usize {
392 self.r_end - self.r_start
393 }
394 pub fn padded_nz(&self) -> usize {
396 let top = if self.pz_idx > 0 { self.halo } else { 0 };
397 let bot = if self.pz_idx + 1 < self.pz {
398 self.halo
399 } else {
400 0
401 };
402 self.local_nz() + top + bot
403 }
404 pub fn padded_nr(&self) -> usize {
406 let left = if self.pr_idx > 0 { self.halo } else { 0 };
407 let right = if self.pr_idx + 1 < self.pr {
408 self.halo
409 } else {
410 0
411 };
412 self.local_nr() + left + right
413 }
414 pub fn core_z_offset(&self) -> usize {
416 if self.pz_idx > 0 {
417 self.halo
418 } else {
419 0
420 }
421 }
422 pub fn core_r_offset(&self) -> usize {
424 if self.pr_idx > 0 {
425 self.halo
426 } else {
427 0
428 }
429 }
430 pub fn has_neighbor_top(&self) -> bool {
431 self.pz_idx > 0
432 }
433 pub fn has_neighbor_bottom(&self) -> bool {
434 self.pz_idx + 1 < self.pz
435 }
436 pub fn has_neighbor_left(&self) -> bool {
437 self.pr_idx > 0
438 }
439 pub fn has_neighbor_right(&self) -> bool {
440 self.pr_idx + 1 < self.pr
441 }
442 pub fn neighbor_rank(&self, dz: i32, dr: i32) -> Option<usize> {
445 let nz = self.pz_idx as i32 + dz;
446 let nr = self.pr_idx as i32 + dr;
447 if nz < 0 || nz >= self.pz as i32 || nr < 0 || nr >= self.pr as i32 {
448 return None;
449 }
450 Some(nz as usize * self.pr + nr as usize)
451 }
452}
453
454pub fn decompose_2d(
459 global_nz: usize,
460 global_nr: usize,
461 pz: usize,
462 pr: usize,
463 halo: usize,
464) -> FusionResult<Vec<CartesianTile>> {
465 if pz == 0 || pr == 0 {
466 return Err(FusionError::PhysicsViolation(
467 "Process grid dimensions pz, pr must be >= 1".to_string(),
468 ));
469 }
470 if pz > global_nz || pr > global_nr {
471 return Err(FusionError::PhysicsViolation(format!(
472 "Cannot split ({global_nz}×{global_nr}) across ({pz}×{pr}) processes"
473 )));
474 }
475 if global_nz < 2 || global_nr < 2 {
476 return Err(FusionError::PhysicsViolation(
477 "Global grid must be at least 2×2".to_string(),
478 ));
479 }
480
481 let z_splits = balanced_split(global_nz, pz);
483 let r_splits = balanced_split(global_nr, pr);
484
485 let mut tiles = Vec::with_capacity(pz * pr);
486 let mut z_cursor = 0usize;
487 for (iz, nz_local) in z_splits.iter().copied().enumerate().take(pz) {
488 let z_start = z_cursor;
489 let z_end = z_start + nz_local;
490 z_cursor = z_end;
491
492 let mut r_cursor = 0usize;
493 for (ir, nr_local) in r_splits.iter().copied().enumerate().take(pr) {
494 let r_start = r_cursor;
495 let r_end = r_start + nr_local;
496 r_cursor = r_end;
497
498 tiles.push(CartesianTile {
499 rank: iz * pr + ir,
500 pz_idx: iz,
501 pr_idx: ir,
502 pz,
503 pr,
504 global_nz,
505 global_nr,
506 halo,
507 z_start,
508 z_end,
509 r_start,
510 r_end,
511 });
512 }
513 }
514 Ok(tiles)
515}
516
517fn balanced_split(n: usize, k: usize) -> Vec<usize> {
519 let base = n / k;
520 let rem = n % k;
521 (0..k).map(|i| base + usize::from(i < rem)).collect()
522}
523
524pub fn extract_tile(global: &Array2<f64>, tile: &CartesianTile) -> FusionResult<Array2<f64>> {
526 let (gnz, gnr) = global.dim();
527 if gnz != tile.global_nz || gnr != tile.global_nr {
528 return Err(FusionError::PhysicsViolation(format!(
529 "Global shape ({gnz},{gnr}) doesn't match tile expectation ({},{})",
530 tile.global_nz, tile.global_nr
531 )));
532 }
533 let z0 = tile.z_start.saturating_sub(tile.core_z_offset());
534 let z1 = (tile.z_end
535 + if tile.has_neighbor_bottom() {
536 tile.halo
537 } else {
538 0
539 })
540 .min(gnz);
541 let r0 = tile.r_start.saturating_sub(tile.core_r_offset());
542 let r1 = (tile.r_end
543 + if tile.has_neighbor_right() {
544 tile.halo
545 } else {
546 0
547 })
548 .min(gnr);
549
550 let local = global.slice(s![z0..z1, r0..r1]).to_owned();
551 if local.iter().any(|v| !v.is_finite()) {
552 return Err(FusionError::PhysicsViolation(
553 "Extracted tile contains non-finite values".to_string(),
554 ));
555 }
556 Ok(local)
557}
558
559pub fn inject_tile(
561 global: &mut Array2<f64>,
562 local: &Array2<f64>,
563 tile: &CartesianTile,
564) -> FusionResult<()> {
565 let cz = tile.core_z_offset();
566 let cr = tile.core_r_offset();
567 let lnz = tile.local_nz();
568 let lnr = tile.local_nr();
569 if cz + lnz > local.nrows() || cr + lnr > local.ncols() {
570 return Err(FusionError::PhysicsViolation(format!(
571 "Tile core out of bounds: local shape {:?}, core ({cz}+{lnz}, {cr}+{lnr})",
572 local.dim()
573 )));
574 }
575 let core = local.slice(s![cz..(cz + lnz), cr..(cr + lnr)]);
576 if core.iter().any(|v| !v.is_finite()) {
577 return Err(FusionError::PhysicsViolation(
578 "Tile core to inject contains non-finite values".to_string(),
579 ));
580 }
581 global
582 .slice_mut(s![tile.z_start..tile.z_end, tile.r_start..tile.r_end])
583 .assign(&core);
584 Ok(())
585}
586
587pub fn serial_halo_exchange_2d(
594 locals: &mut [Array2<f64>],
595 tiles: &[CartesianTile],
596) -> FusionResult<()> {
597 if locals.len() != tiles.len() {
598 return Err(FusionError::PhysicsViolation(format!(
599 "locals/tiles length mismatch: {} vs {}",
600 locals.len(),
601 tiles.len()
602 )));
603 }
604 let ntiles = tiles.len();
605 if ntiles == 0 {
606 return Ok(());
607 }
608 let halo = tiles[0].halo;
609 if halo == 0 {
610 return Ok(());
611 }
612
613 #[derive(Clone, Copy)]
616 enum Face {
617 Top,
618 Bottom,
619 Left,
620 Right,
621 }
622 let mut messages: Vec<(usize, Face, Array2<f64>)> = Vec::new();
623
624 for (i, tile) in tiles.iter().enumerate() {
625 let loc = &locals[i];
626 let cz = tile.core_z_offset();
627 let cr = tile.core_r_offset();
628 let lnz = tile.local_nz();
629 let lnr = tile.local_nr();
630
631 if let Some(dest) = tile.neighbor_rank(-1, 0) {
633 let strip = loc.slice(s![cz..(cz + halo), cr..(cr + lnr)]).to_owned();
634 messages.push((dest, Face::Bottom, strip));
635 }
636 if let Some(dest) = tile.neighbor_rank(1, 0) {
638 let strip = loc
639 .slice(s![(cz + lnz - halo)..(cz + lnz), cr..(cr + lnr)])
640 .to_owned();
641 messages.push((dest, Face::Top, strip));
642 }
643 if let Some(dest) = tile.neighbor_rank(0, -1) {
645 let strip = loc.slice(s![cz..(cz + lnz), cr..(cr + halo)]).to_owned();
646 messages.push((dest, Face::Right, strip));
647 }
648 if let Some(dest) = tile.neighbor_rank(0, 1) {
650 let strip = loc
651 .slice(s![cz..(cz + lnz), (cr + lnr - halo)..(cr + lnr)])
652 .to_owned();
653 messages.push((dest, Face::Left, strip));
654 }
655 }
656
657 for (dest, face, data) in messages {
659 let tile = &tiles[dest];
660 let loc = &mut locals[dest];
661 let cz = tile.core_z_offset();
662 let cr = tile.core_r_offset();
663 let lnz = tile.local_nz();
664 let lnr = tile.local_nr();
665
666 match face {
667 Face::Top => {
668 if cz >= halo {
670 loc.slice_mut(s![(cz - halo)..cz, cr..(cr + lnr)])
671 .assign(&data);
672 }
673 }
674 Face::Bottom => {
675 let row_start = cz + lnz;
677 let row_end = row_start + halo;
678 if row_end <= loc.nrows() {
679 loc.slice_mut(s![row_start..row_end, cr..(cr + lnr)])
680 .assign(&data);
681 }
682 }
683 Face::Left => {
684 if cr >= halo {
686 loc.slice_mut(s![cz..(cz + lnz), (cr - halo)..cr])
687 .assign(&data);
688 }
689 }
690 Face::Right => {
691 let col_start = cr + lnr;
693 let col_end = col_start + halo;
694 if col_end <= loc.ncols() {
695 loc.slice_mut(s![cz..(cz + lnz), col_start..col_end])
696 .assign(&data);
697 }
698 }
699 }
700 }
701
702 for loc in locals.iter() {
704 if loc.iter().any(|v| !v.is_finite()) {
705 return Err(FusionError::PhysicsViolation(
706 "2D halo exchange produced non-finite values".to_string(),
707 ));
708 }
709 }
710 Ok(())
711}
712
713#[derive(Debug, Clone)]
715pub struct DistributedSolverConfig {
716 pub pz: usize,
719 pub pr: usize,
720 pub halo: usize,
723 pub omega: f64,
725 pub max_outer_iters: usize,
727 pub tol: f64,
729 pub inner_sweeps: usize,
731}
732
733impl Default for DistributedSolverConfig {
734 fn default() -> Self {
735 Self {
736 pz: 2,
737 pr: 2,
738 halo: 1,
739 omega: 1.8,
740 max_outer_iters: 200,
741 tol: 1e-8,
742 inner_sweeps: 5,
743 }
744 }
745}
746
747#[derive(Debug, Clone)]
749pub struct DistributedSolveResult {
750 pub psi: Array2<f64>,
752 pub residual: f64,
754 pub iterations: usize,
756 pub converged: bool,
758}
759
760pub fn distributed_gs_solve(
778 psi: &Array2<f64>,
779 source: &Array2<f64>,
780 r_axis: &[f64],
781 z_axis: &[f64],
782 dr: f64,
783 dz: f64,
784 cfg: &DistributedSolverConfig,
785) -> FusionResult<DistributedSolveResult> {
786 let (nz, nr) = psi.dim();
787 if source.dim() != (nz, nr) {
788 return Err(FusionError::PhysicsViolation(format!(
789 "psi/source shape mismatch: {:?} vs {:?}",
790 psi.dim(),
791 source.dim()
792 )));
793 }
794 if r_axis.len() != nr || z_axis.len() != nz {
795 return Err(FusionError::PhysicsViolation(format!(
796 "Axis lengths don't match grid: r_axis={} nr={nr}, z_axis={} nz={nz}",
797 r_axis.len(),
798 z_axis.len()
799 )));
800 }
801 if cfg.omega <= 0.0 || cfg.omega >= 2.0 {
802 return Err(FusionError::PhysicsViolation(format!(
803 "SOR omega must be in (0, 2), got {}",
804 cfg.omega
805 )));
806 }
807 if !dr.is_finite() || !dz.is_finite() || dr <= 0.0 || dz <= 0.0 {
808 return Err(FusionError::PhysicsViolation(format!(
809 "Grid spacing must be finite > 0: dr={dr}, dz={dz}"
810 )));
811 }
812
813 let tiles = decompose_2d(nz, nr, cfg.pz, cfg.pr, cfg.halo)?;
814
815 let mut global_psi = psi.clone();
816 let dr_sq = dr * dr;
817 let dz_sq = dz * dz;
818
819 let mut converged = false;
820 let mut outer_iter = 0usize;
821 let mut residual = f64::MAX;
822
823 for outer in 0..cfg.max_outer_iters {
824 outer_iter = outer + 1;
825
826 let mut locals: Vec<Array2<f64>> = tiles
828 .iter()
829 .map(|t| extract_tile(&global_psi, t))
830 .collect::<FusionResult<Vec<_>>>()?;
831 let sources: Vec<Array2<f64>> = tiles
832 .iter()
833 .map(|t| extract_tile(source, t))
834 .collect::<FusionResult<Vec<_>>>()?;
835
836 use rayon::prelude::*;
839 locals
840 .par_iter_mut()
841 .zip(sources.par_iter())
842 .zip(tiles.par_iter())
843 .for_each(|((loc, src), tile)| {
844 let cz = tile.core_z_offset();
845 let cr = tile.core_r_offset();
846 let _lnz = tile.local_nz();
847 let _lnr = tile.local_nr();
848
849 for _sweep in 0..cfg.inner_sweeps {
850 for color in 0..2u8 {
852 for iz in 1..loc.nrows().saturating_sub(1) {
853 for ir in 1..loc.ncols().saturating_sub(1) {
854 if (iz + ir) % 2 != color as usize {
855 continue;
856 }
857 let gz = tile.z_start as i64 + iz as i64 - cz as i64;
862 let gr = tile.r_start as i64 + ir as i64 - cr as i64;
863 if gz <= 0 || gz >= (nz as i64 - 1) {
865 continue;
866 }
867 if gr <= 0 || gr >= (nr as i64 - 1) {
868 continue;
869 }
870
871 let r_val = r_axis[gr as usize];
872 if r_val <= 0.0 {
873 continue;
874 }
875
876 let psi_n = loc[[iz - 1, ir]];
881 let psi_s = loc[[iz + 1, ir]];
882 let psi_w = loc[[iz, ir - 1]];
883 let psi_e = loc[[iz, ir + 1]];
884
885 let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
886 let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
887 let c_z = 1.0 / dz_sq;
888 let center = 2.0 / dr_sq + 2.0 / dz_sq;
889
890 let rhs_val = src[[iz, ir]];
891 let numerator = c_z * (psi_n + psi_s)
892 + c_r_minus * psi_w
893 + c_r_plus * psi_e
894 + rhs_val;
895 if center.abs() < 1e-30 {
896 continue;
897 }
898 let psi_gs = numerator / center;
899 let old = loc[[iz, ir]];
900 loc[[iz, ir]] = old + cfg.omega * (psi_gs - old);
901 }
902 }
903 }
904 }
905 });
906
907 for (i, tile) in tiles.iter().enumerate() {
909 inject_tile(&mut global_psi, &locals[i], tile)?;
910 }
911
912 residual = gs_residual_l2(&global_psi, source, r_axis, dr, dz);
914
915 if residual < cfg.tol {
916 converged = true;
917 break;
918 }
919 }
920
921 if global_psi.iter().any(|v| !v.is_finite()) {
922 return Err(FusionError::PhysicsViolation(
923 "Distributed GS solve produced non-finite Ψ".to_string(),
924 ));
925 }
926
927 Ok(DistributedSolveResult {
928 psi: global_psi,
929 residual,
930 iterations: outer_iter,
931 converged,
932 })
933}
934
935pub fn gs_residual_l2(
941 psi: &Array2<f64>,
942 source: &Array2<f64>,
943 r_axis: &[f64],
944 dr: f64,
945 dz: f64,
946) -> f64 {
947 let (nz, nr) = psi.dim();
948 let dr_sq = dr * dr;
949 let dz_sq = dz * dz;
950 let mut accum = 0.0f64;
951 let mut count = 0usize;
952
953 for iz in 1..nz - 1 {
954 for ir in 1..nr - 1 {
955 let r_val = r_axis[ir];
956 if r_val <= 0.0 {
957 continue;
958 }
959 let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r_val * dr);
960 let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r_val * dr);
961 let c_z = 1.0 / dz_sq;
962 let center = 2.0 / dr_sq + 2.0 / dz_sq;
963
964 let l_psi = center * psi[[iz, ir]]
965 - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
966 - c_r_plus * psi[[iz, ir + 1]]
967 - c_r_minus * psi[[iz, ir - 1]];
968 let res = l_psi - source[[iz, ir]];
969 accum += res * res;
970 count += 1;
971 }
972 }
973 if count == 0 {
974 return 0.0;
975 }
976 (accum / count as f64).sqrt()
977}
978
979pub fn optimal_process_grid(nz: usize, nr: usize, nranks: usize) -> (usize, usize) {
983 let mut best_pz = 1;
984 let mut best_pr = nranks;
985 let mut best_cost = f64::MAX;
986
987 for pz in 1..=nranks {
988 if !nranks.is_multiple_of(pz) {
989 continue;
990 }
991 let pr = nranks / pz;
992 if pz > nz || pr > nr {
993 continue;
994 }
995 let tile_nz = nz as f64 / pz as f64;
997 let tile_nr = nr as f64 / pr as f64;
998 let perimeter = 2.0 * (tile_nz + tile_nr);
999 let area = tile_nz * tile_nr;
1000 let cost = perimeter / area;
1001 if cost < best_cost {
1002 best_cost = cost;
1003 best_pz = pz;
1004 best_pr = pr;
1005 }
1006 }
1007 (best_pz, best_pr)
1008}
1009
1010#[allow(clippy::too_many_arguments)]
1016pub fn auto_distributed_gs_solve(
1017 psi: &Array2<f64>,
1018 source: &Array2<f64>,
1019 r_axis: &[f64],
1020 z_axis: &[f64],
1021 dr: f64,
1022 dz: f64,
1023 omega: f64,
1024 tol: f64,
1025 max_iters: usize,
1026) -> FusionResult<DistributedSolveResult> {
1027 let (nz, nr) = psi.dim();
1028 let nthreads = rayon::current_num_threads().max(1);
1029 let (pz, pr) = optimal_process_grid(nz, nr, nthreads);
1030
1031 let cfg = DistributedSolverConfig {
1032 pz,
1033 pr,
1034 halo: 1,
1035 omega,
1036 max_outer_iters: max_iters,
1037 tol,
1038 inner_sweeps: 5,
1039 };
1040 distributed_gs_solve(psi, source, r_axis, z_axis, dr, dz, &cfg)
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045 use super::*;
1046
1047 fn sample_grid(nz: usize, nr: usize) -> Array2<f64> {
1048 Array2::from_shape_fn((nz, nr), |(i, j)| (i as f64) * 10.0 + j as f64)
1049 }
1050
1051 #[test]
1052 fn test_decompose_z_covers_domain() {
1053 let slices = decompose_z(17, 4, 1).expect("decomposition must succeed");
1054 assert_eq!(slices.len(), 4);
1055 assert_eq!(slices[0].z_start, 0);
1056 assert_eq!(slices.last().expect("slice expected").z_end, 17);
1057 let covered: usize = slices.iter().map(|s| s.local_nz).sum();
1058 assert_eq!(covered, 17);
1059 }
1060
1061 #[test]
1062 fn test_serial_halo_exchange_and_stitch_roundtrip() {
1063 let global = sample_grid(24, 9);
1064 let slices = decompose_z(global.nrows(), 3, 1).expect("decompose");
1065 let mut locals = split_with_halo(&global, &slices).expect("split");
1066 serial_halo_exchange(&mut locals, &slices).expect("exchange");
1067 let stitched = stitch_without_halo(&locals, &slices, global.ncols()).expect("stitch");
1068 let delta = l2_norm_delta(&stitched, &global).expect("delta");
1069 assert!(
1070 delta < 1e-12,
1071 "Serial halo exchange should preserve core rows"
1072 );
1073 }
1074
1075 #[test]
1076 fn test_pack_halo_errors_for_small_local_block() {
1077 let local = Array2::zeros((2, 4));
1078 let err = pack_halo_rows(&local, 1).expect_err("small local must error");
1079 match err {
1080 FusionError::PhysicsViolation(msg) => {
1081 assert!(msg.contains("insufficient rows"));
1082 }
1083 other => panic!("Unexpected error: {other:?}"),
1084 }
1085 }
1086
1087 #[test]
1088 fn test_apply_halo_shape_guard() {
1089 let mut local = Array2::zeros((6, 4));
1090 let bad_top = Array2::zeros((2, 5));
1091 let err = apply_halo_rows(&mut local, 1, Some(&bad_top), None).expect_err("shape mismatch");
1092 match err {
1093 FusionError::PhysicsViolation(msg) => assert!(msg.contains("shape mismatch")),
1094 other => panic!("Unexpected error: {other:?}"),
1095 }
1096 }
1097
1098 #[test]
1099 fn test_l2_norm_delta_zero_for_identical_arrays() {
1100 let a = sample_grid(8, 8);
1101 let b = a.clone();
1102 let d = l2_norm_delta(&a, &b).expect("delta");
1103 assert!(d.abs() < 1e-12);
1104 }
1105
1106 #[test]
1107 fn test_mpi_domain_rejects_non_finite_inputs() {
1108 let mut local = sample_grid(8, 4);
1109 local[[2, 2]] = f64::NAN;
1110 let err = pack_halo_rows(&local, 1).expect_err("non-finite local should fail");
1111 match err {
1112 FusionError::PhysicsViolation(msg) => assert!(msg.contains("non-finite")),
1113 other => panic!("Unexpected error: {other:?}"),
1114 }
1115
1116 let mut a = sample_grid(4, 4);
1117 let b = sample_grid(4, 4);
1118 a[[0, 0]] = f64::INFINITY;
1119 let err = l2_norm_delta(&a, &b).expect_err("non-finite delta input should fail");
1120 match err {
1121 FusionError::PhysicsViolation(msg) => assert!(msg.contains("finite")),
1122 other => panic!("Unexpected error: {other:?}"),
1123 }
1124 }
1125
1126 #[test]
1131 fn test_decompose_2d_covers_full_domain() {
1132 let tiles = decompose_2d(32, 24, 4, 3, 1).expect("decompose_2d");
1133 assert_eq!(tiles.len(), 12); let mut coverage = Array2::<u8>::zeros((32, 24));
1136 for t in &tiles {
1137 for iz in t.z_start..t.z_end {
1138 for ir in t.r_start..t.r_end {
1139 coverage[[iz, ir]] += 1;
1140 }
1141 }
1142 }
1143 assert!(
1144 coverage.iter().all(|&c| c == 1),
1145 "Every cell owned by exactly one tile"
1146 );
1147 }
1148
1149 #[test]
1150 fn test_decompose_2d_single_rank() {
1151 let tiles = decompose_2d(16, 16, 1, 1, 2).expect("single rank");
1152 assert_eq!(tiles.len(), 1);
1153 let t = &tiles[0];
1154 assert_eq!(t.z_start, 0);
1155 assert_eq!(t.z_end, 16);
1156 assert_eq!(t.r_start, 0);
1157 assert_eq!(t.r_end, 16);
1158 assert_eq!(t.padded_nz(), 16);
1159 assert_eq!(t.padded_nr(), 16);
1160 }
1161
1162 #[test]
1163 fn test_decompose_2d_neighbor_ranks() {
1164 let tiles = decompose_2d(16, 16, 2, 2, 1).expect("2x2");
1165 let tl = &tiles[0];
1167 assert_eq!(tl.neighbor_rank(1, 0), Some(2));
1168 assert_eq!(tl.neighbor_rank(0, 1), Some(1));
1169 assert_eq!(tl.neighbor_rank(-1, 0), None);
1170 assert_eq!(tl.neighbor_rank(0, -1), None);
1171 let br = &tiles[3];
1173 assert_eq!(br.neighbor_rank(-1, 0), Some(1));
1174 assert_eq!(br.neighbor_rank(0, -1), Some(2));
1175 assert_eq!(br.neighbor_rank(1, 0), None);
1176 assert_eq!(br.neighbor_rank(0, 1), None);
1177 }
1178
1179 #[test]
1180 fn test_extract_inject_roundtrip() {
1181 let global = sample_grid(24, 18);
1182 let tiles = decompose_2d(24, 18, 3, 2, 1).expect("decompose");
1183 let mut reconstructed = Array2::<f64>::zeros((24, 18));
1184 for t in &tiles {
1185 let local = extract_tile(&global, t).expect("extract");
1186 inject_tile(&mut reconstructed, &local, t).expect("inject");
1187 }
1188 let delta = l2_norm_delta(&global, &reconstructed).expect("delta");
1189 assert!(
1190 delta < 1e-12,
1191 "Extract→inject roundtrip must be lossless, got delta={delta}"
1192 );
1193 }
1194
1195 #[test]
1196 fn test_serial_halo_exchange_2d_correctness() {
1197 let global = Array2::from_shape_fn((16, 16), |(i, j)| (i as f64) * 100.0 + j as f64);
1200 let tiles = decompose_2d(16, 16, 2, 2, 1).expect("decompose");
1201 let mut locals: Vec<Array2<f64>> = tiles
1202 .iter()
1203 .map(|t| extract_tile(&global, t).expect("extract"))
1204 .collect();
1205 serial_halo_exchange_2d(&mut locals, &tiles).expect("halo exchange");
1206
1207 for (i, t) in tiles.iter().enumerate() {
1210 let loc = &locals[i];
1211 let z0 = t.z_start.saturating_sub(t.core_z_offset());
1212 let r0 = t.r_start.saturating_sub(t.core_r_offset());
1213 for lz in 0..loc.nrows() {
1214 for lr in 0..loc.ncols() {
1215 let gz = z0 + lz;
1216 let gr = r0 + lr;
1217 if gz < 16 && gr < 16 {
1218 let expect = global[[gz, gr]];
1219 let got = loc[[lz, lr]];
1220 assert!(
1221 (expect - got).abs() < 1e-12,
1222 "Tile {i} at local ({lz},{lr}) global ({gz},{gr}): expected {expect}, got {got}"
1223 );
1224 }
1225 }
1226 }
1227 }
1228 }
1229
1230 #[test]
1231 fn test_optimal_process_grid_square() {
1232 let (pz, pr) = optimal_process_grid(64, 64, 4);
1233 assert_eq!(pz, 2);
1234 assert_eq!(pr, 2);
1235 }
1236
1237 #[test]
1238 fn test_optimal_process_grid_rectangular() {
1239 let (pz, pr) = optimal_process_grid(128, 32, 8);
1241 assert!(
1242 pz >= pr,
1243 "Tall grid should bias toward Z decomposition: pz={pz}, pr={pr}"
1244 );
1245 assert_eq!(pz * pr, 8);
1246 }
1247
1248 #[test]
1249 fn test_gs_residual_l2_zero_for_exact_solution() {
1250 let nz = 16;
1254 let nr = 16;
1255 let dr = 0.01;
1256 let dz = 0.01;
1257 let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1258 let _z_axis: Vec<f64> = (0..nz).map(|i| -0.08 + i as f64 * dz).collect();
1259 let psi = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1260 _z_axis[iz] * _z_axis[iz] + r_axis[ir] * r_axis[ir]
1261 });
1262 let dr_sq = dr * dr;
1263 let dz_sq = dz * dz;
1264 let mut source = Array2::zeros((nz, nr));
1265 for iz in 1..nz - 1 {
1266 for ir in 1..nr - 1 {
1267 let r = r_axis[ir];
1268 let c_r_plus = 1.0 / dr_sq - 1.0 / (2.0 * r * dr);
1269 let c_r_minus = 1.0 / dr_sq + 1.0 / (2.0 * r * dr);
1270 let c_z = 1.0 / dz_sq;
1271 let center = 2.0 / dr_sq + 2.0 / dz_sq;
1272 source[[iz, ir]] = center * psi[[iz, ir]]
1273 - c_z * (psi[[iz - 1, ir]] + psi[[iz + 1, ir]])
1274 - c_r_plus * psi[[iz, ir + 1]]
1275 - c_r_minus * psi[[iz, ir - 1]];
1276 }
1277 }
1278 let res = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1279 assert!(
1280 res < 1e-10,
1281 "Residual should be ~0 for manufactured solution, got {res}"
1282 );
1283 }
1284
1285 #[test]
1286 fn test_distributed_gs_solve_smoke() {
1287 let nz = 32;
1289 let nr = 32;
1290 let dr = 0.01;
1291 let dz = 0.01;
1292 let r_axis: Vec<f64> = (0..nr).map(|i| 1.0 + i as f64 * dr).collect();
1293 let z_axis: Vec<f64> = (0..nz).map(|i| -0.16 + i as f64 * dz).collect();
1294
1295 let psi = Array2::zeros((nz, nr));
1297 let r_mid = 1.0 + (nr as f64 / 2.0) * dr;
1298 let z_mid = 0.0;
1299 let source = Array2::from_shape_fn((nz, nr), |(iz, ir)| {
1300 let rr = r_axis[ir] - r_mid;
1301 let zz = z_axis[iz] - z_mid;
1302 -((-(rr * rr + zz * zz) / (0.05 * 0.05)).exp())
1303 });
1304
1305 let res_initial = gs_residual_l2(&psi, &source, &r_axis, dr, dz);
1307
1308 let cfg = DistributedSolverConfig {
1309 pz: 2,
1310 pr: 2,
1311 halo: 1,
1312 omega: 1.6,
1313 max_outer_iters: 200,
1314 tol: 1e-6,
1315 inner_sweeps: 10,
1316 };
1317 let result = distributed_gs_solve(&psi, &source, &r_axis, &z_axis, dr, dz, &cfg)
1318 .expect("distributed solve");
1319 assert!(
1321 result.residual < res_initial * 0.5,
1322 "Residual should decrease: initial={res_initial:.4}, final={:.4}",
1323 result.residual
1324 );
1325 let psi_absmax = result.psi.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
1327 assert!(
1328 psi_absmax > 1e-10,
1329 "Solution should be non-trivial, max |Ψ| = {psi_absmax}"
1330 );
1331 }
1332
1333 #[test]
1334 fn test_decompose_2d_rejects_invalid_inputs() {
1335 assert!(decompose_2d(16, 16, 0, 2, 1).is_err());
1336 assert!(decompose_2d(16, 16, 2, 0, 1).is_err());
1337 assert!(decompose_2d(1, 16, 2, 2, 1).is_err());
1338 assert!(decompose_2d(16, 16, 20, 2, 1).is_err());
1339 }
1340}