Skip to content

Commit

Permalink
Fix ModelCheckpoint alternating between versioned and unversioned file (
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 27, 2023
1 parent e30401a commit 482da0a
Showing 3 changed files with 26 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863))


- Fixed an edge case where `ModelCheckpoint` would alternate between versioned and unversioned filename ([#19064](https://github.com/Lightning-AI/lightning/pull/19064))


## [2.1.2] - 2023-11-15

4 changes: 2 additions & 2 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -703,7 +703,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Di
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")

def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
self._save_checkpoint(trainer, filepath)
@@ -773,7 +773,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
"""Checks if the previous checkpoint should be deleted.
A checkpoint won't be deleted if any of the cases apply:
- The previous checkpoint is the same as the current checkpoint
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
22 changes: 22 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -583,6 +583,28 @@ def test_none_monitor_top_k(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_top_k=1)


def test_none_monitor_not_alternating(tmp_path):
"""Regression test for the case where the callback saved alternating `model.ckpt` and `model-v1.ckpt` files."""

class ListDirModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 0:
assert os.listdir(tmp_path) == ["model.ckpt"]

model = ListDirModel()
model_checkpoint = ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1, filename="model")
trainer = Trainer(
callbacks=model_checkpoint,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=3,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
)
trainer.fit(model)


def test_invalid_every_n_epochs(tmpdir):
"""Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument."""
with pytest.raises(MisconfigurationException, match=r".*Must be >= 0"):

0 comments on commit 482da0a

Please sign in to comment.