Skip to content

Commit f814f07

Browse files
committed
fix issues in mediator discretization
1 parent 9a7f8b5 commit f814f07

File tree

5 files changed

+126
-92
lines changed

5 files changed

+126
-92
lines changed

src/med_bench/estimation/base.py

+42-20
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,34 @@
66
from med_bench.utils.decorators import fitted
77
from med_bench.utils.density import GaussianDensityEstimation
88
from sklearn.cluster import KMeans
9+
from sklearn.preprocessing import LabelEncoder
910

1011

1112
class Estimator:
1213
"""General abstract class for causal mediation Estimator"""
1314

1415
__metaclass__ = ABCMeta
1516

16-
def __init__(self, verbose: bool = True):
17+
def __init__(self, verbose: bool = True, mediator_cardinality_threshold: int=10):
1718
"""Initializes Estimator base class
1819
1920
Parameters
2021
----------
2122
verbose : bool
2223
will print some logs if True
24+
25+
mediator_cardinality_threshold: int
26+
default 10
27+
maximal number of categories in a mediator to treat it as discrete or continuous
28+
if the mediator is 1-dimensional and
29+
if the number of distinct values in the mediator is lower than
30+
mediator_cardinality_threshold, the mediator is going to be considered
31+
discrete when estimating the mediator probability function, otherwise
32+
the mediator is considered continuous,
2333
"""
2434
self._verbose = verbose
2535
self._fitted = False
26-
self.discretizer = KMeans(n_clusters=10, random_state=42, n_init="auto")
27-
self.mediator_bins = [0, 1]
36+
self.mediator_cardinality_threshold = mediator_cardinality_threshold
2837

2938
@property
3039
def verbose(self):
@@ -222,8 +231,16 @@ def _resize(self, t, m, x, y):
222231

223232
def _fit_mediator_discretizer(self, m):
224233
"""Fits the discretization procedure of mediators"""
225-
self.discretizer.fit(m)
226-
self.mediator_bins = self.discretize.cluster_centers_
234+
if (is_array_integer(m)) and (len(np.unique(m)) <= self.mediator_cardinality_threshold):
235+
self.discretizer = LabelEncoder()
236+
self.discretizer.fit(m.ravel())
237+
self.mediator_bins = self.discretizer.classes_
238+
self._mediator_considered_discrete = True
239+
else:
240+
self.discretizer = KMeans(n_clusters=10, random_state=42, n_init="auto")
241+
self.discretizer.fit(m)
242+
self.mediator_bins = self.discretizer.cluster_centers_
243+
self._mediator_considered_discrete = False
227244

228245
def _fit_treatment_propensity_x(self, t, x):
229246
"""Fits the nuisance parameter for the propensity P(T=1|X)"""
@@ -239,10 +256,12 @@ def _fit_treatment_propensity_xm(self, t, m, x):
239256
return self
240257

241258
def _fit_mediator_probability(self, t, m, x):
242-
if not is_array_integer(m):
259+
self._fit_mediator_discretizer(m)
260+
if not self._mediator_considered_discrete:
243261
self._fit_mediator_density(t, m, x)
244262
else:
245-
self._fit_discrete_mediator_probability(t, m, x)
263+
m_label, m_discrete_value = self._discretize_mediators(m)
264+
self._fit_discrete_mediator_probability(t, m_label, x)
246265

247266
def _fit_discrete_mediator_probability(self, t, m, x):
248267
"""Fits the nuisance parameter for the density f(M=m|T, X)"""
@@ -327,7 +346,7 @@ def _fit_cross_conditional_mean_outcome(self, t, m, x, y):
327346

328347
return self
329348

330-
def _estimate_discrete_mediator_probability(self, x, m):
349+
def _estimate_discrete_mediator_probability(self, x, m_label):
331350
"""
332351
Estimate mediator density P(M=m|T,X) for a binary M
333352
@@ -343,13 +362,12 @@ def _estimate_discrete_mediator_probability(self, x, m):
343362
t0 = np.zeros((n, 1))
344363
t1 = np.ones((n, 1))
345364

346-
m = m.ravel()
347365

348366
t0_x = np.hstack([t0.reshape(-1, 1), x])
349367
t1_x = np.hstack([t1.reshape(-1, 1), x])
350368

351-
f_m0x = self._classifier_m.predict_proba(t0_x)[np.arange(m.shape[0]), m]
352-
f_m1x = self._classifier_m.predict_proba(t1_x)[np.arange(m.shape[0]), m]
369+
f_m0x = self._classifier_m.predict_proba(t0_x)[np.arange(m_label.shape[0]), m_label]
370+
f_m1x = self._classifier_m.predict_proba(t1_x)[np.arange(m_label.shape[0]), m_label]
353371

354372
return f_m0x, f_m1x
355373

@@ -380,11 +398,11 @@ def _estimate_mediator_density(self, x, m):
380398
return f_m0x, f_m1x
381399

382400
def _estimate_mediator_probability(self, x, m):
383-
384-
if not is_array_integer(m):
401+
if not self._mediator_considered_discrete:
385402
return self._estimate_mediator_density(x, m)
386403
else:
387-
return self._estimate_discrete_mediator_probability(x, m)
404+
m_label, m_discrete_value = self._discretize_mediators(m)
405+
return self._estimate_discrete_mediator_probability(x, m_label)
388406

389407
def _estimate_discrete_mediator_probability_table(self, x):
390408
"""
@@ -411,9 +429,9 @@ def _estimate_discrete_mediator_probability_table(self, x):
411429
fm_0 = self._classifier_m.predict_proba(t0_x)
412430
fm_1 = self._classifier_m.predict_proba(t1_x)
413431

414-
for m in self.mediator_bins:
415-
f_0x.append(fm_0[:, m])
416-
f_1x.append(fm_1[:, m])
432+
for idx, m_anchor in enumerate(self.mediator_bins):
433+
f_0x.append(fm_0[:, idx])
434+
f_1x.append(fm_1[:, idx])
417435

418436
return f_0x, f_1x
419437

@@ -545,6 +563,10 @@ def _estimate_cross_conditional_mean_outcome(self, m, x):
545563

546564
def _discretize_mediators(self, m):
547565
"""Discretize mediators clustering if they are not explicit."""
548-
if not is_array_integer(m):
549-
m = np.expand_dims(self.discretizer.predict(m), axis=-1)
550-
return m
566+
if self._mediator_considered_discrete:
567+
m_label = self.discretizer.transform(m)
568+
m_discrete_value = m
569+
else:
570+
m_label = self.discretizer.predict(m)
571+
m_discrete_value = self.discretizer.cluster_centers_[m_label, :]
572+
return m_label, m_discrete_value

src/med_bench/estimation/mediation_g_computation.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def fit(self, t, m, x, y):
4343
t, m, x, y = self._resize(t, m, x, y)
4444

4545
if self._integration == "explicit":
46-
self.discretizer.fit(m)
47-
m = self._discretize_mediators(m)
48-
self._fit_discrete_mediator_probability(t, m, x)
49-
self._fit_conditional_mean_outcome(t, m, x, y)
46+
self._fit_mediator_discretizer(m)
47+
m_label, m_discrete_value = self._discretize_mediators(m)
48+
self._fit_discrete_mediator_probability(t, m_label, x)
49+
self._fit_conditional_mean_outcome(t, m_discrete_value, x, y)
5050

5151
elif self._integration == "implicit":
5252
self._fit_cross_conditional_mean_outcome(t, m, x, y)
@@ -56,9 +56,13 @@ def fit(self, t, m, x, y):
5656
if self.verbose:
5757
print("Nuisance models fitted")
5858

59-
if self._integration == "explicit" and not is_array_integer(m):
59+
if self._integration == "explicit" and not self._mediator_considered_discrete:
6060
warnings.warn(
61-
"The explicit integration of the conditional mean outcome is strongly not advised for continuous mediators"
61+
"The explicit integration of the conditional mean outcome is "+
62+
"strongly not advised for continuous mediators,"+
63+
"or a discrete mediator with many classes (you can increase"+
64+
" the parameter `mediator_cardinality_threshold`" +
65+
"to treat your discrete mediator with many classes as discrete."+
6266
"It is advised to set integration to 'implicit'.",
6367
UserWarning,
6468
)
@@ -70,7 +74,7 @@ def _pointwise_estimate(self, t, m, x, y):
7074
t, m, x, y = self._resize(t, m, x, y)
7175

7276
if self._integration == "explicit":
73-
m = self._discretize_mediators(m) if not is_array_integer(m) else m
77+
m_label, m_discrete_value = self._discretize_mediators(m)
7478

7579
f_0x, f_1x = self._estimate_discrete_mediator_probability_table(x)
7680
mu_0x, mu_1x = self._estimate_conditional_mean_outcome_table(x)

src/med_bench/estimation/mediation_mr.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def fit(self, t, m, x, y):
6767
self._fit_treatment_propensity_xm(t, m, x)
6868

6969
if self._integration == "explicit":
70-
self.discretizer.fit(m)
71-
m = self._discretize_mediators(m)
72-
self._fit_discrete_mediator_probability(t, m, x)
73-
self._fit_conditional_mean_outcome(t, m, x, y)
70+
if self._prop_ratio == "treatment":
71+
self._fit_mediator_discretizer(m)
72+
m_label, m_discrete_value = self._discretize_mediators(m)
73+
self._fit_discrete_mediator_probability(t, m_label, x)
74+
self._fit_conditional_mean_outcome(t, m_discrete_value, x, y)
7475

7576
elif self._integration == "implicit":
7677
self._fit_cross_conditional_mean_outcome(t, m, x, y)
@@ -108,7 +109,7 @@ def _pointwise_estimate(self, t, m, x, y):
108109
prop_ratio_t0_m1 = p_xm / ((1 - p_xm) * p_x)
109110

110111
if self._integration == "explicit":
111-
m = self._discretize_mediators(m) if not is_array_integer(m) else m
112+
m_label, m_discrete_value = self._discretize_mediators(m)
112113

113114
f_0x, f_1x = self._estimate_discrete_mediator_probability_table(x)
114115
mu_0x, mu_1x = self._estimate_conditional_mean_outcome_table(x)
@@ -132,7 +133,7 @@ def _pointwise_estimate(self, t, m, x, y):
132133
y1m0 = y1m0.mean()
133134
y0m1 = y0m1.mean()
134135

135-
mu_0mx, mu_1mx = self._estimate_conditional_mean_outcome(x, m)
136+
mu_0mx, mu_1mx = self._estimate_conditional_mean_outcome(x, m_discrete_value)
136137

137138
elif self._integration == "implicit":
138139
mu_0mx, mu_1mx, y0m0, y0m1, y1m0, y1m1 = (

src/med_bench/estimation/mediation_tmle.py

+26-25
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,32 @@ def __init__(self, regressor, classifier, prop_ratio, **kwargs):
4040
assert prop_ratio in ["mediator", "treatment"]
4141
self._prop_ratio = prop_ratio
4242

43+
def fit(self, t, m, x, y):
44+
"""Fits nuisance parameters to data"""
45+
# bucketize if needed
46+
t, m, x, y = self._resize(t, m, x, y)
47+
48+
if (not is_array_binary(m)) and (self._prop_ratio == "mediator"):
49+
raise ValueError(
50+
"The option mediator 'mediator' in TMLE is supported only for 1D binary mediator"
51+
)
52+
53+
self._fit_treatment_propensity_x(t, x)
54+
self._fit_conditional_mean_outcome(t, m, x, y)
55+
56+
if self._prop_ratio == "mediator":
57+
self._fit_mediator_probability(t, m, x)
58+
59+
elif self._prop_ratio == "treatment":
60+
self._fit_treatment_propensity_xm(t, m, x)
61+
62+
self._fitted = True
63+
64+
if self.verbose:
65+
print("Nuisance models fitted")
66+
67+
return self
68+
4369
def _one_step_correction_direct(self, t, m, x, y):
4470
"""Implements the one step correction for the estimation of the natural
4571
direct effect with the prop_ratio of mediator densities or treatment
@@ -194,31 +220,6 @@ def _one_step_correction_indirect(self, t, m, x, y):
194220

195221
return delta_1
196222

197-
def fit(self, t, m, x, y):
198-
"""Fits nuisance parameters to data"""
199-
# bucketize if needed
200-
t, m, x, y = self._resize(t, m, x, y)
201-
202-
if (not is_array_binary(m)) and (self._prop_ratio == "mediator"):
203-
raise ValueError(
204-
"The option mediator 'mediator' in TMLE is supported only for 1D binary mediator"
205-
)
206-
207-
self._fit_treatment_propensity_x(t, x)
208-
self._fit_conditional_mean_outcome(t, m, x, y)
209-
210-
if self._prop_ratio == "mediator":
211-
self._fit_mediator_probability(t, m, x)
212-
213-
elif self._prop_ratio == "treatment":
214-
self._fit_treatment_propensity_xm(t, m, x)
215-
216-
self._fitted = True
217-
218-
if self.verbose:
219-
print("Nuisance models fitted")
220-
221-
return self
222223

223224
@fitted
224225
def estimate(self, t, m, x, y):

0 commit comments

Comments
 (0)