From 79f6fcc383bcc0f165f9c51e4a32675bf6f2f8c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez=20Santiago?= <45119610+ireneisdoomed@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:34:54 +0000 Subject: [PATCH] feat(l2g)!: implement new training strategy splitting between EFO/gene pairs and with cross validation (#938) * feat(gold_standard): add traitFromSourceMappedId to schema * chore: adapt tests * feat(feature_matrix): consider `traitFromSourceMappedId` a static column * feat(feature_matrix): consider `traitFromSourceMappedId` an optional column * feat: update l2g config with best hyperparams * feat(trainer): new train runs when cross_validate=False * chore(model): add default hyperparams based on best params * chore: debug sweep, one single run * feat(trainer): new train runs when cross_validate=True * feat(cross_validate): sweep runs are now together * chore: pre-commit auto fixes [...] * chore: improve error message --- src/gentropy/config.py | 10 +- src/gentropy/l2g.py | 15 +- src/gentropy/method/l2g/model.py | 15 +- src/gentropy/method/l2g/trainer.py | 398 ++++++++++++++++++++--------- 4 files changed, 309 insertions(+), 129 deletions(-) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 65fdb5897..9c454d41b 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -264,20 +264,24 @@ class LocusToGeneConfig(StepConfig): "geneCount500kb", "proteinGeneCount500kb", "credibleSetConfidence", - # "isProteinCoding", ] ) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, - "max_depth": 5, - "loss": "log_loss", + "max_depth": 10, + "ccp_alpha": 0, + "learning_rate": 0.1, + "min_samples_leaf": 5, + "min_samples_split": 5, + "subsample": 1, } ) wandb_run_name: str | None = None hf_hub_repo_id: str | None = "opentargets/locus_to_gene" hf_model_commit_message: str | None = "chore: update model" download_from_hub: bool = True + cross_validate: bool = True _target_: str = "gentropy.l2g.LocusToGeneStep" diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index ff0d47f58..3b73a377d 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -100,11 +100,12 @@ class LocusToGeneStep: def __init__( self, session: Session, - hyperparameters: dict[str, Any], *, run_mode: str, features_list: list[str], + hyperparameters: dict[str, Any], download_from_hub: bool, + cross_validate: bool, wandb_run_name: str, credible_set_path: str, feature_matrix_path: str, @@ -113,18 +114,19 @@ def __init__( variant_index_path: str | None = None, gene_interactions_path: str | None = None, predictions_path: str | None = None, - l2g_threshold: float | None, - hf_hub_repo_id: str | None, + l2g_threshold: float | None = None, + hf_hub_repo_id: str | None = None, hf_model_commit_message: str | None = "chore: update model", ) -> None: """Initialise the step and run the logic based on mode. Args: session (Session): Session object that contains the Spark session - hyperparameters (dict[str, Any]): Hyperparameters for the model run_mode (str): Run mode, either 'train' or 'predict' features_list (list[str]): List of features to use for the model + hyperparameters (dict[str, Any]): Hyperparameters for the model download_from_hub (bool): Whether to download the model from Hugging Face Hub + cross_validate (bool): Whether to run cross validation (5-fold by default) to train the model. wandb_run_name (str): Name of the run to track model training in Weights and Biases credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix feature_matrix_path (str): Path to the L2G feature matrix input dataset @@ -152,6 +154,7 @@ def __init__( self.features_list = list(features_list) self.hyperparameters = dict(hyperparameters) self.wandb_run_name = wandb_run_name + self.cross_validate = cross_validate self.hf_hub_repo_id = hf_hub_repo_id self.download_from_hub = download_from_hub self.hf_model_commit_message = hf_model_commit_message @@ -300,7 +303,7 @@ def run_train(self) -> None: # Instantiate classifier and train model l2g_model = LocusToGeneModel( - model=GradientBoostingClassifier(random_state=42), + model=GradientBoostingClassifier(random_state=42, loss="log_loss"), hyperparameters=self.hyperparameters, ) @@ -310,7 +313,7 @@ def run_train(self) -> None: # Run the training trained_model = LocusToGeneTrainer( model=l2g_model, feature_matrix=feature_matrix - ).train(self.wandb_run_name) + ).train(self.wandb_run_name, cross_validate=self.cross_validate) # Export the model if trained_model.training_data and trained_model.model and self.model_path: diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 336efeb7f..1f18f227f 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,7 +27,17 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - hyperparameters: dict[str, Any] | None = None + hyperparameters: dict[str, Any] = field( + default_factory=lambda: { + "n_estimators": 100, + "max_depth": 10, + "ccp_alpha": 0, + "learning_rate": 0.1, + "min_samples_leaf": 5, + "min_samples_split": 5, + "subsample": 1, + } + ) training_data: L2GFeatureMatrix | None = None label_encoder: dict[str, int] = field( default_factory=lambda: { @@ -38,8 +48,7 @@ class LocusToGeneModel: def __post_init__(self: LocusToGeneModel) -> None: """Post-initialisation to fit the estimator with the provided params.""" - if self.hyperparameters: - self.model.set_params(**self.hyperparameters_dict) + self.model.set_params(**self.hyperparameters_dict) @classmethod def load_from_disk(cls: Type[LocusToGeneModel], path: str) -> LocusToGeneModel: diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index ab2a3fa7e..a43d6609d 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -4,23 +4,26 @@ import os from dataclasses import dataclass -from functools import partial from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt +import numpy as np import pandas as pd import shap +from sklearn.base import clone from sklearn.metrics import ( accuracy_score, + average_precision_score, f1_score, precision_score, recall_score, roc_auc_score, ) -from sklearn.model_selection import train_test_split +from sklearn.model_selection import GroupKFold, GroupShuffleSplit from wandb.data_types import Image, Table from wandb.errors.term import termlog as wandb_termlog from wandb.sdk.wandb_init import init as wandb_init +from wandb.sdk.wandb_setup import _setup from wandb.sdk.wandb_sweep import sweep as wandb_sweep from wandb.sklearn import plot_classifier from wandb.wandb_agent import agent as wandb_agent @@ -34,6 +37,21 @@ from wandb.sdk.wandb_run import Run +def reset_wandb_env() -> None: + """Reset Wandb environment variables except for project, entity and API key. + + This is necessary to log multiple runs in the same sweep without overwriting. More context here: https://github.com/wandb/wandb/issues/5119 + """ + exclude = { + "WANDB_PROJECT", + "WANDB_ENTITY", + "WANDB_API_KEY", + } + for key in list(os.environ.keys()): + if key.startswith("WANDB_") and key not in exclude: + del os.environ[key] + + @dataclass class LocusToGeneTrainer: """Modelling of what is the most likely causal gene associated with a given locus.""" @@ -44,10 +62,11 @@ class LocusToGeneTrainer: # Initialise vars features_list: list[str] | None = None label_col: str = "goldStandardSet" - x_train: pd.DataFrame | None = None - y_train: pd.Series | None = None - x_test: pd.DataFrame | None = None - y_test: pd.Series | None = None + x_train: np.ndarray | None = None + y_train: np.ndarray | None = None + x_test: np.ndarray | None = None + y_test: np.ndarray | None = None + groups_train: np.ndarray | None = None run: Run | None = None wandb_l2g_project_name: str = "gentropy-locus-to-gene" @@ -72,9 +91,9 @@ def fit( """ if self.x_train is not None and self.y_train is not None: assert ( - not self.x_train.empty and not self.y_train.empty + self.x_train.size != 0 and self.y_train.size != 0 ), "Train data not set, nothing to fit." - fitted_model = self.model.model.fit(X=self.x_train.values, y=self.y_train) + fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train) self.model = LocusToGeneModel( model=fitted_model, hyperparameters=fitted_model.get_params(), @@ -100,7 +119,10 @@ def _get_shap_explanation( Exception: (ExplanationError) When the additivity check fails. """ if self.x_train is not None and self.x_test is not None: - training_data = pd.concat([self.x_train, self.x_test], ignore_index=True) + training_data = pd.DataFrame( + np.vstack((self.x_train, self.x_test)), + columns=self.features_list, + ) explainer = shap.TreeExplainer( model.model, data=training_data, @@ -152,151 +174,293 @@ def log_to_wandb( wandb_run_name (str): Name of the W&B run Raises: - ValueError: If dependencies are not available. + RuntimeError: If dependencies are not available. """ if ( - self.x_train is not None - and self.x_test is not None - and self.y_train is not None - and self.y_test is not None - and self.features_list is not None + self.x_train is None + or self.x_test is None + or self.y_train is None + or self.y_test is None + or self.features_list is None ): - assert ( - not self.x_train.empty and not self.y_train.empty - ), "Train data not set, nothing to evaluate." - fitted_classifier = self.model.model - y_predicted = fitted_classifier.predict(self.x_test.values) - y_probas = fitted_classifier.predict_proba(self.x_test.values) - self.run = wandb_init( - project=self.wandb_l2g_project_name, - name=wandb_run_name, - config=fitted_classifier.get_params(), - ) - # Track classification plots - plot_classifier( - self.model.model, - self.x_train.values, - self.x_test.values, - self.y_train, - self.y_test, - y_predicted, - y_probas, - labels=list(self.model.label_encoder.values()), - model_name="L2G-classifier", - feature_names=self.features_list, - is_binary=True, - ) - # Track evaluation metrics - self.run.log( - { - "areaUnderROC": roc_auc_score( - self.y_test, y_probas[:, 1], average="weighted" - ) - } - ) - self.run.log({"accuracy": accuracy_score(self.y_test, y_predicted)}) - self.run.log( - { - "weightedPrecision": precision_score( - self.y_test, y_predicted, average="weighted" - ) - } - ) - self.run.log( - { - "weightedRecall": recall_score( - self.y_test, y_predicted, average="weighted" - ) - } - ) - self.run.log({"f1": f1_score(self.y_test, y_predicted, average="weighted")}) - # Track gold standards and their features - self.run.log( - {"featureMatrix": Table(dataframe=self.feature_matrix._df.toPandas())} - ) - # Log feature missingness - self.run.log( - { - "missingnessRates": self.feature_matrix.calculate_feature_missingness_rate() - } - ) - # Plot marginal contribution of each feature - explanation = self._get_shap_explanation(self.model) - self.log_plot_image_to_wandb( - "Feature Contribution", - shap.plots.bar( - explanation, max_display=len(self.x_train.columns), show=False - ), - ) + raise RuntimeError("Train data not set, we cannot log to W&B.") + assert ( + self.x_train.size != 0 and self.y_train.size != 0 + ), "Train data not set, nothing to evaluate." + fitted_classifier = self.model.model + y_predicted = fitted_classifier.predict(self.x_test) + y_probas = fitted_classifier.predict_proba(self.x_test) + self.run = wandb_init( + project=self.wandb_l2g_project_name, + name=wandb_run_name, + config=fitted_classifier.get_params(), + ) + # Track classification plots + plot_classifier( + self.model.model, + self.x_train, + self.x_test, + self.y_train, + self.y_test, + y_predicted, + y_probas, + labels=list(self.model.label_encoder.values()), + model_name="L2G-classifier", + feature_names=self.features_list, + is_binary=True, + ) + # Track evaluation metrics + self.run.log( + { + "areaUnderROC": roc_auc_score( + self.y_test, y_probas[:, 1], average="weighted" + ) + } + ) + self.run.log({"accuracy": accuracy_score(self.y_test, y_predicted)}) + self.run.log( + { + "weightedPrecision": precision_score( + self.y_test, y_predicted, average="weighted" + ) + } + ) + self.run.log( + { + "averagePrecision": average_precision_score( + self.y_test, y_predicted, average="weighted" + ) + } + ) + self.run.log( + { + "weightedRecall": recall_score( + self.y_test, y_predicted, average="weighted" + ) + } + ) + self.run.log({"f1": f1_score(self.y_test, y_predicted, average="weighted")}) + # Track gold standards and their features + self.run.log( + {"featureMatrix": Table(dataframe=self.feature_matrix._df.toPandas())} + ) + # Log feature missingness + self.run.log( + { + "missingnessRates": self.feature_matrix.calculate_feature_missingness_rate() + } + ) + # Plot marginal contribution of each feature + explanation = self._get_shap_explanation(self.model) + self.log_plot_image_to_wandb( + "Feature Contribution", + shap.plots.bar( + explanation, max_display=len(self.features_list), show=False + ), + ) + self.log_plot_image_to_wandb( + "Beeswarm Plot", + shap.plots.beeswarm( + explanation, max_display=len(self.features_list), show=False + ), + ) + # Plot correlation between feature values and their importance + for feature in self.features_list: self.log_plot_image_to_wandb( - "Beeswarm Plot", - shap.plots.beeswarm( - explanation, max_display=len(self.x_train.columns), show=False + f"Effect of {feature} on the predictions", + shap.plots.scatter( + explanation[:, feature], + show=False, ), ) - # Plot correlation between feature values and their importance - for feature in self.features_list: - self.log_plot_image_to_wandb( - f"Effect of {feature} on the predictions", - shap.plots.scatter( - explanation[:, feature], - show=False, - ), - ) - wandb_termlog("Logged Shapley contributions.") - self.run.finish() - else: - raise ValueError("Something went wrong, couldn't log to W&B.") + wandb_termlog("Logged Shapley contributions.") + self.run.finish() def train( self: LocusToGeneTrainer, wandb_run_name: str, + cross_validate: bool = True, + n_splits: int = 5, + hyperparameter_grid: dict[str, Any] | None = None, ) -> LocusToGeneModel: """Train the Locus to Gene model. + If cross_validation is set to True, we implement the following strategy: + 1. Create held-out test set + 2. Perform cross-validation on training set + 3. Train final model on full training set + 4. Evaluate once on test set + Args: wandb_run_name (str): Name of the W&B run. Unless this is provided, the model will not be logged to W&B. + cross_validate (bool): Whether to run cross-validation. Defaults to True. + n_splits(int): Number of folds the data is splitted in. The model is trained and evaluated `k - 1` times. Defaults to 5. + hyperparameter_grid (dict[str, Any] | None): Hyperparameter grid to sweep over. Defaults to None. Returns: LocusToGeneModel: Fitted model """ - data_df = self.feature_matrix._df.drop("geneId", "studyLocusId").toPandas() + data_df = self.feature_matrix._df.toPandas() + # enforce that data_df is a Pandas DataFrame # Encode labels in `goldStandardSet` to a numeric value data_df[self.label_col] = data_df[self.label_col].map(self.model.label_encoder) - # Ensure all columns are numeric and split - data_df = data_df.apply(pd.to_numeric) - X = data_df[self.features_list].copy() - y = data_df[self.label_col].copy() - self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( - X, y, test_size=0.2, random_state=42 - ) + X = data_df[self.features_list].apply(pd.to_numeric).values + y = data_df[self.label_col].apply(pd.to_numeric).values + gene_trait_groups = ( + data_df["traitFromSourceMappedId"].astype(str) + + "_" + + data_df["geneId"].astype(str) + ) # Group identifier has to be a single string + + # Create hold-out test set separating EFO/Gene pairs between train/test + train_test_split = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42) + for train_idx, test_idx in train_test_split.split(X, y, gene_trait_groups): + self.x_train, self.x_test = X[train_idx], X[test_idx] + self.y_train, self.y_test = y[train_idx], y[test_idx] + self.groups_train = gene_trait_groups[train_idx] + + # Cross-validation + if cross_validate: + self.cross_validate( + wandb_run_name=f"{wandb_run_name}-cv", + parameter_grid=hyperparameter_grid, + n_splits=n_splits, + ) - # Train - model = self.fit() + # Train final model on full training set + self.fit() - # Evaluate + # Evaluate once on hold out test set self.log_to_wandb( - wandb_run_name=wandb_run_name, + wandb_run_name=f"{wandb_run_name}-holdout", ) - return model + return self.model - def hyperparameter_tuning( - self: LocusToGeneTrainer, wandb_run_name: str, parameter_grid: dict[str, Any] + def cross_validate( + self: LocusToGeneTrainer, + wandb_run_name: str, + parameter_grid: dict[str, Any] | None = None, + n_splits: int = 5, ) -> None: - """Perform hyperparameter tuning on the model with W&B Sweeps. Metrics for every combination of hyperparameters will be logged to W&B for comparison. + """Log results of cross validation and hyperparameter tuning with W&B Sweeps. Metrics for every combination of hyperparameters will be logged to W&B for comparison. Args: wandb_run_name (str): Name of the W&B run - parameter_grid (dict[str, Any]): Dictionary containing the hyperparameters to sweep over. The keys are the hyperparameter names, and the values are dictionaries containing the values to sweep over. + parameter_grid (dict[str, Any] | None): Dictionary containing the hyperparameters to sweep over. The keys are the hyperparameter names, and the values are dictionaries containing the values to sweep over. + n_splits (int): Number of folds the data is splitted in. The model is trained and evaluated `k - 1` times. Defaults to 5. """ + + def cross_validate_single_fold( + fold_index: int, + sweep_id: str, + sweep_run_name: str, + config: dict[str, Any], + ) -> None: + """Run cross-validation for a single fold. + + Args: + fold_index (int): Index of the fold to run + sweep_id (str): ID of the sweep + sweep_run_name (str): Name of the sweep run + config (dict[str, Any]): Configuration from the sweep + + Raises: + ValueError: If training data is not set + """ + reset_wandb_env() + train_idx, val_idx = cv_splits[fold_index] + + if ( + self.x_train is None + or self.y_train is None + or self.groups_train is None + ): + raise ValueError("Training data not set") + + # Initialize a new run for this fold + os.environ["WANDB_SWEEP_ID"] = sweep_id + run = wandb_init( + project=self.wandb_l2g_project_name, + name=sweep_run_name, + config=config, + group=sweep_run_name, + job_type="fold", + reinit=True, + ) + + x_fold_train, x_fold_val = ( + self.x_train[train_idx], + self.x_train[val_idx], + ) + y_fold_train, y_fold_val = ( + self.y_train[train_idx], + self.y_train[val_idx], + ) + + fold_model = clone(self.model.model) + fold_model.set_params(**config) + fold_model.fit(x_fold_train, y_fold_train) + y_pred_proba = fold_model.predict_proba(x_fold_val)[:, 1] + y_pred = (y_pred_proba >= 0.5).astype(int) + + # Log metrics + metrics = { + "weightedPrecision": precision_score(y_fold_val, y_pred), + "averagePrecision": average_precision_score(y_fold_val, y_pred_proba), + "areaUnderROC": roc_auc_score(y_fold_val, y_pred_proba), + "accuracy": accuracy_score(y_fold_val, y_pred), + "weightedRecall": recall_score(y_fold_val, y_pred, average="weighted"), + "f1": f1_score(y_fold_val, y_pred, average="weighted"), + } + + run.log(metrics) + wandb_termlog(f"Logged metrics for fold {fold_index + 1}.") + run.finish() + + # If no grid is provided, use default ones set in the model + parameter_grid = parameter_grid or { + param: {"values": [value]} + for param, value in self.model.hyperparameters.items() + } sweep_config = { "method": "grid", - "metric": {"name": "roc", "goal": "maximize"}, + "name": wandb_run_name, # Add name to sweep config + "metric": {"name": "areaUnderROC", "goal": "maximize"}, "parameters": parameter_grid, } sweep_id = wandb_sweep(sweep_config, project=self.wandb_l2g_project_name) - wandb_agent(sweep_id, partial(self.train, wandb_run_name=wandb_run_name)) + gkf = GroupKFold(n_splits=n_splits) + cv_splits = list(gkf.split(self.x_train, self.y_train, self.groups_train)) + + def run_all_folds() -> None: + """Run cross-validation for all folds within a sweep.""" + # Initialize the sweep run and get metadata + sweep_run = wandb_init(name=wandb_run_name) + sweep_id = sweep_run.sweep_id or "unknown" + sweep_url = sweep_run.get_sweep_url() + project_url = sweep_run.get_project_url() + sweep_group_url = f"{project_url}/groups/{sweep_id}" + sweep_run.notes = sweep_group_url + sweep_run.save() + config = dict(sweep_run.config) + + # Reset wandb setup to ensure clean state + _setup(_reset=True) + + # Run all folds + for fold_index in range(len(cv_splits)): + cross_validate_single_fold( + fold_index=fold_index, + sweep_id=sweep_id, + sweep_run_name=f"{wandb_run_name}-fold{fold_index+1}", + config=config, + ) + + wandb_termlog(f"Sweep URL: {sweep_url}") + wandb_termlog(f"Sweep Group URL: {sweep_group_url}") + + wandb_agent(sweep_id, run_all_folds)