Skip to content

Commit

Permalink
Allow NaNs in input X (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinmccarter authored Jan 31, 2025
1 parent f1dd7c6 commit ecffeac
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions ngboost/ngboost.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The NGBoost library"""

# pylint: disable=line-too-long,too-many-instance-attributes,too-many-arguments
# pylint: disable=unused-argument,too-many-locals,too-many-branches,too-many-statements
# pylint: disable=unused-variable,invalid-unary-operand-type,attribute-defined-outside-init
Expand Down Expand Up @@ -342,7 +343,12 @@ def partial_fit(
raise ValueError("y cannot be None")

X, Y = check_X_y(
X, Y, accept_sparse=True, y_numeric=True, multi_output=self.multi_output
X,
Y,
accept_sparse=True,
force_all_finite="allow-nan",
multi_output=self.multi_output,
y_numeric=True,
)

self.n_features = X.shape[1]
Expand All @@ -357,8 +363,9 @@ def partial_fit(
X_val,
Y_val,
accept_sparse=True,
y_numeric=True,
force_all_finite="allow-nan",
multi_output=self.multi_output,
y_numeric=True,
)
val_params = self.pred_param(X_val)
val_loss_list = []
Expand Down Expand Up @@ -490,7 +497,7 @@ def pred_dist(self, X, max_iter=None):
A NGBoost distribution object
"""

X = check_array(X, accept_sparse=True)
X = check_array(X, accept_sparse=True, force_all_finite="allow-nan")

params = np.asarray(self.pred_param(X, max_iter))
dist = self.Dist(params.T)
Expand Down Expand Up @@ -537,7 +544,7 @@ def predict(self, X, max_iter=None):
Numpy array of the estimates of Y
"""

X = check_array(X, accept_sparse=True)
X = check_array(X, accept_sparse=True, force_all_finite="allow-nan")

return self.pred_dist(X, max_iter=max_iter).predict()

Expand Down

0 comments on commit ecffeac

Please sign in to comment.