Skip to content

Commit

Permalink
feat(SusieFineMapperStep): add new function with boundaries (#645)
Browse files Browse the repository at this point in the history
* feat(SusieFineMapperStep): add new fucntion that takes boundaries as input

* fix: typo in function
  • Loading branch information
addramir authored Jun 17, 2024
1 parent 7625a79 commit 79a6cb5
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 53 deletions.
42 changes: 42 additions & 0 deletions src/gentropy/datasource/gnomad/ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,45 @@ def get_numpy_matrix(
)

return (half_matrix + half_matrix.T) - np.diag(np.diag(half_matrix))

def get_locus_index_boundaries(
self: GnomADLDMatrix,
study_locus_row: Row,
major_population: str = "nfe",
) -> DataFrame:
"""Extract hail matrix index from StudyLocus rows.
Args:
study_locus_row (Row): Study-locus row
major_population (str): Major population to extract from gnomad matrix, default is "nfe"
Returns:
DataFrame: Returns the index of the gnomad matrix for the locus
"""
chromosome = str("chr" + study_locus_row["chromosome"])
start = int(study_locus_row["locusStart"])
end = int(study_locus_row["locusEnd"])

liftover_ht = hl.read_table(self.liftover_ht_path)
liftover_ht = (
liftover_ht.filter(
(liftover_ht.locus.contig == chromosome)
& (liftover_ht.locus.position >= start)
& (liftover_ht.locus.position <= end)
)
.key_by()
.select("locus", "alleles", "original_locus")
.key_by("original_locus", "alleles")
.naive_coalesce(20)
)

hail_index = hl.read_table(
self.ld_index_raw_template.format(POP=major_population)
)

joined_index = (
liftover_ht.join(hail_index, how="inner").order_by("idx").to_spark()
)

return joined_index
251 changes: 198 additions & 53 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
purity_min_r2_threshold: float = 0,
cs_lbf_thr: float = 2,
sum_pips: float = 0.99,
logging: bool = False,
susie_est_tausq: bool = False,
run_carma: bool = False,
run_sumstat_imputation: bool = False,
Expand All @@ -80,7 +79,6 @@ def __init__(
purity_min_r2_threshold (float): thrshold for purity min r2 qc metrics for filtering credible sets, default is 0.25
cs_lbf_thr (float): credible set logBF threshold for filtering credible sets, default is 2
sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set)
logging (bool): enable logging, default is False, runs diffrent FM wrapper
susie_est_tausq (bool): estimate tau squared, default is False
run_carma (bool): run CARMA, default is False
run_sumstat_imputation (bool): run summary statistics imputation, default is False
Expand All @@ -99,59 +97,37 @@ def __init__(
)
study_index = StudyIndex.from_parquet(session, study_index_path)
# Run fine-mapping
if logging:
result_logging = (
self.susie_finemapper_one_studylocus_row_v3_dev_ss_gathered(
session=session,
study_locus_row=study_locus,
study_index=study_index,
radius=locus_radius,
max_causal_snps=max_causal_snps,
primary_signal_pval_threshold=primary_signal_pval_threshold,
secondary_signal_pval_threshold=secondary_signal_pval_threshold,
purity_mean_r2_threshold=purity_mean_r2_threshold,
purity_min_r2_threshold=purity_min_r2_threshold,
cs_lbf_thr=cs_lbf_thr,
sum_pips=sum_pips,
susie_est_tausq=susie_est_tausq,
run_carma=run_carma,
run_sumstat_imputation=run_sumstat_imputation,
carma_time_limit=carma_time_limit,
imputed_r2_threshold=imputed_r2_threshold,
ld_score_threshold=ld_score_threshold,
)
)

if result_logging is not None:
# Write result
result_logging["study_locus"].df.write.mode(session.write_mode).parquet(
output_path + "/" + study_locus_to_finemap
)
# Write log
result_logging["log"].to_parquet(
output_path_log + "/" + study_locus_to_finemap + ".parquet",
engine="pyarrow",
index=False,
)
else:
result = self.susie_finemapper_ss_gathered(
session=session,
study_locus_row=study_locus,
study_index=study_index,
radius=locus_radius,
max_causal_snps=max_causal_snps,
primary_signal_pval_threshold=primary_signal_pval_threshold,
secondary_signal_pval_threshold=secondary_signal_pval_threshold,
purity_mean_r2_threshold=purity_mean_r2_threshold,
purity_min_r2_threshold=purity_min_r2_threshold,
cs_lbf_thr=cs_lbf_thr,
sum_pips=sum_pips,
)
result_logging = self.susie_finemapper_one_sl_row_v4_ss_gathered_boundaries(
session=session,
study_locus_row=study_locus,
study_index=study_index,
max_causal_snps=max_causal_snps,
primary_signal_pval_threshold=primary_signal_pval_threshold,
secondary_signal_pval_threshold=secondary_signal_pval_threshold,
purity_mean_r2_threshold=purity_mean_r2_threshold,
purity_min_r2_threshold=purity_min_r2_threshold,
cs_lbf_thr=cs_lbf_thr,
sum_pips=sum_pips,
susie_est_tausq=susie_est_tausq,
run_carma=run_carma,
run_sumstat_imputation=run_sumstat_imputation,
carma_time_limit=carma_time_limit,
imputed_r2_threshold=imputed_r2_threshold,
ld_score_threshold=ld_score_threshold,
)

if result_logging is not None:
# Write result
if result is not None:
result.df.write.mode(session.write_mode).parquet(
output_path + "/" + study_locus_to_finemap
)
result_logging["study_locus"].df.write.mode(session.write_mode).parquet(
output_path + "/" + study_locus_to_finemap
)
# Write log
result_logging["log"].to_parquet(
output_path_log + "/" + study_locus_to_finemap + ".parquet",
engine="pyarrow",
index=False,
)

@staticmethod
def susie_finemapper_one_studylocus_row(
Expand Down Expand Up @@ -1211,3 +1187,172 @@ def credible_set_qc(
)

return cred_sets

@staticmethod
def susie_finemapper_one_sl_row_v4_ss_gathered_boundaries(
session: Session,
study_locus_row: Row,
study_index: StudyIndex,
max_causal_snps: int = 10,
susie_est_tausq: bool = False,
run_carma: bool = False,
run_sumstat_imputation: bool = False,
carma_time_limit: int = 600,
imputed_r2_threshold: float = 0.9,
ld_score_threshold: float = 5,
sum_pips: float = 0.99,
primary_signal_pval_threshold: float = 5e-8,
secondary_signal_pval_threshold: float = 1e-7,
purity_mean_r2_threshold: float = 0,
purity_min_r2_threshold: float = 0.25,
cs_lbf_thr: float = 2,
) -> dict[str, Any] | None:
"""Susie fine-mapper function that uses study-locus row with collected locus, chromosome and position as inputs.
Args:
session (Session): Spark session
study_locus_row (Row): StudyLocus row with collected locus
study_index (StudyIndex): StudyIndex object
max_causal_snps (int): maximum number of causal variants
susie_est_tausq (bool): estimate tau squared, default is False
run_carma (bool): run CARMA, default is False
run_sumstat_imputation (bool): run summary statistics imputation, default is False
carma_time_limit (int): CARMA time limit, default is 600 seconds
imputed_r2_threshold (float): imputed R2 threshold, default is 0.8
ld_score_threshold (float): LD score threshold ofr imputation, default is 4
sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set)
primary_signal_pval_threshold (float): p-value threshold for the lead variant from the primary signal (credibleSetIndex==1)
secondary_signal_pval_threshold (float): p-value threshold for the lead variant from the secondary signals
purity_mean_r2_threshold (float): thrshold for purity mean r2 qc metrics for filtering credible sets
purity_min_r2_threshold (float): thrshold for purity min r2 qc metrics for filtering credible sets
cs_lbf_thr (float): credible set logBF threshold for filtering credible sets, default is 2
Returns:
dict[str, Any] | None: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map, or None
"""
# PLEASE DO NOT REMOVE THIS LINE
pd.DataFrame.iteritems = pd.DataFrame.items

chromosome = study_locus_row["chromosome"]
studyId = study_locus_row["studyId"]
locusStart = study_locus_row["locusStart"]
locusEnd = study_locus_row["locusEnd"]

study_index_df = study_index._df
study_index_df = study_index_df.filter(f.col("studyId") == studyId)
major_population = study_index_df.select(
"studyId",
f.array_max(f.col("ldPopulationStructure"))
.getItem("ldPopulation")
.alias("majorPopulation"),
).collect()[0]["majorPopulation"]

region = chromosome + ":" + str(int(locusStart)) + "-" + str(int(locusEnd))

schema = StudyLocus.get_schema()
gwas_df = session.spark.createDataFrame([study_locus_row], schema=schema)
exploded_df = gwas_df.select(f.explode("locus").alias("locus"))

result_df = exploded_df.select(
"locus.variantId", "locus.beta", "locus.standardError"
)
gwas_df = (
result_df.withColumn("z", f.col("beta") / f.col("standardError"))
.withColumn(
"chromosome", f.split(f.col("variantId"), "_")[0].cast("string")
)
.withColumn("position", f.split(f.col("variantId"), "_")[1].cast("int"))
.filter(f.col("chromosome") == chromosome)
.filter(f.col("position") >= int(locusStart))
.filter(f.col("position") <= int(locusEnd))
.filter(f.col("z").isNotNull())
)

# Remove ALL duplicated variants from GWAS DataFrame - we don't know which is correct
variant_counts = gwas_df.groupBy("variantId").count()
unique_variants = variant_counts.filter(f.col("count") == 1)
gwas_df = gwas_df.join(unique_variants, on="variantId", how="left_semi")

ld_index = (
GnomADLDMatrix()
.get_locus_index_boundaries(
study_locus_row=study_locus_row,
major_population=major_population,
)
.withColumn(
"variantId",
f.concat(
f.lit(chromosome),
f.lit("_"),
f.col("`locus.position`"),
f.lit("_"),
f.col("alleles").getItem(0),
f.lit("_"),
f.col("alleles").getItem(1),
).cast("string"),
)
)
# Remove ALL duplicated variants from ld_index DataFrame - we don't know which is correct
variant_counts = ld_index.groupBy("variantId").count()
unique_variants = variant_counts.filter(f.col("count") == 1)
ld_index = ld_index.join(unique_variants, on="variantId", how="left_semi").sort(
"idx"
)

if not run_sumstat_imputation:
# Filtering out the variants that are not in the LD matrix, we don't need them
gwas_index = gwas_df.join(
ld_index.select("variantId", "alleles", "idx"), on="variantId"
).sort("idx")
gwas_df = gwas_index.select(
"variantId",
"z",
"chromosome",
"position",
"beta",
"StandardError",
)
gwas_index = gwas_index.drop(
"z", "chromosome", "position", "beta", "StandardError"
)
if gwas_index.rdd.isEmpty():
logging.warning("No overlapping variants in the LD Index")
return None
gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
gwas_index, gnomad_ancestry=major_population
)
else:
gwas_index = gwas_df.join(
ld_index.select("variantId", "alleles", "idx"), on="variantId"
).sort("idx")
if gwas_index.rdd.isEmpty():
logging.warning("No overlapping variants in the LD Index")
return None
gwas_index = ld_index
gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
gwas_index, gnomad_ancestry=major_population
)

out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes(
GWAS_df=gwas_df,
ld_index=gwas_index,
gnomad_ld=gnomad_ld,
L=max_causal_snps,
session=session,
studyId=studyId,
region=region,
susie_est_tausq=susie_est_tausq,
run_carma=run_carma,
run_sumstat_imputation=run_sumstat_imputation,
carma_time_limit=carma_time_limit,
imputed_r2_threshold=imputed_r2_threshold,
ld_score_threshold=ld_score_threshold,
sum_pips=sum_pips,
primary_signal_pval_threshold=primary_signal_pval_threshold,
secondary_signal_pval_threshold=secondary_signal_pval_threshold,
purity_mean_r2_threshold=purity_mean_r2_threshold,
purity_min_r2_threshold=purity_min_r2_threshold,
cs_lbf_thr=cs_lbf_thr,
)

return out

0 comments on commit 79a6cb5

Please sign in to comment.