Skip to content

Commit

Permalink
feat(l2g)!: implement new training strategy splitting between EFO/gen…
Browse files Browse the repository at this point in the history
…e pairs and with cross validation (#938)

* feat(gold_standard): add traitFromSourceMappedId to schema

* chore: adapt tests

* feat(feature_matrix): consider `traitFromSourceMappedId` a static column

* feat(feature_matrix): consider `traitFromSourceMappedId` an optional column

* feat: update l2g config with best hyperparams

* feat(trainer): new train runs when cross_validate=False

* chore(model): add default hyperparams based on best params

* chore: debug sweep, one single run

* feat(trainer): new train runs when cross_validate=True

* feat(cross_validate): sweep runs are now together

* chore: pre-commit auto fixes [...]

* chore: improve error message
  • Loading branch information
ireneisdoomed authored Dec 9, 2024
1 parent 43f047a commit 79f6fcc
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 129 deletions.
10 changes: 7 additions & 3 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,24 @@ class LocusToGeneConfig(StepConfig):
"geneCount500kb",
"proteinGeneCount500kb",
"credibleSetConfidence",
# "isProteinCoding",
]
)
hyperparameters: dict[str, Any] = field(
default_factory=lambda: {
"n_estimators": 100,
"max_depth": 5,
"loss": "log_loss",
"max_depth": 10,
"ccp_alpha": 0,
"learning_rate": 0.1,
"min_samples_leaf": 5,
"min_samples_split": 5,
"subsample": 1,
}
)
wandb_run_name: str | None = None
hf_hub_repo_id: str | None = "opentargets/locus_to_gene"
hf_model_commit_message: str | None = "chore: update model"
download_from_hub: bool = True
cross_validate: bool = True
_target_: str = "gentropy.l2g.LocusToGeneStep"


Expand Down
15 changes: 9 additions & 6 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ class LocusToGeneStep:
def __init__(
self,
session: Session,
hyperparameters: dict[str, Any],
*,
run_mode: str,
features_list: list[str],
hyperparameters: dict[str, Any],
download_from_hub: bool,
cross_validate: bool,
wandb_run_name: str,
credible_set_path: str,
feature_matrix_path: str,
Expand All @@ -113,18 +114,19 @@ def __init__(
variant_index_path: str | None = None,
gene_interactions_path: str | None = None,
predictions_path: str | None = None,
l2g_threshold: float | None,
hf_hub_repo_id: str | None,
l2g_threshold: float | None = None,
hf_hub_repo_id: str | None = None,
hf_model_commit_message: str | None = "chore: update model",
) -> None:
"""Initialise the step and run the logic based on mode.
Args:
session (Session): Session object that contains the Spark session
hyperparameters (dict[str, Any]): Hyperparameters for the model
run_mode (str): Run mode, either 'train' or 'predict'
features_list (list[str]): List of features to use for the model
hyperparameters (dict[str, Any]): Hyperparameters for the model
download_from_hub (bool): Whether to download the model from Hugging Face Hub
cross_validate (bool): Whether to run cross validation (5-fold by default) to train the model.
wandb_run_name (str): Name of the run to track model training in Weights and Biases
credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix
feature_matrix_path (str): Path to the L2G feature matrix input dataset
Expand Down Expand Up @@ -152,6 +154,7 @@ def __init__(
self.features_list = list(features_list)
self.hyperparameters = dict(hyperparameters)
self.wandb_run_name = wandb_run_name
self.cross_validate = cross_validate
self.hf_hub_repo_id = hf_hub_repo_id
self.download_from_hub = download_from_hub
self.hf_model_commit_message = hf_model_commit_message
Expand Down Expand Up @@ -300,7 +303,7 @@ def run_train(self) -> None:

# Instantiate classifier and train model
l2g_model = LocusToGeneModel(
model=GradientBoostingClassifier(random_state=42),
model=GradientBoostingClassifier(random_state=42, loss="log_loss"),
hyperparameters=self.hyperparameters,
)

Expand All @@ -310,7 +313,7 @@ def run_train(self) -> None:
# Run the training
trained_model = LocusToGeneTrainer(
model=l2g_model, feature_matrix=feature_matrix
).train(self.wandb_run_name)
).train(self.wandb_run_name, cross_validate=self.cross_validate)

# Export the model
if trained_model.training_data and trained_model.model and self.model_path:
Expand Down
15 changes: 12 additions & 3 deletions src/gentropy/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@ class LocusToGeneModel:
"""Wrapper for the Locus to Gene classifier."""

model: Any = GradientBoostingClassifier(random_state=42)
hyperparameters: dict[str, Any] | None = None
hyperparameters: dict[str, Any] = field(
default_factory=lambda: {
"n_estimators": 100,
"max_depth": 10,
"ccp_alpha": 0,
"learning_rate": 0.1,
"min_samples_leaf": 5,
"min_samples_split": 5,
"subsample": 1,
}
)
training_data: L2GFeatureMatrix | None = None
label_encoder: dict[str, int] = field(
default_factory=lambda: {
Expand All @@ -38,8 +48,7 @@ class LocusToGeneModel:

def __post_init__(self: LocusToGeneModel) -> None:
"""Post-initialisation to fit the estimator with the provided params."""
if self.hyperparameters:
self.model.set_params(**self.hyperparameters_dict)
self.model.set_params(**self.hyperparameters_dict)

@classmethod
def load_from_disk(cls: Type[LocusToGeneModel], path: str) -> LocusToGeneModel:
Expand Down
Loading

0 comments on commit 79f6fcc

Please sign in to comment.