Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee committed Oct 13, 2024
1 parent 99fc59e commit 7aff1ce
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 141 deletions.
4 changes: 2 additions & 2 deletions tests/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
RANDOM_SEED = 42
N_SAMPLE = 1000
N_SAMPLE = 2000
ERROR_THRESHOLD = 0.5
NUM_FEATURES = 6

Expand All @@ -14,4 +14,4 @@
DELTA_UPLIFT_INCREASE_DICT = {
"treatment1": 0.1,
}
N_UPLIFT_INCREASE_DICT = {"treatment1": 5}
N_UPLIFT_INCREASE_DICT = {"treatment1": 2}
6 changes: 3 additions & 3 deletions tests/test_causal_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class CausalTreeBase:
test_size: float = 0.2
control_name: int or str = 0
control_name: int = 0

@abstractmethod
def prepare_model(self, *args, **kwargs):
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_fit(self, generate_regression_data):
treatment_col="is_treated",
treatment_effect_col="treatment_effect",
)
assert df_qini["ctree_ite_pred"] > df_qini["Random"]
assert df_qini["ctree_ite_pred"] > 0.0

@pytest.mark.parametrize("return_ci", (False, True))
@pytest.mark.parametrize("bootstrap_size", (500, 800))
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_fit(self, generate_regression_data, n_estimators):
treatment_col="is_treated",
treatment_effect_col="treatment_effect",
)
assert df_qini["crforest_ite_pred"] > df_qini["Random"]
assert df_qini["crforest_ite_pred"] > 0.0

@pytest.mark.parametrize("n_estimators", (5,))
def test_predict(self, generate_regression_data, n_estimators):
Expand Down
22 changes: 10 additions & 12 deletions tests/test_ivlearner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import statsmodels.api as sm
from xgboost import XGBRegressor
import warnings

from causalml.inference.iv import BaseDRIVLearner
from causalml.metrics import ape, get_cumgain
from causalml.metrics import ape, auuc_score

from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION
from .const import RANDOM_SEED, ERROR_THRESHOLD


def test_drivlearner():
Expand All @@ -34,7 +31,6 @@ def test_drivlearner():
e = e_raw.copy()
e[assignment == 0] = 0
tau = (X[:, 0] + X[:, 1]) / 2
X_obs = X[:, [i for i in range(8) if i != 1]]

w = np.random.binomial(1, e, size=n)
treatment = w
Expand Down Expand Up @@ -75,10 +71,12 @@ def test_drivlearner():
}
)

cumgain = get_cumgain(
auuc_metrics, outcome_col="y", treatment_col="W", treatment_effect_col="tau"
# Check if the normalized AUUC score of model's prediction is higher than random (0.5).
auuc = auuc_score(
auuc_metrics,
outcome_col="y",
treatment_col="W",
treatment_effect_col="tau",
normalize=True,
)

# Check if the cumulative gain when using the model's prediction is
# higher than it would be under random targeting
assert cumgain["cate_p"].sum() > cumgain["Random"].sum()
assert auuc["cate_p"] > 0.5
Loading

0 comments on commit 7aff1ce

Please sign in to comment.