Skip to content

Commit

Permalink
feat: implement posterior prob filter for COLOC at small overlaps N<10 (
Browse files Browse the repository at this point in the history
  • Loading branch information
xyg123 authored Feb 6, 2025
1 parent 8622b5e commit 8d018de
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 31 deletions.
58 changes: 56 additions & 2 deletions src/gentropy/method/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,15 @@ class Coloc(ColocalisationMethodInterface):
Attributes:
PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10.
OVERLAP_SIZE_CUTOFF (int): Minimum number of overlapping variants bfore filtering. Defaults to 5.
POSTERIOR_CUTOFF (float): Minimum overlapping Posterior probability cutoff for small overlaps. Defaults to 0.5.
"""

METHOD_NAME: str = "COLOC"
METHOD_METRIC: str = "h4"
PSEUDOCOUNT: float = 1e-10
OVERLAP_SIZE_CUTOFF: int = 5
POSTERIOR_CUTOFF: float = 0.5

@staticmethod
def _get_posteriors(all_bfs: NDArray[np.float64]) -> DenseVector:
Expand Down Expand Up @@ -277,7 +281,15 @@ def colocalise(
)
.select("*", "statistics.*")
# Before summing log_BF columns nulls need to be filled with 0:
.fillna(0, subset=["left_logBF", "right_logBF"])
.fillna(
0,
subset=[
"left_logBF",
"right_logBF",
"left_posteriorProbability",
"right_posteriorProbability",
],
)
# Sum of log_BFs for each pair of signals
.withColumn(
"sum_log_bf",
Expand Down Expand Up @@ -305,9 +317,18 @@ def colocalise(
fml.array_to_vector(f.collect_list(f.col("right_logBF"))).alias(
"right_logBF"
),
fml.array_to_vector(
f.collect_list(f.col("left_posteriorProbability"))
).alias("left_posteriorProbability"),
fml.array_to_vector(
f.collect_list(f.col("right_posteriorProbability"))
).alias("right_posteriorProbability"),
fml.array_to_vector(f.collect_list(f.col("sum_log_bf"))).alias(
"sum_log_bf"
),
f.collect_list(f.col("tagVariantSource")).alias(
"tagVariantSourceList"
),
)
.withColumn("logsum1", logsum(f.col("left_logBF")))
.withColumn("logsum2", logsum(f.col("right_logBF")))
Expand All @@ -327,10 +348,39 @@ def colocalise(
# h3
.withColumn("sumlogsum", f.col("logsum1") + f.col("logsum2"))
.withColumn("max", f.greatest("sumlogsum", "logsum12"))
.withColumn(
"anySnpBothSidesHigh",
f.aggregate(
f.transform(
f.arrays_zip(
fml.vector_to_array(f.col("left_posteriorProbability")),
fml.vector_to_array(
f.col("right_posteriorProbability")
),
f.col("tagVariantSourceList"),
),
# row["0"] = left PP, row["1"] = right PP, row["tagVariantSourceList"]
lambda row: f.when(
(row["tagVariantSourceList"] == "both")
& (row["0"] > Coloc.POSTERIOR_CUTOFF)
& (row["1"] > Coloc.POSTERIOR_CUTOFF),
1.0,
).otherwise(0.0),
),
f.lit(0.0),
lambda acc, x: acc + x,
)
> 0, # True if sum of these 1.0's > 0
)
.filter(
(f.col("numberColocalisingVariants") > Coloc.OVERLAP_SIZE_CUTOFF)
| (f.col("anySnpBothSidesHigh"))
)
.withColumn(
"logdiff",
f.when(
f.col("sumlogsum") == f.col("logsum12"), Coloc.PSEUDOCOUNT
(f.col("sumlogsum") == f.col("logsum12")),
Coloc.PSEUDOCOUNT,
).otherwise(
f.col("max")
+ f.log(
Expand Down Expand Up @@ -382,6 +432,10 @@ def colocalise(
"lH2bf",
"lH3bf",
"lH4bf",
"left_posteriorProbability",
"right_posteriorProbability",
"tagVariantSourceList",
"anySnpBothSidesHigh",
)
.withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME))
.join(
Expand Down
Loading

0 comments on commit 8d018de

Please sign in to comment.