Skip to content

Commit 8530e26

Browse files
authored
Input check (#81)
enforce input check at the function level to avoid issues with input shape, fix ci with R Co-authored-by: houssamzenati <housszenati@gmail.com>
1 parent 43c118b commit 8530e26

File tree

8 files changed

+191
-75
lines changed

8 files changed

+191
-75
lines changed

.github/workflows/code-cov.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ jobs:
3939
dependencies: 'NA'
4040
install-pandoc: false
4141
packages: |
42+
Matrix@1.6-5
43+
MASS@7.3-60
4244
grf
4345
causalweight
4446
mediation
@@ -53,6 +55,7 @@ jobs:
5355
5456
- name: Run tests with coverage
5557
run: |
58+
export LD_LIBRARY_PATH=$(python -m rpy2.situation LD_LIBRARY_PATH):${LD_LIBRARY_PATH}
5659
pytest --cov=med_bench --cov-report=xml
5760
5861
- name: Upload coverage to Codecov

.github/workflows/tests-with-R.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ jobs:
3939
dependencies: 'NA'
4040
install-pandoc: false
4141
packages: |
42+
Matrix@1.6-5
43+
MASS@7.3-60
4244
grf
4345
causalweight
4446
mediation
@@ -53,4 +55,5 @@ jobs:
5355
5456
- name: Run tests
5557
run: |
56-
pytest
58+
export LD_LIBRARY_PATH=$(python -m rpy2.situation LD_LIBRARY_PATH):${LD_LIBRARY_PATH}
59+
pytest

src/med_bench/mediation.py

+33-54
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
_estimate_mediator_density,
1919
_estimate_treatment_probabilities,
2020
_get_classifier, _get_regressor)
21-
from .utils.utils import r_dependency_required
21+
from .utils.utils import r_dependency_required, _check_input
2222

2323
ALPHAS = np.logspace(-5, 5, 8)
2424
CV_FOLDS = 5
@@ -90,6 +90,9 @@ def mediation_IPW(y, t, m, x, trim, regularization=True, forest=False,
9090
int
9191
number of used observations (non trimmed)
9292
"""
93+
# check input
94+
y, t, m, x = _check_input(y, t, m, x, setting='multidimensional')
95+
9396
# estimate propensities
9497
classifier_t_x = _get_classifier(regularization, forest, calibration)
9598
classifier_t_xm = _get_classifier(regularization, forest, calibration)
@@ -179,12 +182,13 @@ def mediation_coefficient_product(y, t, m, x, interaction=False,
179182
alphas = ALPHAS
180183
else:
181184
alphas = [TINY]
182-
if len(x.shape) == 1:
183-
x = x.reshape(-1, 1)
184-
if len(m.shape) == 1:
185-
m = m.reshape(-1, 1)
185+
186+
# check input
187+
y, t, m, x = _check_input(y, t, m, x, setting='multidimensional')
188+
186189
if len(t.shape) == 1:
187190
t = t.reshape(-1, 1)
191+
188192
coef_t_m = np.zeros(m.shape[1])
189193
for i in range(m.shape[1]):
190194
m_reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)\
@@ -248,17 +252,20 @@ def mediation_g_formula(y, t, m, x, interaction=False, forest=False,
248252
calibration : str, default=sigmoid
249253
calibration mode; for example using a sigmoid function
250254
"""
255+
# check input
256+
y, t, m, x = _check_input(y, t, m, x, setting='binary')
257+
251258
# estimate mediator densities
252259
classifier_m = _get_classifier(regularization, forest, calibration)
253-
f_00x, f_01x, f_10x, f_11x, _, _ = _estimate_mediator_density(t, m, x, y,
260+
f_00x, f_01x, f_10x, f_11x, _, _ = _estimate_mediator_density(y, t, m, x,
254261
crossfit,
255262
classifier_m,
256263
interaction)
257264

258265
# estimate conditional mean outcomes
259266
regressor_y = _get_regressor(regularization, forest)
260267
mu_00x, mu_01x, mu_10x, mu_11x, _, _ = (
261-
_estimate_conditional_mean_outcome(t, m, x, y, crossfit, regressor_y,
268+
_estimate_conditional_mean_outcome(y, t, m, x, crossfit, regressor_y,
262269
interaction))
263270

264271
# G computation
@@ -319,10 +326,9 @@ def alternative_estimator(y, t, m, x, regularization=True):
319326
alphas = ALPHAS
320327
else:
321328
alphas = [TINY]
322-
if len(x.shape) == 1:
323-
x = x.reshape(-1, 1)
324-
if len(m.shape) == 1:
325-
m = m.reshape(-1, 1)
329+
330+
# check input
331+
y, t, m, x = _check_input(y, t, m, x, setting='multidimensional')
326332
treated = (t == 1)
327333

328334
# computation of direct effect
@@ -433,29 +439,9 @@ def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False,
433439
- If x, t, m, or y don't have the same length.
434440
- If m is not binary.
435441
"""
436-
# Format checking
437-
if len(y) != len(y.ravel()):
438-
raise ValueError("Multidimensional y is not supported")
439-
if len(t) != len(t.ravel()):
440-
raise ValueError("Multidimensional t is not supported")
441-
if len(m) != len(m.ravel()):
442-
raise ValueError("Multidimensional m is not supported")
443-
444-
n = len(y)
445-
if len(x.shape) == 1:
446-
x.reshape(n, 1)
447-
if len(m.shape) == 1:
448-
m.reshape(n, 1)
449-
450-
dim_m = m.shape[1]
451-
if n * dim_m != sum(m.ravel() == 1) + sum(m.ravel() == 0):
452-
raise ValueError("m is not binary")
442+
# check input
443+
y, t, m, x = _check_input(y, t, m, x, setting='binary')
453444

454-
y = y.ravel()
455-
t = t.ravel()
456-
m = m.ravel()
457-
if n != len(x) or n != len(m) or n != len(t):
458-
raise ValueError("Inputs don't have the same number of observations")
459445

460446
# estimate propensities
461447
classifier_t_x = _get_classifier(regularization, forest, calibration)
@@ -466,15 +452,15 @@ def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False,
466452
# estimate mediator densities
467453
classifier_m = _get_classifier(regularization, forest, calibration)
468454
f_00x, f_01x, f_10x, f_11x, f_m0x, f_m1x = (
469-
_estimate_mediator_density(t, m, x, y, crossfit,
455+
_estimate_mediator_density(y, t, m, x, crossfit,
470456
classifier_m, interaction))
471457
f = f_00x, f_01x, f_10x, f_11x
472458

473459
# estimate conditional mean outcomes
474460
regressor_y = _get_regressor(regularization, forest)
475461
regressor_cross_y = _get_regressor(regularization, forest)
476462
mu_0mx, mu_1mx, E_mu_t0_t0, E_mu_t0_t1, E_mu_t1_t0, E_mu_t1_t1 = (
477-
_estimate_cross_conditional_mean_outcome(t, m, x, y, crossfit,
463+
_estimate_cross_conditional_mean_outcome(y, t, m, x, crossfit,
478464
regressor_y,
479465
regressor_cross_y, f,
480466
interaction))
@@ -574,7 +560,10 @@ def r_mediate(y, t, m, x, interaction=False):
574560
Rstats = rpackages.importr('stats')
575561
base = rpackages.importr('base')
576562

563+
# check input
564+
y, t, m, x = _check_input(y, t, m, x, setting='binary')
577565
m = m.ravel()
566+
578567
var_names = [[y, 'y'],
579568
[t, 't'],
580569
[m, 'm'],
@@ -629,7 +618,10 @@ def r_mediation_g_estimator(y, t, m, x):
629618
plmed = rpackages.importr('plmed')
630619
base = rpackages.importr('base')
631620

621+
# check input
622+
y, t, m, x = _check_input(y, t, m, x, setting='binary')
632623
m = m.ravel()
624+
633625
var_names = [[y, 'y'],
634626
[t, 't'],
635627
[m, 'm'],
@@ -713,6 +705,9 @@ def r_mediation_dml(y, t, m, x, trim=0.05, order=1):
713705
causalweight = rpackages.importr('causalweight')
714706
base = rpackages.importr('base')
715707

708+
# check input
709+
y, t, m, x = _check_input(y, t, m, x, setting='multidimensional')
710+
716711
x_r, t_r, m_r, y_r = [base.as_matrix(_convert_array_to_R(uu)) for uu in
717712
(x, t, m, y)]
718713
res = causalweight.medDML(y_r, t_r, m_r, x_r, trim=trim, order=order)
@@ -805,25 +800,9 @@ def mediation_dml(y, t, m, x, forest=False, crossfit=0, trim=0.05, clip=1e-6,
805800
- If t or y are multidimensional.
806801
- If x, t, m, or y don't have the same length.
807802
"""
808-
# check format
809-
if len(y) != len(y.ravel()):
810-
raise ValueError("Multidimensional y is not supported")
811-
812-
if len(t) != len(t.ravel()):
813-
raise ValueError("Multidimensional t is not supported")
814-
803+
# check input
804+
y, t, m, x = _check_input(y, t, m, x, setting='multidimensional')
815805
n = len(y)
816-
t = t.ravel()
817-
y = y.ravel()
818-
819-
if n != len(x) or n != len(m) or n != len(t):
820-
raise ValueError("Inputs don't have the same number of observations")
821-
822-
if len(x.shape) == 1:
823-
x.reshape(n, 1)
824-
825-
if len(m.shape) == 1:
826-
m.reshape(n, 1)
827806

828807
nobs = 0
829808

@@ -850,7 +829,7 @@ def mediation_dml(y, t, m, x, forest=False, crossfit=0, trim=0.05, clip=1e-6,
850829
regressor_cross_y = _get_regressor(regularization, forest)
851830

852831
mu_0mx, mu_1mx, E_mu_t0_t0, E_mu_t0_t1, E_mu_t1_t0, E_mu_t1_t1 = (
853-
_estimate_cross_conditional_mean_outcome_nesting(t, m, x, y, crossfit,
832+
_estimate_cross_conditional_mean_outcome_nesting(y, t, m, x, crossfit,
854833
regressor_y,
855834
regressor_cross_y))
856835

src/med_bench/utils/nuisances.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ def _estimate_treatment_probabilities(t, m, x, crossfit, clf_t_x, clf_t_xm):
119119

120120
p_x, p_xm = [np.zeros(n) for h in range(2)]
121121
# compute propensity scores
122-
if len(x.shape) == 1:
123-
x = x.reshape(-1, 1)
124-
if len(m.shape) == 1:
125-
m = m.reshape(-1, 1)
126122
if len(t.shape) == 1:
127123
t = t.reshape(-1, 1)
128124

@@ -143,7 +139,7 @@ def _estimate_treatment_probabilities(t, m, x, crossfit, clf_t_x, clf_t_xm):
143139
return p_x, p_xm
144140

145141

146-
def _estimate_mediator_density(t, m, x, y, crossfit, clf_m, interaction):
142+
def _estimate_mediator_density(y, t, m, x, crossfit, clf_m, interaction):
147143
"""
148144
Estimate mediator density f(M|T,X)
149145
with train test lists from crossfitting
@@ -164,8 +160,6 @@ def _estimate_mediator_density(t, m, x, y, crossfit, clf_m, interaction):
164160
probabilities f(M|T=1,X)
165161
"""
166162
n = len(y)
167-
if len(x.shape) == 1:
168-
x = x.reshape(-1, 1)
169163

170164
if len(t.shape) == 1:
171165
t = t.reshape(-1, 1)
@@ -206,7 +200,7 @@ def _estimate_mediator_density(t, m, x, y, crossfit, clf_m, interaction):
206200
return f_00x, f_01x, f_10x, f_11x, f_m0x, f_m1x
207201

208202

209-
def _estimate_conditional_mean_outcome(t, m, x, y, crossfit, reg_y,
203+
def _estimate_conditional_mean_outcome(y, t, m, x, crossfit, reg_y,
210204
interaction):
211205
"""
212206
Estimate conditional mean outcome E[Y|T,M,X]
@@ -228,12 +222,7 @@ def _estimate_conditional_mean_outcome(t, m, x, y, crossfit, reg_y,
228222
conditional mean outcome estimates E[Y|T=1,M,X]
229223
"""
230224
n = len(y)
231-
if len(x.shape) == 1:
232-
x = x.reshape(-1, 1)
233-
if len(m.shape) == 1:
234-
mr = m.reshape(-1, 1)
235-
else:
236-
mr = np.copy(m)
225+
mr = np.copy(m)
237226
if len(t.shape) == 1:
238227
t = t.reshape(-1, 1)
239228

@@ -275,7 +264,7 @@ def _estimate_conditional_mean_outcome(t, m, x, y, crossfit, reg_y,
275264
return mu_00x, mu_01x, mu_10x, mu_11x, mu_0mx, mu_1mx
276265

277266

278-
def _estimate_cross_conditional_mean_outcome(t, m, x, y, crossfit, reg_y,
267+
def _estimate_cross_conditional_mean_outcome(y, t, m, x, crossfit, reg_y,
279268
reg_cross_y, f, interaction):
280269
"""
281270
Estimate the conditional mean outcome,
@@ -397,7 +386,7 @@ def _estimate_cross_conditional_mean_outcome(t, m, x, y, crossfit, reg_y,
397386
return mu_0mx, mu_1mx, E_mu_t0_t0, E_mu_t0_t1, E_mu_t1_t0, E_mu_t1_t1
398387

399388

400-
def _estimate_cross_conditional_mean_outcome_nesting(t, m, x, y, crossfit,
389+
def _estimate_cross_conditional_mean_outcome_nesting(y, t, m, x, crossfit,
401390
reg_y, reg_cross_y):
402391
"""
403392
Estimate treatment probabilities and the conditional mean outcome,

src/med_bench/utils/utils.py

+80
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pandas as pd
33

4+
45
import subprocess
56
import warnings
67

@@ -158,3 +159,82 @@ def _convert_array_to_R(x):
158159
elif len(x.shape) == 2:
159160
return robjects.r.matrix(robjects.FloatVector(x.ravel()),
160161
nrow=x.shape[0], byrow='TRUE')
162+
163+
164+
def _check_input(y, t, m, x, setting):
165+
"""
166+
internal function to check inputs. `_check_input` adjusts the dimension
167+
of the input (matrix or vectors), and raises an error
168+
- if the size of input is not adequate,
169+
- or if the type of input is not supported (cotinuous treatment or
170+
non-binary one-dimensional mediator if the specified setting parameter
171+
is binary)
172+
173+
Parameters
174+
----------
175+
y : array-like, shape (n_samples)
176+
Outcome value for each unit, continuous
177+
178+
t : array-like, shape (n_samples)
179+
Treatment value for each unit, binary
180+
181+
m : array-like, shape (n_samples, n_mediators)
182+
Mediator value for each unit, binary and unidimensional
183+
184+
x : array-like, shape (n_samples, n_features_covariates)
185+
Covariates value for each unit, continuous
186+
187+
setting : string
188+
('binary', 'continuous', 'multidimensional') value for the mediator
189+
190+
Returns
191+
-------
192+
y_converted : array-like, shape (n_samples,)
193+
Outcome value for each unit, continuous
194+
195+
t_converted : array-like, shape (n_samples,)
196+
Treatment value for each unit, binary
197+
198+
m_converted : array-like, shape (n_samples, n_mediators)
199+
Mediator value for each unit, binary and unidimensional
200+
201+
x_converted : array-like, shape (n_samples, n_features_covariates)
202+
Covariates value for each unit, continuous
203+
"""
204+
# check format
205+
if len(y) != len(y.ravel()):
206+
raise ValueError("Multidimensional y (outcome) is not supported")
207+
208+
if len(t) != len(t.ravel()):
209+
raise ValueError("Multidimensional t (exposure) is not supported")
210+
211+
if len(np.unique(t)) != 2:
212+
raise ValueError("Only a binary t (exposure) is supported")
213+
214+
n = len(y)
215+
t_converted = t.ravel()
216+
y_converted = y.ravel()
217+
218+
if n != len(x) or n != len(m) or n != len(t):
219+
raise ValueError("Inputs don't have the same number of observations")
220+
221+
if len(x.shape) == 1:
222+
x_converted = x.reshape(n, 1)
223+
else:
224+
x_converted = x
225+
226+
if len(m.shape) == 1:
227+
m_converted = m.reshape(n, 1)
228+
else:
229+
m_converted = m
230+
231+
if (m_converted.shape[1] >1) and (setting != 'multidimensional'):
232+
raise ValueError("Multidimensional m (mediator) is not supported")
233+
234+
if (setting == 'binary') and (len(np.unique(m)) != 2):
235+
raise ValueError(
236+
"Only a binary one-dimensional m (mediator) is supported")
237+
238+
return y_converted, t_converted, m_converted, x_converted
239+
240+

0 commit comments

Comments
 (0)