Skip to content

Commit 4f39f10

Browse files
committed
added clipping mr and tmle, fixed bug on mediator density and tests
1 parent 79e05a3 commit 4f39f10

File tree

7 files changed

+333
-98
lines changed

7 files changed

+333
-98
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Judith Abécassis, Julie Josse and Bertrand Thirion (2022). **Causal mediation a
1212
med_bench can be installed by executing
1313
```
1414
python setup.py install
15+
pip install -e .
1516
```
1617

1718
Or the package can be directly installed from the GitHub repository using

src/med_bench/estimation/base.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Estimator:
1414

1515
__metaclass__ = ABCMeta
1616

17-
def __init__(self, verbose: bool = True, mediator_cardinality_threshold: int=10):
17+
def __init__(self, verbose: bool = True, mediator_cardinality_threshold: int = 10):
1818
"""Initializes Estimator base class
1919
2020
Parameters
@@ -25,11 +25,11 @@ def __init__(self, verbose: bool = True, mediator_cardinality_threshold: int=10)
2525
mediator_cardinality_threshold: int
2626
default 10
2727
maximal number of categories in a mediator to treat it as discrete or continuous
28-
if the mediator is 1-dimensional and
28+
if the mediator is 1-dimensional and
2929
if the number of distinct values in the mediator is lower than
3030
mediator_cardinality_threshold, the mediator is going to be considered
3131
discrete when estimating the mediator probability function, otherwise
32-
the mediator is considered continuous,
32+
the mediator is considered continuous,
3333
"""
3434
self._verbose = verbose
3535
self._fitted = False
@@ -39,7 +39,6 @@ def __init__(self, verbose: bool = True, mediator_cardinality_threshold: int=10)
3939
def verbose(self):
4040
return self._verbose
4141

42-
4342
@abstractmethod
4443
def fit(self, t, m, x, y):
4544
"""Fits nuisance parameters to data
@@ -231,7 +230,9 @@ def _resize(self, t, m, x, y):
231230

232231
def _fit_mediator_discretizer(self, m):
233232
"""Fits the discretization procedure of mediators"""
234-
if (is_array_integer(m)) and (len(np.unique(m)) <= self.mediator_cardinality_threshold):
233+
if (is_array_integer(m)) and (
234+
len(np.unique(m)) <= self.mediator_cardinality_threshold
235+
):
235236
self.discretizer = LabelEncoder()
236237
self.discretizer.fit(m.ravel())
237238
self.mediator_bins = self.discretizer.classes_
@@ -362,12 +363,15 @@ def _estimate_discrete_mediator_probability(self, x, m_label):
362363
t0 = np.zeros((n, 1))
363364
t1 = np.ones((n, 1))
364365

365-
366366
t0_x = np.hstack([t0.reshape(-1, 1), x])
367367
t1_x = np.hstack([t1.reshape(-1, 1), x])
368368

369-
f_m0x = self._classifier_m.predict_proba(t0_x)[np.arange(m_label.shape[0]), m_label]
370-
f_m1x = self._classifier_m.predict_proba(t1_x)[np.arange(m_label.shape[0]), m_label]
369+
f_m0x = self._classifier_m.predict_proba(t0_x)[
370+
np.arange(m_label.shape[0]), m_label
371+
]
372+
f_m1x = self._classifier_m.predict_proba(t1_x)[
373+
np.arange(m_label.shape[0]), m_label
374+
]
371375

372376
return f_m0x, f_m1x
373377

@@ -392,8 +396,8 @@ def _estimate_mediator_density(self, x, m):
392396
t0_x = np.hstack([t0.reshape(-1, 1), x])
393397
t1_x = np.hstack([t1.reshape(-1, 1), x])
394398

395-
f_m0x = self._density_m.pdf(t0_x, m)
396-
f_m1x = self._density_m.pdf(t1_x, m)
399+
f_m0x = self._density_m.pdf(t0_x, m.squeeze())
400+
f_m1x = self._density_m.pdf(t1_x, m.squeeze())
397401

398402
return f_m0x, f_m1x
399403

src/med_bench/estimation/mediation_mr.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def __init__(
1313
self,
1414
regressor,
1515
classifier,
16+
clip: float,
17+
trim: float,
1618
prop_ratio="treatment",
1719
integration="implicit",
1820
normalized=True,
@@ -26,6 +28,10 @@ def __init__(
2628
Regressor used for mu estimation, can be any object with a fit and predict method
2729
classifier
2830
Classifier used for propensity estimation, can be any object with a fit and predict_proba method
31+
clip : float
32+
Clipping value for propensity scores
33+
trim : float
34+
Trimming value for propensity scores
2935
prop_ratio : str
3036
prop_ratio to use for estimation, can be either 'mediator' or 'treatment'
3137
integration : str
@@ -47,6 +53,8 @@ def __init__(
4753
self.regressor = regressor
4854
self.classifier = classifier
4955

56+
self._clip = clip
57+
self._trim = trim
5058
assert prop_ratio in ["mediator", "treatment"]
5159
assert integration in ["implicit", "explicit"]
5260
self._prop_ratio = prop_ratio
@@ -96,15 +104,19 @@ def _pointwise_estimate(self, t, m, x, y):
96104
# Format checking
97105
t, m, x, y = self._resize(t, m, x, y)
98106

107+
p_x = self._estimate_treatment_propensity_x(x)
108+
p_x = np.clip(p_x, self._clip, 1 - self._clip)
109+
99110
if self._prop_ratio == "mediator":
100111
f_m0x, f_m1x = self._estimate_mediator_probability(x, m)
101-
p_x = self._estimate_treatment_propensity_x(x)
112+
f_m0x = np.clip(f_m0x, self._clip, None)
113+
f_m1x = np.clip(f_m1x, self._clip, None)
102114
prop_ratio_t1_m0 = f_m0x / (p_x * f_m1x)
103115
prop_ratio_t0_m1 = f_m1x / ((1 - p_x) * f_m0x)
104116

105117
elif self._prop_ratio == "treatment":
106-
p_x = self._estimate_treatment_propensity_x(x)
107118
p_xm = self._estimate_treatment_propensity_xm(m, x)
119+
p_xm = np.clip(p_xm, self._clip, 1 - self._clip)
108120
prop_ratio_t1_m0 = (1 - p_xm) / ((1 - p_x) * p_xm)
109121
prop_ratio_t0_m1 = p_xm / ((1 - p_xm) * p_x)
110122

@@ -133,13 +145,28 @@ def _pointwise_estimate(self, t, m, x, y):
133145
y1m0 = y1m0.mean()
134146
y0m1 = y0m1.mean()
135147

136-
mu_0mx, mu_1mx = self._estimate_conditional_mean_outcome(x, m_discrete_value)
148+
mu_0mx, mu_1mx = self._estimate_conditional_mean_outcome(
149+
x, m_discrete_value
150+
)
137151

138152
elif self._integration == "implicit":
139153
mu_0mx, mu_1mx, y0m0, y0m1, y1m0, y1m1 = (
140154
self._estimate_cross_conditional_mean_outcome(m, x)
141155
)
142156

157+
ind = (p_x > self._trim) & (p_x < (1 - self._trim))
158+
y, t, p_x, prop_ratio_t1_m0, prop_ratio_t0_m1 = (
159+
y[ind],
160+
t[ind],
161+
p_x[ind],
162+
prop_ratio_t1_m0[ind],
163+
prop_ratio_t0_m1[ind],
164+
)
165+
mu_0mx, mu_1mx = (
166+
mu_0mx[ind],
167+
mu_1mx[ind],
168+
)
169+
143170
# score computing
144171
if self._normalized:
145172
sum_score_m1 = np.mean(t / p_x)

src/med_bench/estimation/mediation_tmle.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
class TMLE(Estimator):
1313
"""Implementation of targeted maximum likelihood estimation method class"""
1414

15-
def __init__(self, regressor, classifier, prop_ratio, **kwargs):
15+
def __init__(
16+
self, regressor, classifier, prop_ratio, clip: float, trim: float, **kwargs
17+
):
1618
"""_summary_
1719
1820
Parameters
@@ -23,6 +25,10 @@ def __init__(self, regressor, classifier, prop_ratio, **kwargs):
2325
Classifier used for propensity estimation, can be any object with a fit and predict_proba method
2426
prop_ratio : str
2527
prop_ratio to use for estimation, can be either 'mediator' or 'treatment'
28+
clip : float
29+
Clipping value for propensity scores
30+
trim : float
31+
Trimming value for propensity scores
2632
"""
2733
super().__init__(**kwargs)
2834

@@ -37,6 +43,8 @@ def __init__(self, regressor, classifier, prop_ratio, **kwargs):
3743
self.regressor = regressor
3844
self.classifier = classifier
3945

46+
self._clip = clip
47+
self._trim = trim
4048
assert prop_ratio in ["mediator", "treatment"]
4149
self._prop_ratio = prop_ratio
4250

@@ -65,7 +73,7 @@ def fit(self, t, m, x, y):
6573
print("Nuisance models fitted")
6674

6775
return self
68-
76+
6977
def _one_step_correction_direct(self, t, m, x, y):
7078
"""Implements the one step correction for the estimation of the natural
7179
direct effect with the prop_ratio of mediator densities or treatment
@@ -77,17 +85,31 @@ def _one_step_correction_direct(self, t, m, x, y):
7785
t0 = np.zeros((n))
7886
t1 = np.ones((n))
7987

88+
p_x = self._estimate_treatment_propensity_x(x)
89+
p_x = np.clip(p_x, self._clip, 1 - self._clip)
90+
8091
# estimate mediator densities
8192
if self._prop_ratio == "mediator":
8293
f_m0x, f_m1x = self._estimate_mediator_probability(x, m)
83-
p_x = self._estimate_treatment_propensity_x(x)
94+
f_m0x = np.clip(f_m0x, self._clip, None)
95+
f_m1x = np.clip(f_m1x, self._clip, None)
8496
prop_ratio = f_m0x / (p_x * f_m1x)
8597

8698
elif self._prop_ratio == "treatment":
87-
p_x = self._estimate_treatment_propensity_x(x)
8899
p_xm = self._estimate_treatment_propensity_xm(m, x)
100+
p_xm = np.clip(p_xm, self._clip, 1 - self._clip)
89101
prop_ratio = (1 - p_xm) / ((1 - p_x) * p_xm)
90102

103+
ind = (p_x > self._trim) & (p_x < (1 - self._trim))
104+
y, t, m, x, p_x, prop_ratio = (
105+
y[ind],
106+
t[ind],
107+
m[ind],
108+
x[ind],
109+
p_x[ind],
110+
prop_ratio[ind],
111+
)
112+
91113
# estimation of corrective features for the conditional mean outcome
92114
h_corrector = t * prop_ratio - (1 - t) / (1 - p_x)
93115

@@ -151,17 +173,30 @@ def _one_step_correction_indirect(self, t, m, x, y):
151173
t0 = np.zeros((n))
152174
t1 = np.ones((n))
153175

176+
p_x = self._estimate_treatment_propensity_x(x)
177+
p_x = np.clip(p_x, self._clip, 1 - self._clip)
178+
154179
# estimate mediator densities
155180
if self._prop_ratio == "mediator":
156181
f_m0x, f_m1x = self._estimate_mediator_probability(x, m)
157-
p_x = self._estimate_treatment_propensity_x(x)
182+
f_m0x = np.clip(f_m0x, self._clip, None)
183+
f_m1x = np.clip(f_m1x, self._clip, None)
158184
prop_ratio = f_m0x / (p_x * f_m1x)
159185

160186
elif self._prop_ratio == "treatment":
161-
p_x = self._estimate_treatment_propensity_x(x)
162187
p_xm = self._estimate_treatment_propensity_xm(m, x)
188+
p_xm = np.clip(p_xm, self._clip, 1 - self._clip)
163189
prop_ratio = (1 - p_xm) / ((1 - p_x) * p_xm)
164190

191+
ind = (p_x > self._trim) & (p_x < (1 - self._trim))
192+
y, t, m, x, p_x, prop_ratio = (
193+
y[ind],
194+
t[ind],
195+
m[ind],
196+
x[ind],
197+
p_x[ind],
198+
prop_ratio[ind],
199+
)
165200
# estimation of corrective features for the conditional mean outcome
166201
h_corrector = t / p_x - t * prop_ratio
167202

@@ -220,7 +255,6 @@ def _one_step_correction_indirect(self, t, m, x, y):
220255

221256
return delta_1
222257

223-
224258
@fitted
225259
def estimate(self, t, m, x, y):
226260
"""Estimates causal effect on data"""

0 commit comments

Comments
 (0)