1
1
"""Define the PySRRegressor scikit-learn interface."""
2
2
3
3
import copy
4
+ import logging
4
5
import os
5
6
import pickle as pkl
6
7
import re
67
68
68
69
ALREADY_RAN = False
69
70
71
+ pysr_logger = logging .getLogger (__name__ )
72
+
70
73
71
74
def _process_constraints (
72
75
binary_operators : list [str ],
@@ -1107,7 +1110,7 @@ def from_file(
1107
1110
1108
1111
pkl_filename = Path (run_directory ) / "checkpoint.pkl"
1109
1112
if pkl_filename .exists ():
1110
- print (f"Attempting to load model from { pkl_filename } ..." )
1113
+ pysr_logger . info (f"Attempting to load model from { pkl_filename } ..." )
1111
1114
assert binary_operators is None
1112
1115
assert unary_operators is None
1113
1116
assert n_features_in is None
@@ -1129,7 +1132,7 @@ def from_file(
1129
1132
1130
1133
return model
1131
1134
else :
1132
- print (
1135
+ pysr_logger . info (
1133
1136
f"Checkpoint file { pkl_filename } does not exist. "
1134
1137
"Attempting to recreate model from scratch..."
1135
1138
)
@@ -1232,12 +1235,16 @@ def __getstate__(self) -> dict[str, Any]:
1232
1235
)
1233
1236
state_keys_containing_lambdas = ["extra_sympy_mappings" , "extra_torch_mappings" ]
1234
1237
for state_key in state_keys_containing_lambdas :
1235
- if state [state_key ] is not None and show_pickle_warning :
1236
- warnings .warn (
1237
- f"`{ state_key } ` cannot be pickled and will be removed from the "
1238
- "serialized instance. When loading the model, please redefine "
1239
- f"`{ state_key } ` at runtime."
1240
- )
1238
+ warn_msg = (
1239
+ f"`{ state_key } ` cannot be pickled and will be removed from the "
1240
+ "serialized instance. When loading the model, please redefine "
1241
+ f"`{ state_key } ` at runtime."
1242
+ )
1243
+ if state [state_key ] is not None :
1244
+ if show_pickle_warning :
1245
+ warnings .warn (warn_msg )
1246
+ else :
1247
+ pysr_logger .debug (warn_msg )
1241
1248
state_keys_to_clear = state_keys_containing_lambdas
1242
1249
state_keys_to_clear .append ("logger_" )
1243
1250
pickled_state = {
@@ -1280,7 +1287,7 @@ def _checkpoint(self):
1280
1287
try :
1281
1288
pkl .dump (self , f )
1282
1289
except Exception as e :
1283
- print (f"Error checkpointing model: { e } " )
1290
+ pysr_logger . debug (f"Error checkpointing model: { e } " )
1284
1291
self .show_pickle_warnings_ = True
1285
1292
1286
1293
def get_pkl_filename (self ) -> Path :
@@ -1752,7 +1759,7 @@ def _pre_transform_training_data(
1752
1759
self .selection_mask_ = selection_mask
1753
1760
self .feature_names_in_ = _check_feature_names_in (self , variable_names )
1754
1761
self .display_feature_names_in_ = self .feature_names_in_
1755
- print (f"Using features { self .feature_names_in_ } " )
1762
+ pysr_logger . info (f"Using features { self .feature_names_in_ } " )
1756
1763
1757
1764
# Denoising transformation
1758
1765
if self .denoise :
@@ -1824,7 +1831,7 @@ def _run(
1824
1831
1825
1832
# Start julia backend processes
1826
1833
if not ALREADY_RAN and runtime_params .update_verbosity != 0 :
1827
- print ("Compiling Julia backend..." )
1834
+ pysr_logger . info ("Compiling Julia backend..." )
1828
1835
1829
1836
parallelism , numprocs = _map_parallelism_params (
1830
1837
self .parallelism , self .procs , getattr (self , "multithreading" , None )
0 commit comments