Skip to content

Commit 379cfdf

Browse files
committed
ENH eval_X and eval_y in _DaskLGBMModel
1 parent 5a95c9a commit 379cfdf

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

python-package/lightgbm/dask.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_lgbmmodel_doc_custom_eval_note,
5050
_lgbmmodel_doc_fit,
5151
_lgbmmodel_doc_predict,
52+
_validate_eval_set_Xy,
5253
)
5354

5455
__all__ = [
@@ -318,6 +319,13 @@ def _train_part(
318319
if eval_class_weight:
319320
kwargs["eval_class_weight"] = [eval_class_weight[i] for i in eval_component_idx]
320321

322+
if local_eval_set is None:
323+
local_eval_X=None
324+
local_eval_y=None
325+
else:
326+
local_eval_X=(X for X, y in local_eval_set),
327+
local_eval_y=(y for X, y in local_eval_set),
328+
321329
model = model_factory(**params)
322330
if remote_socket is not None:
323331
remote_socket.release()
@@ -329,7 +337,8 @@ def _train_part(
329337
sample_weight=weight,
330338
init_score=init_score,
331339
group=group,
332-
eval_set=local_eval_set,
340+
eval_X=local_eval_X,
341+
eval_y=local_eval_y,
333342
eval_sample_weight=local_eval_sample_weight,
334343
eval_init_score=local_eval_init_score,
335344
eval_group=local_eval_group,
@@ -342,7 +351,8 @@ def _train_part(
342351
label,
343352
sample_weight=weight,
344353
init_score=init_score,
345-
eval_set=local_eval_set,
354+
eval_X=local_eval_X,
355+
eval_y=local_eval_y,
346356
eval_sample_weight=local_eval_sample_weight,
347357
eval_init_score=local_eval_init_score,
348358
eval_names=local_eval_names,
@@ -422,6 +432,8 @@ def _train(
422432
group: Optional[_DaskVectorLike] = None,
423433
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
424434
eval_names: Optional[List[str]] = None,
435+
eval_X: Optional[Union[_DaskMatrixLike, Tuple[_DaskMatrixLike]]] = None,
436+
eval_y: Optional[Union[_DaskCollection, Tuple[_DaskCollection]]] = None,
425437
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
426438
eval_class_weight: Optional[List[Union[dict, str]]] = None,
427439
eval_init_score: Optional[List[_DaskCollection]] = None,
@@ -461,6 +473,10 @@ def _train(
461473
of ``evals_result_`` and ``best_score_`` will be empty dictionaries.
462474
eval_names : list of str, or None, optional (default=None)
463475
Names of eval_set.
476+
eval_X : Dask Array or Dask DataFrame, tuple thereof or None, optional (default=None)
477+
Feature matrix or tuple thereof, e.g. `(X_val0, X_val1)`, to use as validation sets.
478+
eval_y : Dask Array or Dask DataFrame, tuple thereof or None, optional (default=None)
479+
Target values or tuple thereof, i.g. `(y_val0, y_val1)`, to use as validation sets.
464480
eval_sample_weight : list of Dask Array or Dask Series, or None, optional (default=None)
465481
Weights for each validation set in eval_set. Weights should be non-negative.
466482
eval_class_weight : list of dict or str, or None, optional (default=None)
@@ -570,6 +586,7 @@ def _train(
570586
for i in range(n_parts):
571587
parts[i]["init_score"] = init_score_parts[i]
572588

589+
eval_set = _validate_eval_set_Xy(eval_set=eval_set, eval_X=eval_X, eval_y=eval_y)
573590
# evals_set will to be re-constructed into smaller lists of (X, y) tuples, where
574591
# X and y are each delayed sub-lists of original eval dask Collections.
575592
if eval_set:
@@ -1049,6 +1066,8 @@ def _lgb_dask_fit(
10491066
group: Optional[_DaskVectorLike] = None,
10501067
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
10511068
eval_names: Optional[List[str]] = None,
1069+
eval_X: Optional[Union[_DaskMatrixLike, Tuple[_DaskMatrixLike]]] = None,
1070+
eval_y: Optional[Union[_DaskCollection, Tuple[_DaskCollection]]] = None,
10521071
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
10531072
eval_class_weight: Optional[List[Union[dict, str]]] = None,
10541073
eval_init_score: Optional[List[_DaskCollection]] = None,
@@ -1076,6 +1095,8 @@ def _lgb_dask_fit(
10761095
group=group,
10771096
eval_set=eval_set,
10781097
eval_names=eval_names,
1098+
eval_X=eval_X,
1099+
eval_y=eval_y,
10791100
eval_sample_weight=eval_sample_weight,
10801101
eval_class_weight=eval_class_weight,
10811102
eval_init_score=eval_init_score,
@@ -1182,6 +1203,8 @@ def fit( # type: ignore[override]
11821203
init_score: Optional[_DaskCollection] = None,
11831204
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
11841205
eval_names: Optional[List[str]] = None,
1206+
eval_X: Optional[Union[_DaskMatrixLike, Tuple[_DaskMatrixLike]]] = None,
1207+
eval_y: Optional[Union[_DaskCollection, Tuple[_DaskCollection]]] = None,
11851208
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
11861209
eval_class_weight: Optional[List[Union[dict, str]]] = None,
11871210
eval_init_score: Optional[List[_DaskCollection]] = None,
@@ -1197,6 +1220,8 @@ def fit( # type: ignore[override]
11971220
init_score=init_score,
11981221
eval_set=eval_set,
11991222
eval_names=eval_names,
1223+
eval_X=eval_X,
1224+
eval_y=eval_y,
12001225
eval_sample_weight=eval_sample_weight,
12011226
eval_class_weight=eval_class_weight,
12021227
eval_init_score=eval_init_score,
@@ -1386,6 +1411,8 @@ def fit( # type: ignore[override]
13861411
init_score: Optional[_DaskVectorLike] = None,
13871412
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
13881413
eval_names: Optional[List[str]] = None,
1414+
eval_X: Optional[Union[_DaskMatrixLike, Tuple[_DaskMatrixLike]]] = None,
1415+
eval_y: Optional[Union[_DaskCollection, Tuple[_DaskCollection]]] = None,
13891416
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
13901417
eval_init_score: Optional[List[_DaskVectorLike]] = None,
13911418
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
@@ -1400,6 +1427,8 @@ def fit( # type: ignore[override]
14001427
init_score=init_score,
14011428
eval_set=eval_set,
14021429
eval_names=eval_names,
1430+
eval_X=eval_X,
1431+
eval_y=eval_y,
14031432
eval_sample_weight=eval_sample_weight,
14041433
eval_init_score=eval_init_score,
14051434
eval_metric=eval_metric,
@@ -1555,6 +1584,8 @@ def fit( # type: ignore[override]
15551584
group: Optional[_DaskVectorLike] = None,
15561585
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
15571586
eval_names: Optional[List[str]] = None,
1587+
eval_X: Optional[Union[_DaskMatrixLike, Tuple[_DaskMatrixLike]]] = None,
1588+
eval_y: Optional[Union[_DaskCollection, Tuple[_DaskCollection]]] = None,
15581589
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
15591590
eval_init_score: Optional[List[_DaskVectorLike]] = None,
15601591
eval_group: Optional[List[_DaskVectorLike]] = None,
@@ -1572,6 +1603,8 @@ def fit( # type: ignore[override]
15721603
group=group,
15731604
eval_set=eval_set,
15741605
eval_names=eval_names,
1606+
eval_X=eval_X,
1607+
eval_y=eval_y,
15751608
eval_sample_weight=eval_sample_weight,
15761609
eval_init_score=eval_init_score,
15771610
eval_group=eval_group,

python-package/lightgbm/sklearn.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,29 @@ def _extract_evaluation_meta_data(
490490
raise TypeError(f"{name} should be dict or list")
491491

492492

493+
def _validate_eval_set_Xy(eval_set, eval_X, eval_y):
494+
"""Validate eval args.
495+
496+
Returns
497+
-------
498+
eval_set
499+
"""
500+
if eval_set is not None:
501+
msg = "The argument 'eval_set' is deprecated, use 'eval_X' and 'eval_y' instead."
502+
warnings.warn(msg, category=LGBMDeprecationWarning, stacklevel=2)
503+
if (eval_X is None) != (eval_y is None):
504+
raise ValueError("You must specify eval_X and eval_y, not just one of them.")
505+
if eval_set is None and eval_X is not None:
506+
if isinstance(eval_X, tuple) != isinstance(eval_y, tuple):
507+
raise ValueError("If eval_X is a tuple, y_val must be a tuple of same length, and vice versa.")
508+
if isinstance(eval_X, tuple) and len(eval_X) != len(eval_y):
509+
raise ValueError("If eval_X is a tuple, y_val must be a tuple of same length, and vice versa.")
510+
if not isinstance(eval_X, tuple):
511+
eval_set = (eval_X, eval_y)
512+
else:
513+
eval_set = list(zip(eval_X, eval_y))
514+
515+
493516
class LGBMModel(_LGBMModelBase):
494517
"""Implementation of the scikit-learn API for LightGBM."""
495518

@@ -996,20 +1019,7 @@ def fit(
9961019
)
9971020

9981021
valid_sets: List[Dataset] = []
999-
if eval_set is not None:
1000-
msg = "The argument 'eval_set' is deprecated, use 'eval_X' and 'eval_y' instead."
1001-
warnings.warn(msg, category=LGBMDeprecationWarning, stacklevel=2)
1002-
if (eval_X is None) != (eval_y is None):
1003-
raise ValueError("You must specify eval_X and eval_y, not just one of them.")
1004-
if eval_set is None and eval_X is not None:
1005-
if isinstance(eval_X, tuple) != isinstance(eval_y, tuple):
1006-
raise ValueError("If eval_X is a tuple, y_val must be a tuple of same length, and vice versa.")
1007-
if isinstance(eval_X, tuple) and len(eval_X) != len(eval_y):
1008-
raise ValueError("If eval_X is a tuple, y_val must be a tuple of same length, and vice versa.")
1009-
if not isinstance(eval_X, tuple):
1010-
eval_set = (eval_X, eval_y)
1011-
else:
1012-
eval_set = list(zip(eval_X, eval_y))
1022+
eval_set = _validate_eval_set_Xy(eval_set=eval_set, eval_X=eval_X, eval_y=eval_y)
10131023
if eval_set is not None:
10141024
if isinstance(eval_set, tuple):
10151025
eval_set = [eval_set]

0 commit comments

Comments
 (0)