18
18
_estimate_mediator_density ,
19
19
_estimate_treatment_probabilities ,
20
20
_get_classifier , _get_regressor )
21
- from .utils .utils import r_dependency_required
21
+ from .utils .utils import r_dependency_required , _check_input
22
22
23
23
ALPHAS = np .logspace (- 5 , 5 , 8 )
24
24
CV_FOLDS = 5
@@ -90,6 +90,9 @@ def mediation_IPW(y, t, m, x, trim, regularization=True, forest=False,
90
90
int
91
91
number of used observations (non trimmed)
92
92
"""
93
+ # check input
94
+ y , t , m , x = _check_input (y , t , m , x , setting = 'multidimensional' )
95
+
93
96
# estimate propensities
94
97
classifier_t_x = _get_classifier (regularization , forest , calibration )
95
98
classifier_t_xm = _get_classifier (regularization , forest , calibration )
@@ -179,12 +182,13 @@ def mediation_coefficient_product(y, t, m, x, interaction=False,
179
182
alphas = ALPHAS
180
183
else :
181
184
alphas = [TINY ]
182
- if len ( x . shape ) == 1 :
183
- x = x . reshape ( - 1 , 1 )
184
- if len ( m . shape ) == 1 :
185
- m = m . reshape ( - 1 , 1 )
185
+
186
+ # check input
187
+ y , t , m , x = _check_input ( y , t , m , x , setting = 'multidimensional' )
188
+
186
189
if len (t .shape ) == 1 :
187
190
t = t .reshape (- 1 , 1 )
191
+
188
192
coef_t_m = np .zeros (m .shape [1 ])
189
193
for i in range (m .shape [1 ]):
190
194
m_reg = RidgeCV (alphas = alphas , cv = CV_FOLDS )\
@@ -248,17 +252,20 @@ def mediation_g_formula(y, t, m, x, interaction=False, forest=False,
248
252
calibration : str, default=sigmoid
249
253
calibration mode; for example using a sigmoid function
250
254
"""
255
+ # check input
256
+ y , t , m , x = _check_input (y , t , m , x , setting = 'binary' )
257
+
251
258
# estimate mediator densities
252
259
classifier_m = _get_classifier (regularization , forest , calibration )
253
- f_00x , f_01x , f_10x , f_11x , _ , _ = _estimate_mediator_density (t , m , x , y ,
260
+ f_00x , f_01x , f_10x , f_11x , _ , _ = _estimate_mediator_density (y , t , m , x ,
254
261
crossfit ,
255
262
classifier_m ,
256
263
interaction )
257
264
258
265
# estimate conditional mean outcomes
259
266
regressor_y = _get_regressor (regularization , forest )
260
267
mu_00x , mu_01x , mu_10x , mu_11x , _ , _ = (
261
- _estimate_conditional_mean_outcome (t , m , x , y , crossfit , regressor_y ,
268
+ _estimate_conditional_mean_outcome (y , t , m , x , crossfit , regressor_y ,
262
269
interaction ))
263
270
264
271
# G computation
@@ -319,10 +326,9 @@ def alternative_estimator(y, t, m, x, regularization=True):
319
326
alphas = ALPHAS
320
327
else :
321
328
alphas = [TINY ]
322
- if len (x .shape ) == 1 :
323
- x = x .reshape (- 1 , 1 )
324
- if len (m .shape ) == 1 :
325
- m = m .reshape (- 1 , 1 )
329
+
330
+ # check input
331
+ y , t , m , x = _check_input (y , t , m , x , setting = 'multidimensional' )
326
332
treated = (t == 1 )
327
333
328
334
# computation of direct effect
@@ -433,29 +439,9 @@ def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False,
433
439
- If x, t, m, or y don't have the same length.
434
440
- If m is not binary.
435
441
"""
436
- # Format checking
437
- if len (y ) != len (y .ravel ()):
438
- raise ValueError ("Multidimensional y is not supported" )
439
- if len (t ) != len (t .ravel ()):
440
- raise ValueError ("Multidimensional t is not supported" )
441
- if len (m ) != len (m .ravel ()):
442
- raise ValueError ("Multidimensional m is not supported" )
443
-
444
- n = len (y )
445
- if len (x .shape ) == 1 :
446
- x .reshape (n , 1 )
447
- if len (m .shape ) == 1 :
448
- m .reshape (n , 1 )
449
-
450
- dim_m = m .shape [1 ]
451
- if n * dim_m != sum (m .ravel () == 1 ) + sum (m .ravel () == 0 ):
452
- raise ValueError ("m is not binary" )
442
+ # check input
443
+ y , t , m , x = _check_input (y , t , m , x , setting = 'binary' )
453
444
454
- y = y .ravel ()
455
- t = t .ravel ()
456
- m = m .ravel ()
457
- if n != len (x ) or n != len (m ) or n != len (t ):
458
- raise ValueError ("Inputs don't have the same number of observations" )
459
445
460
446
# estimate propensities
461
447
classifier_t_x = _get_classifier (regularization , forest , calibration )
@@ -466,15 +452,15 @@ def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False,
466
452
# estimate mediator densities
467
453
classifier_m = _get_classifier (regularization , forest , calibration )
468
454
f_00x , f_01x , f_10x , f_11x , f_m0x , f_m1x = (
469
- _estimate_mediator_density (t , m , x , y , crossfit ,
455
+ _estimate_mediator_density (y , t , m , x , crossfit ,
470
456
classifier_m , interaction ))
471
457
f = f_00x , f_01x , f_10x , f_11x
472
458
473
459
# estimate conditional mean outcomes
474
460
regressor_y = _get_regressor (regularization , forest )
475
461
regressor_cross_y = _get_regressor (regularization , forest )
476
462
mu_0mx , mu_1mx , E_mu_t0_t0 , E_mu_t0_t1 , E_mu_t1_t0 , E_mu_t1_t1 = (
477
- _estimate_cross_conditional_mean_outcome (t , m , x , y , crossfit ,
463
+ _estimate_cross_conditional_mean_outcome (y , t , m , x , crossfit ,
478
464
regressor_y ,
479
465
regressor_cross_y , f ,
480
466
interaction ))
@@ -574,7 +560,10 @@ def r_mediate(y, t, m, x, interaction=False):
574
560
Rstats = rpackages .importr ('stats' )
575
561
base = rpackages .importr ('base' )
576
562
563
+ # check input
564
+ y , t , m , x = _check_input (y , t , m , x , setting = 'binary' )
577
565
m = m .ravel ()
566
+
578
567
var_names = [[y , 'y' ],
579
568
[t , 't' ],
580
569
[m , 'm' ],
@@ -629,7 +618,10 @@ def r_mediation_g_estimator(y, t, m, x):
629
618
plmed = rpackages .importr ('plmed' )
630
619
base = rpackages .importr ('base' )
631
620
621
+ # check input
622
+ y , t , m , x = _check_input (y , t , m , x , setting = 'binary' )
632
623
m = m .ravel ()
624
+
633
625
var_names = [[y , 'y' ],
634
626
[t , 't' ],
635
627
[m , 'm' ],
@@ -713,6 +705,9 @@ def r_mediation_dml(y, t, m, x, trim=0.05, order=1):
713
705
causalweight = rpackages .importr ('causalweight' )
714
706
base = rpackages .importr ('base' )
715
707
708
+ # check input
709
+ y , t , m , x = _check_input (y , t , m , x , setting = 'multidimensional' )
710
+
716
711
x_r , t_r , m_r , y_r = [base .as_matrix (_convert_array_to_R (uu )) for uu in
717
712
(x , t , m , y )]
718
713
res = causalweight .medDML (y_r , t_r , m_r , x_r , trim = trim , order = order )
@@ -805,25 +800,9 @@ def mediation_dml(y, t, m, x, forest=False, crossfit=0, trim=0.05, clip=1e-6,
805
800
- If t or y are multidimensional.
806
801
- If x, t, m, or y don't have the same length.
807
802
"""
808
- # check format
809
- if len (y ) != len (y .ravel ()):
810
- raise ValueError ("Multidimensional y is not supported" )
811
-
812
- if len (t ) != len (t .ravel ()):
813
- raise ValueError ("Multidimensional t is not supported" )
814
-
803
+ # check input
804
+ y , t , m , x = _check_input (y , t , m , x , setting = 'multidimensional' )
815
805
n = len (y )
816
- t = t .ravel ()
817
- y = y .ravel ()
818
-
819
- if n != len (x ) or n != len (m ) or n != len (t ):
820
- raise ValueError ("Inputs don't have the same number of observations" )
821
-
822
- if len (x .shape ) == 1 :
823
- x .reshape (n , 1 )
824
-
825
- if len (m .shape ) == 1 :
826
- m .reshape (n , 1 )
827
806
828
807
nobs = 0
829
808
@@ -850,7 +829,7 @@ def mediation_dml(y, t, m, x, forest=False, crossfit=0, trim=0.05, clip=1e-6,
850
829
regressor_cross_y = _get_regressor (regularization , forest )
851
830
852
831
mu_0mx , mu_1mx , E_mu_t0_t0 , E_mu_t0_t1 , E_mu_t1_t0 , E_mu_t1_t1 = (
853
- _estimate_cross_conditional_mean_outcome_nesting (t , m , x , y , crossfit ,
832
+ _estimate_cross_conditional_mean_outcome_nesting (y , t , m , x , crossfit ,
854
833
regressor_y ,
855
834
regressor_cross_y ))
856
835
0 commit comments