1use fusion_types::error::{FusionError, FusionResult};
14use fusion_types::state::Grid2D;
15use ndarray::Array2;
16
17const DEFAULT_BETA_MIX: f64 = 0.5;
20
21const MIN_FLUX_DENOMINATOR: f64 = 1e-9;
24
25const MIN_CURRENT_INTEGRAL: f64 = 1e-9;
28
29#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct ProfileParams {
32 pub ped_top: f64,
33 pub ped_width: f64,
34 pub ped_height: f64,
35 pub core_alpha: f64,
36}
37
38impl Default for ProfileParams {
39 fn default() -> Self {
40 Self {
41 ped_top: 0.9,
42 ped_width: 0.08,
43 ped_height: 1.0,
44 core_alpha: 0.2,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy)]
51pub struct SourceProfileContext<'a> {
52 pub psi: &'a Array2<f64>,
53 pub grid: &'a Grid2D,
54 pub psi_axis: f64,
55 pub psi_boundary: f64,
56 pub mu0: f64,
57 pub i_target: f64,
58}
59
60fn validate_source_inputs(
61 psi: &Array2<f64>,
62 grid: &Grid2D,
63 psi_axis: f64,
64 psi_boundary: f64,
65 mu0: f64,
66 i_target: f64,
67) -> FusionResult<()> {
68 if grid.nz == 0 || grid.nr == 0 {
69 return Err(FusionError::ConfigError(
70 "source update grid requires nz,nr >= 1".to_string(),
71 ));
72 }
73 if !grid.dr.is_finite()
74 || !grid.dz.is_finite()
75 || grid.dr.abs() <= f64::EPSILON
76 || grid.dz.abs() <= f64::EPSILON
77 {
78 return Err(FusionError::ConfigError(format!(
79 "source update requires finite non-zero grid spacing, got dr={} dz={}",
80 grid.dr, grid.dz
81 )));
82 }
83 if psi.nrows() != grid.nz || psi.ncols() != grid.nr {
84 return Err(FusionError::ConfigError(format!(
85 "source update psi shape mismatch: expected ({}, {}), got ({}, {})",
86 grid.nz,
87 grid.nr,
88 psi.nrows(),
89 psi.ncols()
90 )));
91 }
92 if grid.rr.nrows() != grid.nz || grid.rr.ncols() != grid.nr {
93 return Err(FusionError::ConfigError(format!(
94 "source update grid.rr shape mismatch: expected ({}, {}), got ({}, {})",
95 grid.nz,
96 grid.nr,
97 grid.rr.nrows(),
98 grid.rr.ncols()
99 )));
100 }
101 if psi.iter().any(|v| !v.is_finite()) || grid.rr.iter().any(|v| !v.is_finite()) {
102 return Err(FusionError::ConfigError(
103 "source update inputs must be finite".to_string(),
104 ));
105 }
106 if !psi_axis.is_finite() || !psi_boundary.is_finite() {
107 return Err(FusionError::ConfigError(
108 "source update psi_axis/psi_boundary must be finite".to_string(),
109 ));
110 }
111 if !mu0.is_finite() || mu0 <= 0.0 {
112 return Err(FusionError::ConfigError(format!(
113 "source update requires finite mu0 > 0, got {mu0}"
114 )));
115 }
116 if !i_target.is_finite() {
117 return Err(FusionError::ConfigError(
118 "source update target current must be finite".to_string(),
119 ));
120 }
121 let denom = psi_boundary - psi_axis;
122 if !denom.is_finite() || denom.abs() < MIN_FLUX_DENOMINATOR {
123 return Err(FusionError::ConfigError(format!(
124 "source update flux denominator must satisfy |psi_boundary-psi_axis| >= {MIN_FLUX_DENOMINATOR}, got {}",
125 denom
126 )));
127 }
128 Ok(())
129}
130
131fn validate_profile_params(params: &ProfileParams, label: &str) -> FusionResult<()> {
132 if !params.ped_top.is_finite() || params.ped_top <= 0.0 {
133 return Err(FusionError::ConfigError(format!(
134 "{label}.ped_top must be finite and > 0, got {}",
135 params.ped_top
136 )));
137 }
138 if !params.ped_width.is_finite() || params.ped_width <= 0.0 {
139 return Err(FusionError::ConfigError(format!(
140 "{label}.ped_width must be finite and > 0, got {}",
141 params.ped_width
142 )));
143 }
144 if !params.ped_height.is_finite() || !params.core_alpha.is_finite() {
145 return Err(FusionError::ConfigError(format!(
146 "{label}.ped_height/core_alpha must be finite"
147 )));
148 }
149 Ok(())
150}
151
152pub fn mtanh_profile(psi_norm: f64, params: &ProfileParams) -> f64 {
157 let w = params.ped_width.abs().max(1e-8);
158 let ped_top = params.ped_top.abs().max(1e-8);
159 let y = (params.ped_top - psi_norm) / w;
160 let tanh_y = y.tanh();
161 let core = (1.0 - (psi_norm / ped_top).powi(2)).max(0.0);
162 0.5 * params.ped_height * (1.0 + tanh_y) + params.core_alpha * core
163}
164
165pub fn mtanh_profile_derivatives(psi_norm: f64, params: &ProfileParams) -> [f64; 4] {
168 let w = params.ped_width.abs().max(1e-8);
169 let ped_top = params.ped_top.abs().max(1e-8);
170 let y = (params.ped_top - psi_norm) / w;
171 let tanh_y = y.tanh();
172 let sech2 = 1.0 - tanh_y * tanh_y;
173 let core = (1.0 - (psi_norm / ped_top).powi(2)).max(0.0);
174
175 let d_core_d_ped_top = if psi_norm.abs() < ped_top {
176 2.0 * psi_norm.powi(2) / ped_top.powi(3)
177 } else {
178 0.0
179 };
180
181 let d_ped_height = 0.5 * (1.0 + tanh_y);
182 let d_ped_top = 0.5 * params.ped_height * sech2 / w + params.core_alpha * d_core_d_ped_top;
183 let d_ped_width = -0.5 * params.ped_height * sech2 * y / w;
184 let d_core_alpha = core;
185
186 [d_ped_height, d_ped_top, d_ped_width, d_core_alpha]
187}
188
189pub fn update_plasma_source_nonlinear(
201 psi: &Array2<f64>,
202 grid: &Grid2D,
203 psi_axis: f64,
204 psi_boundary: f64,
205 mu0: f64,
206 i_target: f64,
207) -> FusionResult<Array2<f64>> {
208 validate_source_inputs(psi, grid, psi_axis, psi_boundary, mu0, i_target)?;
209 let nz = grid.nz;
210 let nr = grid.nr;
211
212 let denom = psi_boundary - psi_axis;
214
215 let mut j_phi = Array2::zeros((nz, nr));
216
217 for iz in 0..nz {
218 for ir in 0..nr {
219 let psi_norm = (psi[[iz, ir]] - psi_axis) / denom;
220 if !psi_norm.is_finite() {
221 return Err(FusionError::ConfigError(format!(
222 "source update produced non-finite psi_norm at ({iz}, {ir})"
223 )));
224 }
225
226 if (0.0..1.0).contains(&psi_norm) {
228 let profile = 1.0 - psi_norm;
229
230 let r = grid.rr[[iz, ir]];
231 if r <= 0.0 {
232 return Err(FusionError::ConfigError(format!(
233 "source update requires R > 0 inside plasma at ({iz}, {ir}), got {r}"
234 )));
235 }
236
237 let j_p = r * profile;
239
240 let j_f = (1.0 / (mu0 * r)) * profile;
242 if !j_p.is_finite() || !j_f.is_finite() {
243 return Err(FusionError::ConfigError(format!(
244 "source update produced non-finite current components at ({iz}, {ir})"
245 )));
246 }
247
248 j_phi[[iz, ir]] = DEFAULT_BETA_MIX * j_p + (1.0 - DEFAULT_BETA_MIX) * j_f;
250 }
251 }
252 }
253
254 let i_current: f64 = j_phi.iter().sum::<f64>() * grid.dr * grid.dz;
256 if !i_current.is_finite() {
257 return Err(FusionError::ConfigError(
258 "source update current integral became non-finite".to_string(),
259 ));
260 }
261
262 if i_current.abs() > MIN_CURRENT_INTEGRAL {
263 let scale = i_target / i_current;
264 if !scale.is_finite() {
265 return Err(FusionError::ConfigError(
266 "source update renormalization scale became non-finite".to_string(),
267 ));
268 }
269 j_phi.mapv_inplace(|v| v * scale);
270 } else {
271 j_phi.fill(0.0);
272 }
273
274 if j_phi.iter().any(|v| !v.is_finite()) {
275 return Err(FusionError::ConfigError(
276 "source update output contains non-finite values".to_string(),
277 ));
278 }
279 Ok(j_phi)
280}
281
282pub fn update_plasma_source_with_profiles(
287 ctx: SourceProfileContext<'_>,
288 params_p: &ProfileParams,
289 params_ff: &ProfileParams,
290) -> FusionResult<Array2<f64>> {
291 let SourceProfileContext {
292 psi,
293 grid,
294 psi_axis,
295 psi_boundary,
296 mu0,
297 i_target,
298 } = ctx;
299 validate_source_inputs(psi, grid, psi_axis, psi_boundary, mu0, i_target)?;
300 validate_profile_params(params_p, "params_p")?;
301 validate_profile_params(params_ff, "params_ff")?;
302 let nz = grid.nz;
303 let nr = grid.nr;
304
305 let denom = psi_boundary - psi_axis;
306
307 let mut j_phi = Array2::zeros((nz, nr));
308 for iz in 0..nz {
309 for ir in 0..nr {
310 let psi_norm = (psi[[iz, ir]] - psi_axis) / denom;
311 if !psi_norm.is_finite() {
312 return Err(FusionError::ConfigError(format!(
313 "profile source update produced non-finite psi_norm at ({iz}, {ir})"
314 )));
315 }
316 if (0.0..1.0).contains(&psi_norm) {
317 let r = grid.rr[[iz, ir]];
318 if r <= 0.0 {
319 return Err(FusionError::ConfigError(format!(
320 "profile source update requires R > 0 inside plasma at ({iz}, {ir}), got {r}"
321 )));
322 }
323 let p_profile = mtanh_profile(psi_norm, params_p);
324 let ff_profile = mtanh_profile(psi_norm, params_ff);
325 if !p_profile.is_finite() || !ff_profile.is_finite() {
326 return Err(FusionError::ConfigError(format!(
327 "profile source update produced non-finite profile values at ({iz}, {ir})"
328 )));
329 }
330
331 let j_p = r * p_profile;
332 let j_f = (1.0 / (mu0 * r)) * ff_profile;
333 if !j_p.is_finite() || !j_f.is_finite() {
334 return Err(FusionError::ConfigError(format!(
335 "profile source update produced non-finite current components at ({iz}, {ir})"
336 )));
337 }
338 j_phi[[iz, ir]] = DEFAULT_BETA_MIX * j_p + (1.0 - DEFAULT_BETA_MIX) * j_f;
339 }
340 }
341 }
342
343 let i_current: f64 = j_phi.iter().sum::<f64>() * grid.dr * grid.dz;
344 if !i_current.is_finite() {
345 return Err(FusionError::ConfigError(
346 "profile source update current integral became non-finite".to_string(),
347 ));
348 }
349 if i_current.abs() > MIN_CURRENT_INTEGRAL {
350 let scale = i_target / i_current;
351 if !scale.is_finite() {
352 return Err(FusionError::ConfigError(
353 "profile source update renormalization scale became non-finite".to_string(),
354 ));
355 }
356 j_phi.mapv_inplace(|v| v * scale);
357 } else {
358 j_phi.fill(0.0);
359 }
360 if j_phi.iter().any(|v| !v.is_finite()) {
361 return Err(FusionError::ConfigError(
362 "profile source update output contains non-finite values".to_string(),
363 ));
364 }
365 Ok(j_phi)
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_source_zero_outside_plasma() {
374 let grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
375 let psi = Array2::zeros((16, 16));
378 let j = update_plasma_source_nonlinear(&psi, &grid, 1.0, 0.0, 1.0, 1e6)
379 .expect("valid source-update inputs");
380
381 let max_j = j.iter().cloned().fold(0.0_f64, |a, b| a.max(b.abs()));
383 assert!(max_j < 1e-15, "Should be zero outside plasma: {max_j}");
384 }
385
386 #[test]
387 fn test_source_renormalization() {
388 let grid = Grid2D::new(33, 33, 1.0, 9.0, -5.0, 5.0);
389 let psi = Array2::from_shape_fn((33, 33), |(iz, ir)| {
391 let r = grid.rr[[iz, ir]];
392 let z = grid.zz[[iz, ir]];
393 (-(((r - 5.0).powi(2) + z.powi(2)) / 4.0)).exp()
394 });
395
396 let psi_axis = 1.0; let psi_boundary = 0.0; let i_target = 15e6; let j = update_plasma_source_nonlinear(&psi, &grid, psi_axis, psi_boundary, 1.0, i_target)
401 .expect("valid source-update inputs");
402
403 let i_actual: f64 = j.iter().sum::<f64>() * grid.dr * grid.dz;
405 let rel_error = ((i_actual - i_target) / i_target).abs();
406 assert!(
407 rel_error < 1e-10,
408 "Current integral {i_actual} should match target {i_target}"
409 );
410 }
411
412 #[test]
413 fn test_mtanh_derivatives_match_finite_difference() {
414 let params = ProfileParams {
415 ped_top: 0.92,
416 ped_width: 0.07,
417 ped_height: 1.2,
418 core_alpha: 0.3,
419 };
420 let psi = 0.35;
421 let analytic = mtanh_profile_derivatives(psi, ¶ms);
422 let eps = 1e-6;
423
424 let mut p = params;
425 p.ped_height += eps;
426 let fd_h = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
427
428 p = params;
429 p.ped_top += eps;
430 let fd_top = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
431
432 p = params;
433 p.ped_width += eps;
434 let fd_w = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
435
436 p = params;
437 p.core_alpha += eps;
438 let fd_a = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
439
440 let fd = [fd_h, fd_top, fd_w, fd_a];
441 for i in 0..4 {
442 let denom = fd[i].abs().max(1e-8);
443 let rel = (analytic[i] - fd[i]).abs() / denom;
444 assert!(
445 rel < 1e-3,
446 "Derivative mismatch at index {i}: analytic={}, fd={}, rel={}",
447 analytic[i],
448 fd[i],
449 rel
450 );
451 }
452 }
453
454 #[test]
455 fn test_source_with_profiles_finite() {
456 let grid = Grid2D::new(33, 33, 1.0, 9.0, -5.0, 5.0);
457 let psi = Array2::from_shape_fn((33, 33), |(iz, ir)| {
458 let r = grid.rr[[iz, ir]];
459 let z = grid.zz[[iz, ir]];
460 (-(((r - 5.0).powi(2) + z.powi(2)) / 4.0)).exp()
461 });
462
463 let params_p = ProfileParams {
464 ped_top: 0.9,
465 ped_width: 0.08,
466 ped_height: 1.1,
467 core_alpha: 0.25,
468 };
469 let params_ff = ProfileParams {
470 ped_top: 0.85,
471 ped_width: 0.06,
472 ped_height: 0.95,
473 core_alpha: 0.1,
474 };
475
476 let j = update_plasma_source_with_profiles(
477 SourceProfileContext {
478 psi: &psi,
479 grid: &grid,
480 psi_axis: 1.0,
481 psi_boundary: 0.0,
482 mu0: 1.0,
483 i_target: 15e6,
484 },
485 ¶ms_p,
486 ¶ms_ff,
487 )
488 .expect("valid profile-source-update inputs");
489 assert!(
490 j.iter().all(|v| v.is_finite()),
491 "Profile source contains non-finite values"
492 );
493 let i_actual: f64 = j.iter().sum::<f64>() * grid.dr * grid.dz;
494 let rel_error = ((i_actual - 15e6) / 15e6).abs();
495 assert!(rel_error < 1e-10, "Current mismatch after renormalization");
496 }
497
498 #[test]
499 fn test_source_rejects_invalid_runtime_inputs() {
500 let mut grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
501 let psi = Array2::zeros((16, 16));
502
503 let err = update_plasma_source_nonlinear(&psi, &grid, 1.0, 1.0, 1.0, 1.0)
504 .expect_err("degenerate flux normalization must fail");
505 assert!(matches!(err, FusionError::ConfigError(_)));
506
507 grid.rr[[3, 3]] = 0.0;
508 let psi_inside = Array2::from_elem((16, 16), 0.5);
509 let err = update_plasma_source_nonlinear(&psi_inside, &grid, 1.0, 0.0, 1.0, 1.0)
510 .expect_err("non-positive radius inside plasma must fail");
511 assert!(matches!(err, FusionError::ConfigError(_)));
512
513 let params_bad = ProfileParams {
514 ped_top: 0.9,
515 ped_width: 0.0,
516 ped_height: 1.0,
517 core_alpha: 0.2,
518 };
519 let err = update_plasma_source_with_profiles(
520 SourceProfileContext {
521 psi: &Array2::from_elem((16, 16), 0.5),
522 grid: &Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0),
523 psi_axis: 1.0,
524 psi_boundary: 0.0,
525 mu0: 1.0,
526 i_target: 1.0,
527 },
528 ¶ms_bad,
529 &ProfileParams::default(),
530 )
531 .expect_err("invalid profile params must fail");
532 assert!(matches!(err, FusionError::ConfigError(_)));
533 }
534}