Skip to content

Commit

Permalink
The user can specify after which steps or epochs the average model is…
Browse files Browse the repository at this point in the history
… updated
  • Loading branch information
Seppo Enarvi committed Jan 15, 2025
1 parent 11423fd commit efc77dc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 13 deletions.
54 changes: 51 additions & 3 deletions src/lightning/pytorch/callbacks/weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,22 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT


def _return_true(x: int) -> bool:
return True


def _return_false(x: int) -> bool:
return False


class WeightAveraging(Callback):
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
(EMA) after each training step.
The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
model should be updated. If neither function is provided, the average model will be updated after every optimizer
step.
During validation and after the training finishes, the current model parameters will be replaced with the averaged
values.
Expand All @@ -43,22 +55,39 @@ class WeightAveraging(Callback):
avg_fn: The averaging function used to update the parameters. The function must take in an
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
``None``, an equally weighted average will be used.
update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
model should be updated.
update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
should be updated.
"""

def __init__(
self,
device: torch.device | str | None = torch.device("cpu"),
device: torch.device | int | None = torch.device("cpu"),
avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None,
update_on_step: Callable[[int], bool] | None = None,
update_on_epoch: Callable[[int], bool] | None = None,
):
self._device = device
self._avg_fn = avg_fn

if (update_on_step is None) and (update_on_epoch is None):
self._update_on_step: Callable[[int], bool] = _return_true
self._update_on_epoch: Callable[[int], bool] = _return_false
else:
self._update_on_step = _return_false if update_on_step is None else update_on_step
self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch

self._average_model: AveragedModel | None = None

# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
# that the average model will be first updated after the first optimizer step, which takes place after N batches
# when using accumulate_grad_batches=N.
self._latest_update_step = 0
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
# negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
self._latest_update_epoch = -1

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Called when fit, validate, test, predict, or tune begins.
Expand All @@ -80,7 +109,7 @@ def on_train_batch_end(
) -> None:
"""Called when a training batch ends.
Updates the :class:`AveragedModel` parameters.
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``.
Args:
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
Expand All @@ -90,11 +119,26 @@ def on_train_batch_end(
batch_idx: Index of the training batch.
"""
if trainer.global_step > self._latest_update_step:
if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step):
assert self._average_model is not None
self._average_model.update_parameters(pl_module)
self._latest_update_step = trainer.global_step

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when a training epoch ends.
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
Args:
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
"""
if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch):
assert self._average_model is not None
self._average_model.update_parameters(pl_module)
self._latest_update_epoch = trainer.current_epoch

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when training ends.
Expand Down Expand Up @@ -173,6 +217,7 @@ def on_save_checkpoint(
checkpoint: The checkpoint dictionary that will be saved.
"""
assert self._average_model is not None
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
average_model_state = self._average_model.state_dict()
checkpoint["current_model_state"] = checkpoint["state_dict"]
Expand All @@ -196,6 +241,7 @@ def on_load_checkpoint(
checkpoint: The full checkpoint dictionary that got loaded by the Trainer.
"""
assert self._average_model is not None
if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.")
average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
Expand All @@ -216,6 +262,7 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
"""
assert self._average_model is not None
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
Expand All @@ -230,6 +277,7 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
"""
assert self._average_model is not None
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
Expand Down
12 changes: 2 additions & 10 deletions tests/tests_pytorch/callbacks/test_weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import AbstractContextManager
from pathlib import Path
from typing import Any, Optional
from unittest import mock

import pytest
import torch
Expand All @@ -25,7 +23,6 @@
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import WeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.strategies import Strategy
from tests_pytorch.helpers.runif import RunIf


Expand Down Expand Up @@ -209,10 +206,9 @@ def _train(
)

if crash_on_epoch is None:
with _backward_patch(trainer):
trainer.fit(model, ckpt_path=checkpoint_path)
trainer.fit(model, ckpt_path=checkpoint_path)
else:
with _backward_patch(trainer), pytest.raises(Exception, match="CRASH TEST"):
with pytest.raises(Exception, match="CRASH TEST"):
trainer.fit(model, ckpt_path=checkpoint_path)

assert trainer.lightning_module == model
Expand All @@ -230,7 +226,3 @@ def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False)
checkpoint_path = str(checkpoint_dir / checkpoint_names[0])

_train(tmp_path, strategy=strategy, devices=devices, checkpoint_path=checkpoint_path)


def _backward_patch(trainer: Trainer) -> AbstractContextManager:
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)

0 comments on commit efc77dc

Please sign in to comment.