Skip to content

Commit 06d87a8

Browse files
committed
FIX _validate_eval_set_Xy
1 parent 5dd3171 commit 06d87a8

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

python-package/lightgbm/sklearn.py

+2
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def _validate_eval_set_Xy(eval_set, eval_X, eval_y):
500500
if eval_set is not None:
501501
msg = "The argument 'eval_set' is deprecated, use 'eval_X' and 'eval_y' instead."
502502
warnings.warn(msg, category=LGBMDeprecationWarning, stacklevel=2)
503+
return eval_set
503504
if (eval_X is None) != (eval_y is None):
504505
raise ValueError("You must specify eval_X and eval_y, not just one of them.")
505506
if eval_set is None and eval_X is not None:
@@ -511,6 +512,7 @@ def _validate_eval_set_Xy(eval_set, eval_X, eval_y):
511512
eval_set = (eval_X, eval_y)
512513
else:
513514
eval_set = list(zip(eval_X, eval_y))
515+
return eval_set
514516

515517

516518
class LGBMModel(_LGBMModelBase):

tests/python_package_test/test_sklearn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2053,7 +2053,7 @@ def test_eval_set_deprecation():
20532053
gbm = lgb.LGBMRegressor()
20542054
msg = "The argument 'eval_set' is deprecated.*"
20552055
with pytest.warns(LGBMDeprecationWarning, match=msg):
2056-
gbm.fit(X_train, y_train, eval_set=(X_test, y_test))
2056+
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)])
20572057

20582058

20592059
def test_eval_X_eval_y_eval_set_equivalence():
@@ -2062,7 +2062,7 @@ def test_eval_X_eval_y_eval_set_equivalence():
20622062
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
20632063
cbs = [lgb.early_stopping(2)]
20642064
gbm1 = lgb.LGBMRegressor()
2065-
gbm1.fit(X_train, y_train, eval_set=(X_test, y_test), callbacks=cbs)
2065+
gbm1.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=cbs)
20662066
gbm2 = lgb.LGBMRegressor()
20672067
gbm2.fit(X_train, y_train, eval_X=X_test, eval_y=y_test, callbacks=cbs)
20682068
np.testing.assert_allclose(gbm1.predict(X), gbm2.predict(X))

0 commit comments

Comments
 (0)