|
85 | 85 | ]
|
86 | 86 | _LGBM_ScikitCustomEvalSetSplitter = Union[
|
87 | 87 | Callable[
|
88 |
| - [np.ndarray, np.ndarray], |
89 |
| - Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] |
| 88 | + [_LGBM_ScikitMatrixLike, _LGBM_LabelType], |
| 89 | + Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType] |
90 | 90 | ],
|
91 | 91 | Callable[
|
92 |
| - [np.ndarray, np.ndarray, np.ndarray], |
93 |
| - Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]] |
| 92 | + [_LGBM_ScikitMatrixLike, _LGBM_LabelType, Optional[np.ndarray]], |
| 93 | + Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType, Optional[np.ndarray], Optional[np.ndarray]] |
94 | 94 | ],
|
95 | 95 | Callable[
|
96 |
| - [np.ndarray, np.ndarray, np.ndarray], |
97 |
| - Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]] |
| 96 | + [_LGBM_ScikitMatrixLike, _LGBM_LabelType, Optional[np.ndarray], _LGBM_GroupType], |
| 97 | + Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType, Optional[np.ndarray], Optional[np.ndarray], _LGBM_GroupType, _LGBM_GroupType] |
98 | 98 | ],
|
99 | 99 | ]
|
100 | 100 | _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]
|
@@ -256,17 +256,17 @@ def __call__(
|
256 | 256 |
|
257 | 257 |
|
258 | 258 | def _train_test_split(
|
259 |
| - X, |
260 |
| - y, |
| 259 | + X: _LGBM_ScikitMatrixLike, |
| 260 | + y: _LGBM_LabelType, |
261 | 261 | weight,
|
262 | 262 | test_size: float,
|
263 | 263 | random_state: Optional[Union[int, np.random.RandomState]],
|
264 | 264 | stratified: bool,
|
265 | 265 | ) -> Tuple[
|
266 |
| - np.ndarray, |
267 |
| - np.ndarray, |
268 |
| - np.ndarray, |
269 |
| - np.ndarray, |
| 266 | + _LGBM_ScikitMatrixLike, |
| 267 | + _LGBM_ScikitMatrixLike, |
| 268 | + _LGBM_LabelType, |
| 269 | + _LGBM_LabelType, |
270 | 270 | Optional[np.ndarray],
|
271 | 271 | Optional[np.ndarray],
|
272 | 272 | ]:
|
@@ -319,7 +319,22 @@ def _train_test_split(
|
319 | 319 | return X_train, X_val, y_train, y_val, None, None
|
320 | 320 |
|
321 | 321 |
|
322 |
| -def _train_test_group_split(X, y, weight, group, n_splits: int): |
| 322 | +def _train_test_group_split( |
| 323 | + X: _LGBM_ScikitMatrixLike, |
| 324 | + y: _LGBM_LabelType, |
| 325 | + weight, |
| 326 | + group: _LGBM_GroupType, |
| 327 | + n_splits: int |
| 328 | +) -> Tuple[ |
| 329 | + _LGBM_ScikitMatrixLike, |
| 330 | + _LGBM_ScikitMatrixLike, |
| 331 | + _LGBM_LabelType, |
| 332 | + _LGBM_LabelType, |
| 333 | + Optional[np.ndarray], |
| 334 | + Optional[np.ndarray], |
| 335 | + _LGBM_GroupType, |
| 336 | + _LGBM_GroupType, |
| 337 | +]: |
323 | 338 | """Split X, y, weights and group into train and test subsets.
|
324 | 339 |
|
325 | 340 | Parameters
|
@@ -390,20 +405,20 @@ def _train_test_group_split(X, y, weight, group, n_splits: int):
|
390 | 405 |
|
391 | 406 |
|
392 | 407 | def _train_test_split_custom_splitter(
|
393 |
| - custom_splitter, |
394 |
| - X, |
395 |
| - y, |
| 408 | + custom_splitter: _LGBM_ScikitCustomEvalSetSplitter, |
| 409 | + X: _LGBM_ScikitMatrixLike, |
| 410 | + y: _LGBM_LabelType, |
396 | 411 | weight,
|
397 |
| - group |
| 412 | + group: Optional[_LGBM_GroupType] |
398 | 413 | ) -> Tuple[
|
399 |
| - np.ndarray, |
400 |
| - np.ndarray, |
401 |
| - np.ndarray, |
402 |
| - np.ndarray, |
403 |
| - Optional[np.ndarray], |
404 |
| - Optional[np.ndarray], |
| 414 | + _LGBM_ScikitMatrixLike, |
| 415 | + _LGBM_ScikitMatrixLike, |
| 416 | + _LGBM_LabelType, |
| 417 | + _LGBM_LabelType, |
405 | 418 | Optional[np.ndarray],
|
406 | 419 | Optional[np.ndarray],
|
| 420 | + Optional[_LGBM_GroupType], |
| 421 | + Optional[_LGBM_GroupType], |
407 | 422 | ]:
|
408 | 423 | """Call passed custom_splitter with appropriate arguments.
|
409 | 424 |
|
|
0 commit comments