Skip to content

Commit

Permalink
refactor: rename gene_index to target_index in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vivienho committed Dec 6, 2024
1 parent 26abb1b commit 19bb7c9
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 90 deletions.
12 changes: 6 additions & 6 deletions tests/gentropy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gentropy.common.session import Session
from gentropy.dataset.biosample_index import BiosampleIndex
from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.intervals import Intervals
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.l2g_gold_standard import L2GGoldStandard
Expand All @@ -25,6 +24,7 @@
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.study_locus_overlap import StudyLocusOverlap
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.dataset.target_index import TargetIndex
from gentropy.dataset.variant_index import VariantIndex
from gentropy.datasource.eqtl_catalogue.finemapping import EqtlCatalogueFinemapping
from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex
Expand Down Expand Up @@ -379,7 +379,7 @@ def mock_summary_statistics(

@pytest.fixture()
def mock_ld_index(spark: SparkSession) -> LDIndex:
"""Mock gene index."""
"""Mock ld index."""
ld_schema = LDIndex.get_schema()

data_spec = (
Expand Down Expand Up @@ -519,9 +519,9 @@ def sample_target_index(spark: SparkSession) -> DataFrame:


@pytest.fixture()
def mock_gene_index(spark: SparkSession) -> GeneIndex:
"""Mock gene index dataset."""
gi_schema = GeneIndex.get_schema()
def mock_target_index(spark: SparkSession) -> TargetIndex:
"""Mock target index dataset."""
gi_schema = TargetIndex.get_schema()

data_spec = (
dg.DataGenerator(
Expand All @@ -540,7 +540,7 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
.withColumnSpec("strand", percentNulls=0.1)
)

return GeneIndex(_df=data_spec.build(), _schema=gi_schema)
return TargetIndex(_df=data_spec.build(), _schema=gi_schema)


@pytest.fixture()
Expand Down
54 changes: 27 additions & 27 deletions tests/gentropy/dataset/test_l2g_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.target_index import TargetIndex
from gentropy.dataset.l2g_features.colocalisation import (
EQtlColocClppMaximumFeature,
EQtlColocClppMaximumNeighbourhoodFeature,
Expand Down Expand Up @@ -116,15 +116,15 @@ def test_feature_factory_return_type(
mock_colocalisation: Colocalisation,
mock_study_index: StudyIndex,
mock_variant_index: VariantIndex,
mock_gene_index: GeneIndex,
mock_target_index: TargetIndex,
) -> None:
"""Test that every feature factory returns a L2GFeature dataset."""
loader = L2GFeatureInputLoader(
colocalisation=mock_colocalisation,
study_index=mock_study_index,
variant_index=mock_variant_index,
study_locus=mock_study_locus,
gene_index=mock_gene_index,
target_index=mock_target_index,
)
feature_dataset = feature_class.compute(
study_loci_to_annotate=mock_study_locus,
Expand All @@ -136,9 +136,9 @@ def test_feature_factory_return_type(


@pytest.fixture(scope="module")
def sample_gene_index(spark: SparkSession) -> GeneIndex:
"""Create a sample gene index for testing."""
return GeneIndex(
def sample_target_index(spark: SparkSession) -> TargetIndex:
"""Create a sample target index for testing."""
return TargetIndex(
_df=spark.createDataFrame(
[
{
Expand All @@ -157,9 +157,9 @@ def sample_gene_index(spark: SparkSession) -> GeneIndex:
"chromosome": "1",
},
],
GeneIndex.get_schema(),
TargetIndex.get_schema(),
),
_schema=GeneIndex.get_schema(),
_schema=TargetIndex.get_schema(),
)


Expand Down Expand Up @@ -294,7 +294,7 @@ def test__common_colocalisation_feature_logic(
def test_extend_missing_colocalisation_to_neighbourhood_genes(
self: TestCommonColocalisationFeatureLogic,
spark: SparkSession,
sample_gene_index: GeneIndex,
sample_target_index: TargetIndex,
sample_variant_index: VariantIndex,
) -> None:
"""Test the extend_missing_colocalisation_to_neighbourhood_genes function."""
Expand All @@ -316,7 +316,7 @@ def test_extend_missing_colocalisation_to_neighbourhood_genes(
feature_name="eQtlColocH4Maximum",
local_features=local_features,
variant_index=sample_variant_index,
gene_index=sample_gene_index,
target_index=sample_target_index,
study_locus=self.sample_study_locus,
).select("studyLocusId", "geneId", "eQtlColocH4Maximum")
expected_df = spark.createDataFrame(
Expand All @@ -329,7 +329,7 @@ def test_extend_missing_colocalisation_to_neighbourhood_genes(
def test_common_neighbourhood_colocalisation_feature_logic(
self: TestCommonColocalisationFeatureLogic,
spark: SparkSession,
sample_gene_index: GeneIndex,
sample_target_index: TargetIndex,
sample_variant_index: VariantIndex,
) -> None:
"""Test the common logic of the neighbourhood colocalisation features."""
Expand All @@ -343,7 +343,7 @@ def test_common_neighbourhood_colocalisation_feature_logic(
colocalisation=self.sample_colocalisation,
study_index=self.sample_studies,
study_locus=self.sample_study_locus,
gene_index=sample_gene_index,
target_index=sample_target_index,
variant_index=sample_variant_index,
).withColumn(feature_name, f.round(f.col(feature_name), 3))
# expected max is 0.81
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_common_neighbourhood_distance_feature_logic(
common_neighbourhood_distance_feature_logic(
self.sample_study_locus,
variant_index=self.sample_variant_index,
gene_index=self.sample_gene_index,
target_index=self.sample_target_index,
feature_name=feature_name,
distance_type=self.distance_type,
genomic_window=10,
Expand Down Expand Up @@ -653,7 +653,7 @@ def _setup(
),
_schema=VariantIndex.get_schema(),
)
self.sample_gene_index = GeneIndex(
self.sample_target_index = TargetIndex(
_df=spark.createDataFrame(
[
{
Expand All @@ -675,9 +675,9 @@ def _setup(
"biotype": "non_coding",
},
],
GeneIndex.get_schema(),
TargetIndex.get_schema(),
),
_schema=GeneIndex.get_schema(),
_schema=TargetIndex.get_schema(),
)


Expand Down Expand Up @@ -760,7 +760,7 @@ def test_common_vep_feature_logic(
def test_common_neighbourhood_vep_feature_logic(
self: TestCommonVepFeatureLogic,
spark: SparkSession,
sample_gene_index: GeneIndex,
sample_target_index: TargetIndex,
sample_variant_index: VariantIndex,
) -> None:
"""Test the logic of the function that extracts the maximum severity score for a gene given the maximum of the maximum scores for all protein coding genes in the vicinity."""
Expand All @@ -769,7 +769,7 @@ def test_common_neighbourhood_vep_feature_logic(
common_neighbourhood_vep_feature_logic(
self.sample_study_locus,
variant_index=sample_variant_index,
gene_index=sample_gene_index,
target_index=sample_target_index,
feature_name=feature_name,
)
.withColumn(feature_name, f.round(f.col(feature_name), 2))
Expand Down Expand Up @@ -859,7 +859,7 @@ def test_common_genecount_feature_logic(
"""Test the common logic of the gene count features."""
observed_df = common_genecount_feature_logic(
study_loci_to_annotate=self.sample_study_locus,
gene_index=self.sample_gene_index,
target_index=self.sample_target_index,
feature_name=feature_name,
genomic_window=500000,
protein_coding_only=protein_coding_only,
Expand Down Expand Up @@ -892,7 +892,7 @@ def _setup(self: TestCommonGeneCountFeatureLogic, spark: SparkSession) -> None:
),
_schema=StudyLocus.get_schema(),
)
self.sample_gene_index = GeneIndex(
self.sample_target_index = TargetIndex(
_df=spark.createDataFrame(
[
{
Expand All @@ -914,9 +914,9 @@ def _setup(self: TestCommonGeneCountFeatureLogic, spark: SparkSession) -> None:
"biotype": "non_coding",
},
],
GeneIndex.get_schema(),
TargetIndex.get_schema(),
),
_schema=GeneIndex.get_schema(),
_schema=TargetIndex.get_schema(),
)


Expand Down Expand Up @@ -944,7 +944,7 @@ def test_is_protein_coding_feature_logic(
observed_df = (
is_protein_coding_feature_logic(
study_loci_to_annotate=self.sample_study_locus,
gene_index=self.sample_gene_index,
target_index=self.sample_target_index,
feature_name="isProteinCoding500kb",
genomic_window=500000,
)
Expand Down Expand Up @@ -981,8 +981,8 @@ def _setup(self: TestCommonProteinCodingFeatureLogic, spark: SparkSession) -> No
_schema=StudyLocus.get_schema(),
)

# Sample gene index data with biotype
self.sample_gene_index = GeneIndex(
# Sample target index data with biotype
self.sample_target_index = TargetIndex(
_df=spark.createDataFrame(
[
{
Expand All @@ -1004,9 +1004,9 @@ def _setup(self: TestCommonProteinCodingFeatureLogic, spark: SparkSession) -> No
"biotype": "non_coding",
},
],
GeneIndex.get_schema(),
TargetIndex.get_schema(),
),
_schema=GeneIndex.get_schema(),
_schema=TargetIndex.get_schema(),
)


Expand Down
8 changes: 4 additions & 4 deletions tests/gentropy/dataset/test_l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
)

from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.l2g_gold_standard import L2GGoldStandard
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.target_index import TargetIndex
from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_study_locus(
colocalisation=self.sample_colocalisation,
study_index=self.sample_study_index,
study_locus=self.sample_study_locus,
gene_index=self.sample_gene_index,
target_index=self.sample_target_index,
)
fm = L2GFeatureMatrix.from_features_list(
self.sample_study_locus, features_list, loader
Expand Down Expand Up @@ -170,7 +170,7 @@ def _setup(self: TestFromFeaturesList, spark: SparkSession) -> None:
),
_schema=Colocalisation.get_schema(),
)
self.sample_gene_index = GeneIndex(
self.sample_target_index = TargetIndex(
_df=spark.createDataFrame(
[
("g1", "X", "protein_coding", 200),
Expand All @@ -183,7 +183,7 @@ def _setup(self: TestFromFeaturesList, spark: SparkSession) -> None:
"tss",
],
),
_schema=GeneIndex.get_schema(),
_schema=TargetIndex.get_schema(),
)


Expand Down
14 changes: 7 additions & 7 deletions tests/gentropy/dataset/test_study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pyspark.sql import functions as f

from gentropy.dataset.biosample_index import BiosampleIndex
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.target_index import TargetIndex


def test_study_index_creation(mock_study_index: StudyIndex) -> None:
Expand Down Expand Up @@ -188,9 +188,9 @@ def create_study_index(drop_column: str) -> StudyIndex:
self.study_index_no_gene = create_study_index("geneId")
self.study_index_no_biosample_id = create_study_index("biosampleFromSourceId")

self.gene_index = GeneIndex(
self.target_index = TargetIndex(
_df=spark.createDataFrame(self.GENE_DATA, self.GENE_COLUMNS),
_schema=GeneIndex.get_schema(),
_schema=TargetIndex.get_schema(),
)
self.biosample_index = BiosampleIndex(
_df=spark.createDataFrame(self.BIOSAMPLE_DATA, self.BIOSAMPLE_COLUMNS),
Expand All @@ -199,7 +199,7 @@ def create_study_index(drop_column: str) -> StudyIndex:

def test_gene_validation_type(self: TestQTLValidation) -> None:
"""Testing if the target validation runs and returns the expected type."""
validated = self.study_index.validate_target(self.gene_index)
validated = self.study_index.validate_target(self.target_index)
assert isinstance(validated, StudyIndex)

def test_biosample_validation_type(self: TestQTLValidation) -> None:
Expand All @@ -211,7 +211,7 @@ def test_biosample_validation_type(self: TestQTLValidation) -> None:
def test_qtl_validation_correctness(self: TestQTLValidation, test: str) -> None:
"""Testing if the QTL validation only flags the expected studies."""
if test == "gene":
validated = self.study_index.validate_target(self.gene_index).persist()
validated = self.study_index.validate_target(self.target_index).persist()
bad_study = "s2"
if test == "biosample":
validated = self.study_index.validate_biosample(
Expand Down Expand Up @@ -252,15 +252,15 @@ def test_qtl_validation_drop_relevant_column(
"""Testing what happens if an expected column is not present."""
if drop == "gene":
if test == "gene":
validated = self.study_index_no_gene.validate_target(self.gene_index)
validated = self.study_index_no_gene.validate_target(self.target_index)
if test == "biosample":
validated = self.study_index_no_gene.validate_biosample(
self.biosample_index
)
if drop == "biosample":
if test == "gene":
validated = self.study_index_no_biosample_id.validate_target(
self.gene_index
self.target_index
)
if test == "biosample":
validated = self.study_index_no_biosample_id.validate_biosample(
Expand Down
2 changes: 1 addition & 1 deletion tests/gentropy/dataset/test_summary_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def test_summary_statistics__creation(
mock_summary_statistics: SummaryStatistics,
) -> None:
"""Test gene index creation with mock gene index."""
"""Test summary statistics creation with mock summary statistics."""
assert isinstance(mock_summary_statistics, SummaryStatistics)


Expand Down
28 changes: 14 additions & 14 deletions tests/gentropy/dataset/test_target_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@

from pyspark.sql import DataFrame

from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.target_index import TargetIndex


def test_gene_index_creation(mock_gene_index: GeneIndex) -> None:
"""Test gene index creation with mock gene index."""
assert isinstance(mock_gene_index, GeneIndex)
def test_target_index_creation(mock_target_index: TargetIndex) -> None:
"""Test target index creation with mock target index."""
assert isinstance(mock_target_index, TargetIndex)


def test_gene_index_location_lut(mock_gene_index: GeneIndex) -> None:
"""Test gene index location lut."""
assert isinstance(mock_gene_index.locations_lut(), DataFrame)
def test_target_index_location_lut(mock_target_index: TargetIndex) -> None:
"""Test target index location lut."""
assert isinstance(mock_target_index.locations_lut(), DataFrame)


def test_gene_index_symbols_lut(mock_gene_index: GeneIndex) -> None:
"""Test gene index symbols lut."""
assert isinstance(mock_gene_index.symbols_lut(), DataFrame)
def test_target_index_symbols_lut(mock_target_index: TargetIndex) -> None:
"""Test target index symbols lut."""
assert isinstance(mock_target_index.symbols_lut(), DataFrame)


def test_gene_index_filter_by_biotypes(mock_gene_index: GeneIndex) -> None:
"""Test gene index filter by biotypes."""
def test_target_index_filter_by_biotypes(mock_target_index: TargetIndex) -> None:
"""Test target index filter by biotypes."""
assert isinstance(
mock_gene_index.filter_by_biotypes(
mock_target_index.filter_by_biotypes(
biotypes=["protein_coding", "3prime_overlapping_ncRNA", "antisense"]
),
GeneIndex,
TargetIndex,
)
Loading

0 comments on commit 19bb7c9

Please sign in to comment.