diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index 899343998..3e4e91c74 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -104,8 +104,12 @@ class Coloc: Coloc requires the availability of Bayes factors (BF) for each variant in the credible set (`logBF` column). + Attributes: + PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10. """ + PSEUDOCOUNT: float = 1e-10 + @staticmethod def _get_posteriors(all_bfs: NDArray[np.float64]) -> DenseVector: """Calculate posterior probabilities for each hypothesis. @@ -137,6 +141,7 @@ def colocalise( Args: overlapping_signals (StudyLocusOverlap): overlapping peaks + priorc1 (float): Prior on variant being causal for trait 1. Defaults to 1e-4. priorc2 (float): Prior on variant being causal for trait 2. Defaults to 1e-4. priorc12 (float): Prior on variant being causal for traits 1 and 2. Defaults to 1e-5. @@ -188,12 +193,12 @@ def colocalise( .withColumn("lH2bf", f.log(f.col("priorc2")) + f.col("logsum2")) # h3 .withColumn("sumlogsum", f.col("logsum1") + f.col("logsum2")) - # exclude null H3/H4s: due to sumlogsum == logsum12 - .filter(f.col("sumlogsum") != f.col("logsum12")) .withColumn("max", f.greatest("sumlogsum", "logsum12")) .withColumn( "logdiff", - ( + f.when( + f.col("sumlogsum") == f.col("logsum12"), Coloc.PSEUDOCOUNT + ).otherwise( f.col("max") + f.log( f.exp(f.col("sumlogsum") - f.col("max")) diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index a4fe07646..f90d54b3f 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -41,6 +41,140 @@ def test_coloc_colocalise( assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 +def test_single_snp_coloc( + spark: SparkSession, + threshold: float = 1e-5, +) -> None: + """Test edge case of coloc where only one causal SNP is present in the StudyLocusOverlap.""" + test_overlap_df = spark.createDataFrame( + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp", + "left_logBF": 10.3, + "right_logBF": 10.5, + } + ] + ) + test_overlap = StudyLocusOverlap( + test_overlap_df.select( + "leftStudyLocusId", + "rightStudyLocusId", + "chromosome", + "tagVariantId", + f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), + ), + StudyLocusOverlap.get_schema(), + ) + test_result = Coloc.colocalise(test_overlap) + + expected = spark.createDataFrame( + [ + { + "h0": 9.254841951638903e-5, + "h1": 2.7517068829182966e-4, + "h2": 3.3609423764447284e-4, + "h3": 9.254841952564387e-13, + "h4": 0.9992961866536217, + } + ] + ) + difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) + for col in difference.columns: + assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 + + +def test_single_snp_coloc_one_negative( + spark: SparkSession, + threshold: float = 1e-5, +) -> None: + """Test edge case of coloc where only one causal SNP is present (On one side!) in the StudyLocusOverlap.""" + test_overlap_df = spark.createDataFrame( + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp", + "left_logBF": 18.3, + "right_logBF": 0.01, + } + ] + ) + test_overlap = StudyLocusOverlap( + test_overlap_df.select( + "leftStudyLocusId", + "rightStudyLocusId", + "chromosome", + "tagVariantId", + f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), + ), + StudyLocusOverlap.get_schema(), + ) + test_result = Coloc.colocalise(test_overlap) + test_result.df.show(1, False) + expected = spark.createDataFrame( + [ + { + "h0": 1.0246538505087709e-4, + "h1": 0.9081680002273896, + "h2": 1.0349517929098209e-8, + "h3": 1.0246538506112363e-12, + "h4": 0.09172952403701702, + } + ] + ) + difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) + for col in difference.columns: + assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 + + +def test_single_snp_coloc_both_negative( + spark: SparkSession, + threshold: float = 1e-5, +) -> None: + """Test edge case of coloc where only one non-causal SNP overlaps in the StudyLocusOverlap.""" + test_overlap_df = spark.createDataFrame( + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp", + "left_logBF": 0.03, + "right_logBF": 0.01, + } + ] + ) + test_overlap = StudyLocusOverlap( + test_overlap_df.select( + "leftStudyLocusId", + "rightStudyLocusId", + "chromosome", + "tagVariantId", + f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), + ), + StudyLocusOverlap.get_schema(), + ) + test_result = Coloc.colocalise(test_overlap) + expected = spark.createDataFrame( + [ + { + "h0": 0.9997855774090624, + "h1": 1.0302335812225042e-4, + "h2": 1.0098335895103664e-4, + "h3": 9.9978557750904e-9, + "h4": 1.0405876008495098e-5, + } + ] + ) + difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) + for col in difference.columns: + assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 + + def test_ecaviar(mock_study_locus_overlap: StudyLocusOverlap) -> None: """Test eCAVIAR.""" assert isinstance(ECaviar.colocalise(mock_study_locus_overlap), Colocalisation)