Skip to content

Commit

Permalink
unit tests fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Oct 2, 2024
1 parent b2bdce2 commit b8c5bf7
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import collections
import copy
import datetime
import functools
import hashlib
import heapq
import io
import itertools
import os
import pathlib
import platform
Expand Down Expand Up @@ -864,6 +864,9 @@ def test_spectrum_id_mgf(mgf_small, tmp_path):
data_module.valid_dataset,
data_module.test_dataset,
]:
for batch in dataset:
print(batch)

for i, (filename, scan_id) in enumerate(
[
(mgf_small, "0"),
Expand Down Expand Up @@ -901,19 +904,27 @@ def test_spectrum_id_mzml(mzml_small, tmp_path):

def test_train_val_step_functions():
"""Test train and validation step functions operating on batches."""
tokenizer = depthcharge.tokenizers.peptides.MskbPeptideTokenizer()
model = Spec2Pep(
n_beams=1,
residues="massivekb",
min_peptide_len=4,
train_label_smoothing=0.1,
tokenizer=tokenizer,
)
spectra = torch.zeros(1, 5, 2)
precursors = torch.tensor([[469.25364, 2.0, 235.63410]])
peptides = ["PEPK"]
batch = (spectra, precursors, peptides)

train_step_loss = model.training_step(batch)
val_step_loss = model.validation_step(batch)
batch = {
"mz_array": torch.zeros(1, 5),
"intensity_array": torch.zeros(1, 5),
"precursor_mz": torch.tensor(235.63410).unsqueeze(0),
"precursor_charge": torch.tensor(2.0).unsqueeze(0),
"seq": tokenizer.tokenize(["PEPK"]),
}
train_batch = {key: val.unsqueeze(0) for key, val in batch.items()}
val_batch = copy.deepcopy(train_batch)

train_step_loss = model.training_step(train_batch)
val_step_loss = model.validation_step(val_batch)

# Check if valid loss value returned
assert train_step_loss > 0
Expand All @@ -929,12 +940,8 @@ def test_run_map(mgf_small):
out_writer = ms_io.MztabWriter("dummy.mztab")
# Set peak file by base file name only.
out_writer.set_ms_run([os.path.basename(mgf_small.name)])
assert os.path.basename(mgf_small.name) not in out_writer._run_map
assert os.path.abspath(mgf_small.name) in out_writer._run_map
# Set peak file by full path.
out_writer.set_ms_run([os.path.abspath(mgf_small.name)])
assert os.path.basename(mgf_small.name) not in out_writer._run_map
assert os.path.abspath(mgf_small.name) in out_writer._run_map
assert mgf_small.name in out_writer._run_map
assert os.path.abspath(mgf_small.name) not in out_writer._run_map


def test_check_dir(tmp_path):
Expand Down

0 comments on commit b8c5bf7

Please sign in to comment.