Skip to content

Commit

Permalink
fix(l2g_predictions): annotate based on list of features + filter out…
Browse files Browse the repository at this point in the history
… missing annotation (#925)

* fix(prediction): do not annotate all features from matrix

* fix(prediction): filter out features with 0

* chore: pre-commit auto fixes [...]
  • Loading branch information
ireneisdoomed authored Dec 5, 2024
1 parent a02f9c1 commit 43f047a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
29 changes: 9 additions & 20 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,13 @@ def to_disease_target_evidence(
)

def add_locus_to_gene_features(
self: L2GPrediction, feature_matrix: L2GFeatureMatrix
self: L2GPrediction, feature_matrix: L2GFeatureMatrix, features_list: list[str]
) -> L2GPrediction:
"""Add features to the L2G predictions.
"""Add features used to extract the L2G predictions.
Args:
feature_matrix (L2GFeatureMatrix): Feature matrix dataset
features_list (list[str]): List of features used in the model
Returns:
L2GPrediction: L2G predictions with additional features
Expand All @@ -143,38 +144,26 @@ def add_locus_to_gene_features(
if "locusToGeneFeatures" in self.df.columns:
self.df = self.df.drop("locusToGeneFeatures")

# Columns identifying a studyLocus/gene pair
prediction_id_columns = ["studyLocusId", "geneId"]

# L2G matrix columns to build the map:
columns_to_map = [
column
for column in feature_matrix._df.columns
if column not in prediction_id_columns
]

# Aggregating all features into a single map column:
aggregated_features = (
feature_matrix._df.withColumn(
"locusToGeneFeatures",
f.create_map(
*sum(
[
(f.lit(colname), f.col(colname))
for colname in columns_to_map
],
((f.lit(feature), f.col(feature)) for feature in features_list),
(),
)
),
)
# from the freshly created map, we filter out the null values
.withColumn(
"locusToGeneFeatures",
f.expr("map_filter(locusToGeneFeatures, (k, v) -> v is not null)"),
f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"),
)
.drop(*columns_to_map)
.drop(*features_list)
)
return L2GPrediction(
_df=self.df.join(aggregated_features, on=prediction_id_columns, how="left"),
_df=self.df.join(
aggregated_features, on=["studyLocusId", "geneId"], how="left"
),
_schema=self.get_schema(),
)
10 changes: 6 additions & 4 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pyspark.sql.functions as f
from sklearn.ensemble import GradientBoostingClassifier
from wandb import login as wandb_login
from wandb.sdk.wandb_login import login as wandb_login

from gentropy.common.schemas import compare_struct_schemas
from gentropy.common.session import Session
Expand Down Expand Up @@ -285,9 +285,11 @@ def run_predict(self) -> None:
)
predictions.filter(
f.col("score") >= self.l2g_threshold
).add_locus_to_gene_features(self.feature_matrix).df.coalesce(
self.session.output_partitions
).write.mode(self.session.write_mode).parquet(self.predictions_path)
).add_locus_to_gene_features(
self.feature_matrix, self.features_list
).df.coalesce(self.session.output_partitions).write.mode(
self.session.write_mode
).parquet(self.predictions_path)
self.session.logger.info("L2G predictions saved successfully.")

def run_train(self) -> None:
Expand Down

0 comments on commit 43f047a

Please sign in to comment.