16
16
_LGBMCheckClassificationTargets , _LGBMCheckSampleWeight , _LGBMCheckXY , _LGBMClassifierBase ,
17
17
_LGBMComputeSampleWeight , _LGBMCpuCount , _LGBMLabelEncoder , _LGBMModelBase , _LGBMRegressorBase ,
18
18
dt_DataTable , pd_DataFrame )
19
- from .engine import train
19
+ from .engine import _make_n_folds , train
20
20
21
21
__all__ = [
22
22
'LGBMClassifier' ,
@@ -412,6 +412,7 @@ def __init__(
412
412
random_state : Optional [Union [int , np .random .RandomState ]] = None ,
413
413
n_jobs : Optional [int ] = None ,
414
414
importance_type : str = 'split' ,
415
+ validation_fraction : Optional [float ] = 0.1 ,
415
416
** kwargs
416
417
):
417
418
r"""Construct a gradient boosting model.
@@ -491,6 +492,10 @@ def __init__(
491
492
The type of feature importance to be filled into ``feature_importances_``.
492
493
If 'split', result contains numbers of times the feature is used in a model.
493
494
If 'gain', result contains total gains of splits which use the feature.
495
+ validation_fraction : float or None, optional (default=0.1)
496
+ Proportion of training data to set aside as
497
+ validation data for early stopping. If None, early stopping is done on
498
+ the training data. Only used if early stopping is performed.
494
499
**kwargs
495
500
Other parameters for the model.
496
501
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
@@ -566,6 +571,7 @@ def __init__(
566
571
self ._n_features_in : int = - 1
567
572
self ._classes : Optional [np .ndarray ] = None
568
573
self ._n_classes : int = - 1
574
+ self .validation_fraction = validation_fraction
569
575
self .set_params (** kwargs )
570
576
571
577
def _more_tags (self ) -> Dict [str , Any ]:
@@ -668,9 +674,24 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
668
674
params .pop ('importance_type' , None )
669
675
params .pop ('n_estimators' , None )
670
676
params .pop ('class_weight' , None )
677
+ params .pop ("validation_fraction" , None )
671
678
672
679
if isinstance (params ['random_state' ], np .random .RandomState ):
673
680
params ['random_state' ] = params ['random_state' ].randint (np .iinfo (np .int32 ).max )
681
+
682
+ params = _choose_param_value (
683
+ main_param_name = "early_stopping_round" ,
684
+ params = params ,
685
+ default_value = "auto" ,
686
+ )
687
+ if params ["early_stopping_round" ] == "auto" :
688
+ params ["early_stopping_round" ] = 10 if hasattr (self , "n_rows_train" ) and self .n_rows_train > 10000 else None
689
+
690
+ if params ["early_stopping_round" ] is True :
691
+ params ["early_stopping_round" ] = 10
692
+ elif params ["early_stopping_round" ] is False :
693
+ params ["early_stopping_round" ] = None
694
+
674
695
if self ._n_classes > 2 :
675
696
for alias in _ConfigAliases .get ('num_class' ):
676
697
params .pop (alias , None )
@@ -765,7 +786,6 @@ def fit(
765
786
params ['metric' ] = [params ['metric' ]] if isinstance (params ['metric' ], (str , type (None ))) else params ['metric' ]
766
787
params ['metric' ] = [e for e in eval_metrics_builtin if e not in params ['metric' ]] + params ['metric' ]
767
788
params ['metric' ] = [metric for metric in params ['metric' ] if metric is not None ]
768
-
769
789
if not isinstance (X , (pd_DataFrame , dt_DataTable )):
770
790
_X , _y = _LGBMCheckXY (X , y , accept_sparse = True , force_all_finite = False , ensure_min_samples = 2 )
771
791
if sample_weight is not None :
@@ -789,44 +809,61 @@ def fit(
789
809
train_set = Dataset (data = _X , label = _y , weight = sample_weight , group = group ,
790
810
init_score = init_score , categorical_feature = categorical_feature ,
791
811
params = params )
812
+ self ._n_rows_train = _X .shape [0 ]
813
+ if params ["early_stopping_round" ] == "auto" :
814
+ params ["early_stopping_round" ] = 10 if self .n_rows_train > 10000 else None
815
+ if params ["early_stopping_round" ] is not None and eval_set is None :
816
+ if self .validation_fraction is not None :
817
+ n_splits = max (int (np .ceil (1 / self .validation_fraction )), 2 )
818
+ stratified = isinstance (self , LGBMClassifier )
819
+ cvfolds = _make_n_folds (full_data = train_set , folds = None , nfold = n_splits ,
820
+ params = params , seed = self .random_state ,
821
+ stratified = stratified , shuffle = True )
822
+ train_idx , val_idx = next (cvfolds )
823
+ valid_set = train_set .subset (sorted (val_idx ))
824
+ train_set = train_set .subset (sorted (train_idx ))
825
+ else :
826
+ valid_set = train_set
827
+ valid_set = valid_set .construct ()
828
+ valid_sets = [valid_set ]
792
829
793
- valid_sets : List [ Dataset ] = []
794
- if eval_set is not None :
795
-
796
- def _get_meta_data (collection , name , i ):
797
- if collection is None :
798
- return None
799
- elif isinstance (collection , list ):
800
- return collection [i ] if len (collection ) > i else None
801
- elif isinstance (collection , dict ):
802
- return collection .get (i , None )
803
- else :
804
- raise TypeError (f"{ name } should be dict or list" )
805
-
806
- if isinstance (eval_set , tuple ):
807
- eval_set = [eval_set ]
808
- for i , valid_data in enumerate (eval_set ):
809
- # reduce cost for prediction training data
810
- if valid_data [0 ] is X and valid_data [1 ] is y :
811
- valid_set = train_set
812
- else :
813
- valid_weight = _get_meta_data (eval_sample_weight , 'eval_sample_weight' , i )
814
- valid_class_weight = _get_meta_data (eval_class_weight , 'eval_class_weight' , i )
815
- if valid_class_weight is not None :
816
- if isinstance (valid_class_weight , dict ) and self ._class_map is not None :
817
- valid_class_weight = {self ._class_map [k ]: v for k , v in valid_class_weight .items ()}
818
- valid_class_sample_weight = _LGBMComputeSampleWeight (valid_class_weight , valid_data [1 ])
819
- if valid_weight is None or len (valid_weight ) == 0 :
820
- valid_weight = valid_class_sample_weight
821
- else :
822
- valid_weight = np .multiply (valid_weight , valid_class_sample_weight )
823
- valid_init_score = _get_meta_data (eval_init_score , 'eval_init_score' , i )
824
- valid_group = _get_meta_data (eval_group , 'eval_group' , i )
825
- valid_set = Dataset (data = valid_data [0 ], label = valid_data [1 ], weight = valid_weight ,
826
- group = valid_group , init_score = valid_init_score ,
827
- categorical_feature = 'auto' , params = params )
828
-
829
- valid_sets .append (valid_set )
830
+ else :
831
+ valid_sets : List [ Dataset ] = []
832
+ if eval_set is not None :
833
+ def _get_meta_data (collection , name , i ):
834
+ if collection is None :
835
+ return None
836
+ elif isinstance (collection , list ):
837
+ return collection [i ] if len (collection ) > i else None
838
+ elif isinstance (collection , dict ):
839
+ return collection .get (i , None )
840
+ else :
841
+ raise TypeError (f"{ name } should be dict or list" )
842
+
843
+ if isinstance (eval_set , tuple ):
844
+ eval_set = [eval_set ]
845
+ for i , valid_data in enumerate (eval_set ):
846
+ # reduce cost for prediction training data
847
+ if valid_data [0 ] is X and valid_data [1 ] is y :
848
+ valid_set = train_set
849
+ else :
850
+ valid_weight = _get_meta_data (eval_sample_weight , 'eval_sample_weight' , i )
851
+ valid_class_weight = _get_meta_data (eval_class_weight , 'eval_class_weight' , i )
852
+ if valid_class_weight is not None :
853
+ if isinstance (valid_class_weight , dict ) and self ._class_map is not None :
854
+ valid_class_weight = {self ._class_map [k ]: v for k , v in valid_class_weight .items ()}
855
+ valid_class_sample_weight = _LGBMComputeSampleWeight (valid_class_weight , valid_data [1 ])
856
+ if valid_weight is None or len (valid_weight ) == 0 :
857
+ valid_weight = valid_class_sample_weight
858
+ else :
859
+ valid_weight = np .multiply (valid_weight , valid_class_sample_weight )
860
+ valid_init_score = _get_meta_data (eval_init_score , 'eval_init_score' , i )
861
+ valid_group = _get_meta_data (eval_group , 'eval_group' , i )
862
+ valid_set = Dataset (data = valid_data [0 ], label = valid_data [1 ], weight = valid_weight ,
863
+ group = valid_group , init_score = valid_init_score ,
864
+ categorical_feature = 'auto' , params = params )
865
+
866
+ valid_sets .append (valid_set )
830
867
831
868
if isinstance (init_model , LGBMModel ):
832
869
init_model = init_model .booster_
0 commit comments