Skip to content

Commit

Permalink
log metrics refactor, additional log metrics test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Sep 19, 2024
1 parent 9d4109e commit 86747d9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 40 deletions.
2 changes: 1 addition & 1 deletion casanovo/data/ms_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .. import __version__
from ..config import Config
from .pep_spec_match import PepSpecMatch
from .psm import PepSpecMatch


class MztabWriter:
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from . import evaluate
from .. import config
from ..data import ms_io, pep_spec_match
from ..data import ms_io, psm

logger = logging.getLogger("casanovo")

Expand Down Expand Up @@ -914,7 +914,7 @@ def on_predict_batch_end(
if len(peptide) == 0:
continue
self.out_writer.psms.append(
pep_spec_match.PepSpecMatch(
psm.PepSpecMatch(
sequence=peptide,
spectrum_id=tuple(spectrum_i),
peptide_score=peptide_score,
Expand Down
33 changes: 13 additions & 20 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,32 +163,25 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None:
Index containing the annotated spectra used to generate model
predictions
"""
model_output = []
spectrum_annotations = []
psms = iter(self.writer.psms)
curr_psm = next(psms, None)
seq_pred = []
seq_true = []
pred_idx = 0

with test_index as t_ind:
for i in range(t_ind.n_spectra):
if curr_psm is None:
break

spectrum_annotations.append(t_ind[i][4])
if curr_psm.spectrum_id == t_ind.get_spectrum_id(i):
model_output.append(curr_psm.sequence)
curr_psm = next(psms, None)
for true_idx in range(t_ind.n_spectra):
seq_true.append(t_ind[true_idx][4])
if pred_idx < len(self.writer.psms) and self.writer.psms[
pred_idx
].spectrum_id == t_ind.get_spectrum_id(true_idx):
seq_pred.append(self.writer.psms[pred_idx].sequence)
pred_idx += 1
else:
model_output.append("")

if curr_psm is not None:
logger.warning(
"Some spectra were not matched to annotations during evaluation."
)
seq_pred.append("")

aa_precision, _, pep_precision = aa_match_metrics(
*aa_match_batch(
spectrum_annotations,
model_output,
seq_true,
seq_pred,
depthcharge.masses.PeptideMass().masses,
)
)
Expand Down
2 changes: 1 addition & 1 deletion casanovo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import psutil
import torch

from .data.pep_spec_match import PepSpecMatch
from .data.psm import PepSpecMatch


SCORE_BINS = [0.0, 0.5, 0.9, 0.95, 0.99]
Expand Down
49 changes: 33 additions & 16 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch

from casanovo.config import Config
from casanovo.data.psm import PepSpecMatch
from casanovo.denovo.model_runner import ModelRunner
from casanovo.data.ms_io import PepSpecMatch


def test_initialize_model(tmp_path, mgf_small):
Expand Down Expand Up @@ -287,8 +287,6 @@ def test_evaluate(


def test_log_metrics(monkeypatch, tiny_config):
TEST_EPSILON = 10**-5

def get_mock_index(psm_list):
mock_test_index = unittest.mock.MagicMock()
mock_test_index.__enter__.return_value = mock_test_index
Expand Down Expand Up @@ -338,8 +336,8 @@ def get_mock_psm(sequence, spectrum_id):

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert abs(pep_precision - 100) < TEST_EPSILON
assert abs(aa_precision - 100) < TEST_EPSILON
assert pep_precision == pytest.approx(100)
assert aa_precision == pytest.approx(100)

# Test 50% peptide precision (one wrong)
infer_psms = [
Expand All @@ -358,8 +356,8 @@ def get_mock_psm(sequence, spectrum_id):

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert abs(pep_precision - 100 * (1 / 2)) < TEST_EPSILON
assert abs(aa_precision - 100 * (5 / 6)) < TEST_EPSILON
assert pep_precision == pytest.approx(100 * (1 / 2))
assert aa_precision == pytest.approx(100 * (5 / 6))

# Test skipped spectra
act_psms = [
Expand All @@ -383,8 +381,32 @@ def get_mock_psm(sequence, spectrum_id):

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert abs(pep_precision - 100 * (4 / 5)) < TEST_EPSILON
assert abs(aa_precision - 100 * (12 / 13)) < TEST_EPSILON
assert pep_precision == pytest.approx(100 * (4 / 5))
assert aa_precision == pytest.approx(100 * (12 / 13))

act_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
get_mock_psm("PEI", ("foo", "index=3")),
get_mock_psm("PEG", ("foo", "index=4")),
get_mock_psm("PEA", ("foo", "index=5")),
]

infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
get_mock_psm("PEI", ("foo", "index=3")),
get_mock_psm("PEG", ("foo", "index=4")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100 * (4 / 5))
assert aa_precision == pytest.approx(100 * (12 / 13))

# Test un-inferred spectra
act_psms = [
Expand All @@ -408,10 +430,5 @@ def get_mock_psm(sequence, spectrum_id):

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
last_warning_msg = mock_logger.warning.call_args_list[-1][0][0]
assert abs(pep_precision) < TEST_EPSILON
assert abs(aa_precision - 100) < TEST_EPSILON
assert (
last_warning_msg
== "Some spectra were not matched to annotations during evaluation."
)
assert pep_precision == pytest.approx(0)
assert aa_precision == pytest.approx(100)

0 comments on commit 86747d9

Please sign in to comment.