Skip to content

Commit

Permalink
Merge pull request #237 from iomega/update_ms2deepscore_version
Browse files Browse the repository at this point in the history
Update ms2deepscore to 0.5.0
  • Loading branch information
niekdejonge authored Apr 11, 2024
2 parents a3ed83d + 0644750 commit eeed096
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 100 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## 1.4.0
### Changed
- Made compatible with MS2Deepscore 0.5.0

## 1.3.0
### Changed
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- pyarrow=12.0.1
- tensorflow=2.12.1
- scikit-learn=1.3.2
- ms2deepscore=0.4.0
- ms2deepscore=0.5.0
- pandas=2.0.3
- matplotlib=3.7.3
- skl2onnx=1.16.0
Expand Down
99 changes: 20 additions & 79 deletions ms2query/create_new_library/train_ms2deepscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,102 +4,43 @@
"""

import os
from typing import Dict, List, Optional
import numpy as np
import tensorflow as tf
from typing import List
from matchms import Spectrum
from matplotlib import pyplot as plt
from ms2deepscore import SpectrumBinner
from ms2deepscore.data_generators import DataGeneratorAllInchikeys
from ms2deepscore.models import SiameseModel
from tensorflow.keras.callbacks import ( # pylint: disable=import-error
EarlyStopping, ModelCheckpoint)
from tensorflow.keras.optimizers import Adam # pylint: disable=import-error
from ms2deepscore.train_new_model.train_ms2deepscore import (plot_history,
train_ms2ds_model)
from ms2query.create_new_library.calculate_tanimoto_scores import \
calculate_tanimoto_scores_unique_inchikey
from ms2query.create_new_library.split_data_for_training import \
split_spectra_on_inchikeys


def train_ms2ds_model(training_spectra,
validation_spectra,
tanimoto_df,
output_model_file_name,
epochs=150):
def train_ms2deepscore_wrapper(spectra: List[Spectrum],
output_model_file_name,
fraction_validation_spectra,
epochs,
ms2ds_history_file_name=None):
assert not os.path.isfile(output_model_file_name), "The MS2Deepscore output model file name already exists"
# assert len(validation_spectra) >= 100, \
# "Expected more validation spectra, too little validation spectra causes keras to crash"
# Bin training spectra
training_spectra, validation_spectra = split_spectra_on_inchikeys(spectra,
fraction_validation_spectra)
tanimoto_score_df = calculate_tanimoto_scores_unique_inchikey(spectra, spectra)
spectrum_binner = SpectrumBinner(10000, mz_min=10.0, mz_max=1000.0, peak_scaling=0.5,
allowed_missing_percentage=100.0)
binned_spectrums_training = spectrum_binner.fit_transform(training_spectra)
# Bin validation spectra using the binner based on the training spectra.
# Peaks that do not occur in the training spectra will not be binned in the validaiton spectra.
binned_spectrums_val = spectrum_binner.transform(validation_spectra)

same_prob_bins = list(zip(np.linspace(0, 0.9, 10), np.linspace(0.1, 1, 10)))

training_generator = DataGeneratorAllInchikeys(
binned_spectrums_training,
selected_inchikeys=list({s.get("inchikey")[:14] for s in training_spectra}),
reference_scores_df=tanimoto_df,
dim=len(spectrum_binner.known_bins), # The number of bins created
same_prob_bins=same_prob_bins,
num_turns=2,
augment_noise_max=10,
augment_noise_intensity=0.01)

validation_generator = DataGeneratorAllInchikeys(
binned_spectrums_val,
selected_inchikeys=list({s.get("inchikey")[:14] for s in binned_spectrums_val}),
reference_scores_df=tanimoto_df,
dim=len(spectrum_binner.known_bins), # The number of bins created
same_prob_bins=same_prob_bins,
num_turns=10, # Number of pairs for each InChiKey14 during each epoch.
# To prevent data augmentation
augment_removal_max=0, augment_removal_intensity=0, augment_intensity=0, augment_noise_max=0, use_fixed_set=True
history = train_ms2ds_model(
binned_spectrums_training,
binned_spectrums_val,
spectrum_binner,
tanimoto_score_df,
output_model_file_name,
epochs=epochs,
base_dims=(500, 500),
embedding_dim=200,
)

model = SiameseModel(spectrum_binner, base_dims=(500, 500), embedding_dim=200, dropout_rate=0.2)

model.compile(loss='mse', optimizer=Adam(lr=0.001), metrics=["mae", tf.keras.metrics.RootMeanSquaredError()])

# Save best model and include early stopping
checkpointer = ModelCheckpoint(filepath=output_model_file_name, monitor='val_loss', mode="min", verbose=1, save_best_only=True)
earlystopper_scoring_net = EarlyStopping(monitor='val_loss', mode="min", patience=10, verbose=1)
# Fit model and save history
history = model.model.fit(training_generator, validation_data=validation_generator, epochs=epochs, verbose=1,
callbacks=[checkpointer, earlystopper_scoring_net])
model.load_weights(output_model_file_name)
model.save(output_model_file_name)
return history.history


def plot_history(history: Dict[str, List[float]],
file_name: Optional[str] = None):
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
if file_name:
plt.savefig(file_name)
else:
plt.show()


def train_ms2deepscore_wrapper(spectra: List[Spectrum],
output_model_file_name,
fraction_validation_spectra,
epochs,
ms2ds_history_file_name=None):
assert not os.path.isfile(output_model_file_name), "The MS2Deepscore output model file name already exists"
training_spectra, validation_spectra = split_spectra_on_inchikeys(spectra,
fraction_validation_spectra)
tanimoto_score_df = calculate_tanimoto_scores_unique_inchikey(spectra, spectra)
history = train_ms2ds_model(training_spectra, validation_spectra,
tanimoto_score_df, output_model_file_name,
epochs)
print(f"The training history is: {history}")
plot_history(history, ms2ds_history_file_name)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"h5py",
"pyarrow",
"scikit-learn",
"ms2deepscore<=0.4.0",
"ms2deepscore==0.5.0",
"gensim>=4.0.0",
"pandas",
"tqdm",
Expand Down
19 changes: 0 additions & 19 deletions tests/test_train_ms2deepscore.py

This file was deleted.

0 comments on commit eeed096

Please sign in to comment.