Skip to content

Commit

Permalink
Add the ability to feed through partial calculated source profiles to…
Browse files Browse the repository at this point in the history
… other sources

These are not yet used, that will happen in follow-up CLs and reduce the recalculation of sources

PiperOrigin-RevId: 721784558
  • Loading branch information
tamaranorman authored and Torax team committed Feb 3, 2025
1 parent 88a3943 commit 05cdaee
Show file tree
Hide file tree
Showing 40 changed files with 732 additions and 730 deletions.
4 changes: 2 additions & 2 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _prescribe_currents_with_bootstrap(
# notational conventions rather than on Google Python style
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

bootstrap_profile = source_models.j_bootstrap.get_value(
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -481,7 +481,7 @@ def _calculate_currents_from_psi(
core_profiles.psi,
)

bootstrap_profile = source_models.j_bootstrap.get_value(
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down
40 changes: 8 additions & 32 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_operations
from torax.sources import source_profile_builders
Expand Down Expand Up @@ -78,7 +77,7 @@ def __call__(
# Checks if reduced calc_coeffs for explicit terms when theta_imp=1
# should be called
explicit_call: bool = False,
):
) -> block_1d_coeffs.Block1DCoeffs:
# Update core_profiles with the subset of new values of evolving variables
replace = {k: v for k, v in zip(self.evolving_names, x)}
core_profiles = config_args.recursive_replace(core_profiles, **replace)
Expand Down Expand Up @@ -397,18 +396,10 @@ def _calc_coeffs_full(
else:
j_bootstrap = implicit_source_profiles.j_bootstrap

external_current = jnp.zeros_like(geo.rho)
# Sum over all psi sources (except the bootstrap current).
for source_name, source in source_models.psi_sources.items():
if static_runtime_params_slice.sources[source_name].is_explicit:
profiles = explicit_source_profiles.profiles
else:
profiles = implicit_source_profiles.profiles
external_current += source.get_source_profile_for_affected_core_profile(
profile=profiles[source_name],
affected_core_profile=source_lib.AffectedCoreProfile.PSI.value,
geo=geo,
)
external_current = sum(explicit_source_profiles.psi.values()) + sum(
implicit_source_profiles.psi.values()
)

currents = dataclasses.replace(
core_profiles.currents,
Expand All @@ -430,14 +421,8 @@ def _calc_coeffs_full(

# fill source vector based on both original and updated core profiles
source_psi = source_operations.sum_sources_psi(
geo,
implicit_source_profiles,
source_models,
) + source_operations.sum_sources_psi(
geo,
explicit_source_profiles,
source_models,
)
geo, implicit_source_profiles
) + source_operations.sum_sources_psi(geo, explicit_source_profiles)

true_ne = core_profiles.ne.value * dynamic_runtime_params_slice.numerics.nref
true_ni = core_profiles.ni.value * dynamic_runtime_params_slice.numerics.nref
Expand Down Expand Up @@ -627,11 +612,9 @@ def _calc_coeffs_full(
source_ne = source_operations.sum_sources_ne(
geo,
explicit_source_profiles,
source_models,
) + source_operations.sum_sources_ne(
geo,
implicit_source_profiles,
source_models,
)

source_ne += jnp.where(
Expand Down Expand Up @@ -717,34 +700,27 @@ def _calc_coeffs_full(

# Fill heat transport equation sources. Initialize source matrices to zero

source_mat_ii = jnp.zeros_like(geo.rho)
source_mat_ee = jnp.zeros_like(geo.rho)

source_i = source_operations.sum_sources_temp_ion(
geo,
explicit_source_profiles,
source_models,
) + source_operations.sum_sources_temp_ion(
geo,
implicit_source_profiles,
source_models,
)

source_e = source_operations.sum_sources_temp_el(
geo,
explicit_source_profiles,
source_models,
) + source_operations.sum_sources_temp_el(
geo,
implicit_source_profiles,
source_models,
)

# Add the Qei effects.
qei = implicit_source_profiles.qei
source_mat_ii += qei.implicit_ii * geo.vpr
source_mat_ii = qei.implicit_ii * geo.vpr
source_i += qei.explicit_i * geo.vpr
source_mat_ee += qei.implicit_ee * geo.vpr
source_mat_ee = qei.implicit_ee * geo.vpr
source_e += qei.explicit_e * geo.vpr
source_mat_ie = qei.implicit_ie * geo.vpr
source_mat_ei = qei.implicit_ei * geo.vpr
Expand Down
1 change: 0 additions & 1 deletion torax/fvm/discrete_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback


def calc_c(
Expand Down
2 changes: 1 addition & 1 deletion torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def _update_current_distribution(
) -> state.CoreProfiles:
"""Update bootstrap current based on the new core_profiles."""

bootstrap_profile = source_models.j_bootstrap.get_value(
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
Expand Down
38 changes: 21 additions & 17 deletions torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,26 +334,30 @@ def _save_core_sources(
)

# Add source profiles
for profile in self.core_sources.profiles:
if profile in self.source_models.ion_el_sources:
xr_dict[f"{profile}_ion"] = self.core_sources.profiles[profile][
:, 0, ...
]
xr_dict[f"{profile}_el"] = self.core_sources.profiles[profile][
:, 1, ...
]
# TODO(b/376010694): better automation of splitting profiles into
# separate variables.
elif profile == "electron_cyclotron_source":
xr_dict[f"{profile}_el"] = self.core_sources.profiles[profile][
:, 0, ...
]
xr_dict[f"{profile}_j"] = self.core_sources.profiles[profile][:, 1, ...]
# TODO(b/381543891): Simplify this to always add a suffix to remove the
# need for special cases.
# Current complexity is due to the fact that we want to keep the same
# variable names as the previous version of TORAX but this will be changed
# in the future.
for profile in self.core_sources.temp_ion:
xr_dict[f"{profile}_ion"] = self.core_sources.temp_ion[profile]
for profile in self.core_sources.temp_el:
if profile == "electron_cyclotron_source":
xr_dict[f"{profile}_el"] = self.core_sources.temp_el[profile]
elif profile in self.core_sources.temp_ion:
xr_dict[f"{profile}_el"] = self.core_sources.temp_el[profile]
else:
xr_dict[profile] = self.core_sources.profiles[profile]
xr_dict[profile] = self.core_sources.temp_el[profile]
for profile in self.core_sources.psi:
if profile == "electron_cyclotron_source":
xr_dict[f"{profile}_j"] = self.core_sources.psi[profile]
else:
xr_dict[profile] = self.core_sources.psi[profile]
for profile in self.core_sources.ne:
xr_dict[profile] = self.core_sources.ne[profile]

xr_dict = {
name: self._pack_into_data_array(name, data,)
name: self._pack_into_data_array(name, data)
for name, data in xr_dict.items()
}

Expand Down
30 changes: 10 additions & 20 deletions torax/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,10 @@ def _calculate_integrated_sources(
# TORAX internal names.
for key, value in ION_EL_HEAT_SOURCE_TRANSFORMATIONS.items():
# Only populate integrated dict with sources that exist.
if key in core_sources.profiles:
profile_ion, profile_el = core_sources.profiles[key]
ion_profiles = core_sources.temp_ion
el_profiles = core_sources.temp_el
if key in ion_profiles and key in el_profiles:
profile_ion, profile_el = ion_profiles[key], el_profiles[key]
integrated[f'{value}_ion'] = math_utils.cell_integration(
profile_ion * geo.vpr, geo
)
Expand All @@ -274,33 +276,21 @@ def _calculate_integrated_sources(

for key, value in EL_HEAT_SOURCE_TRANSFORMATIONS.items():
# Only populate integrated dict with sources that exist.
if key in core_sources.profiles:
# TODO(b/376010694): better automation of splitting profiles into
# separate variables.
# index 0 corresponds to the electron heating source profile.
if key == 'electron_cyclotron_source':
profile = core_sources.profiles[key][0, :]
else:
profile = core_sources.profiles[key]
profiles = core_sources.temp_el
if key in profiles:
integrated[f'{value}'] = math_utils.cell_integration(
profile * geo.vpr, geo
profiles[key] * geo.vpr, geo
)
integrated['P_sol_el'] += integrated[f'{value}']
if key in EXTERNAL_HEATING_SOURCES:
integrated['P_external_el'] += integrated[f'{value}']

for key, value in CURRENT_SOURCE_TRANSFORMATIONS.items():
# Only populate integrated dict with sources that exist.
if key in core_sources.profiles:
# TODO(b/376010694): better automation of splitting profiles into
# separate variables.
# index 1 corresponds to the current source profile.
if key == 'electron_cyclotron_source':
profile = core_sources.profiles[key][1, :]
else:
profile = core_sources.profiles[key]
profiles = core_sources.psi
if key in profiles:
integrated[f'{value}'] = math_utils.cell_integration(
profile * geo.spr, geo
profiles[key] * geo.spr, geo
)

integrated['P_sol_tot'] = integrated['P_sol_ion'] + integrated['P_sol_el']
Expand Down
Loading

0 comments on commit 05cdaee

Please sign in to comment.