Skip to content

Commit

Permalink
Call configure_model() in LM.load_from_checkpoint() (#19036)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 21, 2023
1 parent aebac09 commit 49caddd
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `Trainer.fit()` loop no longer calls `LightningModule.train()` at the start; it now preserves the user's configuration of frozen layers ([#18951](https://github.com/Lightning-AI/lightning/pull/18951))


- The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036))


### Deprecated

- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def configure_model(self) -> None:
:meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls
to it should be a no-op.
"""

Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,13 @@ def load_from_checkpoint(
**class** to call it instead of the :class:`LightningModule` instance, or a
``TypeError`` will be raised.
Note:
To ensure all layers can be loaded from the checkpoint, this function will call
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the
model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does
not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this
case, consider loading through the Trainer via ``.fit(ckpt_path=...)``.
Example::
# load weights without mapping ...
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.migration import pl_legacy_patch
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.parsing import parse_class_init_keys
from lightning.pytorch.utilities.rank_zero import rank_zero_warn

Expand Down Expand Up @@ -157,6 +158,9 @@ def _load_state(
obj = cls(**_cls_kwargs)

if isinstance(obj, pl.LightningModule):
if is_overridden("configure_model", obj):
obj.configure_model()

# give model a chance to load something
obj.on_load_checkpoint(checkpoint)

Expand Down
20 changes: 16 additions & 4 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,13 +877,22 @@ def test_trainer_datamodule_hook_system(tmpdir):
assert called == expected


def test_load_from_checkpoint_hook_calls(tmpdir):
@pytest.mark.parametrize("override_configure_model", [True, False])
def test_load_from_checkpoint_hook_calls(override_configure_model, tmpdir):
class CustomHookedDataModule(HookedDataModule):
def state_dict(self):
return {"foo": "bar"}

class CustomHookedModel(HookedModel):
pass

if not override_configure_model:
CustomHookedModel.configure_model = None

lm_called, ldm_called = [], []
model = HookedModel(lm_called)
model = CustomHookedModel(lm_called)
assert is_overridden("configure_model", model) == override_configure_model

datamodule = CustomHookedDataModule(ldm_called)
trainer = Trainer()
trainer.strategy.connect(model)
Expand All @@ -908,9 +917,12 @@ def state_dict(self):
assert ldm_called == [{"name": "state_dict"}]

lm_called, ldm_called = [], []
_ = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
_ = CustomHookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
_ = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
assert lm_called == [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}]

expected_lm_called = [{"name": "configure_model"}] if override_configure_model else []
expected_lm_called += [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}]
assert lm_called == expected_lm_called
assert ldm_called == [{"name": "load_state_dict", "args": (saved_ckpt[datamodule_state_dict_key],)}]


Expand Down

0 comments on commit 49caddd

Please sign in to comment.