Skip to content

Commit

Permalink
fix(L2GPrediction): schema validation (#642)
Browse files Browse the repository at this point in the history
* feat(dataset): schema mismatch issue
* feat(L2GPrediction): schema unification
* fix: swapped data types

---------

Co-authored-by: Szymon Szyszkowski <[email protected]>
  • Loading branch information
project-defiant and Szymon Szyszkowski authored Jun 14, 2024
1 parent 3c8ce58 commit 7625a79
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/gentropy/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Feature matrix of study locus pairs annotated with their functional genomics features."""

from __future__ import annotations

from dataclasses import dataclass
Expand Down Expand Up @@ -35,6 +36,7 @@ def __post_init__(self: L2GFeatureMatrix) -> None:
self.features_list = self.features_list or [
col for col in self._df.columns if col not in fixed_cols
]
self.validate_schema()

@classmethod
def generate_features(
Expand Down
6 changes: 4 additions & 2 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ def from_credible_set(
gwas_fm = L2GFeatureMatrix(
_df=(
fm.df.join(
credible_set.filter_by_study_type("gwas", study_index).df,
credible_set.filter_by_study_type("gwas", study_index).df.select(
"studyLocusId"
),
on="studyLocusId",
)
),
_schema=cls.get_schema(),
_schema=L2GFeatureMatrix.get_schema(),
)
return (
L2GPrediction(
Expand Down
2 changes: 2 additions & 0 deletions src/gentropy/dataset/pairwise_ld.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pairwise LD dataset."""

from __future__ import annotations

from dataclasses import dataclass, field
Expand Down Expand Up @@ -37,6 +38,7 @@ def __post_init__(self: PairwiseLD) -> None:
), f"The number of rows in a pairwise LD table has to be square. Got: {row_count}"

self.dimension = (int(sqrt(row_count)), int(sqrt(row_count)))
self.validate_schema()

@classmethod
def get_schema(cls: type[PairwiseLD]) -> StructType:
Expand Down
24 changes: 23 additions & 1 deletion tests/gentropy/dataset/test_l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import TYPE_CHECKING

import pytest
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.l2g_gold_standard import L2GGoldStandard
from gentropy.dataset.l2g_prediction import L2GPrediction
Expand Down Expand Up @@ -152,6 +153,27 @@ def test_remove_false_negatives(spark: SparkSession) -> None:
assert observed_df.collect() == expected_df.collect()


def test_l2g_feature_constructor_with_schema_mismatch(spark: SparkSession) -> None:
"""Test if provided shema mismatch results in error in L2GFeatureMatrix constructor.
distanceTssMean is expected to be FLOAT by schema in src.gentropy.assets.schemas and is actualy DOUBLE.
"""
with pytest.raises(ValueError) as e:
L2GFeatureMatrix(
_df=spark.createDataFrame(
[
(1, "gene1", 100.0),
(2, "gene2", 1000.0),
],
"studyLocusId LONG, geneId STRING, distanceTssMean DOUBLE",
),
_schema=L2GFeatureMatrix.get_schema(),
)
assert e.value.args[0] == (
"The following fields present differences in their datatypes: ['distanceTssMean']."
)


def test_calculate_feature_missingness_rate(spark: SparkSession) -> None:
"""Test L2GFeatureMatrix.calculate_feature_missingness_rate."""
fm = L2GFeatureMatrix(
Expand All @@ -160,7 +182,7 @@ def test_calculate_feature_missingness_rate(spark: SparkSession) -> None:
(1, "gene1", 100.0, None),
(2, "gene2", 1000.0, 0.0),
],
"studyLocusId LONG, geneId STRING, distanceTssMean DOUBLE, distanceTssMinimum DOUBLE",
"studyLocusId LONG, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT",
),
_schema=L2GFeatureMatrix.get_schema(),
)
Expand Down

0 comments on commit 7625a79

Please sign in to comment.