Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding build using binary downloads #8

Merged
merged 4 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modules/utils/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
onnx_driver/
target/
output/
output/
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh we could probably move the package to output/downloaded_onnx instead. Because it is a byproduct of building, even if its required for the build.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now removed the reliance on downloads and it's all now in thetarget directory

downloaded_onnx_package/
6 changes: 6 additions & 0 deletions modules/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ tokio = { version = "1.12.0", features = ["full"] }
[lib]
name = "surrealml_core"
path = "src/lib.rs"

[build-dependencies]
ureq = { version = "2.1", default-features = false, features = [ "tls" ] }
tar = { version = "0.4" }
flate2 = { version = "1.0" }
sha2 = { version = "0.10" }
175 changes: 130 additions & 45 deletions modules/utils/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,67 @@
use std::process::Command;
//! This build file downloads the prebuilt binaries for ONNX Runtime and places them in the root of the crate
//! to be pointed at by the ort environment so we can run the ONNX models.
use ureq;
use std::path::Path;


static CACHE_FILE: &str = "./downloaded_onnx_package";


fn extract_tgz(buf: &[u8], output: &Path) {
let buf: std::io::BufReader<&[u8]> = std::io::BufReader::new(buf);
let tar = flate2::read::GzDecoder::new(buf);
let mut archive = tar::Archive::new(tar);
archive.unpack(output).expect("Failed to extract .tgz file");
}


fn hex_str_to_bytes(c: impl AsRef<[u8]>) -> Vec<u8> {
fn nibble(c: u8) -> u8 {
match c {
b'A'..=b'F' => c - b'A' + 10,
b'a'..=b'f' => c - b'a' + 10,
b'0'..=b'9' => c - b'0',
_ => panic!()
}
}

c.as_ref().chunks(2).map(|n| nibble(n[0]) << 4 | nibble(n[1])).collect()
}


fn verify_file(buf: &[u8], hash: impl AsRef<[u8]>) -> bool {
use sha2::Digest;
sha2::Sha256::digest(buf)[..] == hex_str_to_bytes(hash)
}


/// Fetches a file from the given URL and returns it as a vector of bytes.
///
/// # Arguments
/// * `source_url` - The URL to fetch the file from.
///
/// # Returns
/// A vector of bytes containing the file.
fn fetch_file(source_url: &str) -> Vec<u8> {
let resp = ureq::get(source_url)
.timeout(std::time::Duration::from_secs(1800))
.call()
.unwrap_or_else(|err| panic!("Failed to GET `{source_url}`: {err}"));

let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.expect("Content-Length header should be present on archive response");
let mut reader = resp.into_reader();
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.unwrap_or_else(|err| panic!("Failed to download from `{source_url}`: {err}"));
assert_eq!(buffer.len(), len);
buffer
}



fn main() {

Expand All @@ -7,50 +70,72 @@ fn main() {
println!("cargo:rustc-cfg=onnx_runtime_env_var_set");
},
Err(_) => {
#[cfg(not(windows))]
{
let _ = Command::new("sh")
.arg("-c")
.arg("cargo new onnx_driver && cd onnx_driver && echo 'ort = \"1.16.2\"' >> Cargo.toml
")
.status()
.expect("failed to execute process");
}

#[cfg(windows)]
{
// let _ = Command::new("cmd")
// .args(&["/C", "cargo new onnx_driver && cd onnx_driver && echo ort = \"1.16.2\" >> Cargo.toml"])
// .status()
// .expect("failed to execute process");
let _ = Command::new("powershell")
.arg("-Command")
.arg("cargo new onnx_driver; Set-Location onnx_driver; Add-Content -Path .\\Cargo.toml -Value 'ort = \"1.16.2\"'")
.status()
.expect("failed to execute process");
}

#[cfg(not(windows))]
{
let _ = Command::new("sh")
.arg("-c")
.arg("cd onnx_driver && cargo build")
.status()
.expect("failed to execute process");
}

#[cfg(windows)]
{
let _ = Command::new("cmd")
.args(&["/C", "cd onnx_driver && cargo build"])
.status()
.expect("failed to execute process");
}
let target = std::env::var("TARGET").unwrap();

let (prebuilt_url, prebuilt_hash) = match target.as_str() {
"aarch64-apple-darwin" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-aarch64-apple-darwin.tgz",
"188E07B9304CCC28877195ECD2177EF3EA6603A0B5B3497681A6C9E584721387"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The versions and shared prefixes can be extracted. This guarantees that we don't have version typos then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now moved over to just using the download binaries feature in the ort for the dev-dependencies and then using the binary for the include bytes, all builds are now passing

),
"aarch64-pc-windows-msvc" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-aarch64-pc-windows-msvc.tgz",
"B35F6526EAF61527531D6F73EBA19EF09D6B0886FB66C14E1B594EE70F447817"
),
"aarch64-unknown-linux-gnu" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-aarch64-unknown-linux-gnu.tgz",
"C1E315515856D7880545058479020756BC5CE4C0BA07FB3DD2104233EC7C3C81"
),
"wasm32-unknown-emscripten" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-wasm32-unknown-emscripten.tgz",
"468F74FB4C7451DC94EBABC080779CDFF0C7DA0617D85ADF21D5435A96F9D470"
),
"x86_64-apple-darwin" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-x86_64-apple-darwin.tgz",
"0191C95D9E797BF77C723AD82DC078C6400834B55B8465FA5176BA984FFEAB08"
),
"x86_64-pc-windows-msvc" => {
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
(
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_dylib_cuda-v1.16.3-x86_64-pc-windows-msvc.tgz",
"B0F08E93E580297C170F04933742D04813C9C3BAD3705E1100CA9EF464AE4011"
)
} else {
(
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-x86_64-pc-windows-msvc.tgz",
"32ADC031C0EAA6C521680EEB9A8C39572C600A5B4F90AFE984590EA92B99E3BE"
)
}
}
"x86_64-unknown-linux-gnu" => {
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
(
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_dylib_cuda-v1.16.3-x86_64-unknown-linux-gnu.tgz",
"0F0651D10BA56A6EA613F10B60E5C4D892384416C4D76E1F618BE57D1270993F"
)
} else {
(
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.16.3/ortrs-msort_static-v1.16.3-x86_64-unknown-linux-gnu.tgz",
"D0E63AC1E5A56D0480009049183542E0BB1482CE29A1D110CC8550BEC5D994E2"
)
}
}
x => panic!("downloaded binaries not available for target {x}\nyou may have to compile ONNX Runtime from source")
};

// download the prebuilt binary that is compatible with the target
let downloaded_file = fetch_file(prebuilt_url);

// verify the hash of the downloaded file to ensure that we have downloaded the correct file
assert!(verify_file(&downloaded_file, prebuilt_hash));

// delete the ONNX cache file if it exists
let _ = std::fs::remove_dir_all(CACHE_FILE);

// extract the downloaded file to the cache file
extract_tgz(&downloaded_file, Path::new(CACHE_FILE));

// remove the downloaded bundled file
let _ = std::fs::remove_file("./msort.tar.gz");
}
}
}

// fn main() {
// // println!("cargo:rustc-cfg=onnx_runtime_env_var_set");
// println!("test");
// }
6 changes: 3 additions & 3 deletions modules/utils/src/execution/onnx_environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use std::sync::Arc;

// Compiles the ONNX module into the rust binary.
#[cfg(all(target_os = "macos", not(doc), not(onnx_runtime_env_var_set)))]
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../onnx_driver/target/debug/libonnxruntime.dylib");
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../downloaded_onnx_package/onnxruntime/lib/libonnxruntime.a");

#[cfg(all(any(target_os = "linux", target_os = "android"), not(doc), not(onnx_runtime_env_var_set)))]
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../onnx_driver/target/debug/libonnxruntime.so");
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../downloaded_onnx_package/onnxruntime/lib/libonnxruntime.a");

#[cfg(all(target_os = "windows", not(doc), not(onnx_runtime_env_var_set)))]
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../onnx_driver/target/debug/onnxruntime.dll");
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../downloaded_onnx_package/onnxruntime/lib/libonnxruntime.a");

// Fallback for documentation and other targets
#[cfg(any(doc, onnx_runtime_env_var_set, not(any(target_os = "macos", target_os = "linux", target_os = "android", target_os = "windows"))))]
Expand Down
Loading