Skip to content

Commit 284bfa7

Browse files
authored
Merge pull request #219 from kevinaboos/linux_cuda_support
moxin-runner: explicitly detect CUDA on Linux
2 parents df682ca + c78bf1b commit 284bfa7

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

moxin-runner/src/main.rs

+31-21
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,13 @@ fn install_wasmedge<P: AsRef<Path>>(install_path: P) -> Result<PathBuf, std::io:
328328
// The default `/tmp/` dir used in `install_v2.sh` isn't always accessible to bundled apps.
329329
.arg(&format!("--tmpdir={}", temp_dir.display()));
330330

331-
// If the current CPU doesn't support AVX512, tell the install script to
332-
// the WASI-nn plugin built without AVX support.
331+
let cuda = get_cuda_version();
332+
println!(" --> Found CUDA installation: {cuda:?}");
333+
334+
// If the current machine doesn't have CUDA and the CPU doesn't support AVX512,
335+
// tell the install script to select the no-AVX WASI-nn plugin version.
333336
#[cfg(target_arch = "x86_64")]
334-
if !is_x86_feature_detected!("avx512f") {
337+
if cuda.is_none() && !is_x86_feature_detected!("avx512f") {
335338
bash_cmd.arg("--noavx");
336339
}
337340

@@ -470,6 +473,7 @@ fn set_env_vars<P: AsRef<Path>>(wasmedge_root_dir_path: &P) {
470473

471474

472475
/// Versions of CUDA that WasmEdge supports.
476+
#[derive(Debug)]
473477
enum CudaVersion {
474478
/// CUDA Version 12
475479
V12,
@@ -482,26 +486,32 @@ enum CudaVersion {
482486
/// This function first runs `nvcc --version` on both Linux and Windows,
483487
/// and if that fails, it will try `/usr/local/cuda/bin/nvcc --version` on Linux only.
484488
fn get_cuda_version() -> Option<CudaVersion> {
485-
let mut output = Command::new("nvcc")
486-
.arg("--version")
487-
.output();
488-
489-
#[cfg(target_os = "linux")] {
490-
output = output.or_else(|_|
491-
Command::new("/usr/local/cuda/bin/nvcc")
492-
.arg("--version")
493-
.output()
494-
);
489+
#[cfg(target_os = "macos")] {
490+
None
495491
}
496492

497-
let output = output.ok()?;
498-
let output = String::from_utf8_lossy(&output.stdout);
499-
if output.contains("V12") {
500-
Some(CudaVersion::V12)
501-
} else if output.contains("V11") {
502-
Some(CudaVersion::V11)
503-
} else {
504-
None
493+
#[cfg(not(target_os = "macos"))] {
494+
let mut output = Command::new("nvcc")
495+
.arg("--version")
496+
.output();
497+
498+
#[cfg(target_os = "linux")] {
499+
output = output.or_else(|_|
500+
Command::new("/usr/local/cuda/bin/nvcc")
501+
.arg("--version")
502+
.output()
503+
);
504+
}
505+
506+
let output = output.ok()?;
507+
let output = String::from_utf8_lossy(&output.stdout);
508+
if output.contains("V12") {
509+
Some(CudaVersion::V12)
510+
} else if output.contains("V11") {
511+
Some(CudaVersion::V11)
512+
} else {
513+
None
514+
}
505515
}
506516
}
507517

0 commit comments

Comments
 (0)