Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to feed through partial calculated source profiles to other sources #699

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 23 additions & 121 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,132 +282,19 @@ def get_ion_density_and_charge_states(
return ni, nimp, Zi, Zi_face, Zimp, Zimp_face


def _prescribe_currents_no_bootstrap(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels,
) -> state.Currents:
"""Creates the initial Currents without the bootstrap current.

Args:
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: General runtime parameters at t_initial.
geo: Geometry of the tokamak.
core_profiles: Core profiles.
source_models: All TORAX source/sink functions.

Returns:
currents: Initial Currents
"""
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style

# Calculate splitting of currents depending on input runtime params.
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

# Set zero bootstrap current
bootstrap_profile = source_profiles_lib.BootstrapCurrentProfile.zero_profile(
geo
)

# calculate "External" current profile (e.g. ECCD)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
Iext = (
math_utils.cell_integration(external_current * geo.spr, geo) / 10**6
)
# Total Ohmic current.
Iohm = Ip - Iext

# construct prescribed current formula on grid.
jformula = (
1 - geo.rho_norm**2
) ** dynamic_runtime_params_slice.profile_conditions.nu
# calculate total and Ohmic current profiles
denom = _trapz(jformula * geo.spr, geo.rho_norm)
if dynamic_runtime_params_slice.profile_conditions.initial_j_is_total_current:
Ctot = Ip * 1e6 / denom
jtot = jformula * Ctot
johm = jtot - external_current
else:
Cohm = Iohm * 1e6 / denom
johm = jformula * Cohm
jtot = johm + external_current

jtot_hires = _get_jtot_hires(
dynamic_runtime_params_slice,
geo,
bootstrap_profile,
Iohm,
external_current=external_current,
)

currents = state.Currents(
jtot=jtot,
jtot_face=math_utils.cell_to_face(
jtot,
geo,
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
),
jtot_hires=jtot_hires,
johm=johm,
external_current_source=external_current,
j_bootstrap=bootstrap_profile.j_bootstrap,
j_bootstrap_face=bootstrap_profile.j_bootstrap_face,
I_bootstrap=bootstrap_profile.I_bootstrap,
Ip_profile_face=jnp.zeros(geo.rho_face.shape), # psi not yet calculated
sigma=bootstrap_profile.sigma,
)

return currents


def _prescribe_currents_with_bootstrap(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
def _prescribe_currents(
bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile,
external_current: jax.Array,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels,
) -> state.Currents:
"""Creates the initial Currents.

Args:
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: General runtime parameters at t_initial.
geo: Geometry of the tokamak.
core_profiles: Core profiles.
source_models: All TORAX source/sink functions. If not provided, uses the
default sources.

Returns:
currents: Plasma currents
"""
"""Creates the initial Currents from a given bootstrap profile."""

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6)

# calculate "External" current profile (e.g. ECCD)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
Iext = (
math_utils.cell_integration(external_current * geo.spr, geo) / 10**6
)
Expand Down Expand Up @@ -653,12 +540,21 @@ def _init_psi_and_current(
isinstance(geo, circular_geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
currents = _prescribe_currents_no_bootstrap(
# First calculate currents without bootstrap.
bootstrap = source_profiles_lib.BootstrapCurrentProfile.zero_profile(
geo
)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
currents = _prescribe_currents(
bootstrap_profile=bootstrap,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = _update_psi_from_j(
dynamic_runtime_params_slice,
Expand All @@ -668,12 +564,18 @@ def _init_psi_and_current(
core_profiles = dataclasses.replace(
core_profiles, currents=currents, psi=psi
)
currents = _prescribe_currents_with_bootstrap(
# Now calculate currents with bootstrap.
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
currents = _prescribe_currents(
bootstrap_profile=bootstrap_profile,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = _update_psi_from_j(
dynamic_runtime_params_slice,
Expand Down
1 change: 1 addition & 0 deletions torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def get_value(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
calculated_source_profiles: source_profiles.SourceProfiles | None,
) -> chex.Array:
raise NotImplementedError('Call `get_bootstrap` instead.')

Expand Down
8 changes: 3 additions & 5 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -121,18 +122,15 @@ def calc_relativistic_correction() -> jax.Array:


def bremsstrahlung_model_func(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_caculated_source_profiles: source_profiles.SourceProfiles | None,
unused_model_func: source_models.SourceModels | None,
) -> jax.Array:
"""Model function for the Bremsstrahlung heat sink."""
del (
static_runtime_params_slice,
unused_model_func,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
Expand Down
10 changes: 6 additions & 4 deletions torax/sources/cyclotron_radiation_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models as source_models_lib
from torax.sources import source_models
from torax.sources import source_profiles


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -283,7 +284,8 @@ def cyclotron_radiation_albajar(
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels,
) -> array_typing.ArrayFloat:
"""Calculates the cyclotron radiation heat sink contribution to the electron heat equation.

Expand Down Expand Up @@ -311,13 +313,13 @@ def cyclotron_radiation_albajar(
geo: The geometry object.
source_name: The name of the source.
core_profiles: The core profiles object.
source_models: Collections of source models.
unused_calculated_source_profiles: Unused.
unused_source_models: Unused.

Returns:
The cyclotron radiation heat sink contribution to the electron heat
equation.
"""
del (source_models,)
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
Expand Down
5 changes: 4 additions & 1 deletion torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles

InterpolatedVarTimeRhoInput = (
runtime_params_lib.interpolated_param.InterpolatedVarTimeRhoInput
Expand Down Expand Up @@ -109,6 +110,7 @@ def calc_heating_and_current(
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Model function for the electron-cyclotron source.
Expand All @@ -122,7 +124,8 @@ def calc_heating_and_current(
geo: Magnetic geometry.
source_name: Name of the source.
core_profiles: CoreProfiles component of the state.
unused_model_func: (unused) source models used in the simulation.
unused_calculated_source_profiles: Unused.
unused_source_models: Unused.

Returns:
2D array of electron cyclotron heating power density and current density.
Expand Down
10 changes: 7 additions & 3 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles


# pylint: disable=invalid-name
Expand Down Expand Up @@ -81,7 +82,8 @@ def calc_puff_source(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles | None = None,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Calculates external source term for n from puffs."""
Expand Down Expand Up @@ -177,7 +179,8 @@ def calc_generic_particle_source(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles | None = None,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Calculates external source term for n from SBI."""
Expand Down Expand Up @@ -266,7 +269,8 @@ def calc_pellet_source(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles | None = None,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Calculates external source term for n from pellets."""
Expand Down
5 changes: 3 additions & 2 deletions torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_profiles


def calc_fusion(
Expand Down Expand Up @@ -144,13 +145,13 @@ def fusion_heat_model_func(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: Optional['source_models.SourceModels'],
) -> jax.Array:
"""Model function for fusion heating."""
# pytype: enable=name-error
del source_name
# pylint: disable=invalid-name
_, Pfus_i, Pfus_e = calc_fusion(
geo,
Expand Down
27 changes: 5 additions & 22 deletions torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_profiles
# pylint: disable=invalid-name


Expand Down Expand Up @@ -101,33 +102,15 @@ def __post_init__(self):
# pytype bug: does not treat 'source_models.SourceModels' as a forward reference
# pytype: disable=name-error
def calculate_generic_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles | None = None,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: Optional['source_models.SourceModels'] = None,
) -> jax.Array:
"""Calculates the external current density profiles.

Args:
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: Parameter configuration at present timestep.
geo: Tokamak geometry.
source_name: Name of the source.
unused_state: State argument not used in this function but is present to
adhere to the source API.
unused_source_models: Source models argument not used in this function but
is present to adhere to the source API.

Returns:
External current density profile along the cell grid.
"""
del (
static_runtime_params_slice,
unused_state,
unused_source_models,
) # Unused.
"""Calculates the external current density profiles on the cell grid."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
Expand Down
Loading