Skip to content

Commit df682ca

Browse files
authored
Merge pull request #200 from kevinaboos/windows_cuda_support
Windows cuda support
2 parents 245d104 + 309d2aa commit df682ca

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

moxin-runner/src/main.rs

+54
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,25 @@ const ENV_LD_LIBRARY_PATH: &str = "LD_LIBRARY_PATH";
144144
#[cfg(target_os = "macos")]
145145
const ENV_DYLD_FALLBACK_LIBRARY_PATH: &str = "DYLD_FALLBACK_LIBRARY_PATH";
146146

147+
147148
/// Returns the URL of the WASI-NN plugin that should be downloaded, and its inner directory name.
149+
///
150+
/// Note that this is only used on Windows, because the install_v2.sh script handles it on Linux.
151+
///
152+
/// The plugin selection follows this priority order of hardware features:
153+
/// 1. The CUDA build, if CUDA V12 is installed.
154+
/// 2. The default AVX512 build, if on x86_64 and AVX512F is supported.
155+
/// 3. Otherwise, the noavx build (which itself still requires SSE4.2 or SSE4a).
148156
#[cfg(windows)]
149157
fn wasmedge_wasi_nn_plugin_url() -> (&'static str, &'static str) {
158+
// Currently, WasmEdge's b3499 release only provides a CUDA 12 build for Windows.
159+
if matches!(get_cuda_version(), Some(CudaVersion::V12)) {
160+
return (
161+
"https://github.com/second-state/WASI-NN-GGML-PLUGIN-REGISTRY/releases/download/b3499/WasmEdge-plugin-wasi_nn-ggml-cuda-0.14.0-windows_x86_64.zip",
162+
"WasmEdge-plugin-wasi_nn-ggml-cuda-0.14.0-windows_x86_64",
163+
);
164+
}
165+
150166
#[cfg(target_arch = "x86_64")]
151167
if is_x86_feature_detected!("avx512f") {
152168
return (
@@ -452,6 +468,44 @@ fn set_env_vars<P: AsRef<Path>>(wasmedge_root_dir_path: &P) {
452468
std::env::set_var(ENV_WASMEDGE_PLUGIN_PATH, wasmedge_root_dir_path.as_ref());
453469
}
454470

471+
472+
/// Versions of CUDA that WasmEdge supports.
473+
enum CudaVersion {
474+
/// CUDA Version 12
475+
V12,
476+
/// CUDA Version 11
477+
V11,
478+
}
479+
480+
/// Attempts to discover what version of CUDA is locally installed, if any.
481+
///
482+
/// This function first runs `nvcc --version` on both Linux and Windows,
483+
/// and if that fails, it will try `/usr/local/cuda/bin/nvcc --version` on Linux only.
484+
fn get_cuda_version() -> Option<CudaVersion> {
485+
let mut output = Command::new("nvcc")
486+
.arg("--version")
487+
.output();
488+
489+
#[cfg(target_os = "linux")] {
490+
output = output.or_else(|_|
491+
Command::new("/usr/local/cuda/bin/nvcc")
492+
.arg("--version")
493+
.output()
494+
);
495+
}
496+
497+
let output = output.ok()?;
498+
let output = String::from_utf8_lossy(&output.stdout);
499+
if output.contains("V12") {
500+
Some(CudaVersion::V12)
501+
} else if output.contains("V11") {
502+
Some(CudaVersion::V11)
503+
} else {
504+
None
505+
}
506+
}
507+
508+
455509
/// Runs the `_moxin_app` binary, which must be located in the same directory as this moxin-runner binary.
456510
///
457511
/// An optional path to the directory containing the main WasmEdge dylib can be provided,

0 commit comments

Comments
 (0)