Skip to content

Commit 8c8695b

Browse files
authored
Merge pull request #776 from MilesCranmer/fix-sklearn-tests
test: skip new sklearn checks
2 parents 89b5a89 + 3d23dc0 commit 8c8695b

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

pysr/sr.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
_suggest_keywords,
5959
)
6060

61+
try:
62+
from sklearn.utils.validation import validate_data
63+
64+
OLD_SKLEARN = False
65+
except ImportError:
66+
OLD_SKLEARN = True
67+
6168
ALREADY_RAN = False
6269

6370

@@ -1604,11 +1611,17 @@ def _validate_and_set_fit_params(
16041611
)
16051612

16061613
def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
1607-
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
1614+
if OLD_SKLEARN:
1615+
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
1616+
else:
1617+
raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore
16081618
return cast(tuple[ndarray, ndarray], raw_out)
16091619

16101620
def _validate_data_X(self, X: Any) -> ndarray:
1611-
raw_out = self._validate_data(X=X, reset=False) # type: ignore
1621+
if OLD_SKLEARN:
1622+
raw_out = self._validate_data(X=X, reset=False) # type: ignore
1623+
else:
1624+
raw_out = validate_data(self, X=X, reset=False) # type: ignore
16121625
return cast(ndarray, raw_out)
16131626

16141627
def _get_precision_mapped_dtype(self, X: np.ndarray) -> type:

pysr/test/test_main.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,14 @@ def test_scikit_learn_compatibility(self):
876876
check_generator = check_estimator(model, generate_only=True)
877877
exception_messages = []
878878
for _, check in check_generator:
879-
if check.func.__name__ == "check_complex_data":
880-
# We can use complex data, so avoid this check.
879+
if check.func.__name__ in {
880+
# We can use complex data, so avoid this check
881+
"check_complex_data",
882+
# We handle kwargs manually, so skip this check
883+
"check_do_not_raise_errors_in_init_or_set_params",
884+
# TODO:
885+
"check_n_features_in_after_fitting",
886+
}:
881887
continue
882888
try:
883889
with warnings.catch_warnings():

0 commit comments

Comments
 (0)