@@ -328,10 +328,13 @@ fn install_wasmedge<P: AsRef<Path>>(install_path: P) -> Result<PathBuf, std::io:
328
328
// The default `/tmp/` dir used in `install_v2.sh` isn't always accessible to bundled apps.
329
329
. arg ( & format ! ( "--tmpdir={}" , temp_dir. display( ) ) ) ;
330
330
331
- // If the current CPU doesn't support AVX512, tell the install script to
332
- // the WASI-nn plugin built without AVX support.
331
+ let cuda = get_cuda_version ( ) ;
332
+ println ! ( " --> Found CUDA installation: {cuda:?}" ) ;
333
+
334
+ // If the current machine doesn't have CUDA and the CPU doesn't support AVX512,
335
+ // tell the install script to select the no-AVX WASI-nn plugin version.
333
336
#[ cfg( target_arch = "x86_64" ) ]
334
- if !is_x86_feature_detected ! ( "avx512f" ) {
337
+ if cuda . is_none ( ) && !is_x86_feature_detected ! ( "avx512f" ) {
335
338
bash_cmd. arg ( "--noavx" ) ;
336
339
}
337
340
@@ -470,6 +473,7 @@ fn set_env_vars<P: AsRef<Path>>(wasmedge_root_dir_path: &P) {
470
473
471
474
472
475
/// Versions of CUDA that WasmEdge supports.
476
+ #[ derive( Debug ) ]
473
477
enum CudaVersion {
474
478
/// CUDA Version 12
475
479
V12 ,
@@ -482,26 +486,32 @@ enum CudaVersion {
482
486
/// This function first runs `nvcc --version` on both Linux and Windows,
483
487
/// and if that fails, it will try `/usr/local/cuda/bin/nvcc --version` on Linux only.
484
488
fn get_cuda_version ( ) -> Option < CudaVersion > {
485
- let mut output = Command :: new ( "nvcc" )
486
- . arg ( "--version" )
487
- . output ( ) ;
488
-
489
- #[ cfg( target_os = "linux" ) ] {
490
- output = output. or_else ( |_|
491
- Command :: new ( "/usr/local/cuda/bin/nvcc" )
492
- . arg ( "--version" )
493
- . output ( )
494
- ) ;
489
+ #[ cfg( target_os = "macos" ) ] {
490
+ None
495
491
}
496
492
497
- let output = output. ok ( ) ?;
498
- let output = String :: from_utf8_lossy ( & output. stdout ) ;
499
- if output. contains ( "V12" ) {
500
- Some ( CudaVersion :: V12 )
501
- } else if output. contains ( "V11" ) {
502
- Some ( CudaVersion :: V11 )
503
- } else {
504
- None
493
+ #[ cfg( not( target_os = "macos" ) ) ] {
494
+ let mut output = Command :: new ( "nvcc" )
495
+ . arg ( "--version" )
496
+ . output ( ) ;
497
+
498
+ #[ cfg( target_os = "linux" ) ] {
499
+ output = output. or_else ( |_|
500
+ Command :: new ( "/usr/local/cuda/bin/nvcc" )
501
+ . arg ( "--version" )
502
+ . output ( )
503
+ ) ;
504
+ }
505
+
506
+ let output = output. ok ( ) ?;
507
+ let output = String :: from_utf8_lossy ( & output. stdout ) ;
508
+ if output. contains ( "V12" ) {
509
+ Some ( CudaVersion :: V12 )
510
+ } else if output. contains ( "V11" ) {
511
+ Some ( CudaVersion :: V11 )
512
+ } else {
513
+ None
514
+ }
505
515
}
506
516
}
507
517
0 commit comments