|
15 | 15 | import psutil
|
16 | 16 | import pytest
|
17 | 17 | from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr
|
18 |
| -from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification |
| 18 | +from sklearn.datasets import load_svmlight_file, make_blobs, make_classification, make_multilabel_classification |
19 | 19 | from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score
|
20 | 20 | from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split
|
21 | 21 |
|
@@ -2314,6 +2314,33 @@ def test_refit():
|
2314 | 2314 | assert err_pred > new_err_pred
|
2315 | 2315 |
|
2316 | 2316 |
|
| 2317 | +def test_refit_with_one_tree_regression(): |
| 2318 | + X, y = make_synthetic_regression(n_samples=1_000, n_features=2) |
| 2319 | + lgb_train = lgb.Dataset(X, label=y) |
| 2320 | + params = {"objective": "regression", "verbosity": -1} |
| 2321 | + model = lgb.train(params, lgb_train, num_boost_round=1) |
| 2322 | + model_refit = model.refit(X, y) |
| 2323 | + assert isinstance(model_refit, lgb.Booster) |
| 2324 | + |
| 2325 | + |
| 2326 | +def test_refit_with_one_tree_binary_classification(): |
| 2327 | + X, y = load_breast_cancer(return_X_y=True) |
| 2328 | + lgb_train = lgb.Dataset(X, label=y) |
| 2329 | + params = {"objective": "binary", "verbosity": -1} |
| 2330 | + model = lgb.train(params, lgb_train, num_boost_round=1) |
| 2331 | + model_refit = model.refit(X, y) |
| 2332 | + assert isinstance(model_refit, lgb.Booster) |
| 2333 | + |
| 2334 | + |
| 2335 | +def test_refit_with_one_tree_multiclass_classification(): |
| 2336 | + X, y = load_iris(return_X_y=True) |
| 2337 | + lgb_train = lgb.Dataset(X, y) |
| 2338 | + params = {"objective": "multiclass", "num_class": 3, "verbose": -1} |
| 2339 | + model = lgb.train(params, lgb_train, num_boost_round=1) |
| 2340 | + model_refit = model.refit(X, y) |
| 2341 | + assert isinstance(model_refit, lgb.Booster) |
| 2342 | + |
| 2343 | + |
2317 | 2344 | def test_refit_dataset_params(rng):
|
2318 | 2345 | # check refit accepts dataset_params
|
2319 | 2346 | X, y = load_breast_cancer(return_X_y=True)
|
@@ -3872,6 +3899,67 @@ def test_predict_stump(rng, use_init_score):
|
3872 | 3899 | np.testing.assert_allclose(preds_all, np.full_like(preds_all, fill_value=y_avg))
|
3873 | 3900 |
|
3874 | 3901 |
|
| 3902 | +def test_predict_regression_output_shape(): |
| 3903 | + n_samples = 1_000 |
| 3904 | + n_features = 4 |
| 3905 | + X, y = make_synthetic_regression(n_samples=n_samples, n_features=n_features) |
| 3906 | + dtrain = lgb.Dataset(X, label=y) |
| 3907 | + params = {"objective": "regression", "verbosity": -1} |
| 3908 | + |
| 3909 | + # 1-round model |
| 3910 | + bst = lgb.train(params, dtrain, num_boost_round=1) |
| 3911 | + assert bst.predict(X).shape == (n_samples,) |
| 3912 | + assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1) |
| 3913 | + assert bst.predict(X, pred_leaf=True).shape == (n_samples, 1) |
| 3914 | + |
| 3915 | + # 2-round model |
| 3916 | + bst = lgb.train(params, dtrain, num_boost_round=2) |
| 3917 | + assert bst.predict(X).shape == (n_samples,) |
| 3918 | + assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1) |
| 3919 | + assert bst.predict(X, pred_leaf=True).shape == (n_samples, 2) |
| 3920 | + |
| 3921 | + |
| 3922 | +def test_predict_binary_classification_output_shape(): |
| 3923 | + n_samples = 1_000 |
| 3924 | + n_features = 4 |
| 3925 | + X, y = make_classification(n_samples=n_samples, n_features=n_features, n_classes=2) |
| 3926 | + dtrain = lgb.Dataset(X, label=y) |
| 3927 | + params = {"objective": "binary", "verbosity": -1} |
| 3928 | + |
| 3929 | + # 1-round model |
| 3930 | + bst = lgb.train(params, dtrain, num_boost_round=1) |
| 3931 | + assert bst.predict(X).shape == (n_samples,) |
| 3932 | + assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1) |
| 3933 | + assert bst.predict(X, pred_leaf=True).shape == (n_samples, 1) |
| 3934 | + |
| 3935 | + # 2-round model |
| 3936 | + bst = lgb.train(params, dtrain, num_boost_round=2) |
| 3937 | + assert bst.predict(X).shape == (n_samples,) |
| 3938 | + assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1) |
| 3939 | + assert bst.predict(X, pred_leaf=True).shape == (n_samples, 2) |
| 3940 | + |
| 3941 | + |
| 3942 | +def test_predict_multiclass_classification_output_shape(): |
| 3943 | + n_samples = 1_000 |
| 3944 | + n_features = 10 |
| 3945 | + n_classes = 3 |
| 3946 | + X, y = make_classification(n_samples=n_samples, n_features=n_features, n_classes=n_classes, n_informative=6) |
| 3947 | + dtrain = lgb.Dataset(X, label=y) |
| 3948 | + params = {"objective": "multiclass", "verbosity": -1, "num_class": n_classes} |
| 3949 | + |
| 3950 | + # 1-round model |
| 3951 | + bst = lgb.train(params, dtrain, num_boost_round=1) |
| 3952 | + assert bst.predict(X).shape == (n_samples, n_classes) |
| 3953 | + assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_classes * (n_features + 1)) |
| 3954 | + assert bst.predict(X, pred_leaf=True).shape == (n_samples, n_classes) |
| 3955 | + |
| 3956 | + # 2-round model |
| 3957 | + bst = lgb.train(params, dtrain, num_boost_round=2) |
| 3958 | + assert bst.predict(X).shape == (n_samples, n_classes) |
| 3959 | + assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_classes * (n_features + 1)) |
| 3960 | + assert bst.predict(X, pred_leaf=True).shape == (n_samples, n_classes * 2) |
| 3961 | + |
| 3962 | + |
3875 | 3963 | def test_average_precision_metric():
|
3876 | 3964 | # test against sklearn average precision metric
|
3877 | 3965 | X, y = load_breast_cancer(return_X_y=True)
|
|
0 commit comments