Skip to content

Commit

Permalink
Merge pull request #310 from stanfordmlgroup/sklearn_dict
Browse files Browse the repository at this point in the history
add sklearn dictionary support get_param and version update
  • Loading branch information
ryan-wolbeck authored Mar 14, 2023
2 parents 3575781 + bcac835 commit ae3677b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
8 changes: 8 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# RELEASE NOTES

## Version 0.4.0

* Added support for the gamma distribution
* Added sklearn support to `set_params`
* Fixed off-by-one issue for max trees
* Upgraded version of `black` formatter to 22.8.0
29 changes: 29 additions & 0 deletions ngboost/ngboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,35 @@ def fit(

return self

def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self

def get_params(self, deep=True):
"""
Parameters
----------
deep : Ignored. (for compatibility with sklearn)
Returns
----------
params : returns an dictionary of parameters.
"""
params = {
"Dist": self.Dist,
"Score": self.Score,
"Base": self.Base,
"natural_gradient": self.natural_gradient,
"n_estimators": self.n_estimators,
"learning_rate": self.learning_rate,
"minibatch_frac": self.minibatch_frac,
"col_sample": self.col_sample,
"verbose": self.verbose,
"random_state": self.random_state,
}

return params

def score(self, X, Y): # for sklearn
return self.Manifold(self.pred_dist(X)._params).total_score(Y)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ngboost"
version = "0.3.14dev"
version = "0.4.0dev"
description = "Library for probabilistic predictions via gradient boosting."
authors = ["Stanford ML Group <avati@cs.stanford.edu>"]
readme = "README.md"
Expand Down

0 comments on commit ae3677b

Please sign in to comment.