Skip to content

Commit

Permalink
Adds predict_with_params, raw_scores_with_params functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Mottl committed Jul 19, 2023
1 parent 47f6d2f commit 84173c4
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lightgbm3"
version = "1.0.1"
version = "1.0.2"
edition = "2021"
authors = ["Dmitry Mottl <dmitry.mottl@gmail.com>", "vaaaaanquish <6syun9@gmail.com>"]
license = "MIT"
Expand All @@ -13,7 +13,7 @@ readme = "README.md"
exclude = [".gitignore", ".github", ".gitmodules", "examples", "benches", "lightgbm3-sys"]

[dependencies]
lightgbm3-sys = { path = "lightgbm3-sys", version = "1.0.0" }
lightgbm3-sys = { path = "lightgbm3-sys", version = "1.0.2" }
libc = "0.2"
derive_builder = "0.12"
serde_json = "1.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use lightgbm3::{Dataset, Booster};
let bst = Booster::from_file("path/to/model.lgb").unwrap();
let features = vec![1.0, 2.0, -5.0];
let n_features = features.len();
let y_pred = bst.predict(&features, n_features as i32, true).unwrap()[0];
let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0];
```

Look in the [`./examples/`](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/) folder for more details:
Expand Down
2 changes: 1 addition & 1 deletion lightgbm3-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lightgbm3-sys"
version = "1.0.1"
version = "1.0.2"
edition = "2021"
authors = ["Dmitry Mottl <dmitry.mottl@gmail.com>", "vaaaaanquish <6syun9@gmail.com>"]
build = "build.rs"
Expand Down
2 changes: 1 addition & 1 deletion lightgbm3-sys/lightgbm
Submodule lightgbm updated 45 files
+2 −2 .appveyor.yml
+49 −0 .ci/test-python-oldest.sh
+3 −3 .github/workflows/cuda.yml
+1 −1 .github/workflows/linkchecker.yml
+33 −7 .github/workflows/python_package.yml
+1 −1 .github/workflows/r_valgrind.yml
+11 −11 .vsts-ci.yml
+15 −0 R-package/R/lgb.Booster.R
+3 −0 R-package/R/lgb.drop_serialized.R
+3 −0 R-package/R/lgb.make_serializable.R
+2 −0 R-package/R/lgb.restore_handle.R
+15 −0 R-package/R/lightgbm.R
+9 −9 R-package/configure
+3 −1 R-package/man/lgb.configure_fast_predict.Rd
+2 −0 R-package/man/lgb.drop_serialized.Rd
+2 −0 R-package/man/lgb.make_serializable.Rd
+2 −0 R-package/man/lgb.restore_handle.Rd
+2 −0 R-package/man/lgb_shared_params.Rd
+12 −4 R-package/man/lightgbm.Rd
+8 −2 R-package/man/predict.lgb.Booster.Rd
+2 −0 R-package/man/print.lgb.Booster.Rd
+2 −0 R-package/man/summary.lgb.Booster.Rd
+1 −1 VERSION.txt
+1 −0 build-python.sh
+1 −1 docs/FAQ.rst
+2 −0 docs/Parallel-Learning-Guide.rst
+12 −0 docs/Parameters.rst
+5 −0 docs/conf.py
+6 −0 include/LightGBM/config.h
+1 −1 include/LightGBM/tree_learner.h
+1 −1 include/LightGBM/utils/common.h
+2 −2 include/LightGBM/utils/json11.h
+245 −169 python-package/lightgbm/basic.py
+2 −0 python-package/lightgbm/callback.py
+21 −5 python-package/lightgbm/engine.py
+8 −0 python-package/lightgbm/plotting.py
+7 −0 python-package/lightgbm/sklearn.py
+2 −1 python-package/pyproject.toml
+1 −1 src/boosting/gbdt.h
+1 −1 src/io/dataset_loader.cpp
+3 −3 src/io/json11.cpp
+1 −1 src/treelearner/gpu_tree_learner.h
+1 −1 src/treelearner/gradient_discretizer.hpp
+1 −1 src/treelearner/serial_tree_learner.h
+2 −2 tests/python_package_test/test_basic.py
96 changes: 92 additions & 4 deletions src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ impl Booster {
n_features: i32,
is_row_major: bool,
predict_type: PredictType,
parameters: Option<&str>,
) -> Result<Vec<f64>> {
if self.n_features <= 0 {
return Err(Error::new("n_features should be greater than 0"));
Expand All @@ -252,7 +253,10 @@ impl Booster {
)));
}
let n_rows = flat_x.len() / n_features as usize;
let params = CString::new("").unwrap();
let params_cstring = parameters
.map(|s| CString::new(s))
.unwrap_or(CString::new(""))
.unwrap();
let mut out_length: c_longlong = 0;

let out_result: Vec<f64> = vec![Default::default(); n_rows * self.n_classes as usize];
Expand All @@ -266,7 +270,7 @@ impl Booster {
predict_type.into(), // predict_type
0_i32, // start_iteration
self.max_iterations, // num_iteration, <= 0 means no limit
params.as_ptr() as *const c_char,
params_cstring.as_ptr() as *const c_char,
&mut out_length,
out_result.as_ptr() as *mut c_double
))?;
Expand All @@ -282,7 +286,31 @@ impl Booster {
n_features: i32,
is_row_major: bool,
) -> Result<Vec<f64>> {
self.real_predict(flat_x, n_features, is_row_major, PredictType::Normal)
self.real_predict(flat_x, n_features, is_row_major, PredictType::Normal, None)
}

/// Get predictions given `&[f32]` or `&[f64]` slice of features. The resulting vector
/// will have the size of `n_rows` by `n_classes`.
///
/// Example:
/// ```compile_fail
/// use serde_json::json;
/// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap();
/// ```
pub fn predict_with_params<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
params: &str,
) -> Result<Vec<f64>> {
self.real_predict(
flat_x,
n_features,
is_row_major,
PredictType::Normal,
Some(params),
)
}

/// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector
Expand All @@ -293,7 +321,37 @@ impl Booster {
n_features: i32,
is_row_major: bool,
) -> Result<Vec<f64>> {
self.real_predict(flat_x, n_features, is_row_major, PredictType::RawScore)
self.real_predict(
flat_x,
n_features,
is_row_major,
PredictType::RawScore,
None,
)
}

/// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector
/// will have the size of `n_rows` by `n_classes`.
///
/// Example:
/// ```compile_fail
/// use serde_json::json;
/// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap();
/// ```
pub fn raw_scores_with_params<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
parameters: &str,
) -> Result<Vec<f64>> {
self.real_predict(
flat_x,
n_features,
is_row_major,
PredictType::RawScore,
Some(parameters),
)
}

/// Predicts results for the given `x` and returns a vector or vectors (inner vectors will
Expand Down Expand Up @@ -482,6 +540,36 @@ mod tests {
assert_eq!(normalized_result, vec![0, 0, 1]);
}

#[test]
fn predict_with_params() {
let params = json! {
{
"num_iterations": 10,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = _train_booster(&params);
// let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]];
let mut feature = [0.0; 28 * 3];
for i in 0..28 {
feature[i] = 0.5;
}
for i in 56..feature.len() {
feature[i] = 0.9;
}

let result = bst
.predict_with_params(&feature, 28, true, "num_threads=1")
.unwrap();
let mut normalized_result = Vec::new();
for r in &result {
normalized_result.push(if *r > 0.5 { 1 } else { 0 });
}
assert_eq!(normalized_result, vec![0, 0, 1]);
}

#[test]
fn num_feature() {
let params = _default_params();
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
//! let bst = Booster::from_file("path/to/model.lgb").unwrap();
//! let features = vec![1.0, 2.0, -5.0];
//! let n_features = features.len();
//! let y_pred = bst.predict(&features, n_features as i32, true).unwrap()[0];
//! let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0];
//! ```
macro_rules! lgbm_call {
Expand Down

0 comments on commit 84173c4

Please sign in to comment.