From 4ccae4217bffe7fcbb18716a39e7d120c3c537e0 Mon Sep 17 00:00:00 2001 From: Sameer Sheorey Date: Thu, 3 Oct 2024 00:05:24 -0700 Subject: [PATCH] Match PyTorch ops CUDA arch list to that of Open3D --- 3rdparty/cmake/FindPytorch.cmake | 23 +++++++++++++++++++++++ docker/Dockerfile.ci | 3 --- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/3rdparty/cmake/FindPytorch.cmake b/3rdparty/cmake/FindPytorch.cmake index eb2a53e2ec5..28da7ec71ba 100644 --- a/3rdparty/cmake/FindPytorch.cmake +++ b/3rdparty/cmake/FindPytorch.cmake @@ -14,6 +14,20 @@ # # and import the target 'torch'. +# "80-real" to "8.0" and "80" to "8.0+PTX": +macro(translate_arch_string input output) + if("${input}" MATCHES "[0-9]+-real") + string(REGEX REPLACE "([1-9])([0-9])-real" "\\1.\\2" version "${input}") + elseif("${input}" MATCHES "([0-9]+)") + string(REGEX REPLACE "([1-9])([0-9])" "\\1.\\2+PTX" version "${input}") + elseif(input STREQUAL "native") + set(version "Auto") + else() + message(FATAL_ERROR "Invalid architecture string: ${input}") + endif() + set(${output} "${version}") +endmacro() + if(NOT Pytorch_FOUND) # Searching for pytorch requires the python executable if (NOT Python3_EXECUTABLE) @@ -41,6 +55,15 @@ if(NOT Pytorch_FOUND) unset(PyTorch_FETCH_PROPERTIES) unset(PyTorch_PROPERTIES) + # Using CUDA 12.x and Pytorch <2.4 gives the error "Unknown CUDA Architecture Name 9.0a in CUDA_SELECT_NVCC_ARCH_FLAGS". + # As a workaround we explicitly set TORCH_CUDA_ARCH_LIST + set(TORCH_CUDA_ARCH_LIST "") + foreach(arch IN LISTS CMAKE_CUDA_ARCHITECTURES) + translate_arch_string("${arch}" ptarch) + list(APPEND TORCH_CUDA_ARCH_LIST "${ptarch}") + endforeach() + + message(STATUS "Using top level CMAKE_CUDA_ARCHITECTURES for TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}") # Use the cmake config provided by torch find_package(Torch REQUIRED PATHS "${Pytorch_ROOT}" NO_DEFAULT_PATH) diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci index dcdc0599dd9..5c5991e5a76 100644 --- a/docker/Dockerfile.ci +++ b/docker/Dockerfile.ci @@ -196,10 +196,7 @@ RUN \ export CMAKE_CXX_COMPILER=g++; \ export CMAKE_C_COMPILER=gcc; \ # TODO: PyTorch still use old CXX ABI, remove this line when PyTorch is updated - # TODO: Using CUDA 12.x and Pytorch <2.4 gives the error "Unknown CUDA Architecture Name 9.0a in CUDA_SELECT_NVCC_ARCH_FLAGS". - # As a workaround we explicitly set TORCH_CUDA_ARCH_LIST if [ "$BUILD_PYTORCH_OPS" = "ON" ]; then \ - export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0" \ export GLIBCXX_USE_CXX11_ABI=OFF; \ else \ export GLIBCXX_USE_CXX11_ABI=ON; \