Skip to content

Commit 79e05a3

Browse files
committed
fix clipping in IPW
1 parent f814f07 commit 79e05a3

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/med_bench/estimation/mediation_ipw.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,18 @@ def _pointwise_estimate(self, t, m, x, y):
5959

6060
t, m, x, y = self._resize(t, m, x, y)
6161
p_x = self._estimate_treatment_propensity_x(x)
62+
p_x = np.clip(p_x, self._clip, 1 - self._clip)
6263

6364
if self._prop_ratio == "treatment":
6465
p_xm = self._estimate_treatment_propensity_xm(m, x)
66+
p_xm = np.clip(p_xm, self._clip, 1 - self._clip)
6567
prop_ratio_t1_m0 = (1 - p_xm) / ((1 - p_x) * p_xm)
6668
prop_ratio_t0_m1 = p_xm / ((1 - p_xm) * p_x)
6769

6870
elif self._prop_ratio == "mediator":
6971
f_m0x, f_m1x = self._estimate_mediator_probability(x, m)
72+
f_m0x = np.clip(f_m0x, self._clip, None)
73+
f_m1x = np.clip(f_m1x, self._clip, None)
7074
prop_ratio_t1_m0 = f_m0x / (p_x * f_m1x)
7175
prop_ratio_t0_m1 = f_m1x / ((1 - p_x) * f_m0x)
7276

@@ -79,8 +83,6 @@ def _pointwise_estimate(self, t, m, x, y):
7983
prop_ratio_t0_m1[ind],
8084
)
8185

82-
p_x = np.clip(p_x, self._clip, 1 - self._clip)
83-
8486
# importance weighting
8587
y1m1 = (y * t / p_x) / np.mean(t / p_x)
8688
y1m0 = (y * t * prop_ratio_t1_m0) / np.mean(t * prop_ratio_t1_m0)

0 commit comments

Comments
 (0)