1use fusion_types::error::{FusionError, FusionResult};
9use fusion_types::state::Grid2D;
10use ndarray::Array2;
11
12const DEFAULT_BETA_MIX: f64 = 0.5;
15
16const MIN_FLUX_DENOMINATOR: f64 = 1e-9;
19
20const MIN_CURRENT_INTEGRAL: f64 = 1e-9;
23
24#[derive(Debug, Clone, Copy, PartialEq)]
26pub struct ProfileParams {
27 pub ped_top: f64,
28 pub ped_width: f64,
29 pub ped_height: f64,
30 pub core_alpha: f64,
31}
32
33impl Default for ProfileParams {
34 fn default() -> Self {
35 Self {
36 ped_top: 0.9,
37 ped_width: 0.08,
38 ped_height: 1.0,
39 core_alpha: 0.2,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy)]
46pub struct SourceProfileContext<'a> {
47 pub psi: &'a Array2<f64>,
48 pub grid: &'a Grid2D,
49 pub psi_axis: f64,
50 pub psi_boundary: f64,
51 pub mu0: f64,
52 pub i_target: f64,
53}
54
55fn validate_source_inputs(
56 psi: &Array2<f64>,
57 grid: &Grid2D,
58 psi_axis: f64,
59 psi_boundary: f64,
60 mu0: f64,
61 i_target: f64,
62) -> FusionResult<()> {
63 if grid.nz == 0 || grid.nr == 0 {
64 return Err(FusionError::ConfigError(
65 "source update grid requires nz,nr >= 1".to_string(),
66 ));
67 }
68 if !grid.dr.is_finite()
69 || !grid.dz.is_finite()
70 || grid.dr.abs() <= f64::EPSILON
71 || grid.dz.abs() <= f64::EPSILON
72 {
73 return Err(FusionError::ConfigError(format!(
74 "source update requires finite non-zero grid spacing, got dr={} dz={}",
75 grid.dr, grid.dz
76 )));
77 }
78 if psi.nrows() != grid.nz || psi.ncols() != grid.nr {
79 return Err(FusionError::ConfigError(format!(
80 "source update psi shape mismatch: expected ({}, {}), got ({}, {})",
81 grid.nz,
82 grid.nr,
83 psi.nrows(),
84 psi.ncols()
85 )));
86 }
87 if grid.rr.nrows() != grid.nz || grid.rr.ncols() != grid.nr {
88 return Err(FusionError::ConfigError(format!(
89 "source update grid.rr shape mismatch: expected ({}, {}), got ({}, {})",
90 grid.nz,
91 grid.nr,
92 grid.rr.nrows(),
93 grid.rr.ncols()
94 )));
95 }
96 if psi.iter().any(|v| !v.is_finite()) || grid.rr.iter().any(|v| !v.is_finite()) {
97 return Err(FusionError::ConfigError(
98 "source update inputs must be finite".to_string(),
99 ));
100 }
101 if !psi_axis.is_finite() || !psi_boundary.is_finite() {
102 return Err(FusionError::ConfigError(
103 "source update psi_axis/psi_boundary must be finite".to_string(),
104 ));
105 }
106 if !mu0.is_finite() || mu0 <= 0.0 {
107 return Err(FusionError::ConfigError(format!(
108 "source update requires finite mu0 > 0, got {mu0}"
109 )));
110 }
111 if !i_target.is_finite() {
112 return Err(FusionError::ConfigError(
113 "source update target current must be finite".to_string(),
114 ));
115 }
116 let denom = psi_boundary - psi_axis;
117 if !denom.is_finite() || denom.abs() < MIN_FLUX_DENOMINATOR {
118 return Err(FusionError::ConfigError(format!(
119 "source update flux denominator must satisfy |psi_boundary-psi_axis| >= {MIN_FLUX_DENOMINATOR}, got {}",
120 denom
121 )));
122 }
123 Ok(())
124}
125
126fn validate_profile_params(params: &ProfileParams, label: &str) -> FusionResult<()> {
127 if !params.ped_top.is_finite() || params.ped_top <= 0.0 {
128 return Err(FusionError::ConfigError(format!(
129 "{label}.ped_top must be finite and > 0, got {}",
130 params.ped_top
131 )));
132 }
133 if !params.ped_width.is_finite() || params.ped_width <= 0.0 {
134 return Err(FusionError::ConfigError(format!(
135 "{label}.ped_width must be finite and > 0, got {}",
136 params.ped_width
137 )));
138 }
139 if !params.ped_height.is_finite() || !params.core_alpha.is_finite() {
140 return Err(FusionError::ConfigError(format!(
141 "{label}.ped_height/core_alpha must be finite"
142 )));
143 }
144 Ok(())
145}
146
147pub fn mtanh_profile(psi_norm: f64, params: &ProfileParams) -> f64 {
152 let w = params.ped_width.abs().max(1e-8);
153 let ped_top = params.ped_top.abs().max(1e-8);
154 let y = (params.ped_top - psi_norm) / w;
155 let tanh_y = y.tanh();
156 let core = (1.0 - (psi_norm / ped_top).powi(2)).max(0.0);
157 0.5 * params.ped_height * (1.0 + tanh_y) + params.core_alpha * core
158}
159
160pub fn mtanh_profile_derivatives(psi_norm: f64, params: &ProfileParams) -> [f64; 4] {
163 let w = params.ped_width.abs().max(1e-8);
164 let ped_top = params.ped_top.abs().max(1e-8);
165 let y = (params.ped_top - psi_norm) / w;
166 let tanh_y = y.tanh();
167 let sech2 = 1.0 - tanh_y * tanh_y;
168 let core = (1.0 - (psi_norm / ped_top).powi(2)).max(0.0);
169
170 let d_core_d_ped_top = if psi_norm.abs() < ped_top {
171 2.0 * psi_norm.powi(2) / ped_top.powi(3)
172 } else {
173 0.0
174 };
175
176 let d_ped_height = 0.5 * (1.0 + tanh_y);
177 let d_ped_top = 0.5 * params.ped_height * sech2 / w + params.core_alpha * d_core_d_ped_top;
178 let d_ped_width = -0.5 * params.ped_height * sech2 * y / w;
179 let d_core_alpha = core;
180
181 [d_ped_height, d_ped_top, d_ped_width, d_core_alpha]
182}
183
184pub fn update_plasma_source_nonlinear(
196 psi: &Array2<f64>,
197 grid: &Grid2D,
198 psi_axis: f64,
199 psi_boundary: f64,
200 mu0: f64,
201 i_target: f64,
202) -> FusionResult<Array2<f64>> {
203 validate_source_inputs(psi, grid, psi_axis, psi_boundary, mu0, i_target)?;
204 let nz = grid.nz;
205 let nr = grid.nr;
206
207 let denom = psi_boundary - psi_axis;
209
210 let mut j_phi = Array2::zeros((nz, nr));
211
212 for iz in 0..nz {
213 for ir in 0..nr {
214 let psi_norm = (psi[[iz, ir]] - psi_axis) / denom;
215 if !psi_norm.is_finite() {
216 return Err(FusionError::ConfigError(format!(
217 "source update produced non-finite psi_norm at ({iz}, {ir})"
218 )));
219 }
220
221 if (0.0..1.0).contains(&psi_norm) {
223 let profile = 1.0 - psi_norm;
224
225 let r = grid.rr[[iz, ir]];
226 if r <= 0.0 {
227 return Err(FusionError::ConfigError(format!(
228 "source update requires R > 0 inside plasma at ({iz}, {ir}), got {r}"
229 )));
230 }
231
232 let j_p = r * profile;
234
235 let j_f = (1.0 / (mu0 * r)) * profile;
237 if !j_p.is_finite() || !j_f.is_finite() {
238 return Err(FusionError::ConfigError(format!(
239 "source update produced non-finite current components at ({iz}, {ir})"
240 )));
241 }
242
243 j_phi[[iz, ir]] = DEFAULT_BETA_MIX * j_p + (1.0 - DEFAULT_BETA_MIX) * j_f;
245 }
246 }
247 }
248
249 let i_current: f64 = j_phi.iter().sum::<f64>() * grid.dr * grid.dz;
251 if !i_current.is_finite() {
252 return Err(FusionError::ConfigError(
253 "source update current integral became non-finite".to_string(),
254 ));
255 }
256
257 if i_current.abs() > MIN_CURRENT_INTEGRAL {
258 let scale = i_target / i_current;
259 if !scale.is_finite() {
260 return Err(FusionError::ConfigError(
261 "source update renormalization scale became non-finite".to_string(),
262 ));
263 }
264 j_phi.mapv_inplace(|v| v * scale);
265 } else {
266 j_phi.fill(0.0);
267 }
268
269 if j_phi.iter().any(|v| !v.is_finite()) {
270 return Err(FusionError::ConfigError(
271 "source update output contains non-finite values".to_string(),
272 ));
273 }
274 Ok(j_phi)
275}
276
277pub fn update_plasma_source_with_profiles(
282 ctx: SourceProfileContext<'_>,
283 params_p: &ProfileParams,
284 params_ff: &ProfileParams,
285) -> FusionResult<Array2<f64>> {
286 let SourceProfileContext {
287 psi,
288 grid,
289 psi_axis,
290 psi_boundary,
291 mu0,
292 i_target,
293 } = ctx;
294 validate_source_inputs(psi, grid, psi_axis, psi_boundary, mu0, i_target)?;
295 validate_profile_params(params_p, "params_p")?;
296 validate_profile_params(params_ff, "params_ff")?;
297 let nz = grid.nz;
298 let nr = grid.nr;
299
300 let denom = psi_boundary - psi_axis;
301
302 let mut j_phi = Array2::zeros((nz, nr));
303 for iz in 0..nz {
304 for ir in 0..nr {
305 let psi_norm = (psi[[iz, ir]] - psi_axis) / denom;
306 if !psi_norm.is_finite() {
307 return Err(FusionError::ConfigError(format!(
308 "profile source update produced non-finite psi_norm at ({iz}, {ir})"
309 )));
310 }
311 if (0.0..1.0).contains(&psi_norm) {
312 let r = grid.rr[[iz, ir]];
313 if r <= 0.0 {
314 return Err(FusionError::ConfigError(format!(
315 "profile source update requires R > 0 inside plasma at ({iz}, {ir}), got {r}"
316 )));
317 }
318 let p_profile = mtanh_profile(psi_norm, params_p);
319 let ff_profile = mtanh_profile(psi_norm, params_ff);
320 if !p_profile.is_finite() || !ff_profile.is_finite() {
321 return Err(FusionError::ConfigError(format!(
322 "profile source update produced non-finite profile values at ({iz}, {ir})"
323 )));
324 }
325
326 let j_p = r * p_profile;
327 let j_f = (1.0 / (mu0 * r)) * ff_profile;
328 if !j_p.is_finite() || !j_f.is_finite() {
329 return Err(FusionError::ConfigError(format!(
330 "profile source update produced non-finite current components at ({iz}, {ir})"
331 )));
332 }
333 j_phi[[iz, ir]] = DEFAULT_BETA_MIX * j_p + (1.0 - DEFAULT_BETA_MIX) * j_f;
334 }
335 }
336 }
337
338 let i_current: f64 = j_phi.iter().sum::<f64>() * grid.dr * grid.dz;
339 if !i_current.is_finite() {
340 return Err(FusionError::ConfigError(
341 "profile source update current integral became non-finite".to_string(),
342 ));
343 }
344 if i_current.abs() > MIN_CURRENT_INTEGRAL {
345 let scale = i_target / i_current;
346 if !scale.is_finite() {
347 return Err(FusionError::ConfigError(
348 "profile source update renormalization scale became non-finite".to_string(),
349 ));
350 }
351 j_phi.mapv_inplace(|v| v * scale);
352 } else {
353 j_phi.fill(0.0);
354 }
355 if j_phi.iter().any(|v| !v.is_finite()) {
356 return Err(FusionError::ConfigError(
357 "profile source update output contains non-finite values".to_string(),
358 ));
359 }
360 Ok(j_phi)
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_source_zero_outside_plasma() {
369 let grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
370 let psi = Array2::zeros((16, 16));
373 let j = update_plasma_source_nonlinear(&psi, &grid, 1.0, 0.0, 1.0, 1e6)
374 .expect("valid source-update inputs");
375
376 let max_j = j.iter().cloned().fold(0.0_f64, |a, b| a.max(b.abs()));
378 assert!(max_j < 1e-15, "Should be zero outside plasma: {max_j}");
379 }
380
381 #[test]
382 fn test_source_renormalization() {
383 let grid = Grid2D::new(33, 33, 1.0, 9.0, -5.0, 5.0);
384 let psi = Array2::from_shape_fn((33, 33), |(iz, ir)| {
386 let r = grid.rr[[iz, ir]];
387 let z = grid.zz[[iz, ir]];
388 (-(((r - 5.0).powi(2) + z.powi(2)) / 4.0)).exp()
389 });
390
391 let psi_axis = 1.0; let psi_boundary = 0.0; let i_target = 15e6; let j = update_plasma_source_nonlinear(&psi, &grid, psi_axis, psi_boundary, 1.0, i_target)
396 .expect("valid source-update inputs");
397
398 let i_actual: f64 = j.iter().sum::<f64>() * grid.dr * grid.dz;
400 let rel_error = ((i_actual - i_target) / i_target).abs();
401 assert!(
402 rel_error < 1e-10,
403 "Current integral {i_actual} should match target {i_target}"
404 );
405 }
406
407 #[test]
408 fn test_mtanh_derivatives_match_finite_difference() {
409 let params = ProfileParams {
410 ped_top: 0.92,
411 ped_width: 0.07,
412 ped_height: 1.2,
413 core_alpha: 0.3,
414 };
415 let psi = 0.35;
416 let analytic = mtanh_profile_derivatives(psi, ¶ms);
417 let eps = 1e-6;
418
419 let mut p = params;
420 p.ped_height += eps;
421 let fd_h = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
422
423 p = params;
424 p.ped_top += eps;
425 let fd_top = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
426
427 p = params;
428 p.ped_width += eps;
429 let fd_w = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
430
431 p = params;
432 p.core_alpha += eps;
433 let fd_a = (mtanh_profile(psi, &p) - mtanh_profile(psi, ¶ms)) / eps;
434
435 let fd = [fd_h, fd_top, fd_w, fd_a];
436 for i in 0..4 {
437 let denom = fd[i].abs().max(1e-8);
438 let rel = (analytic[i] - fd[i]).abs() / denom;
439 assert!(
440 rel < 1e-3,
441 "Derivative mismatch at index {i}: analytic={}, fd={}, rel={}",
442 analytic[i],
443 fd[i],
444 rel
445 );
446 }
447 }
448
449 #[test]
450 fn test_source_with_profiles_finite() {
451 let grid = Grid2D::new(33, 33, 1.0, 9.0, -5.0, 5.0);
452 let psi = Array2::from_shape_fn((33, 33), |(iz, ir)| {
453 let r = grid.rr[[iz, ir]];
454 let z = grid.zz[[iz, ir]];
455 (-(((r - 5.0).powi(2) + z.powi(2)) / 4.0)).exp()
456 });
457
458 let params_p = ProfileParams {
459 ped_top: 0.9,
460 ped_width: 0.08,
461 ped_height: 1.1,
462 core_alpha: 0.25,
463 };
464 let params_ff = ProfileParams {
465 ped_top: 0.85,
466 ped_width: 0.06,
467 ped_height: 0.95,
468 core_alpha: 0.1,
469 };
470
471 let j = update_plasma_source_with_profiles(
472 SourceProfileContext {
473 psi: &psi,
474 grid: &grid,
475 psi_axis: 1.0,
476 psi_boundary: 0.0,
477 mu0: 1.0,
478 i_target: 15e6,
479 },
480 ¶ms_p,
481 ¶ms_ff,
482 )
483 .expect("valid profile-source-update inputs");
484 assert!(
485 j.iter().all(|v| v.is_finite()),
486 "Profile source contains non-finite values"
487 );
488 let i_actual: f64 = j.iter().sum::<f64>() * grid.dr * grid.dz;
489 let rel_error = ((i_actual - 15e6) / 15e6).abs();
490 assert!(rel_error < 1e-10, "Current mismatch after renormalization");
491 }
492
493 #[test]
494 fn test_source_rejects_invalid_runtime_inputs() {
495 let mut grid = Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0);
496 let psi = Array2::zeros((16, 16));
497
498 let err = update_plasma_source_nonlinear(&psi, &grid, 1.0, 1.0, 1.0, 1.0)
499 .expect_err("degenerate flux normalization must fail");
500 assert!(matches!(err, FusionError::ConfigError(_)));
501
502 grid.rr[[3, 3]] = 0.0;
503 let psi_inside = Array2::from_elem((16, 16), 0.5);
504 let err = update_plasma_source_nonlinear(&psi_inside, &grid, 1.0, 0.0, 1.0, 1.0)
505 .expect_err("non-positive radius inside plasma must fail");
506 assert!(matches!(err, FusionError::ConfigError(_)));
507
508 let params_bad = ProfileParams {
509 ped_top: 0.9,
510 ped_width: 0.0,
511 ped_height: 1.0,
512 core_alpha: 0.2,
513 };
514 let err = update_plasma_source_with_profiles(
515 SourceProfileContext {
516 psi: &Array2::from_elem((16, 16), 0.5),
517 grid: &Grid2D::new(16, 16, 1.0, 9.0, -5.0, 5.0),
518 psi_axis: 1.0,
519 psi_boundary: 0.0,
520 mu0: 1.0,
521 i_target: 1.0,
522 },
523 ¶ms_bad,
524 &ProfileParams::default(),
525 )
526 .expect_err("invalid profile params must fail");
527 assert!(matches!(err, FusionError::ConfigError(_)));
528 }
529}