diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bd936a4..c4559d33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Precursor charges are exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification. +### Removed + +- Removed the `save_top_k` option from the Casanovo config, the model with the lowest validation loss during training will now be saved to a fixed filename `.best.ckpt`. + ## [4.2.1] - 2024-06-25 ### Fixed diff --git a/casanovo/config.py b/casanovo/config.py index 792da35a..453f7b15 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -18,6 +18,7 @@ _config_deprecated = dict( every_n_train_steps="val_check_interval", max_iters="cosine_schedule_period_iters", + save_top_k=None, ) @@ -74,7 +75,6 @@ class Config: top_match=int, max_epochs=int, num_sanity_val_steps=int, - save_top_k=int, model_save_folder_path=str, val_check_interval=int, calculate_precision=bool, @@ -96,12 +96,20 @@ def __init__(self, config_file: Optional[str] = None): # Remap deprecated config entries. for old, new in _config_deprecated.items(): if old in self._user_config: - self._user_config[new] = self._user_config.pop(old) - warnings.warn( - f"Deprecated config option '{old}' remapped to " - f"'{new}'", - DeprecationWarning, - ) + if new is not None: + self._user_config[new] = self._user_config.pop(old) + warning_msg = ( + f"Deprecated config option '{old}' " + f"remapped to '{new}'" + ) + else: + del self._user_config[old] + warning_msg = ( + f"Deprecated config option '{old}' " + "is no longer in use" + ) + + warnings.warn(warning_msg, DeprecationWarning) # Check for missing entries in config file. config_missing = self._params.keys() - self._user_config.keys() if len(config_missing) > 0: diff --git a/casanovo/config.yaml b/casanovo/config.yaml index c7186ff7..3beb5f30 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -42,9 +42,6 @@ random_seed: 454 n_log: 1 # Tensorboard directory to use for keeping track of training metrics. tb_summarywriter: -# Save the top k model checkpoints during training. -1 saves all, and leaving -# this field empty saves none. -save_top_k: 5 # Path to saved checkpoints. model_save_folder_path: "" # Model validation and checkpointing frequency in training steps. diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index d5acacb3..07ec4166 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -36,12 +36,15 @@ class ModelRunner: model_filename : str, optional The model filename is required for eval and de novo modes, but not for training a model from scratch. + output_rootname : str, optional + The rootname for all output files (e.g. checkpoints or results) """ def __init__( self, config: Config, model_filename: Optional[str] = None, + output_rootname: Optional[str] = None, ) -> None: """Initialize a ModelRunner""" self.config = config @@ -54,24 +57,23 @@ def __init__( self.loaders = None self.writer = None + best_filename = "best" + if output_rootname is not None: + best_filename = f"{output_rootname}.{best_filename}" + # Configure checkpoints. self.callbacks = [ ModelCheckpoint( dirpath=config.model_save_folder_path, save_on_train_epoch_end=True, - ) + ), + ModelCheckpoint( + dirpath=config.model_save_folder_path, + monitor="valid_CELoss", + filename=best_filename, + ), ] - if config.save_top_k is not None: - self.callbacks.append( - ModelCheckpoint( - dirpath=config.model_save_folder_path, - monitor="valid_CELoss", - mode="min", - save_top_k=config.save_top_k, - ) - ) - def __enter__(self): """Enter the context manager""" self.tmp_dir = tempfile.TemporaryDirectory() diff --git a/tests/conftest.py b/tests/conftest.py index 02a6d0f2..54c8d03c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -208,7 +208,6 @@ def tiny_config(tmp_path): "random_seed": 454, "n_log": 1, "tb_summarywriter": None, - "save_top_k": 5, "n_peaks": 150, "min_mz": 50.0, "max_mz": 2500.0, diff --git a/tests/test_integration.py b/tests/test_integration.py index a622b188..3c6718f5 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -36,8 +36,10 @@ def test_train_and_run( result = run(train_args) model_file = tmp_path / "epoch=19-step=20.ckpt" + best_model = tmp_path / "best.ckpt" assert result.exit_code == 0 assert model_file.exists() + assert best_model.exists() # Try evaluating: eval_args = [ diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index fbe10eee..14d2abe5 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -43,9 +43,18 @@ def test_deprecated(tmp_path, tiny_config): filename = str(tmp_path / "config_deprecated.yml") with open(tiny_config, "r") as f_in, open(filename, "w") as f_out: cfg = yaml.safe_load(f_in) - # Insert deprecated config option. + # Insert remapped deprecated config option. cfg["max_iters"] = 1 yaml.safe_dump(cfg, f_out) with pytest.warns(DeprecationWarning): Config(filename) + + with open(tiny_config, "r") as f_in, open(filename, "w") as f_out: + cfg = yaml.safe_load(f_in) + # Insert non-remapped deprecated config option. + cfg["save_top_k"] = 5 + yaml.safe_dump(cfg, f_out) + + with pytest.warns(DeprecationWarning): + Config(filename) diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index 2d0513bd..b39c3758 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -168,8 +168,8 @@ def test_save_final_model(tmp_path, mgf_small, tiny_config): # Test checkpoint saving when val_check_interval is not a factor of training steps config.val_check_interval = 15 - validation_file = tmp_path / "epoch=14-step=15.ckpt" - with ModelRunner(config) as runner: + validation_file = tmp_path / "foobar.best.ckpt" + with ModelRunner(config, output_rootname="foobar") as runner: runner.train([mgf_small], [mgf_small]) assert model_file.exists()