Skip to content

Commit f2ccd52

Browse files
Merge branch '3313-enable-auto-early-stopping' of https://github.com/ClaudioSalvatoreArcidiacono/LightGBM into 3313-enable-auto-early-stopping
2 parents 2ca8cc1 + 1910076 commit f2ccd52

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

python-package/lightgbm/sklearn.py

+38-23
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@
8585
]
8686
_LGBM_ScikitCustomEvalSetSplitter = Union[
8787
Callable[
88-
[np.ndarray, np.ndarray],
89-
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
88+
[_LGBM_ScikitMatrixLike, _LGBM_LabelType],
89+
Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType]
9090
],
9191
Callable[
92-
[np.ndarray, np.ndarray, np.ndarray],
93-
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
92+
[_LGBM_ScikitMatrixLike, _LGBM_LabelType, Optional[np.ndarray]],
93+
Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType, Optional[np.ndarray], Optional[np.ndarray]]
9494
],
9595
Callable[
96-
[np.ndarray, np.ndarray, np.ndarray],
97-
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]
96+
[_LGBM_ScikitMatrixLike, _LGBM_LabelType, Optional[np.ndarray], _LGBM_GroupType],
97+
Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType, Optional[np.ndarray], Optional[np.ndarray], _LGBM_GroupType, _LGBM_GroupType]
9898
],
9999
]
100100
_LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]
@@ -256,17 +256,17 @@ def __call__(
256256

257257

258258
def _train_test_split(
259-
X,
260-
y,
259+
X: _LGBM_ScikitMatrixLike,
260+
y: _LGBM_LabelType,
261261
weight,
262262
test_size: float,
263263
random_state: Optional[Union[int, np.random.RandomState]],
264264
stratified: bool,
265265
) -> Tuple[
266-
np.ndarray,
267-
np.ndarray,
268-
np.ndarray,
269-
np.ndarray,
266+
_LGBM_ScikitMatrixLike,
267+
_LGBM_ScikitMatrixLike,
268+
_LGBM_LabelType,
269+
_LGBM_LabelType,
270270
Optional[np.ndarray],
271271
Optional[np.ndarray],
272272
]:
@@ -319,7 +319,22 @@ def _train_test_split(
319319
return X_train, X_val, y_train, y_val, None, None
320320

321321

322-
def _train_test_group_split(X, y, weight, group, n_splits: int):
322+
def _train_test_group_split(
323+
X: _LGBM_ScikitMatrixLike,
324+
y: _LGBM_LabelType,
325+
weight,
326+
group: _LGBM_GroupType,
327+
n_splits: int
328+
) -> Tuple[
329+
_LGBM_ScikitMatrixLike,
330+
_LGBM_ScikitMatrixLike,
331+
_LGBM_LabelType,
332+
_LGBM_LabelType,
333+
Optional[np.ndarray],
334+
Optional[np.ndarray],
335+
_LGBM_GroupType,
336+
_LGBM_GroupType,
337+
]:
323338
"""Split X, y, weights and group into train and test subsets.
324339
325340
Parameters
@@ -390,20 +405,20 @@ def _train_test_group_split(X, y, weight, group, n_splits: int):
390405

391406

392407
def _train_test_split_custom_splitter(
393-
custom_splitter,
394-
X,
395-
y,
408+
custom_splitter: _LGBM_ScikitCustomEvalSetSplitter,
409+
X: _LGBM_ScikitMatrixLike,
410+
y: _LGBM_LabelType,
396411
weight,
397-
group
412+
group: Optional[_LGBM_GroupType]
398413
) -> Tuple[
399-
np.ndarray,
400-
np.ndarray,
401-
np.ndarray,
402-
np.ndarray,
403-
Optional[np.ndarray],
404-
Optional[np.ndarray],
414+
_LGBM_ScikitMatrixLike,
415+
_LGBM_ScikitMatrixLike,
416+
_LGBM_LabelType,
417+
_LGBM_LabelType,
405418
Optional[np.ndarray],
406419
Optional[np.ndarray],
420+
Optional[_LGBM_GroupType],
421+
Optional[_LGBM_GroupType],
407422
]:
408423
"""Call passed custom_splitter with appropriate arguments.
409424

0 commit comments

Comments
 (0)