Skip to content

Commit 3d23dc0

Browse files
committed
feat: compat with new sklearn version
1 parent 85b4f40 commit 3d23dc0

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

pysr/sr.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
_suggest_keywords,
5959
)
6060

61+
try:
62+
from sklearn.utils.validation import validate_data
63+
64+
OLD_SKLEARN = False
65+
except ImportError:
66+
OLD_SKLEARN = True
67+
6168
ALREADY_RAN = False
6269

6370

@@ -1604,11 +1611,17 @@ def _validate_and_set_fit_params(
16041611
)
16051612

16061613
def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
1607-
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
1614+
if OLD_SKLEARN:
1615+
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
1616+
else:
1617+
raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore
16081618
return cast(tuple[ndarray, ndarray], raw_out)
16091619

16101620
def _validate_data_X(self, X: Any) -> ndarray:
1611-
raw_out = self._validate_data(X=X, reset=False) # type: ignore
1621+
if OLD_SKLEARN:
1622+
raw_out = self._validate_data(X=X, reset=False) # type: ignore
1623+
else:
1624+
raw_out = validate_data(self, X=X, reset=False) # type: ignore
16121625
return cast(ndarray, raw_out)
16131626

16141627
def _get_precision_mapped_dtype(self, X: np.ndarray) -> type:

0 commit comments

Comments
 (0)