diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6ab5a2223cd59..d929941f0f154 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863)) +- Fixed an edge case where `ModelCheckpoint` would alternate between versioned and unversioned filename ([#19064](https://github.com/Lightning-AI/lightning/pull/19064)) + ## [2.1.2] - 2023-11-15 diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 1d3ec47e8782e..cad170791591b 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -703,7 +703,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Di rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: - filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath self._save_checkpoint(trainer, filepath) @@ -773,7 +773,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren """Checks if the previous checkpoint should be deleted. A checkpoint won't be deleted if any of the cases apply: - - The previous checkpoint is the same as the current checkpoint + - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new) - The previous checkpoint is not in the current checkpoint directory and the filesystem is local - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 66764c78303ff..3d25556316ed4 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -583,6 +583,28 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=1) +def test_none_monitor_not_alternating(tmp_path): + """Regression test for the case where the callback saved alternating `model.ckpt` and `model-v1.ckpt` files.""" + + class ListDirModel(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch > 0: + assert os.listdir(tmp_path) == ["model.ckpt"] + + model = ListDirModel() + model_checkpoint = ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1, filename="model") + trainer = Trainer( + callbacks=model_checkpoint, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=3, + enable_model_summary=False, + enable_progress_bar=False, + logger=False, + ) + trainer.fit(model) + + def test_invalid_every_n_epochs(tmpdir): """Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument.""" with pytest.raises(MisconfigurationException, match=r".*Must be >= 0"):