Skip to content

Commit

Permalink
Use dummy spectra in test_train-all_models and optimize ms2deepscore …
Browse files Browse the repository at this point in the history
…settings for running fast
  • Loading branch information
niekdejonge committed Jun 24, 2024
1 parent 6c260d9 commit 626b5d9
Showing 1 changed file with 42 additions and 20 deletions.
62 changes: 42 additions & 20 deletions tests/test_train_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,52 @@
import numpy as np
import pytest
from ms2deepscore import SettingsMS2Deepscore

from ms2query.create_new_library.train_models import clean_and_train_models
import string
from matchms.Spectrum import Spectrum
from ms2query.create_new_library.train_models import train_all_models, SettingsTrainingModels
from ms2query.ms2library import MS2Library, create_library_object_from_one_dir


def create_test_spectra(num_of_unique_inchikeys):
# Define other parameters
mz, intens = 100.0, 0.1
spectrums = []
letters = list(string.ascii_uppercase[:num_of_unique_inchikeys])
letters += letters

# Create fake spectra
fake_inchikeys = []
for i, letter in enumerate(letters):
dummy_inchikey = f"{14 * letter}-{10 * letter}-N"
# fingerprint = generate_binary_vector(i)
fake_inchikeys.append(dummy_inchikey)
spectrums.append(
Spectrum(mz=np.array([mz + (i + 1) * 1.0, mz + 100 + (i + 1) * 1.0, mz + 200 + (i + 1) * 1.0]),
intensities=np.array([intens, intens, intens]),
metadata={"precursor_mz": 111.1,
"inchikey": dummy_inchikey,
"smiles": "C"*(i+1)
}))
return spectrums


@pytest.mark.integration
def test_train_all_models(path_to_general_test_files, tmp_path):
path_to_test_spectra = os.path.join(path_to_general_test_files, "2000_negative_test_spectra.mgf")
def test_train_all_models(tmp_path):
test_spectra = create_test_spectra(11)

models_folder = os.path.join(tmp_path, "models")
clean_and_train_models(path_to_test_spectra,
"negative",
models_folder,
{"ms2ds_fraction_validation_spectra": 2,
"ms2ds_training_settings": SettingsMS2Deepscore(
mz_bin_width=1.0,
epochs=2,
base_dims=(100, 100),
embedding_dim=50,
same_prob_bins=np.array([(0, 0.5), (0.5, 1.0)]),
average_pairs_per_bin=2,
batch_size=2),
"spec2vec_iterations": 2,
"ms2query_fraction_for_making_pairs": 400,
"add_compound_classes": False}
)
train_all_models(test_spectra, test_spectra, output_folder=models_folder,
settings=SettingsTrainingModels({"ms2ds_fraction_validation_spectra": 2,
"ms2ds_training_settings": SettingsMS2Deepscore(
mz_bin_width=1.0,
epochs=2,
base_dims=(100, 100),
embedding_dim=50,
same_prob_bins=np.array([(0, 1.0)]),
average_pairs_per_bin=2,
batch_size=2),
"spec2vec_iterations": 2,
"ms2query_fraction_for_making_pairs": 10,
"add_compound_classes": False}))
ms2library = create_library_object_from_one_dir(models_folder)
assert isinstance(ms2library, MS2Library)

0 comments on commit 626b5d9

Please sign in to comment.