Skip to content

Commit

Permalink
Match PyTorch ops CUDA arch list to that of Open3D
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheorey committed Oct 3, 2024
1 parent 9856186 commit 4ccae42
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
23 changes: 23 additions & 0 deletions 3rdparty/cmake/FindPytorch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
#
# and import the target 'torch'.

# "80-real" to "8.0" and "80" to "8.0+PTX":
macro(translate_arch_string input output)
if("${input}" MATCHES "[0-9]+-real")
string(REGEX REPLACE "([1-9])([0-9])-real" "\\1.\\2" version "${input}")
elseif("${input}" MATCHES "([0-9]+)")
string(REGEX REPLACE "([1-9])([0-9])" "\\1.\\2+PTX" version "${input}")
elseif(input STREQUAL "native")
set(version "Auto")
else()
message(FATAL_ERROR "Invalid architecture string: ${input}")
endif()
set(${output} "${version}")
endmacro()

if(NOT Pytorch_FOUND)
# Searching for pytorch requires the python executable
if (NOT Python3_EXECUTABLE)
Expand Down Expand Up @@ -41,6 +55,15 @@ if(NOT Pytorch_FOUND)
unset(PyTorch_FETCH_PROPERTIES)
unset(PyTorch_PROPERTIES)

# Using CUDA 12.x and Pytorch <2.4 gives the error "Unknown CUDA Architecture Name 9.0a in CUDA_SELECT_NVCC_ARCH_FLAGS".
# As a workaround we explicitly set TORCH_CUDA_ARCH_LIST
set(TORCH_CUDA_ARCH_LIST "")
foreach(arch IN LISTS CMAKE_CUDA_ARCHITECTURES)
translate_arch_string("${arch}" ptarch)
list(APPEND TORCH_CUDA_ARCH_LIST "${ptarch}")
endforeach()

message(STATUS "Using top level CMAKE_CUDA_ARCHITECTURES for TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
# Use the cmake config provided by torch
find_package(Torch REQUIRED PATHS "${Pytorch_ROOT}"
NO_DEFAULT_PATH)
Expand Down
3 changes: 0 additions & 3 deletions docker/Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ RUN \
export CMAKE_CXX_COMPILER=g++; \
export CMAKE_C_COMPILER=gcc; \
# TODO: PyTorch still use old CXX ABI, remove this line when PyTorch is updated
# TODO: Using CUDA 12.x and Pytorch <2.4 gives the error "Unknown CUDA Architecture Name 9.0a in CUDA_SELECT_NVCC_ARCH_FLAGS".
# As a workaround we explicitly set TORCH_CUDA_ARCH_LIST
if [ "$BUILD_PYTORCH_OPS" = "ON" ]; then \
export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0" \
export GLIBCXX_USE_CXX11_ABI=OFF; \
else \
export GLIBCXX_USE_CXX11_ABI=ON; \
Expand Down

0 comments on commit 4ccae42

Please sign in to comment.