Skip to content

Commit

Permalink
FIX make_synthetic_competing_weibull() with default parameters (#66)
Browse files Browse the repository at this point in the history
* fix key error when return_X_y=False

* fix tests

---------

Co-authored-by: Vincent Maladiere <maladiere.vincent@yahoo.fr>
  • Loading branch information
jovan-stojanovic and Vincent-Maladiere authored Jan 14, 2025
1 parent ddc2f35 commit 82fe8cb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hazardous/data/_competing_weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def make_synthetic_competing_weibull(
return X, y

frame = pd.concat([X, y], axis=1)
return Bunch(data=frame[X.columns], target=X[y.columns], frame=frame)
return Bunch(data=frame[X.columns], target=frame[y.columns], frame=frame)
20 changes: 20 additions & 0 deletions hazardous/data/tests/test_competing_weibull.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from numpy.testing import assert_array_equal
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import cross_val_score
Expand Down Expand Up @@ -121,3 +122,22 @@ def test_competing_weibull_with_censoring(seed):
# Check that high scale censoring keeps approximate balance between events:
event_counts = y_high_scale.query("event != 0")["event"].value_counts().sort_index()
assert event_counts.max() < 2 * event_counts.min(), event_counts


@pytest.mark.parametrize("seed", range(3))
def test_make_synthetic_competing_weibull_return(seed):
n_samples = 1000
df = make_synthetic_competing_weibull(
n_events=3,
n_samples=n_samples,
return_X_y=False,
random_state=seed,
)
X, y = make_synthetic_competing_weibull(
n_events=3,
n_samples=n_samples,
return_X_y=True,
random_state=seed,
)
assert_array_equal(df["data"], X)
assert_array_equal(df["target"], y)

0 comments on commit 82fe8cb

Please sign in to comment.