Skip to content

Commit 60b0155

Browse files
authored
[python-package] Fix inconsistency in predict() output shape for 1-tree models (#6753)
1 parent 4ee0bc0 commit 60b0155

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

python-package/lightgbm/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def predict(
12481248
if pred_leaf:
12491249
preds = preds.astype(np.int32)
12501250
is_sparse = isinstance(preds, (list, scipy.sparse.spmatrix))
1251-
if not is_sparse and preds.size != nrow:
1251+
if not is_sparse and (preds.size != nrow or pred_leaf or pred_contrib):
12521252
if preds.size % nrow == 0:
12531253
preds = preds.reshape(nrow, -1)
12541254
else:

tests/python_package_test/test_engine.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import psutil
1616
import pytest
1717
from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr
18-
from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification
18+
from sklearn.datasets import load_svmlight_file, make_blobs, make_classification, make_multilabel_classification
1919
from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score
2020
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split
2121

@@ -2314,6 +2314,33 @@ def test_refit():
23142314
assert err_pred > new_err_pred
23152315

23162316

2317+
def test_refit_with_one_tree_regression():
2318+
X, y = make_synthetic_regression(n_samples=1_000, n_features=2)
2319+
lgb_train = lgb.Dataset(X, label=y)
2320+
params = {"objective": "regression", "verbosity": -1}
2321+
model = lgb.train(params, lgb_train, num_boost_round=1)
2322+
model_refit = model.refit(X, y)
2323+
assert isinstance(model_refit, lgb.Booster)
2324+
2325+
2326+
def test_refit_with_one_tree_binary_classification():
2327+
X, y = load_breast_cancer(return_X_y=True)
2328+
lgb_train = lgb.Dataset(X, label=y)
2329+
params = {"objective": "binary", "verbosity": -1}
2330+
model = lgb.train(params, lgb_train, num_boost_round=1)
2331+
model_refit = model.refit(X, y)
2332+
assert isinstance(model_refit, lgb.Booster)
2333+
2334+
2335+
def test_refit_with_one_tree_multiclass_classification():
2336+
X, y = load_iris(return_X_y=True)
2337+
lgb_train = lgb.Dataset(X, y)
2338+
params = {"objective": "multiclass", "num_class": 3, "verbose": -1}
2339+
model = lgb.train(params, lgb_train, num_boost_round=1)
2340+
model_refit = model.refit(X, y)
2341+
assert isinstance(model_refit, lgb.Booster)
2342+
2343+
23172344
def test_refit_dataset_params(rng):
23182345
# check refit accepts dataset_params
23192346
X, y = load_breast_cancer(return_X_y=True)
@@ -3872,6 +3899,67 @@ def test_predict_stump(rng, use_init_score):
38723899
np.testing.assert_allclose(preds_all, np.full_like(preds_all, fill_value=y_avg))
38733900

38743901

3902+
def test_predict_regression_output_shape():
3903+
n_samples = 1_000
3904+
n_features = 4
3905+
X, y = make_synthetic_regression(n_samples=n_samples, n_features=n_features)
3906+
dtrain = lgb.Dataset(X, label=y)
3907+
params = {"objective": "regression", "verbosity": -1}
3908+
3909+
# 1-round model
3910+
bst = lgb.train(params, dtrain, num_boost_round=1)
3911+
assert bst.predict(X).shape == (n_samples,)
3912+
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
3913+
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 1)
3914+
3915+
# 2-round model
3916+
bst = lgb.train(params, dtrain, num_boost_round=2)
3917+
assert bst.predict(X).shape == (n_samples,)
3918+
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
3919+
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 2)
3920+
3921+
3922+
def test_predict_binary_classification_output_shape():
3923+
n_samples = 1_000
3924+
n_features = 4
3925+
X, y = make_classification(n_samples=n_samples, n_features=n_features, n_classes=2)
3926+
dtrain = lgb.Dataset(X, label=y)
3927+
params = {"objective": "binary", "verbosity": -1}
3928+
3929+
# 1-round model
3930+
bst = lgb.train(params, dtrain, num_boost_round=1)
3931+
assert bst.predict(X).shape == (n_samples,)
3932+
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
3933+
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 1)
3934+
3935+
# 2-round model
3936+
bst = lgb.train(params, dtrain, num_boost_round=2)
3937+
assert bst.predict(X).shape == (n_samples,)
3938+
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_features + 1)
3939+
assert bst.predict(X, pred_leaf=True).shape == (n_samples, 2)
3940+
3941+
3942+
def test_predict_multiclass_classification_output_shape():
3943+
n_samples = 1_000
3944+
n_features = 10
3945+
n_classes = 3
3946+
X, y = make_classification(n_samples=n_samples, n_features=n_features, n_classes=n_classes, n_informative=6)
3947+
dtrain = lgb.Dataset(X, label=y)
3948+
params = {"objective": "multiclass", "verbosity": -1, "num_class": n_classes}
3949+
3950+
# 1-round model
3951+
bst = lgb.train(params, dtrain, num_boost_round=1)
3952+
assert bst.predict(X).shape == (n_samples, n_classes)
3953+
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_classes * (n_features + 1))
3954+
assert bst.predict(X, pred_leaf=True).shape == (n_samples, n_classes)
3955+
3956+
# 2-round model
3957+
bst = lgb.train(params, dtrain, num_boost_round=2)
3958+
assert bst.predict(X).shape == (n_samples, n_classes)
3959+
assert bst.predict(X, pred_contrib=True).shape == (n_samples, n_classes * (n_features + 1))
3960+
assert bst.predict(X, pred_leaf=True).shape == (n_samples, n_classes * 2)
3961+
3962+
38753963
def test_average_precision_metric():
38763964
# test against sklearn average precision metric
38773965
X, y = load_breast_cancer(return_X_y=True)

0 commit comments

Comments
 (0)