diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index f9d0ffa8..fa8b57a1 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -282,132 +282,19 @@ def get_ion_density_and_charge_states( return ni, nimp, Zi, Zi_face, Zimp, Zimp_face -def _prescribe_currents_no_bootstrap( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels, -) -> state.Currents: - """Creates the initial Currents without the bootstrap current. - - Args: - static_runtime_params_slice: Static runtime parameters. - dynamic_runtime_params_slice: General runtime parameters at t_initial. - geo: Geometry of the tokamak. - core_profiles: Core profiles. - source_models: All TORAX source/sink functions. - - Returns: - currents: Initial Currents - """ - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style - - # Calculate splitting of currents depending on input runtime params. - Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot - - # Set zero bootstrap current - bootstrap_profile = source_profiles_lib.BootstrapCurrentProfile.zero_profile( - geo - ) - - # calculate "External" current profile (e.g. ECCD) - external_current = source_models.external_current_source( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) - Iext = ( - math_utils.cell_integration(external_current * geo.spr, geo) / 10**6 - ) - # Total Ohmic current. - Iohm = Ip - Iext - - # construct prescribed current formula on grid. - jformula = ( - 1 - geo.rho_norm**2 - ) ** dynamic_runtime_params_slice.profile_conditions.nu - # calculate total and Ohmic current profiles - denom = _trapz(jformula * geo.spr, geo.rho_norm) - if dynamic_runtime_params_slice.profile_conditions.initial_j_is_total_current: - Ctot = Ip * 1e6 / denom - jtot = jformula * Ctot - johm = jtot - external_current - else: - Cohm = Iohm * 1e6 / denom - johm = jformula * Cohm - jtot = johm + external_current - - jtot_hires = _get_jtot_hires( - dynamic_runtime_params_slice, - geo, - bootstrap_profile, - Iohm, - external_current=external_current, - ) - - currents = state.Currents( - jtot=jtot, - jtot_face=math_utils.cell_to_face( - jtot, - geo, - preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE, - ), - jtot_hires=jtot_hires, - johm=johm, - external_current_source=external_current, - j_bootstrap=bootstrap_profile.j_bootstrap, - j_bootstrap_face=bootstrap_profile.j_bootstrap_face, - I_bootstrap=bootstrap_profile.I_bootstrap, - Ip_profile_face=jnp.zeros(geo.rho_face.shape), # psi not yet calculated - sigma=bootstrap_profile.sigma, - ) - - return currents - - -def _prescribe_currents_with_bootstrap( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, +def _prescribe_currents( + bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile, + external_current: jax.Array, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels, ) -> state.Currents: - """Creates the initial Currents. - - Args: - static_runtime_params_slice: Static runtime parameters. - dynamic_runtime_params_slice: General runtime parameters at t_initial. - geo: Geometry of the tokamak. - core_profiles: Core profiles. - source_models: All TORAX source/sink functions. If not provided, uses the - default sources. - - Returns: - currents: Plasma currents - """ + """Creates the initial Currents from a given bootstrap profile.""" # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot - - bootstrap_profile = source_models.j_bootstrap.get_bootstrap( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6) - # calculate "External" current profile (e.g. ECCD) - external_current = source_models.external_current_source( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) Iext = ( math_utils.cell_integration(external_current * geo.spr, geo) / 10**6 ) @@ -653,12 +540,21 @@ def _init_psi_and_current( isinstance(geo, circular_geometry.CircularAnalyticalGeometry) or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j ): - currents = _prescribe_currents_no_bootstrap( + # First calculate currents without bootstrap. + bootstrap = source_profiles_lib.BootstrapCurrentProfile.zero_profile( + geo + ) + external_current = source_models.external_current_source( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, - source_models=source_models, + ) + currents = _prescribe_currents( + bootstrap_profile=bootstrap, + external_current=external_current, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, ) psi = _update_psi_from_j( dynamic_runtime_params_slice, @@ -668,12 +564,18 @@ def _init_psi_and_current( core_profiles = dataclasses.replace( core_profiles, currents=currents, psi=psi ) - currents = _prescribe_currents_with_bootstrap( + # Now calculate currents with bootstrap. + bootstrap_profile = source_models.j_bootstrap.get_bootstrap( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, - source_models=source_models, + ) + currents = _prescribe_currents( + bootstrap_profile=bootstrap_profile, + external_current=external_current, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, ) psi = _update_psi_from_j( dynamic_runtime_params_slice, diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index 891b6e02..b105a64b 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -148,6 +148,7 @@ def get_value( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + calculated_source_profiles: source_profiles.SourceProfiles | None, ) -> chex.Array: raise NotImplementedError('Call `get_bootstrap` instead.') diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index 9033c5b7..0a26ca55 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -30,6 +30,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_models +from torax.sources import source_profiles @dataclasses.dataclass(kw_only=True) @@ -121,18 +122,15 @@ def calc_relativistic_correction() -> jax.Array: def bremsstrahlung_model_func( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + unused_caculated_source_profiles: source_profiles.SourceProfiles | None, unused_model_func: source_models.SourceModels | None, ) -> jax.Array: """Model function for the Bremsstrahlung heat sink.""" - del ( - static_runtime_params_slice, - unused_model_func, - ) # Unused. dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ source_name ] diff --git a/torax/sources/cyclotron_radiation_heat_sink.py b/torax/sources/cyclotron_radiation_heat_sink.py index c244c5de..04014377 100644 --- a/torax/sources/cyclotron_radiation_heat_sink.py +++ b/torax/sources/cyclotron_radiation_heat_sink.py @@ -31,7 +31,8 @@ from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source -from torax.sources import source_models as source_models_lib +from torax.sources import source_models +from torax.sources import source_profiles @dataclasses.dataclass(kw_only=True) @@ -283,7 +284,8 @@ def cyclotron_radiation_albajar( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, + unused_source_models: source_models.SourceModels, ) -> array_typing.ArrayFloat: """Calculates the cyclotron radiation heat sink contribution to the electron heat equation. @@ -311,13 +313,13 @@ def cyclotron_radiation_albajar( geo: The geometry object. source_name: The name of the source. core_profiles: The core profiles object. - source_models: Collections of source models. + unused_calculated_source_profiles: Unused. + unused_source_models: Unused. Returns: The cyclotron radiation heat sink contribution to the electron heat equation. """ - del (source_models,) dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ source_name ] diff --git a/torax/sources/electron_cyclotron_source.py b/torax/sources/electron_cyclotron_source.py index e51c5c6b..a5f0cfc5 100644 --- a/torax/sources/electron_cyclotron_source.py +++ b/torax/sources/electron_cyclotron_source.py @@ -32,6 +32,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_models +from torax.sources import source_profiles InterpolatedVarTimeRhoInput = ( runtime_params_lib.interpolated_param.InterpolatedVarTimeRhoInput @@ -109,6 +110,7 @@ def calc_heating_and_current( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Model function for the electron-cyclotron source. @@ -122,7 +124,8 @@ def calc_heating_and_current( geo: Magnetic geometry. source_name: Name of the source. core_profiles: CoreProfiles component of the state. - unused_model_func: (unused) source models used in the simulation. + unused_calculated_source_profiles: Unused. + unused_source_models: Unused. Returns: 2D array of electron cyclotron heating power density and current density. diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index e03652e2..cc1ed5a3 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -30,6 +30,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_models +from torax.sources import source_profiles # pylint: disable=invalid-name @@ -81,7 +82,8 @@ def calc_puff_source( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, - unused_state: state.CoreProfiles | None = None, + unused_state: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from puffs.""" @@ -177,7 +179,8 @@ def calc_generic_particle_source( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, - unused_state: state.CoreProfiles | None = None, + unused_state: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from SBI.""" @@ -266,7 +269,8 @@ def calc_pellet_source( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, - unused_state: state.CoreProfiles | None = None, + unused_state: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from pellets.""" diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index 0043c505..3a251e1c 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -28,6 +28,7 @@ from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source +from torax.sources import source_profiles def calc_fusion( @@ -144,13 +145,13 @@ def fusion_heat_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, - source_name: str, + unused_source_name: str, core_profiles: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Model function for fusion heating.""" # pytype: enable=name-error - del source_name # pylint: disable=invalid-name _, Pfus_i, Pfus_e = calc_fusion( geo, diff --git a/torax/sources/generic_current_source.py b/torax/sources/generic_current_source.py index 9a64392d..300c58bf 100644 --- a/torax/sources/generic_current_source.py +++ b/torax/sources/generic_current_source.py @@ -32,6 +32,7 @@ from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source +from torax.sources import source_profiles # pylint: disable=invalid-name @@ -101,33 +102,15 @@ def __post_init__(self): # pytype bug: does not treat 'source_models.SourceModels' as a forward reference # pytype: disable=name-error def calculate_generic_current( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, - unused_state: state.CoreProfiles | None = None, + unused_state: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: - """Calculates the external current density profiles. - - Args: - static_runtime_params_slice: Static runtime parameters. - dynamic_runtime_params_slice: Parameter configuration at present timestep. - geo: Tokamak geometry. - source_name: Name of the source. - unused_state: State argument not used in this function but is present to - adhere to the source API. - unused_source_models: Source models argument not used in this function but - is present to adhere to the source API. - - Returns: - External current density profile along the cell grid. - """ - del ( - static_runtime_params_slice, - unused_state, - unused_source_models, - ) # Unused. + """Calculates the external current density profiles on the cell grid.""" dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ source_name ] diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index ce445745..f7d285ac 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -30,6 +30,7 @@ from torax.sources import formulas from torax.sources import runtime_params as runtime_params_lib from torax.sources import source +from torax.sources import source_profiles # Many variables throughout this function are capitalized based on physics # notational conventions rather than on Google Python style # pylint: disable=invalid-name @@ -114,20 +115,16 @@ def calc_generic_heat_source( # pytype: disable=name-error def default_formula( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, - core_profiles: state.CoreProfiles, + unused_core_profiles: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Returns the default formula-based ion/electron heat source profile.""" # pytype: enable=name-error - del ( - core_profiles, - static_runtime_params_slice, - unused_source_models, - ) # Unused. dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ source_name ] diff --git a/torax/sources/impurity_radiation_heat_sink/impurity_radiation_constant_fraction.py b/torax/sources/impurity_radiation_heat_sink/impurity_radiation_constant_fraction.py index 3338bcf9..354f0fe2 100644 --- a/torax/sources/impurity_radiation_heat_sink/impurity_radiation_constant_fraction.py +++ b/torax/sources/impurity_radiation_heat_sink/impurity_radiation_constant_fraction.py @@ -28,6 +28,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profiles as source_profiles_lib MODEL_FUNCTION_NAME = "radially_constant_fraction_of_Pin" @@ -38,6 +39,7 @@ def radially_constant_fraction_of_Pin( # pylint: disable=invalid-name geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + calculated_source_profiles: source_profiles_lib.SourceProfiles | None, source_models: source_models_lib.SourceModels, ) -> jax.Array: """Model function for radiation heat sink from impurities. @@ -51,6 +53,8 @@ def radially_constant_fraction_of_Pin( # pylint: disable=invalid-name geo: Geometry object. source_name: Name of the source. core_profiles: Core profiles object. + calculated_source_profiles: Source profiles which have already been + calculated and can be used to avoid recomputing them. source_models: Source models object. Returns: @@ -73,6 +77,7 @@ def get_heat_source_profile(source: source_lib.Source) -> jax.Array: static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=calculated_source_profiles, ) return source.get_source_profile_for_affected_core_profile( profile, source_lib.AffectedCoreProfile.TEMP_EL.value, geo diff --git a/torax/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py b/torax/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py index da50a572..23408339 100644 --- a/torax/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py +++ b/torax/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py @@ -31,6 +31,7 @@ from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib from torax.sources import source_models as source_models_lib +from torax.sources import source_profiles MODEL_FUNCTION_NAME = 'impurity_radiation_mavrin_fit' @@ -190,6 +191,7 @@ def impurity_radiation_mavrin_fit( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: source_models_lib.SourceModels | None = None, ) -> jax.Array: """Model function for impurity radiation heat sink.""" diff --git a/torax/sources/ion_cyclotron_source.py b/torax/sources/ion_cyclotron_source.py index 69e72021..3dcff1dd 100644 --- a/torax/sources/ion_cyclotron_source.py +++ b/torax/sources/ion_cyclotron_source.py @@ -37,6 +37,7 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_models +from torax.sources import source_profiles from typing_extensions import override # Internal import. @@ -366,19 +367,16 @@ def _helium3_tail_temperature( def icrh_model_func( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, unused_source_models: source_models.SourceModels | None, toric_nn: ToricNNWrapper, ) -> jax.Array: """Compute ion/electron heat source terms.""" - del ( - unused_source_models, - static_runtime_params_slice, - ) # Unused. dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[ source_name ] diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py index 4763f9ac..4234eedc 100644 --- a/torax/sources/ohmic_heat_source.py +++ b/torax/sources/ohmic_heat_source.py @@ -34,6 +34,7 @@ from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib from torax.sources import source_operations +from torax.sources import source_profiles @functools.partial( @@ -167,12 +168,12 @@ def ohmic_model_func( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, - source_name: str, + unused_source_name: str, core_profiles: state.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles | None, source_models: source_models_lib.SourceModels, ) -> jax.Array: """Returns the Ohmic source for electron heat equation.""" - del source_name # Unused. if source_models is None: raise TypeError('source_models is a required argument for ohmic_model_func') diff --git a/torax/sources/source.py b/torax/sources/source.py index 8be242eb..5c7a47b6 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -41,6 +41,7 @@ from torax.config import runtime_params_slice from torax.geometry import geometry from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source_profiles # pytype bug: 'source_models.SourceModels' not treated as forward reference @@ -56,6 +57,7 @@ def __call__( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + caculated_source_profiles: source_profiles.SourceProfiles | None, source_models: Optional['source_models.SourceModels'], ) -> chex.Array: ... @@ -136,6 +138,7 @@ def get_value( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + calculated_source_profiles: source_profiles.SourceProfiles | None, ) -> chex.Array: """Returns the cell grid profile for this source during one time step. @@ -150,6 +153,12 @@ def get_value( sources get the core profiles at the start of the time step, implicit sources get the "live" profiles that is updated through the course of the time step as the solver converges. + calculated_source_profiles: The source profiles which have already been + calculated for this time step if they exist. This is used to avoid + recalculating profiles that are used as inputs to other sources. These + profiles will only exist if the source is explicit and then also + depends on the source type, see source_profile_builders.py for more + details. Returns: An array of shape (num affected core profiles, cell grid length) @@ -167,6 +176,7 @@ def get_value( static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=calculated_source_profiles, model_func=self.model_func, prescribed_values=dynamic_source_runtime_params.prescribed_values, output_shape=output_shape, @@ -239,6 +249,7 @@ def _get_source_profiles( geo: geometry.Geometry, source_name: str, core_profiles: state.CoreProfiles, + calculated_source_profiles: source_profiles.SourceProfiles | None, model_func: SourceProfileFunction | None, prescribed_values: chex.Array, output_shape: tuple[int, ...], @@ -259,6 +270,9 @@ def _get_source_profiles( source_name: The name of the source. core_profiles: Core plasma profiles. Used as input to the source profile functions. + calculated_source_profiles: The source profiles which have already been + calculated for this time step. This is used to avoid recalculating + profiles that are used as inputs to other sources.` model_func: Model function. prescribed_values: Array of values for this timeslice, interpolated onto the grid (ie with shape output_shape) @@ -282,6 +296,7 @@ def _get_source_profiles( geo, source_name, core_profiles, + calculated_source_profiles, source_models, ) case runtime_params_lib.Mode.PRESCRIBED.value: diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index 0734557f..bac39c41 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -235,6 +235,7 @@ def external_current_source( static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) total += source.get_source_profile_for_affected_core_profile( source_value, source_lib.AffectedCoreProfile.PSI, geo, diff --git a/torax/sources/source_profile_builders.py b/torax/sources/source_profile_builders.py index 92aeb99e..724aef46 100644 --- a/torax/sources/source_profile_builders.py +++ b/torax/sources/source_profile_builders.py @@ -209,6 +209,7 @@ def build_standard_source_profiles( dynamic_runtime_params_slice, geo, core_profiles, + None, ) if len(source.affected_core_profiles) == 1: computed_source_profiles[source.affected_core_profiles[0]][ diff --git a/torax/sources/tests/bootstrap_current_source_test.py b/torax/sources/tests/bootstrap_current_source_test.py index 0750e9b4..517b383c 100644 --- a/torax/sources/tests/bootstrap_current_source_test.py +++ b/torax/sources/tests/bootstrap_current_source_test.py @@ -178,6 +178,7 @@ def test_raise_error_on_get_value(self): mock.ANY, mock.ANY, mock.ANY, + None, ) def test_raise_error_on_get_source_profile_for_affected_core_profile(self): diff --git a/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py b/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py index 1bcb5208..5377dd8c 100644 --- a/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py +++ b/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py @@ -127,6 +127,7 @@ def test_source_value(self): dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) ) diff --git a/torax/sources/tests/electron_cyclotron_source_test.py b/torax/sources/tests/electron_cyclotron_source_test.py index 84e4a56c..a7116142 100644 --- a/torax/sources/tests/electron_cyclotron_source_test.py +++ b/torax/sources/tests/electron_cyclotron_source_test.py @@ -80,6 +80,7 @@ def test_source_value(self): static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) # ElectronCyclotronSource provides TEMP_EL and PSI chex.assert_rank(value, 2) diff --git a/torax/sources/tests/generic_current_source_test.py b/torax/sources/tests/generic_current_source_test.py index 5a1eff95..1a6ba3ec 100644 --- a/torax/sources/tests/generic_current_source_test.py +++ b/torax/sources/tests/generic_current_source_test.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest -from torax.config import runtime_params as general_runtime_params -from torax.config import runtime_params_slice -from torax.geometry import circular_geometry from torax.sources import generic_current_source from torax.sources.tests import test_lib -class GenericCurrentSourceTest(test_lib.SourceTestCase): +class GenericCurrentSourceTest(test_lib.SingleProfileSourceTestCase): """Tests for GenericCurrentSource.""" @classmethod @@ -31,47 +28,6 @@ def setUpClass(cls): model_func=generic_current_source.calculate_generic_current, ) - def test_profile_is_on_cell_grid(self): - """Tests that the profile is given on the cell grid.""" - geo = circular_geometry.build_circular_geometry() - torax_mesh = geo.torax_mesh - source_builder = self._source_class_builder() - source = source_builder() - self.assertEqual( - source.output_shape(torax_mesh), - torax_mesh.cell_centers.shape, - ) - runtime_params = general_runtime_params.GeneralRuntimeParams() - dynamic_runtime_params_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources={ - generic_current_source.GenericCurrentSource.SOURCE_NAME: ( - source_builder.runtime_params - ), - }, - torax_mesh=torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - static_slice = runtime_params_slice.build_static_runtime_params_slice( - runtime_params=runtime_params, - source_runtime_params={ - generic_current_source.GenericCurrentSource.SOURCE_NAME: ( - source_builder.runtime_params - ), - }, - torax_mesh=torax_mesh, - ) - self.assertEqual( - source.get_value( - static_slice, - dynamic_runtime_params_slice, - geo, - core_profiles=None, - ).shape, - torax_mesh.cell_centers.shape, - ) - if __name__ == '__main__': absltest.main() diff --git a/torax/sources/tests/ion_cyclotron_source_test.py b/torax/sources/tests/ion_cyclotron_source_test.py index ea1268dd..85e39a9a 100644 --- a/torax/sources/tests/ion_cyclotron_source_test.py +++ b/torax/sources/tests/ion_cyclotron_source_test.py @@ -169,6 +169,7 @@ def test_source_value(self, mock_path): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) chex.assert_rank(ion_and_el, 2) diff --git a/torax/sources/tests/source_models_test.py b/torax/sources/tests/source_models_test.py index 2e8dd9a4..27cee0a8 100644 --- a/torax/sources/tests/source_models_test.py +++ b/torax/sources/tests/source_models_test.py @@ -95,6 +95,7 @@ def source_name(self) -> str: dynamic_runtime_params_slice, geo, core_profiles, + None, ) external_current_source = source_models.external_current_source( diff --git a/torax/sources/tests/source_operations_test.py b/torax/sources/tests/source_operations_test.py index 9e6c4ab7..b0261b44 100644 --- a/torax/sources/tests/source_operations_test.py +++ b/torax/sources/tests/source_operations_test.py @@ -101,6 +101,7 @@ def foo_formula( geo: geometry.Geometry, unused_source_name: str, unused_state, + unused_calculated_source_profiles, unused_source_models, ): return jnp.stack([ diff --git a/torax/sources/tests/source_test.py b/torax/sources/tests/source_test.py index 69eb5ef7..27a9bd02 100644 --- a/torax/sources/tests/source_test.py +++ b/torax/sources/tests/source_test.py @@ -189,6 +189,7 @@ def test_zero_profile_works_by_default(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) np.testing.assert_allclose( profile, @@ -214,7 +215,8 @@ def test_output_shape_works_multiple_profiles(self): def test_correct_mode_called(self, mode, expected_profile): source_builder = source_lib.make_source_builder( test_lib.TestSource, - model_func=lambda _0, _1, _2, _3, _4, _5: jnp.full(geo.rho.shape, 2), + model_func=lambda _0, _1, _2, _3, _4, _5, _6: jnp.full( + geo.rho.shape, 2), )() source_models_builder = source_models_lib.SourceModelsBuilder( {'foo': source_builder}, @@ -254,6 +256,7 @@ def test_correct_mode_called(self, mode, expected_profile): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) np.testing.assert_allclose( profile, @@ -306,6 +309,7 @@ def test_defaults_output_zeros(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) with self.subTest('prescribed'): static_slice = runtime_params_slice.build_static_runtime_params_slice( @@ -323,6 +327,7 @@ def test_defaults_output_zeros(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) np.testing.assert_allclose( profile, @@ -335,7 +340,7 @@ def test_overriding_model(self): expected_output = jnp.ones_like(geo.rho) source_builder = source_lib.make_source_builder( IonElTestSource, - model_func=lambda _0, _1, _2, _3, _4, _5: expected_output, + model_func=lambda _0, _1, _2, _3, _4, _5, _6: expected_output, )() source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED source_models_builder = source_models_lib.SourceModelsBuilder( @@ -369,6 +374,7 @@ def test_overriding_model(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) np.testing.assert_allclose(profile, expected_output) @@ -415,6 +421,7 @@ def test_overriding_prescribed_values(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) np.testing.assert_allclose(profile, expected_output) diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 68f007d9..a31cc33e 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -166,6 +166,7 @@ def test_source_value_on_the_cell_grid(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) chex.assert_rank(value, 1) self.assertEqual(value.shape, geo.rho.shape) @@ -212,6 +213,7 @@ def test_source_values_on_the_cell_grid(self): static_runtime_params_slice=static_slice, geo=geo, core_profiles=core_profiles, + calculated_source_profiles=None, ) chex.assert_rank(ion_and_el, 2) self.assertEqual(ion_and_el.shape, (2, geo.torax_mesh.nx)) diff --git a/torax/tests/physics.py b/torax/tests/physics.py index 35a1c366..514610e9 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -32,6 +32,7 @@ from torax.sources import generic_current_source from torax.sources import runtime_params as source_runtime_params from torax.sources import source_models as source_models_lib +from torax.sources import source_profiles from torax.tests.test_lib import torax_refs @@ -137,12 +138,15 @@ def test_update_psi_from_j( # pylint: disable=protected-access if isinstance(geo, circular_geometry.CircularAnalyticalGeometry): - currents = core_profile_setters._prescribe_currents_no_bootstrap( - static_slice, - dynamic_runtime_params_slice, - geo, - source_models=source_models, - core_profiles=initial_core_profiles, + bootstrap = source_profiles.BootstrapCurrentProfile.zero_profile(geo) + external_current = source_models.external_current_source( + geo, initial_core_profiles, dynamic_runtime_params_slice, static_slice + ) + currents = core_profile_setters._prescribe_currents( + bootstrap_profile=bootstrap, + external_current=external_current, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, ) psi = core_profile_setters._update_psi_from_j( dynamic_runtime_params_slice, geo, currents.jtot_hires @@ -152,7 +156,6 @@ def test_update_psi_from_j( else: raise ValueError(f'Unknown geometry type: {geo.geometry_type}') # pylint: enable=protected-access - print(psi) np.testing.assert_allclose(psi, references.psi.value) diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index c017b656..0c4f1d95 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -37,6 +37,7 @@ from torax.sources import electron_density_sources from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib +from torax.sources import source_profiles from torax.sources.tests import test_lib from torax.stepper import linear_theta_method from torax.tests.test_lib import default_sources @@ -103,29 +104,32 @@ def custom_source_formula( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, unused_source_name: str, - unused_state: state_lib.CoreProfiles | None, + unused_state: state_lib.CoreProfiles, + unused_calculated_source_profiles: source_profiles.SourceProfiles, unused_source_models: ..., ): # Combine the outputs. # pylint: disable=protected-access + kwargs = dict( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo, + unused_state=unused_state, + unused_calculated_source_profiles=unused_calculated_source_profiles, + unused_source_models=unused_source_models, + ) return ( electron_density_sources.calc_puff_source( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, source_name=electron_density_sources.GasPuffSource.SOURCE_NAME, + **kwargs ) + electron_density_sources.calc_generic_particle_source( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, source_name=electron_density_sources.GenericParticleSource.SOURCE_NAME, + **kwargs ) + electron_density_sources.calc_pellet_source( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, source_name=electron_density_sources.PelletSource.SOURCE_NAME, + **kwargs ) ) # pylint: enable=protected-access diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 68b3e45c..b7b793cb 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -122,6 +122,7 @@ def custom_source_formula( unused_geo, source_name, unused_state, + unused_calculated_source_profiles, unused_source_models, ): dynamic_source_params = dynamic_runtime_params.sources[source_name]