Skip to content

Commit

Permalink
comprehensive configure integration test, shared file io commands\
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Dec 6, 2024
1 parent 4f76f5c commit a18473d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 53 deletions.
84 changes: 34 additions & 50 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,13 @@
click.rich_click.SHOW_ARGUMENTS = True


class _SharedParams(click.RichCommand):
"""Options shared between most Casanovo commands"""
class _SharedFileIOParams(click.RichCommand):
"""File IO options shared between most Casanovo commands"""

def __init__(self, *args, **kwargs) -> None:
"""Define shared options."""
super().__init__(*args, **kwargs)
self.params += [
click.Option(
("-m", "--model"),
help="""
Either the model weights (.ckpt file) or a URL pointing to the
model weights file. If not provided, Casanovo will try to
download the latest release automatically.
""",
),
click.Option(
("-d", "--output_dir"),
help="The destination directory for output files.",
Expand All @@ -77,30 +69,44 @@ def __init__(self, *args, **kwargs) -> None:
type=click.Path(dir_okay=False),
),
click.Option(
("-c", "--config"),
help="""
The YAML configuration file overriding the default options.
""",
type=click.Path(exists=True, dir_okay=False),
("-f", "--force_overwrite"),
help="Whether to overwrite output files.",
is_flag=True,
show_default=True,
default=False,
),
click.Option(
("-v", "--verbosity"),
help="""
Set the verbosity of console logging messages. Log files are
always set to 'debug'.
""",
help=(
"Set the verbosity of console logging messages."
" Log files are always set to 'debug'."
),
type=click.Choice(
["debug", "info", "warning", "error"],
case_sensitive=False,
),
default="info",
),
]


class _SharedParams(_SharedFileIOParams):
"""Options shared between main Casanovo commands"""

def __init__(self, *args, **kwargs) -> None:
"""Define shared options."""
super().__init__(*args, **kwargs)
self.params += [
click.Option(
("-f", "--force_overwrite"),
help="Whether to overwrite output files.",
is_flag=True,
show_default=True,
default=False,
("-m", "--model"),
help="""Either the model weights (.ckpt file) or a URL pointing to
the model weights file. If not provided, Casanovo will try to
download the latest release automatically.""",
),
click.Option(
("-c", "--config"),
help="The YAML configuration file overriding the default options.",
type=click.Path(exists=True, dir_okay=False),
),
]

Expand Down Expand Up @@ -335,38 +341,16 @@ def version() -> None:
sys.stdout.write("\n".join(versions) + "\n")


@main.command()
@click.option(
"-d",
"--output_dir",
help="The destination directory for log and config file.",
type=click.Path(dir_okay=True),
required=False,
)
@click.option(
"-o",
"--output_root",
help="The root name for log and config file.",
type=click.Path(dir_okay=False),
required=False,
)
@click.option(
"-f",
"--force_overwrite",
help="Whether to overwrite output files.",
is_flag=True,
show_default=True,
default=False,
)
@main.command(cls=_SharedFileIOParams)
def configure(
output_dir: str, output_root: str, force_overwrite: bool
output_dir: str, output_root: str, verbosity: str, force_overwrite: bool
) -> None:
"""Generate a Casanovo configuration file to customize.
The casanovo configuration file is in the YAML format.
"""
output_path, _ = _setup_output(
output_dir, output_root, force_overwrite, "info"
output_dir, output_root, force_overwrite, verbosity
)
config_fname = output_root if output_root is not None else "casanovo"
config_fname = Path(config_fname).with_suffix(".yaml")
Expand All @@ -375,7 +359,7 @@ def configure(

config_path = str(output_path / config_fname)
Config.copy_default(config_path)
logger.info(f"Wrote {config_path}\n")
logger.info(f"Wrote {config_path}")


def setup_logging(
Expand Down
28 changes: 27 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import subprocess
import yaml
from pathlib import Path

import pyteomics.mztab
Expand Down Expand Up @@ -215,7 +216,7 @@ def test_train_and_run(
assert output_filename.is_file()


def test_auxilliary_cli(tmp_path, monkeypatch):
def test_auxilliary_cli(tmp_path, mgf_small, monkeypatch):
"""Test the secondary CLI commands"""
run = functools.partial(
CliRunner().invoke, casanovo.main, catch_exceptions=False
Expand All @@ -231,5 +232,30 @@ def test_auxilliary_cli(tmp_path, monkeypatch):
with pytest.raises(FileExistsError):
run(["configure", "-o", "test.yaml"])

with open("casanovo.yaml") as f_in:
config = yaml.safe_load(f_in)

config["max_epochs"] = 1
config["n_layers"] = 1

with open("small.yaml", "w") as f_out:
yaml.dump(config, f_out)

train_args = [
"train",
"--validation_peak_path",
str(mgf_small),
"--config",
"small.yaml",
"--output_dir",
str(tmp_path),
"--output_root",
"train",
str(mgf_small),
]

result = run(train_args)
assert result.exit_code == 0

res = run("version")
assert res.output
2 changes: 0 additions & 2 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,6 @@ def test_calc_match_score():


def test_digest_fasta_cleave(tiny_fasta_file, residues_dict):

# No missed cleavages
expected_normal = [
"ATSIPAR",
Expand Down Expand Up @@ -1092,7 +1091,6 @@ def test_get_candidates(tiny_fasta_file, residues_dict):


def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict):

# Tide isotope error windows for 496.2, 2+:
# 0: [980.481617, 1000.289326]
# 1: [979.491114, 999.278813]
Expand Down

0 comments on commit a18473d

Please sign in to comment.