Skip to content

Commit

Permalink
Added working tests for interval + nbh features
Browse files Browse the repository at this point in the history
  • Loading branch information
xyg123 committed Dec 11, 2024
1 parent 1de5fcf commit c332d93
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,17 @@ class LocusToGeneConfig(StepConfig):
"vepMaximumNeighbourhood",
"vepMean",
"vepMeanNeighbourhood",
# other
"geneCount500kb",
"proteinGeneCount500kb",
"credibleSetConfidence",
# intervals
"pchicMean",
"pchicMeanNeighbourhood",
"enhTssMean",
"enhTssMeanNeighbourhood",
"dhsPmtrMean",
"dhsPmtrMeanNeighbourhood",
# other
"geneCount500kb",
"proteinGeneCount500kb",
"credibleSetConfidence",
]
)
hyperparameters: dict[str, Any] = field(
Expand Down
12 changes: 7 additions & 5 deletions src/gentropy/dataset/l2g_features/intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def common_interval_feature_logic(
),
)
.join(
intervals.df.filter(f.col("datasourceId") == interval_source)
.withColumnRenamed("variantId", "variantInLocusId")
.withColumnRenamed("targetId", "geneId"),
on=["variantInLocusId", "geneId"],
intervals.df.filter(
f.col("datasourceId") == interval_source
).withColumnRenamed("variantId", "variantInLocusId"),
# .withColumnRenamed("targetId", "geneId"),
# on=["variantInLocusId", "geneId"],
on="variantInLocusId",
how="inner",
)
.withColumn(
Expand Down Expand Up @@ -101,7 +103,7 @@ def common_neighbourhood_interval_feature_logic(
f.mean(local_feature_name).over(Window.partitionBy("studyLocusId")),
)
.withColumn(feature_name, f.col(local_feature_name) - f.col("regional_mean"))
.drop("regional_mean")
.drop("regional_mean", local_feature_name)
)


Expand Down
162 changes: 162 additions & 0 deletions tests/gentropy/dataset/test_l2g_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@
common_neighbourhood_vep_feature_logic,
common_vep_feature_logic,
)
from gentropy.dataset.l2g_features.intervals import (
PchicMeanFeature,
PchicMeanNeighbourhoodFeature,
EnhTssMeanFeature,
EnhTssMeanNeighbourhoodFeature,
common_interval_feature_logic,
common_neighbourhood_interval_feature_logic,
)
from gentropy.dataset.l2g_features.other import (
common_genecount_feature_logic,
is_protein_coding_feature_logic,
Expand All @@ -72,6 +80,7 @@
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.variant_index import VariantIndex
from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader
from gentropy.dataset.intervals import Intervals

if TYPE_CHECKING:
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -104,6 +113,10 @@
VepMeanFeature,
VepMaximumNeighbourhoodFeature,
VepMeanNeighbourhoodFeature,
PchicMeanFeature,
PchicMeanNeighbourhoodFeature,
EnhTssMeanFeature,
EnhTssMeanNeighbourhoodFeature,
GeneCountFeature,
ProteinGeneCountFeature,
CredibleSetConfidenceFeature,
Expand All @@ -117,6 +130,7 @@ def test_feature_factory_return_type(
mock_study_index: StudyIndex,
mock_variant_index: VariantIndex,
mock_gene_index: GeneIndex,
mock_intervals: Intervals,
) -> None:
"""Test that every feature factory returns a L2GFeature dataset."""
loader = L2GFeatureInputLoader(
Expand All @@ -125,6 +139,7 @@ def test_feature_factory_return_type(
variant_index=mock_variant_index,
study_locus=mock_study_locus,
gene_index=mock_gene_index,
intervals=mock_intervals,
)
feature_dataset = feature_class.compute(
study_loci_to_annotate=mock_study_locus,
Expand Down Expand Up @@ -681,6 +696,153 @@ def _setup(
)


class TestCommonIntervalFeatureLogic:
"""Test the CommonIntervalFeatureLogic methods."""

@pytest.mark.parametrize(
("feature_name", "interval_source", "expected_data"),
[
(
"pchicMean",
"javierre2016",
[
{
"studyLocusId": "1",
"geneId": "gene1",
"pchicMean": 0.4,
},
{
"studyLocusId": "1",
"geneId": "gene2",
"pchicMean": 0.6,
},
],
),
],
)
def test_common_interval_feature_logic(
self: TestCommonIntervalFeatureLogic,
spark: SparkSession,
feature_name: str,
interval_source: str,
expected_data: dict[str, Any],
) -> None:
"""Test the logic of the function that extracts features from intervals."""
observed_df = (
common_interval_feature_logic(
self.sample_study_locus,
intervals=self.sample_intervals,
feature_name=feature_name,
interval_source=interval_source,
)
.withColumn(feature_name, f.round(f.col(feature_name), 2))
.orderBy(feature_name)
)
expected_df = (
spark.createDataFrame(expected_data)
.select("studyLocusId", "geneId", feature_name)
.orderBy(feature_name)
)
assert (
observed_df.collect() == expected_df.collect()
), f"Expected and observed dataframes are not equal for feature {feature_name}."

def test_common_neighbourhood_interval_feature_logic(
self: TestCommonIntervalFeatureLogic,
spark: SparkSession,
) -> None:
"""Test the logic of the function that computes neighbourhood interval scores."""
feature_name = "pchicMeanNeighbourhood"
observed_df = (
common_neighbourhood_interval_feature_logic(
self.sample_study_locus,
intervals=self.sample_intervals,
feature_name=feature_name,
interval_source="javierre2016",
)
.withColumn(feature_name, f.round(f.col(feature_name), 2))
.orderBy(f.col(feature_name).asc())
)
expected_df = (
spark.createDataFrame(
[
{"studyLocusId": "1", "geneId": "gene1", feature_name: -0.1},
{"studyLocusId": "1", "geneId": "gene2", feature_name: 0.1},
]
)
.orderBy(feature_name)
.select("studyLocusId", "geneId", feature_name)
)
assert (
observed_df.collect() == expected_df.collect()
), "Output doesn't meet the expectation."

@pytest.fixture(autouse=True)
def _setup(
self: TestCommonIntervalFeatureLogic,
spark: SparkSession,
) -> None:
"""Set up testing fixtures."""
self.sample_study_locus = StudyLocus(
_df=spark.createDataFrame(
[
{
"studyLocusId": "1",
"variantId": "lead1",
"studyId": "study1",
"locus": [
{
"variantId": "lead1",
"posteriorProbability": 0.5,
},
{
"variantId": "tag1",
"posteriorProbability": 0.5,
},
],
"chromosome": "1",
},
],
StudyLocus.get_schema(),
),
_schema=StudyLocus.get_schema(),
)
self.sample_intervals = Intervals(
_df=spark.createDataFrame(
[
{
"chromosome": "1",
"start": "1000000",
"end": "1005000",
"geneId": "gene1",
"variantId": "lead1",
"resourceScore": 0.8,
"score": 0.95,
"datasourceId": "javierre2016",
"datatypeId": "pchic", # Required field
"pmid": "12345678", # Example PubMed ID
"biofeature": "enhancer", # Descriptive feature
},
{
"chromosome": "1",
"start": "1000000",
"end": "1005000",
"geneId": "gene2",
"variantId": "tag1",
"resourceScore": 1.2,
"score": 1.1,
"datasourceId": "javierre2016",
"datatypeId": "pchic", # Required field
"pmid": "87654321", # Example PubMed ID
"biofeature": "tss", # Descriptive feature
},
],
Intervals.get_schema(),
),
_schema=Intervals.get_schema(),
)


class TestCommonVepFeatureLogic:
"""Test the common_vep_feature_logic methods."""

Expand Down

0 comments on commit c332d93

Please sign in to comment.