|
| 1 | +from abc import ABCMeta, abstractmethod |
| 2 | +import numpy as np |
| 3 | +from sklearn import clone |
| 4 | +from sklearn.model_selection import GridSearchCV |
| 5 | + |
| 6 | +from med_bench.utils.decorators import fitted |
| 7 | + |
| 8 | + |
| 9 | +class Estimator: |
| 10 | + """General abstract class for causal mediation Estimator""" |
| 11 | + |
| 12 | + __metaclass__ = ABCMeta |
| 13 | + |
| 14 | + def __init__(self, verbose: bool = True, crossfit: int = 0): |
| 15 | + """Initializes Estimator base class |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + verbose : bool |
| 20 | + will print some logs if True |
| 21 | + crossfit : int |
| 22 | + number of crossfit folds, if 0 no crossfit is performed |
| 23 | + """ |
| 24 | + self._crossfit = crossfit |
| 25 | + self._crossfit_check() |
| 26 | + |
| 27 | + self._verbose = verbose |
| 28 | + |
| 29 | + self._fitted = False |
| 30 | + |
| 31 | + @property |
| 32 | + def verbose(self): |
| 33 | + return self._verbose |
| 34 | + |
| 35 | + def _crossfit_check(self): |
| 36 | + """Checks if the estimator inputs are valid""" |
| 37 | + if self._crossfit > 0: |
| 38 | + raise NotImplementedError( |
| 39 | + """Crossfit is not implemented yet |
| 40 | + You should perform something like this on your side : |
| 41 | + cf_iterator = KFold(k=5) |
| 42 | + for data_train, data_test in cf_iterator: |
| 43 | + result.append(DML(...., cross_fitting=False) |
| 44 | + .fit(train_data.X, train_data.t, train_data.m, train_data.y)\ |
| 45 | + .estimate(test_data.X, test_data.t, test_data.m, test_data.y)) |
| 46 | + np.mean(result)""" |
| 47 | + ) |
| 48 | + |
| 49 | + @abstractmethod |
| 50 | + def fit(self, t, m, x, y): |
| 51 | + """Fits nuisance parameters to data |
| 52 | +
|
| 53 | + Parameters |
| 54 | + ---------- |
| 55 | + t array-like, shape (n_samples) |
| 56 | + treatment value for each unit, binary |
| 57 | +
|
| 58 | + m array-like, shape (n_samples) |
| 59 | + mediator value for each unit, here m is necessary binary and uni- |
| 60 | + dimensional |
| 61 | +
|
| 62 | + x array-like, shape (n_samples, n_features_covariates) |
| 63 | + covariates (potential confounders) values |
| 64 | +
|
| 65 | + y array-like, shape (n_samples) |
| 66 | + outcome value for each unit, continuous |
| 67 | +
|
| 68 | + """ |
| 69 | + pass |
| 70 | + |
| 71 | + @abstractmethod |
| 72 | + @fitted |
| 73 | + def estimate(self, t, m, x, y): |
| 74 | + """Estimates causal effect on data |
| 75 | +
|
| 76 | + Parameters |
| 77 | + ---------- |
| 78 | + t array-like, shape (n_samples) |
| 79 | + treatment value for each unit, binary |
| 80 | +
|
| 81 | + m array-like, shape (n_samples) |
| 82 | + mediator value for each unit, here m is necessary binary and uni- |
| 83 | + dimensional |
| 84 | +
|
| 85 | + x array-like, shape (n_samples, n_features_covariates) |
| 86 | + covariates (potential confounders) values |
| 87 | +
|
| 88 | + y array-like, shape (n_samples) |
| 89 | + outcome value for each unit, continuous |
| 90 | +
|
| 91 | + nuisances |
| 92 | + """ |
| 93 | + pass |
| 94 | + |
| 95 | + def _resize(self, t, m, x, y): |
| 96 | + """Resize data for the right shape |
| 97 | +
|
| 98 | + Parameters |
| 99 | + ---------- |
| 100 | + t array-like, shape (n_samples) |
| 101 | + treatment value for each unit, binary |
| 102 | +
|
| 103 | + m array-like, shape (n_samples) |
| 104 | + mediator value for each unit, here m is necessary binary and uni- |
| 105 | + dimensional |
| 106 | +
|
| 107 | + x array-like, shape (n_samples, n_features_covariates) |
| 108 | + covariates (potential confounders) values |
| 109 | +
|
| 110 | + y array-like, shape (n_samples) |
| 111 | + outcome value for each unit, continuous |
| 112 | + """ |
| 113 | + if len(y) != len(y.ravel()): |
| 114 | + raise ValueError("Multidimensional y is not supported") |
| 115 | + if len(t) != len(t.ravel()): |
| 116 | + raise ValueError("Multidimensional t is not supported") |
| 117 | + |
| 118 | + n = len(y) |
| 119 | + if len(x.shape) == 1: |
| 120 | + x.reshape(n, 1) |
| 121 | + if len(m.shape) == 1: |
| 122 | + m = m.reshape(n, 1) |
| 123 | + |
| 124 | + if n != len(x) or n != len(m) or n != len(t): |
| 125 | + raise ValueError("Inputs don't have the same number of observations") |
| 126 | + |
| 127 | + y = y.ravel() |
| 128 | + t = t.ravel() |
| 129 | + |
| 130 | + return t, m, x, y |
| 131 | + |
| 132 | + def _fit_treatment_propensity_x(self, t, x): |
| 133 | + """Fits the nuisance parameter for the propensity P(T=1|X)""" |
| 134 | + self._classifier_t_x = clone(self.classifier).fit(x, t) |
| 135 | + |
| 136 | + return self |
| 137 | + |
| 138 | + def _fit_treatment_propensity_xm(self, t, m, x): |
| 139 | + """Fits the nuisance parameter for the propensity P(T=1|X, M)""" |
| 140 | + xm = np.hstack((x, m)) |
| 141 | + self._classifier_t_xm = clone(self.classifier).fit(xm, t) |
| 142 | + |
| 143 | + return self |
| 144 | + |
| 145 | + # TODO : Enable any sklearn object as classifier or regressor |
| 146 | + def _fit_binary_mediator_probability(self, t, m, x): |
| 147 | + """Fits the nuisance parameter for the density f(M=m|T, X)""" |
| 148 | + # estimate mediator densities |
| 149 | + t_x = np.hstack([t.reshape(-1, 1), x]) |
| 150 | + |
| 151 | + # Fit classifier |
| 152 | + self._classifier_m = clone(self.classifier).fit(t_x, m.ravel()) |
| 153 | + |
| 154 | + return self |
| 155 | + |
| 156 | + def _fit_conditional_mean_outcome(self, t, m, x, y): |
| 157 | + """Fits the nuisance for the conditional mean outcome for the density f(M=m|T, X)""" |
| 158 | + x_t_m = np.hstack([x, t.reshape(-1, 1), m]) |
| 159 | + self._regressor_y = clone(self.regressor).fit(x_t_m, y) |
| 160 | + |
| 161 | + return self |
| 162 | + |
| 163 | + def _fit_cross_conditional_mean_outcome(self, t, m, x, y): |
| 164 | + """Fits the cross conditional mean outcome E[E[Y|T=t,M,X]|T=t',X]""" |
| 165 | + |
| 166 | + xm = np.hstack((x, m)) |
| 167 | + |
| 168 | + n = t.shape[0] |
| 169 | + train = np.arange(n) |
| 170 | + ( |
| 171 | + mu_1mx_nested, # E[Y|T=1,M,X] predicted on train_nested set |
| 172 | + mu_0mx_nested, # E[Y|T=0,M,X] predicted on train_nested set |
| 173 | + ) = [np.zeros(n) for _ in range(2)] |
| 174 | + |
| 175 | + train1 = train[t[train] == 1] |
| 176 | + train0 = train[t[train] == 0] |
| 177 | + |
| 178 | + train_mean, train_nested = np.array_split(train, 2) |
| 179 | + train_mean1 = train_mean[t[train_mean] == 1] |
| 180 | + train_mean0 = train_mean[t[train_mean] == 0] |
| 181 | + train_nested1 = train_nested[t[train_nested] == 1] |
| 182 | + train_nested0 = train_nested[t[train_nested] == 0] |
| 183 | + |
| 184 | + self.regressors = {} |
| 185 | + |
| 186 | + # predict E[Y|T=1,M,X] |
| 187 | + self.regressors["y_t1_mx"] = clone(self.regressor) |
| 188 | + self.regressors["y_t1_mx"].fit(xm[train_mean1], y[train_mean1]) |
| 189 | + mu_1mx_nested[train_nested] = self.regressors["y_t1_mx"].predict( |
| 190 | + xm[train_nested] |
| 191 | + ) |
| 192 | + |
| 193 | + # predict E[Y|T=0,M,X] |
| 194 | + self.regressors["y_t0_mx"] = clone(self.regressor) |
| 195 | + self.regressors["y_t0_mx"].fit(xm[train_mean0], y[train_mean0]) |
| 196 | + mu_0mx_nested[train_nested] = self.regressors["y_t0_mx"].predict( |
| 197 | + xm[train_nested] |
| 198 | + ) |
| 199 | + |
| 200 | + # predict E[E[Y|T=1,M,X]|T=0,X] |
| 201 | + self.regressors["y_t1_x_t0"] = clone(self.regressor) |
| 202 | + self.regressors["y_t1_x_t0"].fit(x[train_nested0], mu_1mx_nested[train_nested0]) |
| 203 | + |
| 204 | + # predict E[E[Y|T=0,M,X]|T=1,X] |
| 205 | + self.regressors["y_t0_x_t1"] = clone(self.regressor) |
| 206 | + self.regressors["y_t0_x_t1"].fit(x[train_nested1], mu_0mx_nested[train_nested1]) |
| 207 | + |
| 208 | + # predict E[Y|T=1,X] |
| 209 | + self.regressors["y_t1_x"] = clone(self.regressor) |
| 210 | + self.regressors["y_t1_x"].fit(x[train1], y[train1]) |
| 211 | + |
| 212 | + # predict E[Y|T=0,X] |
| 213 | + self.regressors["y_t0_x"] = clone(self.regressor) |
| 214 | + self.regressors["y_t0_x"].fit(x[train0], y[train0]) |
| 215 | + |
| 216 | + return self |
| 217 | + |
| 218 | + def _estimate_binary_mediator_probability(self, x, m): |
| 219 | + """ |
| 220 | + Estimate mediator density P(M=m|T,X) for a binary M |
| 221 | +
|
| 222 | + Returns |
| 223 | + ------- |
| 224 | + f_m0x, array-like, shape (n_samples) |
| 225 | + probabilities f(M|T=0,X) |
| 226 | + f_m1x, array-like, shape (n_samples) |
| 227 | + probabilities f(M|T=1,X) |
| 228 | + """ |
| 229 | + n = x.shape[0] |
| 230 | + |
| 231 | + t0 = np.zeros((n, 1)) |
| 232 | + t1 = np.ones((n, 1)) |
| 233 | + |
| 234 | + m = m.ravel() |
| 235 | + |
| 236 | + t0_x = np.hstack([t0.reshape(-1, 1), x]) |
| 237 | + t1_x = np.hstack([t1.reshape(-1, 1), x]) |
| 238 | + |
| 239 | + f_m0x = self._classifier_m.predict_proba(t0_x)[np.arange(m.shape[0]), m] |
| 240 | + f_m1x = self._classifier_m.predict_proba(t1_x)[np.arange(m.shape[0]), m] |
| 241 | + |
| 242 | + return f_m0x, f_m1x |
| 243 | + |
| 244 | + def _estimate_binary_mediator_probability_table(self, x): |
| 245 | + """ |
| 246 | + Estimate mediator density f(M|T,X) |
| 247 | +
|
| 248 | + Returns |
| 249 | + ------- |
| 250 | + f_00x: array-like, shape (n_samples) |
| 251 | + probabilities f(M=0|T=0,X) |
| 252 | + f_01x, array-like, shape (n_samples) |
| 253 | + probabilities f(M=0|T=1,X) |
| 254 | + f_10x, array-like, shape (n_samples) |
| 255 | + probabilities f(M=1|T=0,X) |
| 256 | + f_11x, array-like, shape (n_samples) |
| 257 | + probabilities f(M=1|T=1,X) |
| 258 | + """ |
| 259 | + n = x.shape[0] |
| 260 | + |
| 261 | + t0 = np.zeros((n, 1)) |
| 262 | + t1 = np.ones((n, 1)) |
| 263 | + |
| 264 | + t0_x = np.hstack([t0.reshape(-1, 1), x]) |
| 265 | + t1_x = np.hstack([t1.reshape(-1, 1), x]) |
| 266 | + |
| 267 | + # predict f(M=m|T=t,X) |
| 268 | + fm_0 = self._classifier_m.predict_proba(t0_x) |
| 269 | + f_00x = fm_0[:, 0] |
| 270 | + f_01x = fm_0[:, 1] |
| 271 | + fm_1 = self._classifier_m.predict_proba(t1_x) |
| 272 | + f_10x = fm_1[:, 0] |
| 273 | + f_11x = fm_1[:, 1] |
| 274 | + |
| 275 | + return f_00x, f_01x, f_10x, f_11x |
| 276 | + |
| 277 | + def _estimate_treatment_propensity_x(self, x): |
| 278 | + """ |
| 279 | + Estimate treatment propensity P(T=1|X) |
| 280 | +
|
| 281 | + Returns |
| 282 | + ------- |
| 283 | + p_x : array-like, shape (n_samples) |
| 284 | + probabilities P(T=1|X) |
| 285 | + """ |
| 286 | + p_x = self._classifier_t_x.predict_proba(x)[:, 1] |
| 287 | + |
| 288 | + return p_x |
| 289 | + |
| 290 | + def _estimate_treatment_propensity_xm(self, m, x): |
| 291 | + """ |
| 292 | + Estimate treatment probabilities P(T=1|X) and P(T=1|X, M) with train |
| 293 | +
|
| 294 | + Returns |
| 295 | + ------- |
| 296 | + p_x : array-like, shape (n_samples) |
| 297 | + probabilities P(T=1|X) |
| 298 | + p_xm : array-like, shape (n_samples) |
| 299 | + probabilities P(T=1|X, M) |
| 300 | + """ |
| 301 | + xm = np.hstack((x, m)) |
| 302 | + |
| 303 | + p_xm = self._classifier_t_xm.predict_proba(xm)[:, 1] |
| 304 | + |
| 305 | + return p_xm |
| 306 | + |
| 307 | + def _estimate_cross_conditional_mean_outcome(self, m, x): |
| 308 | + """ |
| 309 | + Estimate the conditional mean outcome, |
| 310 | + the cross conditional mean outcome |
| 311 | +
|
| 312 | + Returns |
| 313 | + ------- |
| 314 | + mu_m0x, array-like, shape (n_samples) |
| 315 | + conditional mean outcome estimates E[Y|T=0,M,X] |
| 316 | + mu_m1x, array-like, shape (n_samples) |
| 317 | + conditional mean outcome estimates E[Y|T=1,M,X] |
| 318 | + mu_0x, array-like, shape (n_samples) |
| 319 | + cross conditional mean outcome estimates E[E[Y|T=0,M,X]|T=0,X] |
| 320 | + E_mu_t0_t1, array-like, shape (n_samples) |
| 321 | + cross conditional mean outcome estimates E[E[Y|T=0,M,X]|T=1,X] |
| 322 | + E_mu_t1_t0, array-like, shape (n_samples) |
| 323 | + cross conditional mean outcome estimates E[E[Y|T=1,M,X]|T=0,X] |
| 324 | + mu_1x, array-like, shape (n_samples) |
| 325 | + cross conditional mean outcome estimates E[E[Y|T=1,M,X]|T=1,X] |
| 326 | + """ |
| 327 | + xm = np.hstack((x, m)) |
| 328 | + |
| 329 | + # predict E[Y|T=1,M,X] |
| 330 | + mu_1mx = self.regressors["y_t1_mx"].predict(xm) |
| 331 | + |
| 332 | + # predict E[Y|T=0,M,X] |
| 333 | + mu_0mx = self.regressors["y_t0_mx"].predict(xm) |
| 334 | + |
| 335 | + # predict E[E[Y|T=1,M,X]|T=0,X] |
| 336 | + E_mu_t1_t0 = self.regressors["y_t1_x_t0"].predict(x) |
| 337 | + |
| 338 | + # predict E[E[Y|T=0,M,X]|T=1,X] |
| 339 | + E_mu_t0_t1 = self.regressors["y_t0_x_t1"].predict(x) |
| 340 | + |
| 341 | + # predict E[Y|T=1,X] |
| 342 | + mu_1x = self.regressors["y_t1_x"].predict(x) |
| 343 | + |
| 344 | + # predict E[Y|T=0,X] |
| 345 | + mu_0x = self.regressors["y_t0_x"].predict(x) |
| 346 | + |
| 347 | + return mu_0mx, mu_1mx, mu_0x, E_mu_t0_t1, E_mu_t1_t0, mu_1x |
0 commit comments