Skip to content

Commit

Permalink
Update ms2library to use new ms2deepscore methods
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdejonge committed Jun 24, 2024
1 parent d3a5f35 commit 501c86a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ms2query/ms2library.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from gensim.models import Word2Vec
from matchms.Spectrum import Spectrum
from ms2deepscore import MS2DeepScore
from ms2deepscore.models import load_model as load_ms2ds_model
from ms2deepscore.models import load_model as load_ms2ds_model, compute_embedding_array
from onnxruntime import InferenceSession
from spec2vec.vector_operations import calc_vector, cosine_similarity_matrix
from tqdm import tqdm
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self,
self.s2v_embeddings: pd.DataFrame = load_df_from_parquet_file(s2v_embeddings_file_name)
self.ms2ds_embeddings: pd.DataFrame = load_df_from_parquet_file(ms2ds_embeddings_file_name)

assert self.ms2ds_model.base.output_shape[1] == self.ms2ds_embeddings.shape[1], \
assert self.ms2ds_model.model_settings.embedding_dim == self.ms2ds_embeddings.shape[1], \
"Dimension of pre-computed MS2DeepScore embeddings does not fit given model."

# load precursor mz's
Expand Down Expand Up @@ -251,7 +251,9 @@ def _get_all_ms2ds_scores(self, query_spectrum: Spectrum
spectra in the ms2ds embeddings file.
"""
ms2ds = MS2DeepScore(self.ms2ds_model, progress_bar=False)
query_embeddings = ms2ds.calculate_vectors([query_spectrum])

query_embeddings = compute_embedding_array(self.ms2ds_model, [query_spectrum])

library_ms2ds_embeddings_numpy = self.ms2ds_embeddings.to_numpy()
ms2ds_scores = cosine_similarity_matrix(library_ms2ds_embeddings_numpy,
query_embeddings)
Expand Down Expand Up @@ -397,7 +399,7 @@ def get_ms2query_model_prediction_single_spectrum(
def select_files_for_ms2query(file_names: List[str], files_to_select=None):
"""Selects the files needed for MS2Library based on their file extensions. """
dict_with_file_extensions = \
{"sqlite": ".sqlite", "s2v_model": ".model", "ms2ds_model": ".hdf5",
{"sqlite": ".sqlite", "s2v_model": ".model", "ms2ds_model": ".pt",
"ms2query_model": ".onnx", "s2v_embeddings": "s2v_embeddings.parquet",
"ms2ds_embeddings": "ms2ds_embeddings.parquet"}
if files_to_select is not None:
Expand Down

0 comments on commit 501c86a

Please sign in to comment.