Skip to content

Commit

Permalink
Markdown rework (#2)
Browse files Browse the repository at this point in the history
* Add markdown based documentation

* Show all of base loss

* Add constructor doc

* Update base config to markdown

* Format line length

* Rework supervised configuration doc

* Translate residuum config to markdown based syntax

* translate leftover configuration to new docstrings

* Translate supervised trainer

* Update the other two convinience trainers

* Update documentation to mixer components

* Update docs of general trainer

* Display additional methods in documentation

* Allow supplying optstate

* Allow option to deactivate the tqdm progress meter

* Add documentation to all callbacks
  • Loading branch information
Ceyron authored Aug 2, 2024
1 parent 9c9c363 commit c12b4d3
Show file tree
Hide file tree
Showing 22 changed files with 736 additions and 539 deletions.
2 changes: 2 additions & 0 deletions docs/api/general_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
members:
- __init__
- __call__
- full_loss
- step_fn
6 changes: 1 addition & 5 deletions docs/api/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,4 @@

---

::: trainax.loss.BaseLoss
options:
members:
- __init__
- __call__
::: trainax.loss.BaseLoss
165 changes: 99 additions & 66 deletions trainax/_general_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional
from typing import Optional, Union

import equinox as eqx
import jax.numpy as jnp
import optax
from jaxtyping import PRNGKeyArray, PyTree
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from tqdm.autonotebook import tqdm

from ._mixer import PermutationMixer, TrajectorySubStacker
Expand Down Expand Up @@ -34,31 +34,35 @@ def __init__(
callback_fn: Optional[BaseCallback] = None,
):
"""
Abstract training for an autoregressive neural emulator on a collection of
trajectories.
The length of (sub-)trajectories returned by `trajectory_sub_stacker` must
match the requires length of reference for the used `loss_configuration`.
Args:
trajectory_sub_stacker (TrajectorySubStacker): A callable that takes a
list of indices and returns a collection of (sub-)trajectories.
loss_configuration (BaseConfiguration): A configuration that defines the
loss function to be minimized.
ref_stepper (eqx.Module, optional): A reference stepper that is used to
compute the residuum. Supply this if the loss configuration requires
a reference stepper. Defaults to None.
residuum_fn (eqx.Module, optional): A residuum function that computes the
discrete residuum between two consecutive states. Supply this if the
loss configuration requires a residuum function. Defaults to None.
optimizer (optax.GradientTransformation): An optimizer that updates the
parameters of the stepper given the gradient.
num_minibatches (int): The number of minibatches to train on. This equals
the total number of update steps performed. The number of epochs is
determined based on this and the `batch_size`.
batch_size (int): The size of each batch.
callback_fn (BaseCallback, optional): A callback function that is called
at the end of each minibatch. Defaults to None.
Abstract training for an autoregressive neural emulator on a collection
of trajectories.
!!! info
The length of (sub-)trajectories returned by
`trajectory_sub_stacker` must match the required length of reference
for the used `loss_configuration`.
**Arguments:**
- `trajectory_sub_stacker`: A callable that takes a
list of indices and returns a collection of (sub-)trajectories.
- `loss_configuration`: A configuration that defines the
loss function to be minimized.
- `ref_stepper`: A reference stepper that is used to
compute the residuum. Supply this if the loss configuration requires
a reference stepper.
- `residuum_fn`: A residuum function that computes the
discrete residuum between two consecutive states. Supply this if the
loss configuration requires a residuum function. Defaults to None.
- `optimizer`: An optimizer that updates the
parameters of the stepper given the gradient.
- `num_minibatches`: The number of minibatches to train on. This equals
the total number of update steps performed. The number of epochs is
automatically determined based on this and the `batch_size`.
- `batch_size`: The size of each minibatch, i.e., how many samples are
included within.
- `callback_fn`: A callback function that is called
at the end of each minibatch. Defaults to None.
"""
self.trajectory_sub_stacker = trajectory_sub_stacker
self.loss_configuration = loss_configuration
Expand All @@ -75,6 +79,17 @@ def full_loss(
) -> float:
"""
Compute the loss on the entire dataset.
!!! warning
This can lead to out of memory errors if the dataset is too large.
**Arguments:**
- `stepper`: The stepper to compute the loss with.
**Returns:**
- The loss value.
"""
return self.loss_configuration(
stepper,
Expand All @@ -87,19 +102,22 @@ def step_fn(
self,
stepper: eqx.Module,
opt_state: optax.OptState,
data: PyTree,
data: PyTree[float[Array, "batch_size sub_trj_len ..."]],
) -> tuple[eqx.Module, optax.OptState, float]:
"""
Perform a single update step to the `stepper`'s parameters.
Args:
stepper (eqx.Module): The stepper to be updated.
opt_state (optax.OptState): The optimizer state.
data (PyTree): The data for the current minibatch.
**Arguments:**
- `stepper`: The equinox module to be updated.
- `opt_state`: The current optimizer state.
- `data`: The data for the current minibatch.
**Returns:**
Returns:
tuple[eqx.Module, optax.OptState, float]: The updated stepper, the
updated optimizer state, and the loss value.
- The updated equinox module
- The updated optimizer state
- The loss value
"""
loss, grad = eqx.filter_value_and_grad(
lambda m: self.loss_configuration(
Expand All @@ -114,16 +132,20 @@ def __call__(
self,
stepper: eqx.Module,
key: PRNGKeyArray,
opt_state: Optional[optax.OptState] = None,
*,
return_loss_history: bool = True,
record_loss_every: int = 1,
):
spawn_tqdm: bool = True,
) -> Union[
tuple[eqx.Module, Float[Array, "num_minibatches"]],
eqx.Module,
tuple[eqx.Module, Float[Array, "num_minibatches"], list],
tuple[eqx.Module, list],
]:
"""
Perform the entire training of an autoregressive neural emulator
`stepper`.
This method spawns a `tqdm` progress meter showing the current update
step and displaying the epoch with its respetive minibatch counter.
Perform the entire training of an autoregressive neural emulator given
in an initial state as `stepper`.
This method's return signature depends on the presence of a callback
function. If a callback function is provided, this function has at max
Expand All @@ -133,25 +155,32 @@ def __call__(
values of the callback function at each minibatch. If no callback
function is provided, this function has at max two return values. The
first return value is the trained stepper, and the second return value
is the loss history.
is the loss history. If `return_loss_history` is set to `False`, the
loss history will not be returned.
**Arguments:**
- `stepper`: The equinox Module to be trained.
- `key`: The random key to be used for shuffling the minibatches.
- `opt_state`: The initial optimizer state. Defaults to None, meaning
the optimizer will be reinitialized.
- `return_loss_history`: Whether to return the loss history.
- `record_loss_every`: Record the loss every `record_loss_every`
minibatches. Defaults to 1, i.e., record every minibatch.
- `spawn_tqdm`: Whether to spawn the tqdm progress meter showing the
current update step and displaying the epoch with its respetive
minibatch counter.
Args:
stepper (eqx.Module): The stepper to be trained. key (PRNGKeyArray):
The random key to be used for shuffling the
minibatches.
return_loss_history (bool, optional): Whether to return the loss
history. Defaults to True.
record_loss_every (int, optional): Record the loss every
`record_loss_every` minibatches. Defaults to 1.
**Returns:**
Returns:
Varying, see above.
- Varying, see above. It will always return the trained stepper as the
first return value.
Tipp:
!!! tip
You can use `equinox.filter_vmap` to train mulitple networks (of the
same architecture) at the same time. For example, if your GPU
is not fully utilized yet, this will give you a init-seed
statistic basically for free.
same architecture) at the same time. For example, if your GPU is not
fully utilized yet, this will give you a init-seed statistic
basically for free.
"""
loss_history = []
if self.callback_fn is not None:
Expand All @@ -164,15 +193,17 @@ def __call__(
shuffle_key=key,
)

p_meter = tqdm(
total=self.num_minibatches,
desc=f"E: {0:05d}, B: {0:05d}",
)
if spawn_tqdm:
p_meter = tqdm(
total=self.num_minibatches,
desc=f"E: {0:05d}, B: {0:05d}",
)

update_fn = eqx.filter_jit(self.step_fn)

trained_stepper = stepper
opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array))
if opt_state is None:
opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array))

for update_i in range(self.num_minibatches):
batch_indices, (expoch_id, batch_id) = mixer(update_i, return_info=True)
Expand All @@ -185,13 +216,15 @@ def __call__(
)
if update_i % record_loss_every == 0:
loss_history.append(loss)
p_meter.update(1)
if spawn_tqdm:
p_meter.update(1)

p_meter.set_description(
f"E: {expoch_id:05d}, B: {batch_id:05d}",
)
p_meter.set_description(
f"E: {expoch_id:05d}, B: {batch_id:05d}",
)

p_meter.close()
if spawn_tqdm:
p_meter.close()

loss_history = jnp.array(loss_history)

Expand Down
Loading

0 comments on commit c12b4d3

Please sign in to comment.