Skip to content

Commit

Permalink
Save best model (#365)
Browse files Browse the repository at this point in the history
* save best model

* save best model

* updated unit tests

* remove save top k config item

* added save_top_k to deprecated config options

* changelog entry

* test case, formatting

* requested changes
  • Loading branch information
Lilferrit authored Aug 12, 2024
1 parent a46f995 commit ba58668
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 25 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

- Precursor charges are exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.

### Removed

- Removed the `save_top_k` option from the Casanovo config, the model with the lowest validation loss during training will now be saved to a fixed filename `<output_root>.best.ckpt`.

## [4.2.1] - 2024-06-25

### Fixed
Expand Down
22 changes: 15 additions & 7 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_config_deprecated = dict(
every_n_train_steps="val_check_interval",
max_iters="cosine_schedule_period_iters",
save_top_k=None,
)


Expand Down Expand Up @@ -74,7 +75,6 @@ class Config:
top_match=int,
max_epochs=int,
num_sanity_val_steps=int,
save_top_k=int,
model_save_folder_path=str,
val_check_interval=int,
calculate_precision=bool,
Expand All @@ -96,12 +96,20 @@ def __init__(self, config_file: Optional[str] = None):
# Remap deprecated config entries.
for old, new in _config_deprecated.items():
if old in self._user_config:
self._user_config[new] = self._user_config.pop(old)
warnings.warn(
f"Deprecated config option '{old}' remapped to "
f"'{new}'",
DeprecationWarning,
)
if new is not None:
self._user_config[new] = self._user_config.pop(old)
warning_msg = (
f"Deprecated config option '{old}' "
f"remapped to '{new}'"
)
else:
del self._user_config[old]
warning_msg = (
f"Deprecated config option '{old}' "
"is no longer in use"
)

warnings.warn(warning_msg, DeprecationWarning)
# Check for missing entries in config file.
config_missing = self._params.keys() - self._user_config.keys()
if len(config_missing) > 0:
Expand Down
3 changes: 0 additions & 3 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ random_seed: 454
n_log: 1
# Tensorboard directory to use for keeping track of training metrics.
tb_summarywriter:
# Save the top k model checkpoints during training. -1 saves all, and leaving
# this field empty saves none.
save_top_k: 5
# Path to saved checkpoints.
model_save_folder_path: ""
# Model validation and checkpointing frequency in training steps.
Expand Down
24 changes: 13 additions & 11 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ class ModelRunner:
model_filename : str, optional
The model filename is required for eval and de novo modes,
but not for training a model from scratch.
output_rootname : str, optional
The rootname for all output files (e.g. checkpoints or results)
"""

def __init__(
self,
config: Config,
model_filename: Optional[str] = None,
output_rootname: Optional[str] = None,
) -> None:
"""Initialize a ModelRunner"""
self.config = config
Expand All @@ -54,24 +57,23 @@ def __init__(
self.loaders = None
self.writer = None

best_filename = "best"
if output_rootname is not None:
best_filename = f"{output_rootname}.{best_filename}"

# Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
),
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
filename=best_filename,
),
]

if config.save_top_k is not None:
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
)

def __enter__(self):
"""Enter the context manager"""
self.tmp_dir = tempfile.TemporaryDirectory()
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def tiny_config(tmp_path):
"random_seed": 454,
"n_log": 1,
"tb_summarywriter": None,
"save_top_k": 5,
"n_peaks": 150,
"min_mz": 50.0,
"max_mz": 2500.0,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def test_train_and_run(

result = run(train_args)
model_file = tmp_path / "epoch=19-step=20.ckpt"
best_model = tmp_path / "best.ckpt"
assert result.exit_code == 0
assert model_file.exists()
assert best_model.exists()

# Try evaluating:
eval_args = [
Expand Down
11 changes: 10 additions & 1 deletion tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,18 @@ def test_deprecated(tmp_path, tiny_config):
filename = str(tmp_path / "config_deprecated.yml")
with open(tiny_config, "r") as f_in, open(filename, "w") as f_out:
cfg = yaml.safe_load(f_in)
# Insert deprecated config option.
# Insert remapped deprecated config option.
cfg["max_iters"] = 1
yaml.safe_dump(cfg, f_out)

with pytest.warns(DeprecationWarning):
Config(filename)

with open(tiny_config, "r") as f_in, open(filename, "w") as f_out:
cfg = yaml.safe_load(f_in)
# Insert non-remapped deprecated config option.
cfg["save_top_k"] = 5
yaml.safe_dump(cfg, f_out)

with pytest.warns(DeprecationWarning):
Config(filename)
4 changes: 2 additions & 2 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def test_save_final_model(tmp_path, mgf_small, tiny_config):

# Test checkpoint saving when val_check_interval is not a factor of training steps
config.val_check_interval = 15
validation_file = tmp_path / "epoch=14-step=15.ckpt"
with ModelRunner(config) as runner:
validation_file = tmp_path / "foobar.best.ckpt"
with ModelRunner(config, output_rootname="foobar") as runner:
runner.train([mgf_small], [mgf_small])

assert model_file.exists()
Expand Down

0 comments on commit ba58668

Please sign in to comment.