Skip to content

Commit

Permalink
feat(l2g): better l2g training, evaluation, and integration (#576)
Browse files Browse the repository at this point in the history
* chore: checkpoint

* chore: checkpoint

* chore: deprecate spark evaluator

* chore: checkpoint

* chore: resolve conflicts with dev

* chore: resolve conflicts with dev

* chore(model): add parameters class property

* feat: add module to export model to hub

* refactor: make model agnostic of features list

* chore: add wandb to gitignore

* feat: download model from hub

* chore(model): adapt predict method

* feat(trainer): add hyperparameter tuning

* chore: deprecate trainer tests

* refactor: modularise step

* feat: download model from hub by default

* fix: convert omegaconfig defaults to python objects

* fix: write serialised model to disk and then upload to gcs

* fix(matrix): drop goldStandardSet when in predict mode

* chore: pass token to access private model

* chore: pass token to access private model

* fix: pass right schema

* chore: pre-commit auto fixes [...]

* chore: fix mypy issues

* build: remove xgboost

* chore: merge

* chore: pre-commit auto fixes [...]

* chore: address comments
  • Loading branch information
ireneisdoomed authored Jun 24, 2024
1 parent a2a7a82 commit 0d9160f
Show file tree
Hide file tree
Showing 16 changed files with 961 additions and 1,060 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ src/airflow/logs/*
site/
.env
.coverage*
wandb/
2 changes: 1 addition & 1 deletion config/datasets/ot_gcp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ from_sumstats_pics: ${datasets.credible_set}/from_sumstats

# ETL output datasets:
l2g_gold_standard_curation: ${datasets.release_folder}/locus_to_gene_gold_standard.json
l2g_model: ${datasets.release_folder}/locus_to_gene_model
l2g_model: ${datasets.release_folder}/locus_to_gene_model/classifier.skops
l2g_predictions: ${datasets.release_folder}/locus_to_gene_predictions
l2g_feature_matrix: ${datasets.release_folder}/locus_to_gene_feature_matrix
colocalisation: ${datasets.release_folder}/colocalisation
Expand Down
2 changes: 1 addition & 1 deletion config/step/ot_locus_to_gene_predict.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- locus_to_gene

run_mode: predict
model_path: ${datasets.l2g_model}
model_path: null
predictions_path: ${datasets.l2g_predictions}
feature_matrix_path: ${datasets.l2g_feature_matrix}
credible_set_path: ${datasets.credible_set}
Expand Down
6 changes: 4 additions & 2 deletions config/step/ot_locus_to_gene_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:

run_mode: train
wandb_run_name: null
perform_cross_validation: false
hf_hub_repo_id: opentargets/locus_to_gene
model_path: ${datasets.l2g_model}
predictions_path: ${datasets.l2g_predictions}
credible_set_path: ${datasets.credible_set}
Expand All @@ -13,5 +13,7 @@ study_index_path: ${datasets.study_index}
gold_standard_curation_path: ${datasets.l2g_gold_standard_curation}
gene_interactions_path: ${datasets.gene_interactions}
hyperparameters:
n_estimators: 100
max_depth: 5
loss_function: binary:logistic
loss: log_loss
download_from_hub: true
5 changes: 0 additions & 5 deletions docs/python_api/methods/l2g/evaluator.md

This file was deleted.

592 changes: 320 additions & 272 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pyspark = "3.3.4"
scipy = "^1.11.4"
hydra-core = "^1.3.2"
pyliftover = "^0.4"
xgboost = "^1.7.3"
numpy = "^1.26.2"
hail = "0.2.127"
wandb = ">=0.16.2,<0.18.0"
Expand All @@ -33,6 +32,7 @@ omegaconf = "^2.3.0"
typing-extensions = "^4.9.0"
scikit-learn = "^1.3.2"
pandas = {extras = ["gcp", "parquet"], version = "^2.2.2"}
skops = "^0.9.0"
google-cloud-secret-manager = "^2.20.0"

[tool.poetry.dev-dependencies]
Expand Down
23 changes: 23 additions & 0 deletions src/gentropy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,26 @@ def access_gcp_secret(secret_id: str, project_id: str) -> str:
name = f"projects/{project_id}/secrets/{secret_id}/versions/latest"
response = client.access_secret_version(name=name)
return response.payload.data.decode("UTF-8")


def copy_to_gcs(source_path: str, destination_blob: str) -> None:
"""Copy a file to a Google Cloud Storage bucket.
Args:
source_path (str): Path to the local file to copy
destination_blob (str): GS path to the destination blob in the GCS bucket
Raises:
ValueError: If the path is a directory
"""
import os
from urllib.parse import urlparse

from google.cloud import storage

if os.path.isdir(source_path):
raise ValueError("Path should be a file, not a directory.")
client = storage.Client()
bucket = client.bucket(bucket_name=urlparse(destination_blob).hostname)
blob = bucket.blob(blob_name=urlparse(destination_blob).path.lstrip("/"))
blob.upload_from_filename(source_path)
36 changes: 21 additions & 15 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,12 @@ class LocusToGeneConfig(StepConfig):
}
)
run_mode: str = MISSING
model_path: str = MISSING
predictions_path: str = MISSING
credible_set_path: str = MISSING
variant_gene_path: str = MISSING
colocalisation_path: str = MISSING
study_index_path: str = MISSING
model_path: str | None = None
feature_matrix_path: str | None = None
gold_standard_curation_path: str | None = None
gene_interactions_path: str | None = None
Expand Down Expand Up @@ -248,28 +248,34 @@ class LocusToGeneConfig(StepConfig):
"tuqtlColocClppMaximum",
# max clpp for each (study, locus, gene) aggregating over all tuQTLs
"tuqtlColocClppMaximumNeighborhood",
# # max log-likelihood ratio value for each (study, locus, gene) aggregating over all eQTLs
# "eqtlColocLlrLocalMaximum",
# # max log-likelihood ratio value for each (study, locus) aggregating over all eQTLs
# "eqtlColocLlpMaximumNeighborhood",
# # max log-likelihood ratio value for each (study, locus, gene) aggregating over all pQTLs
# "pqtlColocLlrLocalMaximum",
# # max log-likelihood ratio value for each (study, locus) aggregating over all pQTLs
# "pqtlColocLlpMaximumNeighborhood",
# # max log-likelihood ratio value for each (study, locus, gene) aggregating over all sQTLs
# "sqtlColocLlrLocalMaximum",
# # max log-likelihood ratio value for each (study, locus) aggregating over all sQTLs
# "sqtlColocLlpMaximumNeighborhood",
# max log-likelihood ratio value for each (study, locus, gene) aggregating over all eQTLs
"eqtlColocLlrMaximum",
# max log-likelihood ratio value for each (study, locus) aggregating over all eQTLs
"eqtlColocLlrMaximumNeighborhood",
# max log-likelihood ratio value for each (study, locus, gene) aggregating over all pQTLs
"pqtlColocLlrMaximum",
# max log-likelihood ratio value for each (study, locus) aggregating over all pQTLs
"pqtlColocLlrMaximumNeighborhood",
# max log-likelihood ratio value for each (study, locus, gene) aggregating over all sQTLs
"sqtlColocLlrMaximum",
# max log-likelihood ratio value for each (study, locus) aggregating over all sQTLs
"sqtlColocLlrMaximumNeighborhood",
# max log-likelihood ratio value for each (study, locus, gene) aggregating over all tuQTLs
"tuqtlColocLlrMaximum",
# max log-likelihood ratio value for each (study, locus) aggregating over all tuQTLs
"tuqtlColocLlrMaximumNeighborhood",
]
)
hyperparameters: dict[str, Any] = field(
default_factory=lambda: {
"n_estimators": 100,
"max_depth": 5,
"loss_function": "binary:logistic",
"loss": "log_loss",
}
)
wandb_run_name: str | None = None
perform_cross_validation: bool = False
hf_hub_repo_id: str | None = "opentargets/locus_to_gene"
download_from_hub: bool = True
_target_: str = "gentropy.l2g.LocusToGeneStep"


Expand Down
48 changes: 23 additions & 25 deletions src/gentropy/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import reduce
from typing import TYPE_CHECKING, Type

Expand All @@ -26,15 +26,26 @@ class L2GFeatureMatrix(Dataset):
Attributes:
features_list (list[str] | None): List of features to use. If None, all possible features are used.
fixed_cols (list[str]): Columns that should be kept fixed in the feature matrix, although not considered as features.
mode (str): Mode of the feature matrix. Defaults to "train". Can be either "train" or "predict".
"""

features_list: list[str] | None = None
fixed_cols: list[str] = field(default_factory=lambda: ["studyLocusId", "geneId"])
mode: str = "train"

def __post_init__(self: L2GFeatureMatrix) -> None:
"""Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used."""
fixed_cols = ["studyLocusId", "geneId", "goldStandardSet"]
"""Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used.
Raises:
ValueError: If the mode is neither 'train' nor 'predict'.
"""
if self.mode not in ["train", "predict"]:
raise ValueError("Mode should be either 'train' or 'predict'")
if self.mode == "train":
self.fixed_cols = self.fixed_cols + ["goldStandardSet"]
self.features_list = self.features_list or [
col for col in self._df.columns if col not in fixed_cols
col for col in self._df.columns if col not in self.fixed_cols
]
self.validate_schema()

Expand Down Expand Up @@ -138,7 +149,8 @@ def fill_na(
return self

def select_features(
self: L2GFeatureMatrix, features_list: list[str] | None
self: L2GFeatureMatrix,
features_list: list[str] | None,
) -> L2GFeatureMatrix:
"""Select a subset of features from the feature matrix.
Expand All @@ -147,25 +159,11 @@ def select_features(
Returns:
L2GFeatureMatrix: L2G feature matrix dataset
"""
features_list = features_list or self.features_list
fixed_cols = ["studyLocusId", "geneId", "goldStandardSet"]
self.df = self._df.select(fixed_cols + features_list) # type: ignore
return self
def train_test_split(
self: L2GFeatureMatrix, fraction: float
) -> tuple[L2GFeatureMatrix, L2GFeatureMatrix]:
"""Split the dataset into training and test sets.
Args:
fraction (float): Fraction of the dataset to use for training
Returns:
tuple[L2GFeatureMatrix, L2GFeatureMatrix]: Training and test datasets
Raises:
ValueError: If no features have been selected.
"""
train, test = self._df.randomSplit([fraction, 1 - fraction], seed=42)
return (
L2GFeatureMatrix(_df=train, _schema=L2GFeatureMatrix.get_schema()),
L2GFeatureMatrix(_df=test, _schema=L2GFeatureMatrix.get_schema()),
)
if features_list := features_list or self.features_list:
self.df = self._df.select(self.fixed_cols + features_list)
return self
raise ValueError("features_list cannot be None")
65 changes: 33 additions & 32 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Type

import pyspark.sql.functions as f
from pyspark.ml.functions import vector_to_array

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.session import Session
from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
Expand Down Expand Up @@ -42,26 +40,41 @@ def get_schema(cls: type[L2GPrediction]) -> StructType:
@classmethod
def from_credible_set(
cls: Type[L2GPrediction],
model_path: str,
features_list: list[str],
credible_set: StudyLocus,
study_index: StudyIndex,
v2g: V2G,
coloc: Colocalisation,
session: Session,
model_path: str | None,
hf_token: str | None = None,
download_from_hub: bool = True,
) -> tuple[L2GPrediction, L2GFeatureMatrix]:
"""Extract L2G predictions for a set of credible sets derived from GWAS.
Args:
model_path (str): Path to the fitted model
features_list (list[str]): List of features to use for the model
credible_set (StudyLocus): Credible set dataset
study_index (StudyIndex): Study index dataset
v2g (V2G): Variant to gene dataset
coloc (Colocalisation): Colocalisation dataset
session (Session): Session object that contains the Spark session
model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name).
hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private.
download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True.
Returns:
tuple[L2GPrediction, L2GFeatureMatrix]: L2G dataset and feature matrix limited to GWAS study only.
"""
# Load the model
if download_from_hub:
# Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops".
model_id = model_path or "opentargets/locus_to_gene"
l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token)
elif model_path:
l2g_model = LocusToGeneModel.load_from_disk(model_path)

# Prepare data
fm = L2GFeatureMatrix.generate_features(
features_list=features_list,
credible_set=credible_set,
Expand All @@ -70,35 +83,23 @@ def from_credible_set(
colocalisation=coloc,
).fill_na()

gwas_fm = L2GFeatureMatrix(
_df=(
fm.df.join(
credible_set.filter_by_study_type("gwas", study_index).df.select(
"studyLocusId"
),
on="studyLocusId",
)
),
_schema=L2GFeatureMatrix.get_schema(),
)
return (
L2GPrediction(
# Load and apply fitted model
gwas_fm = (
L2GFeatureMatrix(
_df=(
LocusToGeneModel.load_from_disk(
model_path,
features_list=features_list,
)
.predict(gwas_fm)
# the probability of the positive class is the second element inside the probability array
# - this is selected as the L2G probability
.select(
"studyLocusId",
"geneId",
vector_to_array(f.col("probability"))[1].alias("score"),
fm.df.join(
credible_set.filter_by_study_type(
"gwas", study_index
).df.select("studyLocusId"),
on="studyLocusId",
)
),
_schema=cls.get_schema(),
),
_schema=L2GFeatureMatrix.get_schema(),
mode="predict",
)
.select_features(features_list)
.persist()
)
return (
l2g_model.predict(gwas_fm, session),
gwas_fm,
)
Loading

0 comments on commit 0d9160f

Please sign in to comment.