From c12b4d3eb003757ab4429173c154216438bfac07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20K=C3=B6hler?= <27728103+Ceyron@users.noreply.github.com> Date: Fri, 2 Aug 2024 10:28:01 +0200 Subject: [PATCH] Markdown rework (#2) * Add markdown based documentation * Show all of base loss * Add constructor doc * Update base config to markdown * Format line length * Rework supervised configuration doc * Translate residuum config to markdown based syntax * translate leftover configuration to new docstrings * Translate supervised trainer * Update the other two convinience trainers * Update documentation to mixer components * Update docs of general trainer * Display additional methods in documentation * Allow supplying optstate * Allow option to deactivate the tqdm progress meter * Add documentation to all callbacks --- docs/api/general_trainer.md | 2 + docs/api/loss.md | 6 +- trainax/_general_trainer.py | 165 +++++++++------- trainax/_mixer.py | 187 ++++++++++-------- trainax/callback/_base.py | 18 ++ trainax/callback/_composite.py | 4 + trainax/callback/_get_network.py | 5 +- trainax/callback/_grad_norm.py | 22 ++- trainax/callback/_loss.py | 31 ++- trainax/callback/_save_network.py | 13 +- trainax/callback/_weight_norm.py | 12 +- trainax/configuration/_base_configuration.py | 26 +-- trainax/configuration/_composite.py | 38 ++-- trainax/configuration/_diverted_chain.py | 95 ++++----- .../_diverted_chain_branch_one.py | 72 +++---- .../configuration/_mix_chain_post_physics.py | 85 ++++---- trainax/configuration/_residuum.py | 74 +++---- trainax/configuration/_supervised.py | 79 ++++---- trainax/loss/_base_loss.py | 58 ++++++ trainax/trainer/_diverted_chain_branch_one.py | 98 +++++---- trainax/trainer/_residuum.py | 94 +++++---- trainax/trainer/_supervised.py | 91 ++++----- 22 files changed, 736 insertions(+), 539 deletions(-) diff --git a/docs/api/general_trainer.md b/docs/api/general_trainer.md index c4eec5d..9e0fb43 100644 --- a/docs/api/general_trainer.md +++ b/docs/api/general_trainer.md @@ -5,3 +5,5 @@ members: - __init__ - __call__ + - full_loss + - step_fn diff --git a/docs/api/loss.md b/docs/api/loss.md index b5bb980..1fa831d 100644 --- a/docs/api/loss.md +++ b/docs/api/loss.md @@ -32,8 +32,4 @@ --- -::: trainax.loss.BaseLoss - options: - members: - - __init__ - - __call__ \ No newline at end of file +::: trainax.loss.BaseLoss \ No newline at end of file diff --git a/trainax/_general_trainer.py b/trainax/_general_trainer.py index a0a865c..e232116 100644 --- a/trainax/_general_trainer.py +++ b/trainax/_general_trainer.py @@ -1,9 +1,9 @@ -from typing import Optional +from typing import Optional, Union import equinox as eqx import jax.numpy as jnp import optax -from jaxtyping import PRNGKeyArray, PyTree +from jaxtyping import Array, Float, PRNGKeyArray, PyTree from tqdm.autonotebook import tqdm from ._mixer import PermutationMixer, TrajectorySubStacker @@ -34,31 +34,35 @@ def __init__( callback_fn: Optional[BaseCallback] = None, ): """ - Abstract training for an autoregressive neural emulator on a collection of - trajectories. - - The length of (sub-)trajectories returned by `trajectory_sub_stacker` must - match the requires length of reference for the used `loss_configuration`. - - Args: - trajectory_sub_stacker (TrajectorySubStacker): A callable that takes a - list of indices and returns a collection of (sub-)trajectories. - loss_configuration (BaseConfiguration): A configuration that defines the - loss function to be minimized. - ref_stepper (eqx.Module, optional): A reference stepper that is used to - compute the residuum. Supply this if the loss configuration requires - a reference stepper. Defaults to None. - residuum_fn (eqx.Module, optional): A residuum function that computes the - discrete residuum between two consecutive states. Supply this if the - loss configuration requires a residuum function. Defaults to None. - optimizer (optax.GradientTransformation): An optimizer that updates the - parameters of the stepper given the gradient. - num_minibatches (int): The number of minibatches to train on. This equals - the total number of update steps performed. The number of epochs is - determined based on this and the `batch_size`. - batch_size (int): The size of each batch. - callback_fn (BaseCallback, optional): A callback function that is called - at the end of each minibatch. Defaults to None. + Abstract training for an autoregressive neural emulator on a collection + of trajectories. + + !!! info + The length of (sub-)trajectories returned by + `trajectory_sub_stacker` must match the required length of reference + for the used `loss_configuration`. + + **Arguments:** + + - `trajectory_sub_stacker`: A callable that takes a + list of indices and returns a collection of (sub-)trajectories. + - `loss_configuration`: A configuration that defines the + loss function to be minimized. + - `ref_stepper`: A reference stepper that is used to + compute the residuum. Supply this if the loss configuration requires + a reference stepper. + - `residuum_fn`: A residuum function that computes the + discrete residuum between two consecutive states. Supply this if the + loss configuration requires a residuum function. Defaults to None. + - `optimizer`: An optimizer that updates the + parameters of the stepper given the gradient. + - `num_minibatches`: The number of minibatches to train on. This equals + the total number of update steps performed. The number of epochs is + automatically determined based on this and the `batch_size`. + - `batch_size`: The size of each minibatch, i.e., how many samples are + included within. + - `callback_fn`: A callback function that is called + at the end of each minibatch. Defaults to None. """ self.trajectory_sub_stacker = trajectory_sub_stacker self.loss_configuration = loss_configuration @@ -75,6 +79,17 @@ def full_loss( ) -> float: """ Compute the loss on the entire dataset. + + !!! warning + This can lead to out of memory errors if the dataset is too large. + + **Arguments:** + + - `stepper`: The stepper to compute the loss with. + + **Returns:** + + - The loss value. """ return self.loss_configuration( stepper, @@ -87,19 +102,22 @@ def step_fn( self, stepper: eqx.Module, opt_state: optax.OptState, - data: PyTree, + data: PyTree[float[Array, "batch_size sub_trj_len ..."]], ) -> tuple[eqx.Module, optax.OptState, float]: """ Perform a single update step to the `stepper`'s parameters. - Args: - stepper (eqx.Module): The stepper to be updated. - opt_state (optax.OptState): The optimizer state. - data (PyTree): The data for the current minibatch. + **Arguments:** + + - `stepper`: The equinox module to be updated. + - `opt_state`: The current optimizer state. + - `data`: The data for the current minibatch. + + **Returns:** - Returns: - tuple[eqx.Module, optax.OptState, float]: The updated stepper, the - updated optimizer state, and the loss value. + - The updated equinox module + - The updated optimizer state + - The loss value """ loss, grad = eqx.filter_value_and_grad( lambda m: self.loss_configuration( @@ -114,16 +132,20 @@ def __call__( self, stepper: eqx.Module, key: PRNGKeyArray, + opt_state: Optional[optax.OptState] = None, *, return_loss_history: bool = True, record_loss_every: int = 1, - ): + spawn_tqdm: bool = True, + ) -> Union[ + tuple[eqx.Module, Float[Array, "num_minibatches"]], + eqx.Module, + tuple[eqx.Module, Float[Array, "num_minibatches"], list], + tuple[eqx.Module, list], + ]: """ - Perform the entire training of an autoregressive neural emulator - `stepper`. - - This method spawns a `tqdm` progress meter showing the current update - step and displaying the epoch with its respetive minibatch counter. + Perform the entire training of an autoregressive neural emulator given + in an initial state as `stepper`. This method's return signature depends on the presence of a callback function. If a callback function is provided, this function has at max @@ -133,25 +155,32 @@ def __call__( values of the callback function at each minibatch. If no callback function is provided, this function has at max two return values. The first return value is the trained stepper, and the second return value - is the loss history. + is the loss history. If `return_loss_history` is set to `False`, the + loss history will not be returned. + + **Arguments:** + + - `stepper`: The equinox Module to be trained. + - `key`: The random key to be used for shuffling the minibatches. + - `opt_state`: The initial optimizer state. Defaults to None, meaning + the optimizer will be reinitialized. + - `return_loss_history`: Whether to return the loss history. + - `record_loss_every`: Record the loss every `record_loss_every` + minibatches. Defaults to 1, i.e., record every minibatch. + - `spawn_tqdm`: Whether to spawn the tqdm progress meter showing the + current update step and displaying the epoch with its respetive + minibatch counter. - Args: - stepper (eqx.Module): The stepper to be trained. key (PRNGKeyArray): - The random key to be used for shuffling the - minibatches. - return_loss_history (bool, optional): Whether to return the loss - history. Defaults to True. - record_loss_every (int, optional): Record the loss every - `record_loss_every` minibatches. Defaults to 1. + **Returns:** - Returns: - Varying, see above. + - Varying, see above. It will always return the trained stepper as the + first return value. - Tipp: + !!! tip You can use `equinox.filter_vmap` to train mulitple networks (of the - same architecture) at the same time. For example, if your GPU - is not fully utilized yet, this will give you a init-seed - statistic basically for free. + same architecture) at the same time. For example, if your GPU is not + fully utilized yet, this will give you a init-seed statistic + basically for free. """ loss_history = [] if self.callback_fn is not None: @@ -164,15 +193,17 @@ def __call__( shuffle_key=key, ) - p_meter = tqdm( - total=self.num_minibatches, - desc=f"E: {0:05d}, B: {0:05d}", - ) + if spawn_tqdm: + p_meter = tqdm( + total=self.num_minibatches, + desc=f"E: {0:05d}, B: {0:05d}", + ) update_fn = eqx.filter_jit(self.step_fn) trained_stepper = stepper - opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array)) + if opt_state is None: + opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array)) for update_i in range(self.num_minibatches): batch_indices, (expoch_id, batch_id) = mixer(update_i, return_info=True) @@ -185,13 +216,15 @@ def __call__( ) if update_i % record_loss_every == 0: loss_history.append(loss) - p_meter.update(1) + if spawn_tqdm: + p_meter.update(1) - p_meter.set_description( - f"E: {expoch_id:05d}, B: {batch_id:05d}", - ) + p_meter.set_description( + f"E: {expoch_id:05d}, B: {batch_id:05d}", + ) - p_meter.close() + if spawn_tqdm: + p_meter.close() loss_history = jnp.array(loss_history) diff --git a/trainax/_mixer.py b/trainax/_mixer.py index 8cf1b66..7511445 100644 --- a/trainax/_mixer.py +++ b/trainax/_mixer.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp import jax.tree_util as jtu -from jaxtyping import Array, Float, PRNGKeyArray, PyTree +from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree from ._utils import stack_sub_trajectories @@ -23,30 +23,38 @@ def __init__( """ Slice a batch of trajectories into sub-trajectories. - Useful to create windows of specific length for (rollout) training + Useful to create windows of specific length for (unrolled) training methodologies of autoregressive neural emulators. - Args: - data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]): - The batch of trajectories to slice. This must be a PyTree of - Arrays who have at least two leading axes: a batch-axis and a - time axis. For example, the zeroth axis can be associated with - multiple initial conditions or constitutive parameters and the - first axis represents all temporal snapshots. A PyTree can also - just be an array. You can provide additional leafs in the - PyTree, e.g., for the corresponding constitutive parameters etc. - Make sure that the emulator has the corresponding signature. - sub_trajectory_len (int): The length of the sub-trajectories. This - must be smaller equal to the length of the trajectories - (`trj_len`). For rollout training with `t` steps, set this to - `t+1` to include the necessary initial condition. - do_sub_stacking (bool, optional): Whether to slice out all possible - (overlapping) windows out of the `trj_len` or just slice the - `trj_len` axis from `0:sub_trajectory_len`. Defaults to True. - only_store_ic (bool, optional): Whether to only store the initial - condition of the sub-trajectories. This can be helpful for - configurations that do not need the reference trajectory like - residuum-based learning strategies. Defaults to False. + **Arguments:** + + - `data_trajectories`: The batch of trajectories to slice. This must be + a PyTree of Arrays who have at least two leading axes: a batch-axis + and a time axis. For example, the zeroth axis can be associated with + multiple initial conditions or constitutive parameters and the first + axis represents all temporal snapshots. A PyTree can also just be an + array. You can provide additional leafs in the PyTree, e.g., for the + corresponding constitutive parameters etc. Make sure that the + emulator has the corresponding signature. + - `sub_trajectory_len`: The length of the sub-trajectories. This + must be smaller equal to the length of the trajectories (`trj_len`). + For unrolled training with `t` steps, set this to `t+1` to include + the necessary initial condition. + - `do_sub_stacking`: Whether to slice out all possible + (overlapping) windows out of the `trj_len` or just slice the + `trj_len` axis from `0:sub_trajectory_len`. + - `only_store_ic`: Whether to only store the initial + condition of the sub-trajectories. This can be helpful for + configurations that do not need the reference trajectory like + residuum-based learning strategies. + + !!! info + * Since the windows sliced out are overlapping, the produces + internal array can be large, especially if `sub_trajectory_len` + is large. Certainly, this is not the most memory-efficient + solution but is sufficient if your problem easily fits into + memory. Consider overwriting this class with a more memory + efficient implementation if you run into memory issues. """ if do_sub_stacking: # return shape is (num_samples, num_stacks, sub_trj_len, ...) @@ -75,10 +83,21 @@ def __init__( def __call__( self, - indices, - ): + indices: slice, + ) -> PyTree[Float[Array, "len(indices) sub_trj_len ..."]]: """ Slice out sub-samples based on the given indices. + + **Arguments:** + + - `indices`: The indices to slice out the sub-trajectories, e.g., this + can be `[0, 4, 5]` to slice out the zeroth, fourth, and fifth + sub-trajectories or it can be a `slice` object. + + **Returns:** + + - `PyTree[Float[Array, "len(indices) sub_trj_len ..."]]`: The sliced + sub-trajectories. """ return jtu.tree_map(lambda x: x[indices], self.data_sub_trajectories) @@ -101,21 +120,23 @@ def __init__( ): """ Precompute permuations for a given number of minibatches within a - dataset. Automatically determines the number of epochs necessary. Upon - calling returns a collection of indices to produce a new minibatch. + dataset. Automatically determines the number of necessary epochs (runs + over the entire dataset). Upon calling returns a collection of indices + to produce a new minibatch. If the remainder minibatch in one epoch is smaller than the batch size, it will **not** be extended using data from the next epoch, but returned as smaller list of indices. - Args: - num_total_samples (int): The total number of samples in the dataset. - num_minibatches (int): The size of minibatches to train on. - batch_size (int): The size of the minibatches. - shuffle_key (PRNGKeyArray): The key to create the permutation; needed for - deterministic reproducibility. + **Arguments:** + + - `num_total_samples`: The total number of samples in the dataset. + - `num_minibatches`: The size of minibatches to train on. + - `batch_size`: The size of the minibatches. + - `shuffle_key`: The key to create the permutation; needed for + deterministic reproducibility. - Raises: + !!! warning ValueError: If the batch size is larger than the total number of samples for one epoch. """ @@ -154,21 +175,23 @@ def __call__( i: int, *, return_info: bool = False, - ): + ) -> Int[Array, "batch_size"]: """ Given the batch index `i`, return the corresponding indices to slice out the minibatch. - Args: - i (int): The batch index. - return_info (bool, optional): Whether to return additional - information about the current epoch and batch index. Defaults to - False. + **Arguments:** + + - `i`: The batch index. + - `return_info`: Whether to return additional information about the + current epoch and batch index. + + **Returns:** - Returns: - Array: The indices to slice out the minibatch. + - The indices to slice out the minibatch in form of an array of + integers. - Raises: + !!! warning ValueError: If the batch index is larger than the number of minibatches (because likely there will be no permuation for it) """ @@ -206,34 +229,38 @@ def __init__( ): """ Convenience class to combine `TrajectorySubStacker` and - `PermutationMixer`. Please prefer using the `TrajectorySubStacker` and - `PermutationMixer` directly. - - Args: - data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]): - The batch of trajectories to slice. This must be a PyTree of - Arrays who have at least two leading axes: a batch-axis and a - time axis. For example, the zeroth axis can be associated with - multiple initial conditions or constitutive parameters and the - first axis represents all temporal snapshots. A PyTree can also - just be an array. You can provide additional leafs in the - PyTree, e.g., for the corresponding constitutive parameters etc. - Make sure that the emulator has the corresponding signature. - sub_trajectory_len (int): The length of the sub-trajectories. This - must be smaller equal to the length of the trajectories - (`trj_len`). For rollout training with `t` steps, set this to - `t+1` to include the necessary initial condition. - num_minibatches (int): The number of minibatches to train on. - batch_size (int): The size of the minibatches. - shuffle_key (PRNGKeyArray): The key to create the permutation; needed for - deterministic reproducibility. - do_sub_stacking (bool, optional): Whether to slice out all possible - (overlapping) windows out of the `trj_len` or just slice the - `trj_len` axis from `0:sub_trajectory_len`. Defaults to True. - only_store_ic (bool, optional): Whether to only store the initial - condition of the sub-trajectories. This can be helpful for - configurations that do not need the reference trajectory like - residuum-based learning strategies. Defaults to False. + `PermutationMixer`. + + !!! info + Please prefer using the `TrajectorySubStacker` and + `PermutationMixer` directly as this is more amendable to `jax.vmap` + transformation in case of training multiple networks in parallel. + + **Arguments:** + + - `data_trajectories`: The batch of trajectories to slice. This must be + a PyTree of Arrays who have at least two leading axes: a batch-axis + and a time axis. For example, the zeroth axis can be associated with + multiple initial conditions or constitutive parameters and the first + axis represents all temporal snapshots. A PyTree can also just be an + array. You can provide additional leafs in the PyTree, e.g., for the + corresponding constitutive parameters etc. Make sure that the + emulator has the corresponding signature. + - `sub_trajectory_len`: The length of the sub-trajectories. This + must be smaller equal to the length of the trajectories (`trj_len`). + For rollout training with `t` steps, set this to `t+1` to include + the necessary initial condition. + - `num_minibatches`: The number of minibatches to train on. + - `batch_size`: The size of the minibatches. + - `shuffle_key`: The key to create the permutation; needed for + deterministic reproducibility. + - `do_sub_stacking`: Whether to slice out all possible (overlapping) + windows out of the `trj_len` or just slice the `trj_len` axis from + `0:sub_trajectory_len`. + - `only_store_ic`: Whether to only store the initial condition of the + sub-trajectories. This can be helpful for configurations that do not + need the reference trajectory like residuum-based learning + strategies. """ self.trajectory_sub_stacker = TrajectorySubStacker( data_trajectories, @@ -254,19 +281,19 @@ def __call__( i: int, *, return_info: bool = False, - ): + ) -> PyTree[Float[Array, "batch_size sub_trj_len ..."]]: """ Given the batch index `i`, return the corresponding sub-trajectories. - Args: - i (int): The batch index. - return_info (bool, optional): Whether to return additional - information about the current epoch and batch index. Defaults to - False. + **Arguments:** + + - `i`: The batch index. + - `return_info`: Whether to return additional information about the + current epoch and batch index. + + **Returns:** - Returns: - PyTree[Float[Array, "batch_size sub_trj_len ..."]]: The - sub-trajectories corresponding to the batch index. + - The sub-trajectories corresponding to the batch index. """ if return_info: batch_indices, permutation_info = self.permutation_mixer( diff --git a/trainax/callback/_base.py b/trainax/callback/_base.py index e64cd02..d31538b 100644 --- a/trainax/callback/_base.py +++ b/trainax/callback/_base.py @@ -9,6 +9,11 @@ class BaseCallback(eqx.Module, ABC): every: int name: str + def __init__(self, every: int, name: str): + """Base class for callbacks.""" + self.every = every + self.name = name + @abstractmethod def callback( self, @@ -24,6 +29,19 @@ def __call__( stepper: eqx.Module, data: PyTree, ) -> Dict[str, Any]: + """ + Evaluate the Callback. + + **Arguments:** + + - `update_i`: The current update step. + - `stepper`: The equinox.Module to evaluate the callback on. + - `data`: The data to evaluate the callback on. + + **Returns:** + + - The result of the callback wrapped into a dictionary. + """ if update_i % self.every == 0: res = self.callback(update_i, stepper, data) return {self.name: res} diff --git a/trainax/callback/_composite.py b/trainax/callback/_composite.py index d8134af..5c1199f 100644 --- a/trainax/callback/_composite.py +++ b/trainax/callback/_composite.py @@ -9,6 +9,10 @@ class CompositeCallback(eqx.Module): callbacks: list[BaseCallback] + def __init__(self, callbacks: list[BaseCallback]): + """Callback to combine multiple callbacks.""" + self.callbacks = callbacks + def __call__( self, update_i: int, diff --git a/trainax/callback/_get_network.py b/trainax/callback/_get_network.py index 2049dc8..550bb28 100644 --- a/trainax/callback/_get_network.py +++ b/trainax/callback/_get_network.py @@ -8,8 +8,8 @@ class GetNetwork(BaseCallback): def __init__(self, every: int, name: str = "network"): - self.every = every - self.name = name + """Callback to write out the network state `every` update step.""" + super().__init__(every, name) def callback( self, @@ -17,4 +17,5 @@ def callback( stepper: eqx.Module, data: PyTree, ) -> Any: + """Write out the network state.""" return stepper diff --git a/trainax/callback/_grad_norm.py b/trainax/callback/_grad_norm.py index a5205cd..5002f5f 100644 --- a/trainax/callback/_grad_norm.py +++ b/trainax/callback/_grad_norm.py @@ -24,12 +24,29 @@ def __init__( residuum_fn: eqx.Module = None, name: str, ): - self.every = every + """ + Callback to save the gradient norm associated with `loss_configuration` + `every` update steps. + + **Arguments:** + + - `every`: The frequency of the callback. + - `loss_configuration`: The loss configuration to compute the gradient + norm. If the gradient norm associated with the training loss is + desired, the corresponding loss configuration has to be re-supplied. + - `squared`: Whether to return the squared gradient norm. + - `ref_stepper`: A reference stepper that is used to compute the residuum. + Supply this if the loss configuration requires a reference stepper. + - `residuum_fn`: A residuum function that computes the discrete residuum + between two consecutive states. Supply this if the loss configuration + requires a residuum function. + - `name`: The name of the callback. + """ self.loss_configuration = loss_configuration self.squared = squared self.ref_stepper = ref_stepper self.residuum_fn = residuum_fn - self.name = name + super().__init__(every, name) def callback( self, @@ -37,6 +54,7 @@ def callback( stepper: eqx.Module, data: PyTree, ) -> eqx.Module: + """Compute the gradient norm.""" grad = eqx.filter_grad(self.loss_configuration)( stepper, data, diff --git a/trainax/callback/_loss.py b/trainax/callback/_loss.py index 266ee29..ea286a5 100644 --- a/trainax/callback/_loss.py +++ b/trainax/callback/_loss.py @@ -1,3 +1,5 @@ +from typing import Union + import equinox as eqx from jaxtyping import PyTree @@ -22,19 +24,42 @@ def __init__( residuum_fn: eqx.Module = None, name: str, ): - self.every = every + """ + Callback to save the loss associated with `loss_configuration` `every` + update steps. + + Use this to measure a stepper performance on a difference configuration + than the training loss. + + **Arguments:** + + - `every`: The frequency of the callback. + - `loss_configuration`: The loss configuration to compute the loss. + - `with_grad`: Whether to also return the associated gradient. If only + the gradient norm is desired, set this to `False` and consider using + [`trainax.callback.GradNorm`](). + - `ref_stepper`: A reference stepper that is used to compute the residuum. + Supply this if the loss configuration requires a reference stepper. + - `residuum_fn`: A residuum function that computes the discrete residuum + between two consecutive states. Supply this if the loss configuration + requires a residuum function. + - `name`: The name of the callback. + """ self.loss_configuration = loss_configuration self.with_grad = with_grad self.ref_stepper = ref_stepper self.residuum_fn = residuum_fn - self.name = name + super().__init__(every, name) def callback( self, update_i: int, stepper: eqx.Module, data: PyTree, - ) -> eqx.Module: + ) -> Union[eqx.Module, tuple[eqx.Module, eqx.Module]]: + """ + Compute the loss and optionally the associated gradient. + """ if self.with_grad: loss, grad = eqx.filter_value_and_grad(self.loss_configuration)( stepper, diff --git a/trainax/callback/_save_network.py b/trainax/callback/_save_network.py index c77a5a0..081621f 100644 --- a/trainax/callback/_save_network.py +++ b/trainax/callback/_save_network.py @@ -17,10 +17,19 @@ def __init__( file_name: str, name: str = "network_saved", ): - self.every = every + """ + Callback to write the network state to a file `every` update step. + + **Arguments:** + + - `every`: The frequency of the callback. + - `path`: The path to save the network state. + - `file_name`: The file name to save the network state. + - `name`: The name of the callback + """ self.path = path self.file_name = file_name - self.name = name + super().__init__(every, name) def callback( self, diff --git a/trainax/callback/_weight_norm.py b/trainax/callback/_weight_norm.py index 3a3bf55..68a84ae 100644 --- a/trainax/callback/_weight_norm.py +++ b/trainax/callback/_weight_norm.py @@ -10,9 +10,17 @@ class WeightNorm(BaseCallback): squared: bool = False def __init__(self, every: int, squared: bool = False, name: str = "weight_norm"): - self.every = every + """ + Callback to save the weight norm `every` update steps. + + **Arguments:** + + - `every`: The frequency of the callback. + - `squared`: Whether to return the squared weight norm. + - `name`: The name of the callback + """ self.squared = squared - self.name = name + super().__init__(every, name) def callback( self, diff --git a/trainax/configuration/_base_configuration.py b/trainax/configuration/_base_configuration.py index ed66bf3..616e298 100644 --- a/trainax/configuration/_base_configuration.py +++ b/trainax/configuration/_base_configuration.py @@ -17,18 +17,20 @@ def __call__( """ Evaluate the configuration on the given data. - Args: - stepper (eqx.Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. This - depends on the concrete configuration. In the most reduced case, - it just contains the set of initial states. - ref_stepper (eqx.Module): The reference stepper to use for some - configurations. (keyword-only argument) - residuum_fn (eqx.Module): The residuum function to use for some - configurations. (keyword-only argument) + **Arguments:** - Returns: - float: The loss value computed by this configuration. + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. This + depends on the concrete configuration. In the most reduced case, it + just contains the set of initial states. + - `ref_stepper`: The reference stepper to use for some + configurations. (keyword-only argument) + - `residuum_fn`: The residuum function to use for some + configurations. (keyword-only argument) + + **Returns:** + + - The loss value computed by this configuration. """ pass diff --git a/trainax/configuration/_composite.py b/trainax/configuration/_composite.py index 892ccdd..99de534 100644 --- a/trainax/configuration/_composite.py +++ b/trainax/configuration/_composite.py @@ -15,11 +15,10 @@ def __init__( """ Compose configurations with respective weights. - Args: - configurations (list[BaseConfiguration]): The list of configurations - to compose. - weights (list[float]): The list of weights to apply to the - configurations. + **Arguments:** + + - `configurations`: The list of configurations to compose. + - `weights`: The list of weights to apply to the configurations. """ self.configurations = configurations self.weights = weights @@ -38,20 +37,21 @@ def __call__( Based on the underlying configurations, `ref_stepper` or `residuum_fn` or both have to be supplied (as keyword-only arguments). - Args: - stepper (Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. This - depends on the concrete configuration. In the most reduced case, - it just contains the set of initial states. - ref_stepper (Module): The reference stepper to use for some - configurations. Defaults to None. - residuum_fn (Module): The residuum function to use for some - configurations. Defaults to None. - - Returns: - float: The loss value computed by all configurations combined and - weighted. + **Arguments:** + + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. This + depends on the concrete configuration. In the most reduced case, it + just contains the set of initial states. + - `ref_stepper`: The reference stepper to use for some + configurations. + - `residuum_fn`: The residuum function to use for some + configurations. + + **Returns:** + + - The loss value computed by all configurations combined and weighted. """ loss = sum( weight diff --git a/trainax/configuration/_diverted_chain.py b/trainax/configuration/_diverted_chain.py index 9e3b57e..1bd3faf 100644 --- a/trainax/configuration/_diverted_chain.py +++ b/trainax/configuration/_diverted_chain.py @@ -39,38 +39,40 @@ def __init__( `num_branch_steps=num_rollout_steps` and the `DivertedChainBranchOne` configuration as special case of `num_branch_steps=1`. - Args: - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. Defaults to 1. - num_branch_steps (int): The number of time steps to branch off the - main chain. Must be less than or equal to `num_rollout_steps`. - Defaults to 1. - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). - cut_div_chain (bool): Whether to cut the diverted chain, i.e., - insert a `jax.lax.stop_gradient` to not have cotangents flow - over the `ref_stepper`. In this case, the `ref_stepper` does not - have to be differentiable. Defaults to False. - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). - branch_level_weights (array[float], optional): An array of length - `num_branch_steps` that contains the weights for each branch - step. Defaults to None, which means that all branch steps have - the same weight (=1.0). - - Raises: - ValueError: If `num_branch_steps` is greater than - `num_rollout_steps`. - - Info: + **Arguments:** + + - `num_rollout_steps`: The number of time steps to + autoregressively roll out the model. Defaults to 1. + - `num_branch_steps`: The number of time steps to branch off the + main chain. Must be less than or equal to `num_rollout_steps`. + Defaults to 1. + - `time_level_loss`: The loss function to use at + each time step. Defaults to MSELoss(). + - `cut_bptt`: Whether to cut the backpropagation through time + (BPTT), i.e., insert a `jax.lax.stop_gradient` into the + autoregressive network main chain. Defaults to False. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after + each step). + - `cut_div_chain`: Whether to cut the diverted chain, i.e., + insert a `jax.lax.stop_gradient` to not have cotangents flow over + the `ref_stepper`. In this case, the `ref_stepper` does not have to + be differentiable. Defaults to False. + - `time_level_weights`: An array of length + `num_rollout_steps` that contains the weights for each time step. + Defaults to None, which means that all time steps have the same + weight (=1.0). + - `branch_level_weights`: An array of length + `num_branch_steps` that contains the weights for each branch step. + Defaults to None, which means that all branch steps have the same + weight (=1.0). + + **Raises:** + + - ValueError: If `num_branch_steps` is greater than + `num_rollout_steps`. + + !!! info * The `ref_stepper` is called on-the-fly. If its forward (and vjp) evaluation is expensive, this will dominate the computational cost of this configuration. @@ -114,20 +116,21 @@ def __call__( The data only has to contain one time level, the initial condition. - Args: - stepper (eqx.Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. This - depends on the concrete configuration. In this case, it only - has to contain the set of initial states. - ref_stepper (eqx.Module): The reference stepper to use for the - diverted chain. This is called on-the-fly. (keyword-only - argument) - residuum_fn (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - - Returns: - float: The loss value computed by this configuration. + **Arguments:** + + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. This + depends on the concrete configuration. In this case, it only + has to contain the set of initial states. + - `ref_stepper`: The reference stepper to use for the + diverted chain. This is called on-the-fly. + - `residuum_fn`: For compatibility with other + configurations; not used. + + **Returns:** + + - The loss value computed by this configuration. """ # Data is supposed to contain the initial condition, trj is not used ic, _ = extract_ic_and_trj(data) diff --git a/trainax/configuration/_diverted_chain_branch_one.py b/trainax/configuration/_diverted_chain_branch_one.py index b4c8013..ea7ee52 100644 --- a/trainax/configuration/_diverted_chain_branch_one.py +++ b/trainax/configuration/_diverted_chain_branch_one.py @@ -36,27 +36,28 @@ def __init__( classical one-step supervised training for `num_rollout_steps=1` (default). - Args: - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model.Defaults to 1. - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). - cut_div_chain (bool): Whether to cut the diverted chain, i.e., - insert a `jax.lax.stop_gradient` to not have cotangents flow - over the `ref_stepper`. In this case, the `ref_stepper` does not - have to be differentiable. Defaults to False. - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). - - Info: + **Arguments:** + + - `num_rollout_steps`: The number of time steps to + autoregressively roll out the model.Defaults to 1. + - `time_level_loss`: The loss function to use at + each time step. Defaults to MSELoss(). + - `cut_bptt`: Whether to cut the backpropagation through time + (BPTT), i.e., insert a `jax.lax.stop_gradient` into the + autoregressive network main chain. Defaults to False. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning + after each step). + - `cut_div_chain`: Whether to cut the diverted chain, i.e., + insert a `jax.lax.stop_gradient` to not have cotangents flow + over the `ref_stepper`. In this case, the `ref_stepper` does not + have to be differentiable. + - `time_level_weights`: An array of length + `num_rollout_steps` that contains the weights for each time + step. Defaults to None, which means that all time steps have the + same weight (=1.0). + + !!! info * The `ref_stepper` is called on-the-fly. If its forward (and vjp) execution are expensive, this will dominate the computational cost of this configuration. @@ -89,20 +90,21 @@ def __call__( The data only has to contain one time level, the initial condition. - Args: - stepper (eqx.Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. This - depends on the concrete configuration. In this case, it only - contains the set of initial states. - ref_stepper (eqx.Module): The reference stepper to use for the - diverted chain. This is called on-the-fly. (keyword-only - argument) - residuum_fn (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - - Returns: - float: The loss value computed by this configuration. + **Arguments:** + + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. This + depends on the concrete configuration. In this case, it only + contains the set of initial states. + - `ref_stepper`: The reference stepper to use for the + diverted chain. This is called on-the-fly. + - `residuum_fn`: For compatibility with other + configurations; not used. + + **Returns:** + + - The loss value computed by this configuration. """ # Data is supposed to contain the initial condition, trj is not used ic, _ = extract_ic_and_trj(data) diff --git a/trainax/configuration/_mix_chain_post_physics.py b/trainax/configuration/_mix_chain_post_physics.py index 91f2bb8..0117c38 100644 --- a/trainax/configuration/_mix_chain_post_physics.py +++ b/trainax/configuration/_mix_chain_post_physics.py @@ -34,30 +34,30 @@ def __init__( Mix chain (rollout) configuration with autoregressive physics steps after the autoregressive emulator steps in the main chain. - THIS IS A SPECIAL CASE TODO... - - Args: - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. Defaults to 1. - num_post_physics_steps (int): The number of time steps to - autoregressively roll physics **after** the model in the main - chain. Defaults to 1. Hence, in the default config, the main - chain is model -> physics - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). (keyword-only argument) - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. This excludes the - post-physics steps; those are not cutted. Defaults to False. - (keyword-only argument) - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). (keyword-only argument) - time_level_weights (array[float], optional): An array of length - `num_rollout_steps+num_post_physics_steps` that contains the - weights for each time step. Defaults to None, which means that - all time steps have the same weight (=1.0). (keyword-only - argument) + This is a special case of potentially more complicated combitations of + neural stepper with reference physics stepper in the main chain. + + **Arguments:** + + - `num_rollout_steps`: The number of time steps to + autoregressively roll out the model. Defaults to 1. + - `num_post_physics_steps`: The number of time steps to + autoregressively roll physics **after** the model in the main chain. + Defaults to 1. Hence, in the default config, the main chain is model + -> physics + - `time_level_loss`: The loss function to use at + each time step. Defaults to `trainax.loss.MSELoss`. + - `cut_bptt`: Whether to cut the backpropagation through time + (BPTT), i.e., insert a `jax.lax.stop_gradient` into the + autoregressive network main chain. This excludes the post-physics + steps; those are not cutted. Defaults to False. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after + each step). + - `time_level_weights`: An array of length + `num_rollout_steps+num_post_physics_steps` that contains the weights + for each time step. Defaults to None, which means that all time + steps have the same weight (=1.0). """ self.num_rollout_steps = num_rollout_steps self.num_post_physics_steps = num_post_physics_steps @@ -85,23 +85,26 @@ def __call__( The data only has to contain as many time levels as the sum of the number of rollout steps and post physics steps plus one. - Args: - stepper (eqx.Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. Has to - contain the initial condition and the target trajectory. - ref_stepper (eqx.Module): The reference stepper to use for the - configuration. Must have the signature `ref_stepper(u_prev: - PyTree) -> u_next: PyTree`. (keyword-only argument) - residuum_fn (eqx.Module): For compatibility with other - configurations; not used here. (keyword-only argument) - - Returns: - float: The loss value computed by this configuration. - - Raises: - ValueError: If the number of snapshots in the trajectory is less than - the number of rollout steps and post physics steps plus one. + **Arguments:** + + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. Has to + contain the initial condition and the target trajectory. + - `ref_stepper`: The reference stepper to use for the + configuration. Must have the signature `ref_stepper(u_prev: PyTree) + -> u_next: PyTree`. + - `residuum_fn`: For compatibility with other + configurations; not used here. + + **Returns:** + + - The loss value computed by this configuration. + + **Raises:** + + - ValueError: If the number of snapshots in the trajectory is less than + the number of rollout steps and post physics steps plus one. """ # Data is supposed to contain both the initial condition and the target ic, trj = extract_ic_and_trj(data) diff --git a/trainax/configuration/_residuum.py b/trainax/configuration/_residuum.py index 5546c65..935724b 100644 --- a/trainax/configuration/_residuum.py +++ b/trainax/configuration/_residuum.py @@ -40,26 +40,26 @@ def __init__( different optimization trajectories (and different local optima) because the residuum-based loss is conditioned worse. - Args: - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. Defaults to 1. - time_level_loss (BaseLoss): The loss function to use at - each time step. Must operate based on a single input. Defaults - to MSELoss(). - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). - cut_prev (bool): Whether to cut the previous time level contribution - to `residuum_fn`. Defaults to False. - cut_next (bool): Whether to cut the next time level contribution - to `residuum_fn`. Defaults to False. - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). + **Arguments:** + + - `num_rollout_steps`: The number of time steps to + autoregressively roll out the model. Defaults to 1. + - `time_level_loss`: The loss function to use at + each time step. Must operate based on a single input. Defaults to + MSELoss(). + - `cut_bptt`: Whether to cut the backpropagation through time + (BPTT), i.e., insert a `jax.lax.stop_gradient` into the + autoregressive network main chain. Defaults to False. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after + each step). + - `cut_prev`: Whether to cut the previous time level contribution + to `residuum_fn`. Defaults to False. + - `cut_next`: Whether to cut the next time level contribution + to `residuum_fn`. Defaults to False. + - `time_level_weights`: An array of length `num_rollout_steps` that + contains the weights for each time step. Defaults to None, which + means that all time steps have the same weight (=1.0). """ self.num_rollout_steps = num_rollout_steps self.time_level_loss = time_level_loss @@ -78,7 +78,7 @@ def __call__( data: PyTree[Float[Array, "batch num_snapshots ..."]], *, ref_stepper: eqx.Module = None, # unused - residuum_fn: eqx.Module, # unused + residuum_fn: eqx.Module, ) -> float: """ Evaluate the residuum (rollout) configuration on the given data. @@ -87,21 +87,23 @@ def __call__( `residuum_fn` will be used to compute a loss based on two consecutive time levels. - Args: - stepper (eqx.Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. This - depends on the concrete configuration. In this case, it only - contains the initial condition. - ref_stepper (eqx.Module): The reference stepper to use for the - configuration. Must have the signature - `ref_stepper(u_prev: PyTree) -> u_next: PyTree`. Defaults to None. - residuum_fn (eqx.Module): The residuum function to use for the - configuration. Must have the signature - `residuum_fn(u_next: PyTree, u_prev: PyTree) -> residuum: PyTree`. - - Returns: - float: The loss of the configuration. + **Arguments:** + + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. This + depends on the concrete configuration. In this case, it only + contains the initial condition. + - `ref_stepper`: The reference stepper to use for the + configuration. Must have the signature `ref_stepper(u_prev: PyTree) + -> u_next: PyTree`. Defaults to None. + - `residuum_fn`: The residuum function to use for the + configuration. Must have the signature `residuum_fn(u_next: PyTree, + u_prev: PyTree) -> residuum: PyTree`. + + **Returns:** + + - The loss of the configuration. """ # Data is supposed to contain the initial condition, trj is not used ic, _ = extract_ic_and_trj(data) diff --git a/trainax/configuration/_supervised.py b/trainax/configuration/_supervised.py index 49498f9..477b9e2 100644 --- a/trainax/configuration/_supervised.py +++ b/trainax/configuration/_supervised.py @@ -32,27 +32,29 @@ def __init__( Falls back to classical one-step supervised training for `num_rollout_steps=1` (default). - Args: - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. During calling this - configuration, it requires a similarly long reference trajectory - to be available. Defaults to 1. - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). - - Info: - * Under reverse-mode automatic differentiation memory usage grows - linearly with `num_rollout_steps`. + **Arguments:** + + - `num_rollout_steps`: The number of time steps to + autoregressively roll out the model. During calling this + configuration, it requires a similarly long reference trajectory to + be available. Defaults to 1. + - `time_level_loss`: The loss function to use at + each time step. Defaults to MSELoss(). + - `cut_bptt`: Whether to cut the backpropagation through time + (BPTT), i.e., insert a `jax.lax.stop_gradient` into the + autoregressive network main chain. Defaults to False. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after + each step). + - `time_level_weights`: An array of length + `num_rollout_steps` that contains the weights for each time step. + Defaults to None, which means that all time steps have the same + weight (=1.0). + + + !!! warning + Under reverse-mode automatic differentiation memory usage grows + linearly with `num_rollout_steps`. """ self.num_rollout_steps = num_rollout_steps self.time_level_loss = time_level_loss @@ -77,22 +79,25 @@ def __call__( The data is supposed to have as many time steps as the number of rollout steps plus one. No `ref_stepper` or `residuum_fn` is needed. - Args: - stepper (eqx.Module): The stepper to use for the configuration. Must - have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. - data (PyTree): The data to evaluate the configuration on. This - should contain the initial condition and the target trajectory. - ref_stepper (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) residuum_fn - residuum_fn (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - - Returns: - float: The loss value computed by this configuration. - - Raises: - ValueError: If the number of snapshots in the trajectory is less than - the number of rollout steps plus one. + **Arguments:** + + - `stepper`: The stepper to use for the configuration. Must + have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`. + - `data`: The data to evaluate the configuration on. This + should contain the initial condition and the target trajectory. + - `ref_stepper`: For compatibility with other + configurations; not used. + - `residuum_fn`: For compatibility with other + configurations; not used. + + **Returns:** + + - The loss value computed by this configuration. + + **Raises:** + + - ValueError: If the number of snapshots in the trajectory is less than + the number of rollout steps plus one. """ # Data is supposed to contain both the initial condition and the target ic, trj = extract_ic_and_trj(data) diff --git a/trainax/loss/_base_loss.py b/trainax/loss/_base_loss.py index 7eb0e26..0e94702 100644 --- a/trainax/loss/_base_loss.py +++ b/trainax/loss/_base_loss.py @@ -11,6 +11,7 @@ class BaseLoss(eqx.Module, ABC): batch_reduction: Callable def __init__(self, *, batch_reduction: Callable = jnp.mean): + """Base class for loss functions.""" self.batch_reduction = batch_reduction @abstractmethod @@ -19,6 +20,27 @@ def single_batch( prediction: Float[Array, "num_channels ..."], target: Optional[Float[Array, "num_channels ..."]] = None, ) -> float: + """ + Evaluate the loss for a single sample. + + Inputs must be PyTrees of identical structure with array leafs having at + least a channel/feature axis, and optionally one or more subsequent axes + (e.g., spatial axes). There should be **no batch axis**. + + !!! info + + To operate on a batch of inputs, either use `multi_batch` or use + `jax.vmap` on this method. + + **Arguments:** + + - `prediction`: The predicted values. + - `target`: The target values. + + **Returns:** + + - The loss value. + """ pass def multi_batch( @@ -26,6 +48,24 @@ def multi_batch( prediction: Float[Array, "num_batches num_channels ..."], target: Optional[Float[Array, "num_batches num_channels ..."]] = None, ) -> float: + """ + Evaluate the loss for a batch of samples. + + Inputs must be PyTrees of identical structure with array leafs having a + leading batch axis, a subsequent channel/feature axis, and optionally one + or more subsequent axes (e.g., spatial axes). + + Uses the batch aggregator function specified during initialization. + + **Arguments:** + + - `prediction`: The predicted values. + - `target`: The target values. + + **Returns:** + + - The loss value. + """ if target is None: return self.batch_reduction( jax.vmap( @@ -46,4 +86,22 @@ def __call__( prediction: Float[Array, "num_batches num_channels ..."], target: Optional[Float[Array, "num_batches num_channels ..."]] = None, ) -> float: + """ + Evaluate the loss for a batch of samples. + + Inputs must be PyTrees of identical structure with array leafs having a + leading batch axis, a subsequent channel/feature axis, and optionally one + or more subsequent axes (e.g., spatial axes). + + Uses the batch aggregator function specified during initialization. + + **Arguments:** + + - `prediction`: The predicted values. + - `target`: The target values. + + **Returns:** + + - The loss value. + """ return self.multi_batch(prediction, target) diff --git a/trainax/trainer/_diverted_chain_branch_one.py b/trainax/trainer/_diverted_chain_branch_one.py index 62f9d00..62bc272 100644 --- a/trainax/trainer/_diverted_chain_branch_one.py +++ b/trainax/trainer/_diverted_chain_branch_one.py @@ -1,6 +1,7 @@ from typing import Optional import equinox as eqx +import optax from jaxtyping import Array, Float from .._general_trainer import GeneralTrainer @@ -17,7 +18,7 @@ def __init__( *, ref_stepper: eqx.Module, residuum_fn: eqx.Module = None, # for compatibility - optimizer, + optimizer: optax.GradientTransformation, callback_fn: Optional[BaseCallback] = None, num_training_steps: int, batch_size: int, @@ -35,63 +36,54 @@ def __init__( Diverted chain (rollout) configuration with branch length fixed to one. Essentially, this amounts to a one-step difference to a reference - (create on the fly by the differentiable `ref_stepper`). Falls back to + (created on the fly by the differentiable `ref_stepper`). Falls back to classical one-step supervised training for `num_rollout_steps=1` (default). - Args: - data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]): - The batch of trajectories to slice. This must be a PyTree of - Arrays who have at least two leading axes: a batch-axis and a - time axis. For example, the zeroth axis can be associated with - multiple initial conditions or constitutive parameters and the - first axis represents all temporal snapshots. A PyTree can also - just be an array. You can provide additional leafs in the - PyTree, e.g., for the corresponding constitutive parameters etc. - Make sure that the emulator has the corresponding signature. - ref_stepper (eqx.Module): The reference stepper to use for the - diverted chain. This is called on-the-fly. (keyword-only - argument) - residuum_fn (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - optimizer (optax.GradientTransformation): The optimizer to use for - training. For example, this can be `optax.adam(LEARNING_RATE)`. - Also use this to supply an optimizer with learning rate decay, - for example `optax.adam(optax.exponential_decay(...))`. If your - learning rate decay is designed for a certain number of update - steps, make sure that it aligns with `num_training_steps`. - (keyword-only argument) - callback_fn (BaseCallback, optional): A callback to use during - training. Defaults to None. (keyword-only argument) - num_training_steps (int): The number of training steps to perform. - (keyword-only argument) - batch_size (int): The batch size to use for training. Batches are - randomly sampled across both multiple trajectories, but also over - different windows within one trajectory. (keyword-only) - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. Defaults to 1. (keyword-only - argument) - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). (keyword-only argument) - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - (keyword-only argument) - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). (keyword-only argument) - cut_div_chain (bool): Whether to cut the diverted chain, i.e., - insert a `jax.lax.stop_gradient` to not have cotangents flow - over the `ref_stepper`. In this case, the `ref_stepper` does not - have to be differentiable. Defaults to False. (keyword-only - argument) - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). (keyword-only argument) + **Arguments:** + - `data_trajectories`: The batch of trajectories to slice. This must be + a PyTree of Arrays who have at least two leading axes: a batch-axis + and a time axis. For example, the zeroth axis can be associated with + multiple initial conditions or constitutive parameters and the first + axis represents all temporal snapshots. A PyTree can also just be an + array. You can provide additional leafs in the PyTree, e.g., for the + corresponding constitutive parameters etc. Make sure that the + emulator has the corresponding signature. + - `ref_stepper`: The reference stepper to use for the diverted chain. + This is called on-the-fly. + - `residuum_fn`: For compatibility with other configurations; not used. + - `optimizer`: The optimizer to use for training. For example, this can + be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer + with learning rate decay, for example + `optax.adam(optax.exponential_decay(...))`. If your learning rate + decay is designed for a certain number of update steps, make sure + that it aligns with `num_training_steps`. + - `callback_fn`: A callback to use during training. Defaults to None. + - `num_training_steps`: The number of training steps to perform. + - `batch_size`: The batch size to use for training. Batches are + randomly sampled across both multiple trajectories, but also over + different windows within one trajectory. + - `num_rollout_steps: The number of time steps to autoregressively + roll out the model. + - `time_level_loss`: The loss function to use at each time step. + - `cut_bptt`: Whether to cut the backpropagation through time (BPTT), + i.e., insert a `jax.lax.stop_gradient` into the autoregressive + network main chain. + - `cut_bptt_every`: The frequency at which to cut the BPTT. Only + relevant if `cut_bptt` is True. Defaults to 1 (meaning after each + step). + - `cut_div_chain`: Whether to cut the diverted chain, i.e., + insert a `jax.lax.stop_gradient` to not have cotangents flow over + the `ref_stepper`. In this case, the `ref_stepper` does not have to + be differentiable. + - `time_level_weights`: An array of length `num_rollout_steps` that + contains the weights for each time step. Defaults to None, which + means that all time steps have the same weight (=1.0). (keyword-only + argument) - Info: + + !!! info * The `ref_stepper` is called on-the-fly. If its forward (and vjp) execution are expensive, this will dominate the computational cost of this configuration. diff --git a/trainax/trainer/_residuum.py b/trainax/trainer/_residuum.py index 060ed29..b5595b9 100644 --- a/trainax/trainer/_residuum.py +++ b/trainax/trainer/_residuum.py @@ -43,55 +43,51 @@ def __init__( different optimization trajectories (and different local optima) because the residuum-based loss is conditioned worse. - Args: - data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]): - The batch of trajectories to slice. This must be a PyTree of - Arrays who have at least two leading axes: a batch-axis and a - time axis. For example, the zeroth axis can be associated with - multiple initial conditions or constitutive parameters and the - first axis represents all temporal snapshots. A PyTree can also - just be an array. You can provide additional leafs in the - PyTree, e.g., for the corresponding constitutive parameters etc. - Make sure that the emulator has the corresponding signature. - ref_stepper (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - residuum_fn (eqx.Module): The residuum function to use for the - configuration. Must have the signature - `residuum_fn(u_next: PyTree, u_prev: PyTree) -> residuum: PyTree`. - (keyword-only argument) - optimizer (optax.GradientTransformation): The optimizer to use for - training. For example, this can be `optax.adam(LEARNING_RATE)`. - Also use this to supply an optimizer with learning rate decay, - for example `optax.adam(optax.exponential_decay(...))`. If your - learning rate decay is designed for a certain number of update - steps, make sure that it aligns with `num_training_steps`. - (keyword-only argument) - callback_fn (BaseCallback, optional): A callback to use during - training. Defaults to None. (keyword-only argument) - num_training_steps (int): The number of training steps to perform. - (keyword-only argument) - batch_size (int): The batch size to use for training. Batches are - randomly sampled across both multiple trajectories, but also over - different windows within one trajectory. (keyword-only) - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. (keyword-only argument) - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). (keyword-only argument) - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - (keyword-only argument) - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). (keyword-only argument) - cut_prev (bool): Whether to cut the previous time level contribution - to `residuum_fn`. Defaults to False. - cut_next (bool): Whether to cut the next time level contribution - to `residuum_fn`. Defaults to False. - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). + **Arguments:** + + - `data_trajectories`: The batch of trajectories to slice. This must be + a PyTree of Arrays who have at least two leading axes: a batch-axis + and a time axis. For example, the zeroth axis can be associated with + multiple initial conditions or constitutive parameters and the first + axis represents all temporal snapshots. A PyTree can also just be an + array. You can provide additional leafs in the PyTree, e.g., for the + corresponding constitutive parameters etc. Make sure that the + emulator has the corresponding signature. + - `ref_stepper`: For compatibility with other configurations; not used. + - `residuum_fn`: The residuum function to use for the configuration. + Must have the signature `residuum_fn(u_next: PyTree, u_prev: PyTree) + -> residuum: PyTree`. + - `optimizer`: The optimizer to use for training. For example, this can + be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer + with learning rate decay, for example + `optax.adam(optax.exponential_decay(...))`. If your learning rate + decay is designed for a certain number of update steps, make sure + that it aligns with `num_training_steps`. + - `callback_fn`: A callback to use during training. Defaults to None. + - `num_training_steps`: The number of training steps to perform. + - `batch_size`: The batch size to use for training. Batches are + randomly sampled across both multiple trajectories, but also over + different windows within one trajectory. + - `num_rollout_steps`: The number of time steps to autoregressively roll + out the model during training. + - `time_level_loss`: The loss function to use at each time step. + - `cut_bptt`: Whether to cut the backpropagation through time (BPTT), + i.e., insert a `jax.lax.stop_gradient` into the autoregressive + network main chain. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after + each step). + - `cut_prev`: Whether to cut the previous time level contribution + to `residuum_fn`. + - `cut_next`: Whether to cut the next time level contribution + to `residuum_fn`. + - `time_level_weights: An array of length `num_rollout_steps` that + contains the weights for each time step. Defaults to None, which + means that all time steps have the same weight (=1.0). + + !!! info + * Under reverse-mode automatic differentiation memory usage grows + linearly with `num_rollout_steps`. """ trajectory_sub_stacker = TrajectorySubStacker( data_trajectories, diff --git a/trainax/trainer/_supervised.py b/trainax/trainer/_supervised.py index c7381a3..4c2c090 100644 --- a/trainax/trainer/_supervised.py +++ b/trainax/trainer/_supervised.py @@ -42,56 +42,49 @@ def __init__( long as `num_rollout_steps + 1` (the additional step is for the initial condition). - Args: - data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]): - The batch of trajectories to slice. This must be a PyTree of - Arrays who have at least two leading axes: a batch-axis and a - time axis. For example, the zeroth axis can be associated with - multiple initial conditions or constitutive parameters and the - first axis represents all temporal snapshots. A PyTree can also - just be an array. You can provide additional leafs in the - PyTree, e.g., for the corresponding constitutive parameters etc. - Make sure that the emulator has the corresponding signature. - ref_stepper (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - residuum_fn (eqx.Module): For compatibility with other - configurations; not used. (keyword-only argument) - optimizer (optax.GradientTransformation): The optimizer to use for - training. For example, this can be `optax.adam(LEARNING_RATE)`. - Also use this to supply an optimizer with learning rate decay, - for example `optax.adam(optax.exponential_decay(...))`. If your - learning rate decay is designed for a certain number of update - steps, make sure that it aligns with `num_training_steps`. - (keyword-only argument) - callback_fn (BaseCallback, optional): A callback to use during - training. Defaults to None. (keyword-only argument) - num_training_steps (int): The number of training steps to perform. - (keyword-only argument) - batch_size (int): The batch size to use for training. Batches are - randomly sampled across both multiple trajectories, but also over - different windows within one trajectory. (keyword-only) - num_rollout_steps (int): The number of time steps to - autoregressively roll out the model. Defaults to 1 (which is one-step - supervised training). (keyword-only argument) - time_level_loss (BaseLoss): The loss function to use at - each time step. Defaults to MSELoss(). (keyword-only argument) - cut_bptt (bool): Whether to cut the backpropagation through time - (BPTT), i.e., insert a `jax.lax.stop_gradient` into the - autoregressive network main chain. Defaults to False. - (keyword-only argument) - cut_bptt_every (int): The frequency at which to cut the BPTT. - Only relevant if `cut_bptt` is True. Defaults to 1 (meaning - after each step). (keyword-only argument) - time_level_weights (array[float], optional): An array of length - `num_rollout_steps` that contains the weights for each time - step. Defaults to None, which means that all time steps have the - same weight (=1.0). (keyword-only argument) - do_sub_stacking (bool): Whether to use sub-stacking. If `False`, then - the given reference trajectory will not be sliced into windows - of length `num_rollout_steps + 1`. Defaults to True. - (keyword-only argument) + **Arguments:** - Info: + - `data_trajectories`: The batch of trajectories to slice. This must be + a PyTree of Arrays who have at least two leading axes: a batch-axis + and a time axis. For example, the zeroth axis can be associated with + multiple initial conditions or constitutive parameters and the first + axis represents all temporal snapshots. A PyTree can also just be an + array. You can provide additional leafs in the PyTree, e.g., for the + corresponding constitutive parameters etc. Make sure that the + emulator has the corresponding signature. + - `ref_stepper`: For compatibility with other configurations; not used. + - `residuum_fn`: For compatibility with other configurations; not used. + - `optimizer`: The optimizer to use for training. For example, this can + be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer + with learning rate decay, for example + `optax.adam(optax.exponential_decay(...))`. If your learning rate + decay is designed for a certain number of update steps, make sure + that it aligns with `num_training_steps`. + - `callback_fn`: A callback to use during training. Defaults to None. + - `num_training_steps`: The number of training steps to perform. + - `batch_size: The batch size to use for training. Batches are + randomly sampled across both multiple trajectories, but also over + different windows within one trajectory. + - `num_rollout_steps`: The number of time steps to autoregressively roll + out the model. Defaults to 1 (which is one-step supervised + training). + - `time_level_loss`: The loss function to use at each time step. + Defaults to MSELoss(). + - `cut_bptt`: Whether to cut the backpropagation through time + (BPTT), i.e., insert a `jax.lax.stop_gradient` into the + autoregressive network main chain. + - `cut_bptt_every`: The frequency at which to cut the BPTT. + Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after + each step). + - `time_level_weights`: An array of length `num_rollout_steps` that + contains the weights for each time step. Defaults to None, which + means that all time steps have the same weight (=1.0). (keyword-only + argument) + - `do_sub_stacking`: Whether to use sub-stacking. If `False`, then the + given reference trajectory will not be sliced into windows of length + `num_rollout_steps + 1`. + + !!! info * Under reverse-mode automatic differentiation memory usage grows linearly with `num_rollout_steps`. """