Skip to content

Commit

Permalink
Merge branch 'dev' into dev_db_search
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunAnanth2003 authored Sep 17, 2024
2 parents 4e696b4 + aefa73c commit 9a24817
Show file tree
Hide file tree
Showing 12 changed files with 599 additions and 297 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Changed

- Removed the `evaluate` sub-command, and all model evaluation functionality has been moved to the `sequence` command using the new `--evaluate` flag.
- The `--output` option has been split into two options, `--output_dir` and `--output_root`.
- The `--validation_peak_path` is now optional when training; if `--validation_peak_path` is not set then the `train_peak_path` will also be used for validation.

### Fixed

- Precursor charges are exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.

### Removed

- Removed the `save_top_k` option from the Casanovo config, the model with the lowest validation loss during training will now be saved to a fixed filename `<output_root>.best.ckpt`.
- Removed the `save_top_k` option from the Casanovo config, the model with the lowest validation loss during training will now be saved to a fixed filename `<output_root>.best.ckpt`.
- The `model_save_folder_path` config option has been eliminated; model checkpoints will now be saved to `--output_dir` during training.

## [4.2.1] - 2024-06-25

Expand Down
182 changes: 136 additions & 46 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import urllib.parse
import warnings
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, List

warnings.formatwarning = lambda message, category, *args, **kwargs: (
f"{category.__name__}: {message}"
Expand Down Expand Up @@ -67,8 +67,13 @@ def __init__(self, *args, **kwargs) -> None:
""",
),
click.Option(
("-o", "--output"),
help="The mzTab file to which results will be written.",
("-d", "--output_dir"),
help="The destination directory for output files",
type=click.Path(dir_okay=True),
),
click.Option(
("-o", "--output_root"),
help="The root name for all output files",
type=click.Path(dir_okay=False),
),
click.Option(
Expand All @@ -90,6 +95,13 @@ def __init__(self, *args, **kwargs) -> None:
),
default="info",
),
click.Option(
("-f", "--force_overwrite"),
help="Whether to overwrite output files.",
is_flag=True,
show_default=True,
default=False,
),
]


Expand Down Expand Up @@ -144,8 +156,10 @@ def sequence(
peak_path: Tuple[str],
model: Optional[str],
config: Optional[str],
output: Optional[str],
output_dir: Optional[str],
output_root: Optional[str],
verbosity: str,
force_overwrite: bool,
evaluate: bool,
) -> None:
"""De novo sequence peptides from tandem mass spectra.
Expand All @@ -154,18 +168,33 @@ def sequence(
to sequence peptides. If evaluate is set to True PEAK_PATH must be
one or more annotated MGF file.
"""
output = setup_logging(output, verbosity)
config, model = setup_model(model, config, output, False)
output_path, output_root_name = _setup_output(
output_dir, output_root, force_overwrite, verbosity
)
utils.check_dir_file_exists(output_path, f"{output_root}.mztab")
config, model = setup_model(
model, config, output_path, output_root_name, False
)
start_time = time.time()
with ModelRunner(config, model) as runner:
with ModelRunner(
config,
model,
output_path,
output_root_name if output_root is not None else None,
False,
) as runner:
logger.info(
"Sequencing %speptides from:",
"and evaluating " if evaluate else "",
)
for peak_file in peak_path:
logger.info(" %s", peak_file)

runner.predict(peak_path, output, evaluate=evaluate)
runner.predict(
peak_path,
str((output_path / output_root).with_suffix(".mztab")),
evaluate=evaluate,
)
psms = runner.writer.psms
utils.log_sequencing_report(
psms, start_time=start_time, end_time=time.time()
Expand Down Expand Up @@ -230,31 +259,46 @@ def db_search(
An annotated MGF file for validation, like from MassIVE-KB. Use this
option multiple times to specify multiple files.
""",
required=True,
required=False,
multiple=True,
type=click.Path(exists=True, dir_okay=False),
)
def train(
train_peak_path: Tuple[str],
validation_peak_path: Tuple[str],
validation_peak_path: Optional[Tuple[str]],
model: Optional[str],
config: Optional[str],
output: Optional[str],
output_dir: Optional[str],
output_root: Optional[str],
verbosity: str,
force_overwrite: bool,
) -> None:
"""Train a Casanovo model on your own data.
TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
output = setup_logging(output, verbosity)
config, model = setup_model(model, config, output, True)
output_path, output_root_name = _setup_output(
output_dir, output_root, force_overwrite, verbosity
)
config, model = setup_model(
model, config, output_path, output_root_name, True
)
start_time = time.time()
with ModelRunner(config, model) as runner:
with ModelRunner(
config,
model,
output_path,
output_root_name if output_root is not None else None,
not force_overwrite,
) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
logger.info(" %s", peak_file)

if len(validation_peak_path) == 0:
validation_peak_path = train_peak_path

logger.info("Using the following validation files:")
for peak_file in validation_peak_path:
logger.info(" %s", peak_file)
Expand Down Expand Up @@ -294,7 +338,7 @@ def configure(output: str) -> None:


def setup_logging(
output: Optional[str],
log_file_path: Path,
verbosity: str,
) -> Path:
"""Set up the logger.
Expand All @@ -303,21 +347,11 @@ def setup_logging(
Parameters
----------
output : Optional[str]
The provided output file name.
log_file_path: Path
The log file path.
verbosity : str
The logging level to use in the console.
Return
------
output : Path
The output file path.
"""
if output is None:
output = f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"

output = Path(output).expanduser().resolve()

logging_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand All @@ -344,9 +378,7 @@ def setup_logging(
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
warnings_logger.addHandler(console_handler)
file_handler = logging.FileHandler(
output.with_suffix(".log"), encoding="utf8"
)
file_handler = logging.FileHandler(log_file_path, encoding="utf8")
file_handler.setFormatter(log_formatter)
root_logger.addHandler(file_handler)
warnings_logger.addHandler(file_handler)
Expand All @@ -363,33 +395,38 @@ def setup_logging(
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

return output


def setup_model(
model: Optional[str],
config: Optional[str],
output: Optional[Path],
model: str | None,
config: str | None,
output_dir: Path | str,
output_root_name: str,
is_train: bool,
) -> Config:
"""Setup Casanovo for most commands.
) -> Tuple[Config, Path | None]:
"""Setup Casanovo config and resolve model weights (.ckpt) path
Parameters
----------
model : Optional[str]
The provided model weights file.
config : Optional[str]
The provided configuration file.
output : Optional[Path]
The provided output file name.
model : str | None
May be a file system path, a URL pointing to a .ckpt file, or None.
If `model` is a URL the weights will be downloaded and cached from
`model`. If `model` is `None` the weights from the latest matching
official release will be used (downloaded and cached).
config : str | None
Config file path. If None the default config will be used.
output_dir: : Path | str
The path to the output directory.
output_root_name : str,
The base name for the output files.
is_train : bool
Are we training? If not, we need to retrieve weights when the model is
None.
Return
------
config : Config
The parsed configuration
Tuple[Config, Path]
Initialized Casanovo config, local path to model weights if any (may be
`None` if training using random starting weights).
"""
# Read parameters from the config file.
config = Config(config)
Expand Down Expand Up @@ -429,7 +466,8 @@ def setup_model(
logger.info("Casanovo version %s", str(__version__))
logger.debug("model = %s", model)
logger.debug("config = %s", config.file)
logger.debug("output = %s", output)
logger.debug("output directory = %s", output_dir)
logger.debug("output root name = %s", output_root_name)
for key, value in config.items():
logger.debug("%s = %s", str(key), str(value))

Expand Down Expand Up @@ -533,6 +571,58 @@ def _get_model_weights(cache_dir: Path) -> str:
)


def _setup_output(
output_dir: str | None,
output_root: str | None,
overwrite: bool,
verbosity: str,
) -> Tuple[Path, str]:
"""
Set up the output directory, output file root name, and logging.
Parameters:
-----------
output_dir : str | None
The path to the output directory. If `None`, the output directory will
be resolved to the current working directory.
output_root : str | None
The base name for the output files. If `None` the output root name will
be resolved to casanovo_<current date and time>
overwrite: bool
Whether to overwrite log file if it already exists in the output
directory.
verbosity : str
The verbosity level for logging.
Returns:
--------
Tuple[Path, str]
A tuple containing the resolved output directory and root name for
output files.
"""
if output_root is None:
output_root = (
f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
)

if output_dir is None:
output_path = Path.cwd()
else:
output_path = Path(output_dir).expanduser().resolve()
if not output_path.is_dir():
output_path.mkdir(parents=True)
logger.warning(
"Target output directory %s does not exists, so it will be created.",
output_path,
)

if not overwrite:
utils.check_dir_file_exists(output_path, f"{output_root}.log")

setup_logging((output_path / output_root).with_suffix(".log"), verbosity)
return output_path, output_root


def _get_weights_from_url(
file_url: str,
cache_dir: Path,
Expand Down
4 changes: 2 additions & 2 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
max_iters="cosine_schedule_period_iters",
max_length="max_peptide_len",
save_top_k=None,
model_save_folder_path=None,
)


Expand Down Expand Up @@ -64,7 +65,7 @@ class Config:
max_peptide_len=int,
residues=dict,
n_log=int,
tb_summarywriter=str,
tb_summarywriter=bool,
train_label_smoothing=float,
warmup_iters=int,
cosine_schedule_period_iters=int,
Expand All @@ -76,7 +77,6 @@ class Config:
top_match=int,
max_epochs=int,
num_sanity_val_steps=int,
model_save_folder_path=str,
val_check_interval=int,
calculate_precision=bool,
accelerator=str,
Expand Down
6 changes: 2 additions & 4 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ random_seed: 454
# OUTPUT OPTIONS
# Logging frequency in training steps.
n_log: 1
# Tensorboard directory to use for keeping track of training metrics.
tb_summarywriter:
# Path to saved checkpoints.
model_save_folder_path: ""
# Whether to create tensorboard directory
tb_summarywriter: false
# Model validation and checkpointing frequency in training steps.
val_check_interval: 50_000

Expand Down
Loading

0 comments on commit 9a24817

Please sign in to comment.