Skip to content

Commit

Permalink
merging from main as main was bypassed in a previous pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Jan 19, 2024
2 parents 1ede48e + 1769322 commit d264888
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 108 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ name = "surrealml"
path = "src/lib.rs"

[build-dependencies]
ort = { version = "1.16.2", default-features = true }
ort = { version = "1.16.2", default-features = true }
96 changes: 67 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,42 @@
# SurrealMl

This package is for storing machine learning models with meta data in Rust so they can be used on the SurrealDB server.

## What is SurrealML?

SurrealML is a feature that allows you to store trained machine learning models in a special format called 'surml'. This enables you to run these models in either Python or Rust, and even upload them to a SurrealDB node to run the models on the server

## Prerequisites

1. A basic understanding of Machine Learning: You should be familiar with ML concepts, algorithms, and model training processes.
2. Knowledge of Python: Proficiency in Python is necessary as SurrealML involves working with Python-based ML models.
3. Familiarity with SurrealDB: Basic knowledge of how SurrealDB operates is required since SurrealML integrates directly with it.
4. Python Environment Setup: A Python environment with necessary libraries installed, including SurrealML, PyTorch or SKLearn (depending on your model preference).
5. SurrealDB Installation: Ensure you have SurrealDB installed and running on your machine or server

## Installation

To install SurrealML, make sure you have Python installed. Then, install the SurrealML library and either PyTorch or SKLearn, based on your model choice. You can install these using pip:

```
pip install surrealml
pip install torch # If using PyTorch
pip install scikit-learn # If using SKLearn
```

After that, you can train your model and save it in the SurrealML format.

## Compilation config
If nothing is configured the crate will compile the ONNX runtime into the binary. This is the default behaviour. However, if you want to use an ONNX runtime that is installed on your system, you can set the environment variable `ONNXRUNTIME_LIB_PATH` before you compile the crate. This will make the crate use the ONNX runtime that is installed on your system.

If nothing is configured the crate will compile the ONNX runtime into the binary. This is the default behaviour. However, you have 2 more options:

- If you want to use an ONNX runtime that is installed on your system, you can set the environment variable `ONNXRUNTIME_LIB_PATH` before you compile the crate. This will make the crate use the ONNX runtime that is installed on your system.
- If you want to statically compile the library, you can download it from https://github.com/surrealdb/onnxruntime-build/releases/tag/v1.16.3 and then build the crate this way:

```
$ tar xvf <onnx-archive-file> -C extract-dir
$ ORT_STRATEGY=system ORT_LIB_LOCATION=$(pwd)/extract-dir cargo build
```

## Quick start with Sk-learn

Expand Down Expand Up @@ -33,14 +67,15 @@ print(test_load.raw_compute(random_floats, [1, -1]))
```

## Python tutorial using Pytorch

To carry out this example we need the following:

- pytorch (pip installed for python)
- numpy
- surrealml

First we need to have one script where we create and store the model. In this example we will merely do a linear regression model
to predict the house price using the number of floors and the square feet.
to predict the house price using the number of floors and the square feet.

### Defining the data

Expand Down Expand Up @@ -150,6 +185,7 @@ from surrealml import SurMlFile

file = SurMlFile(model=model, name="House Price Prediction", inputs=test_inputs)
```

The name is optional but the inputs and model are essential. We can now add some meta data to the file such as our inputs and outputs with the following code, however meta data is not essential, it just helps with some types of computation:

```python
Expand Down Expand Up @@ -187,6 +223,7 @@ new_file = SurMlFile.load("./test.surml")
Our model is now loaded. We can now perform computations.

### Raw computation in Python

If you haven't put any meta data into the file then don't worry, we can just perform a raw computation with the following command:

```python
Expand Down Expand Up @@ -219,12 +256,12 @@ We can upload our trained model with the following code:
```python
url = "http://0.0.0.0:8000/ml/import"
SurMlFile.upload(
path="./linear_test.surml",
url=url,
chunk_size=36864,
namespace="test",
database="test",
username="root",
path="./linear_test.surml",
url=url,
chunk_size=36864,
namespace="test",
database="test",
username="root",
password="root"
)
```
Expand All @@ -239,35 +276,35 @@ CREATE house_listing SET squarefoot_col = 1000.0, num_floors_col = 2.0;
CREATE house_listing SET squarefoot_col = 1500.0, num_floors_col = 3.0;

SELECT * FROM (
SELECT
*,
ml::house-price-prediction<0.0.1>({
squarefoot: squarefoot_col,
num_floors: num_floors_col
}) AS price_prediction
SELECT
*,
ml::house-price-prediction<0.0.1>({
squarefoot: squarefoot_col,
num_floors: num_floors_col
}) AS price_prediction
FROM house_listing
)
)
WHERE price_prediction > 177206.21875;
```

What is happening here is that we are feeding the columns from the table `house_listing` into a model we uploaded
called `house-price-prediction` with a version of `0.0.1`. We then get the results of that trained ML model as the column
What is happening here is that we are feeding the columns from the table `house_listing` into a model we uploaded
called `house-price-prediction` with a version of `0.0.1`. We then get the results of that trained ML model as the column
`price_prediction`. We then use the calculated predictions to filter the rows giving us the following result:

```json
[
{
"id": "house_listing:7bo0f35tl4hpx5bymq5d",
"num_floors_col": 3,
"price_prediction": 406534.75,
"squarefoot_col": 1500
},
{
"id": "house_listing:8k2ttvhp2vh8v7skwyie",
"num_floors_col": 2,
"price_prediction": 291870.5,
"squarefoot_col": 1000
}
{
"id": "house_listing:7bo0f35tl4hpx5bymq5d",
"num_floors_col": 3,
"price_prediction": 406534.75,
"squarefoot_col": 1500
},
{
"id": "house_listing:8k2ttvhp2vh8v7skwyie",
"num_floors_col": 2,
"price_prediction": 291870.5,
"squarefoot_col": 1000
}
]
```

Expand All @@ -282,6 +319,7 @@ let mut file = SurMlFile::from_file("./test.surml").unwrap();
```

### Raw computation in Rust

You can have an empty header if you want. This makes sense if you're doing something novel, or complex such as convolutional neural networks
for image processing. To perform a raw computation you can merely just do the following:

Expand Down
37 changes: 37 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::env;
use std::path::Path;

fn main() {
// if the ONNX library is statically linked we do not need to do anything
if cfg!(onnx_statically_linked) {
return;
}

let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() {
ref s if s.contains("linux") => "libonnxruntime.so",
ref s if s.contains("macos") => "libonnxruntime.dylib",
ref s if s.contains("windows") => "onnxruntime.dll",
// ref s if s.contains("android") => "android", => not building for android
_ => panic!("Unsupported target os"),
};
let profile = match env::var("PROFILE").unwrap() {
ref s if s.contains("release") => "release",
ref s if s.contains("debug") => "debug",
_ => panic!("Unsupported profile"),
};

// remove ./modules/utils/target folder if there
let _ =
std::fs::remove_dir_all(Path::new("modules").join("utils").join("target")).unwrap_or(());

// create the target module folder for the utils module
let _ = std::fs::create_dir(Path::new("modules").join("utils").join("target"));
let _ = std::fs::create_dir(Path::new("modules").join("utils").join("target").join(profile));

// copy target folder to modules/utils/target profile for the utils modules
std::fs::copy(
Path::new("target").join(profile).join(target_lib),
Path::new("modules").join("utils").join("target").join(profile).join(target_lib),
)
.unwrap();
}
78 changes: 43 additions & 35 deletions modules/utils/build.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::path::Path;
use std::env;
use std::fs;
use std::path::Path;


/// works out where the `onnxruntime` library is in the build target and copies the library to the root
/// of the crate so the core library can find it and load it into the binary using `include_bytes!()`.
///
///
/// # Notes
/// This is a workaround for the fact that `onnxruntime` doesn't support `cargo` yet. This build step
/// is reliant on the `ort` crate downloading and building the `onnxruntime` library. This is
Expand All @@ -29,45 +29,53 @@ use std::fs;
/// we do not need to move the `onnxruntime` library around with the binary, and there is no complicated setup required
/// or linking.
fn unpack_onnx() -> std::io::Result<()> {
let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set");
let out_path = Path::new(&out_dir);
let build_dir = out_path
.ancestors() // This gives an iterator over all ancestors of the path
.nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory
.expect("Failed to find debug directory");
let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set");
let out_path = Path::new(&out_dir);
let build_dir = out_path
.ancestors() // This gives an iterator over all ancestors of the path
.nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory
.expect("Failed to find debug directory");

match std::env::var("ONNXRUNTIME_LIB_PATH") {
Ok(_) => {
println!("cargo:rustc-cfg=onnx_runtime_env_var_set");
},
Err(_) => {
let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() {
ref s if s.contains("linux") => "libonnxruntime.so",
ref s if s.contains("macos") => "libonnxruntime.dylib",
ref s if s.contains("windows") => "onnxruntime.dll",
// ref s if s.contains("android") => "android", => not building for android
_ => panic!("Unsupported target os")
};
match std::env::var("ONNXRUNTIME_LIB_PATH") {
Ok(_) => {
println!("cargo:rustc-cfg=onnx_runtime_env_var_set");
}
Err(_) => {
let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() {
ref s if s.contains("linux") => "libonnxruntime.so",
ref s if s.contains("macos") => "libonnxruntime.dylib",
ref s if s.contains("windows") => "onnxruntime.dll",
// ref s if s.contains("android") => "android", => not building for android
_ => panic!("Unsupported target os"),
};

let lib_path = build_dir.join(target_lib);
let lib_path = lib_path.to_str().unwrap();
let lib_path = build_dir.join(target_lib);
let lib_path = lib_path.to_str().unwrap();

// put it next to the file of the embedding
let destination = Path::new(target_lib);
fs::copy(lib_path, destination)?;
}
}
Ok(())
// put it next to the file of the embedding
let destination = Path::new(target_lib);
fs::copy(lib_path, destination)?;
}
}
Ok(())
}


fn main() -> std::io::Result<()> {
if std::env::var("DOCS_RS").is_ok() {
// we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building
// the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it.
return Ok(());
}

if env::var("ORT_STRATEGY").as_deref() == Ok("system") {
// If the ORT crate is built with the `system` strategy, then the crate will take care of statically linking the library.
// No need to do anything here.
println!("cargo:rustc-cfg=onnx_statically_linked");

return Ok(());
}

if std::env::var("DOCS_RS").is_ok() {
// we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building
// the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it.
} else {
unpack_onnx()?;
}
Ok(())
unpack_onnx()?;
Ok(())
}
Loading

0 comments on commit d264888

Please sign in to comment.