6
6
from med_bench .utils .decorators import fitted
7
7
from med_bench .utils .density import GaussianDensityEstimation
8
8
from sklearn .cluster import KMeans
9
+ from sklearn .preprocessing import LabelEncoder
9
10
10
11
11
12
class Estimator :
12
13
"""General abstract class for causal mediation Estimator"""
13
14
14
15
__metaclass__ = ABCMeta
15
16
16
- def __init__ (self , verbose : bool = True ):
17
+ def __init__ (self , verbose : bool = True , mediator_cardinality_threshold : int = 10 ):
17
18
"""Initializes Estimator base class
18
19
19
20
Parameters
20
21
----------
21
22
verbose : bool
22
23
will print some logs if True
24
+
25
+ mediator_cardinality_threshold: int
26
+ default 10
27
+ maximal number of categories in a mediator to treat it as discrete or continuous
28
+ if the mediator is 1-dimensional and
29
+ if the number of distinct values in the mediator is lower than
30
+ mediator_cardinality_threshold, the mediator is going to be considered
31
+ discrete when estimating the mediator probability function, otherwise
32
+ the mediator is considered continuous,
23
33
"""
24
34
self ._verbose = verbose
25
35
self ._fitted = False
26
- self .discretizer = KMeans (n_clusters = 10 , random_state = 42 , n_init = "auto" )
27
- self .mediator_bins = [0 , 1 ]
36
+ self .mediator_cardinality_threshold = mediator_cardinality_threshold
28
37
29
38
@property
30
39
def verbose (self ):
@@ -222,8 +231,16 @@ def _resize(self, t, m, x, y):
222
231
223
232
def _fit_mediator_discretizer (self , m ):
224
233
"""Fits the discretization procedure of mediators"""
225
- self .discretizer .fit (m )
226
- self .mediator_bins = self .discretize .cluster_centers_
234
+ if (is_array_integer (m )) and (len (np .unique (m )) <= self .mediator_cardinality_threshold ):
235
+ self .discretizer = LabelEncoder ()
236
+ self .discretizer .fit (m .ravel ())
237
+ self .mediator_bins = self .discretizer .classes_
238
+ self ._mediator_considered_discrete = True
239
+ else :
240
+ self .discretizer = KMeans (n_clusters = 10 , random_state = 42 , n_init = "auto" )
241
+ self .discretizer .fit (m )
242
+ self .mediator_bins = self .discretizer .cluster_centers_
243
+ self ._mediator_considered_discrete = False
227
244
228
245
def _fit_treatment_propensity_x (self , t , x ):
229
246
"""Fits the nuisance parameter for the propensity P(T=1|X)"""
@@ -239,10 +256,12 @@ def _fit_treatment_propensity_xm(self, t, m, x):
239
256
return self
240
257
241
258
def _fit_mediator_probability (self , t , m , x ):
242
- if not is_array_integer (m ):
259
+ self ._fit_mediator_discretizer (m )
260
+ if not self ._mediator_considered_discrete :
243
261
self ._fit_mediator_density (t , m , x )
244
262
else :
245
- self ._fit_discrete_mediator_probability (t , m , x )
263
+ m_label , m_discrete_value = self ._discretize_mediators (m )
264
+ self ._fit_discrete_mediator_probability (t , m_label , x )
246
265
247
266
def _fit_discrete_mediator_probability (self , t , m , x ):
248
267
"""Fits the nuisance parameter for the density f(M=m|T, X)"""
@@ -327,7 +346,7 @@ def _fit_cross_conditional_mean_outcome(self, t, m, x, y):
327
346
328
347
return self
329
348
330
- def _estimate_discrete_mediator_probability (self , x , m ):
349
+ def _estimate_discrete_mediator_probability (self , x , m_label ):
331
350
"""
332
351
Estimate mediator density P(M=m|T,X) for a binary M
333
352
@@ -343,13 +362,12 @@ def _estimate_discrete_mediator_probability(self, x, m):
343
362
t0 = np .zeros ((n , 1 ))
344
363
t1 = np .ones ((n , 1 ))
345
364
346
- m = m .ravel ()
347
365
348
366
t0_x = np .hstack ([t0 .reshape (- 1 , 1 ), x ])
349
367
t1_x = np .hstack ([t1 .reshape (- 1 , 1 ), x ])
350
368
351
- f_m0x = self ._classifier_m .predict_proba (t0_x )[np .arange (m .shape [0 ]), m ]
352
- f_m1x = self ._classifier_m .predict_proba (t1_x )[np .arange (m .shape [0 ]), m ]
369
+ f_m0x = self ._classifier_m .predict_proba (t0_x )[np .arange (m_label .shape [0 ]), m_label ]
370
+ f_m1x = self ._classifier_m .predict_proba (t1_x )[np .arange (m_label .shape [0 ]), m_label ]
353
371
354
372
return f_m0x , f_m1x
355
373
@@ -380,11 +398,11 @@ def _estimate_mediator_density(self, x, m):
380
398
return f_m0x , f_m1x
381
399
382
400
def _estimate_mediator_probability (self , x , m ):
383
-
384
- if not is_array_integer (m ):
401
+ if not self ._mediator_considered_discrete :
385
402
return self ._estimate_mediator_density (x , m )
386
403
else :
387
- return self ._estimate_discrete_mediator_probability (x , m )
404
+ m_label , m_discrete_value = self ._discretize_mediators (m )
405
+ return self ._estimate_discrete_mediator_probability (x , m_label )
388
406
389
407
def _estimate_discrete_mediator_probability_table (self , x ):
390
408
"""
@@ -411,9 +429,9 @@ def _estimate_discrete_mediator_probability_table(self, x):
411
429
fm_0 = self ._classifier_m .predict_proba (t0_x )
412
430
fm_1 = self ._classifier_m .predict_proba (t1_x )
413
431
414
- for m in self .mediator_bins :
415
- f_0x .append (fm_0 [:, m ])
416
- f_1x .append (fm_1 [:, m ])
432
+ for idx , m_anchor in enumerate ( self .mediator_bins ) :
433
+ f_0x .append (fm_0 [:, idx ])
434
+ f_1x .append (fm_1 [:, idx ])
417
435
418
436
return f_0x , f_1x
419
437
@@ -545,6 +563,10 @@ def _estimate_cross_conditional_mean_outcome(self, m, x):
545
563
546
564
def _discretize_mediators (self , m ):
547
565
"""Discretize mediators clustering if they are not explicit."""
548
- if not is_array_integer (m ):
549
- m = np .expand_dims (self .discretizer .predict (m ), axis = - 1 )
550
- return m
566
+ if self ._mediator_considered_discrete :
567
+ m_label = self .discretizer .transform (m )
568
+ m_discrete_value = m
569
+ else :
570
+ m_label = self .discretizer .predict (m )
571
+ m_discrete_value = self .discretizer .cluster_centers_ [m_label , :]
572
+ return m_label , m_discrete_value
0 commit comments