diff --git a/3rdparty/cmake/FindPaddle.cmake b/3rdparty/cmake/FindPaddle.cmake new file mode 100644 index 00000000000..a564e0061f9 --- /dev/null +++ b/3rdparty/cmake/FindPaddle.cmake @@ -0,0 +1,120 @@ +# Find the Paddle root and use the provided cmake module +# The following variables will be set: +# - Paddle_FOUND +# - Paddle_VERSION +# - Paddle_ROOT +# - Paddle_DEFINITIONS +# +# - PADDLE_FOUND +# - PADDLE_INCLUDE_DIRS +# - PADDLE_LIBRARY_DIRS +# - PADDLE_LIBRARIES +# - PADDLE_CXX_FLAGS +# +# and import the target 'paddle'. + +if(NOT Paddle_FOUND) + # Searching for Paddle requires the python executable + if (NOT Python3_EXECUTABLE) + message(FATAL_ERROR "Python 3 not found in top level file") + endif() + + if(BUILD_CUDA_MODULE) + find_package(CUDAToolkit REQUIRED) + string(SUBSTRING ${CUDAToolkit_VERSION} 0 4 CUDA_VERSION) + endif() + + message(STATUS "Getting Paddle properties ...") + + set(Paddle_FETCH_PROPERTIES + "import os" + "import paddle" + "import sysconfig" + "print(paddle.__version__, end=';')" + "print(os.path.dirname(paddle.__file__), end=';')" + "print(sysconfig.get_path('include', scheme='posix_prefix'), end=';')" + ) + execute_process( + COMMAND ${Python3_EXECUTABLE} "-c" "${Paddle_FETCH_PROPERTIES}" + OUTPUT_VARIABLE Paddle_PROPERTIES + ) + + + list(GET Paddle_PROPERTIES 0 Paddle_VERSION) + list(GET Paddle_PROPERTIES 1 Paddle_ROOT) + list(GET Paddle_PROPERTIES 2 Python_INCLUDE) + + set(Paddle_CXX11_ABI True) + + unset(Paddle_FETCH_PROPERTIES) + unset(Paddle_PROPERTIES) + + add_library(paddle STATIC IMPORTED) + + # handle include directories + set(PADDLE_INCLUDE_DIRS) + list(APPEND PADDLE_INCLUDE_DIRS "${Paddle_ROOT}/include") + list(APPEND PADDLE_INCLUDE_DIRS "${Paddle_ROOT}/include/third_party") + list(APPEND PADDLE_INCLUDE_DIRS "${Python_INCLUDE}") + + if(BUILD_CUDA_MODULE) + list(APPEND PADDLE_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS}") + endif() + + # handle library directories + set(PADDLE_LIBRARY_DIRS) + list(APPEND PADDLE_LIBRARY_DIRS "${Paddle_ROOT}/libs") + list(APPEND PADDLE_LIBRARY_DIRS "${Paddle_ROOT}/base") + + if(BUILD_CUDA_MODULE) + list(APPEND PADDLE_LIBRARY_DIRS "${CUDAToolkit_LIBRARY_DIR}") + endif() + + # handle libraries + set(PADDLE_LIBRARIES) + find_library(PADDLE_LIB NAMES paddle PATHS "${Paddle_ROOT}/base") + list(APPEND PADDLE_LIBRARY_DIRS "${PADDLE_LIB}") + + if(BUILD_CUDA_MODULE) + find_library(CUDART_LIB NAMES cudart PATHS "${CUDAToolkit_LIBRARY_DIR}") + list(APPEND PADDLE_LIBRARY_DIRS "${CUDART_LIB}") + endif() + + # handle compile flags + set(PADDLE_CXX_FLAGS) + if(BUILD_CUDA_MODULE) + set(PADDLE_CXX_FLAGS "-DPADDLE_WITH_CUDA ${PADDLE_CXX_FLAGS}") + endif() + + set_target_properties(paddle PROPERTIES + IMPORTED_LOCATION "${PADDLE_LIB}" + ) + set_target_properties(paddle PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PADDLE_INCLUDE_DIRS}" + ) + set_property(TARGET paddle PROPERTY INTERFACE_COMPILE_OPTIONS "${PADDLE_CXX_FLAGS}") + + set(PADDLE_FOUND True) +endif() + +if(PRINT_ONCE) + message(STATUS "Paddle version: ${Paddle_VERSION}") + message(STATUS " root dir: ${Paddle_ROOT}") + message(STATUS " compile flags: ${PADDLE_CXX_FLAGS}") + if (UNIX AND NOT APPLE) + message(STATUS " use cxx11 abi: ${Paddle_CXX11_ABI}") + endif() + foreach(idir ${PADDLE_INCLUDE_DIRS}) + message(STATUS " include dirs: ${idir}") + endforeach(idir) + foreach(ldir ${PADDLE_LIBRARY_DIRS}) + message(STATUS " library dirs: ${ldir}") + endforeach(ldir) + foreach(lib ${PADDLE_LIBRARIES}) + message(STATUS " libraries: ${lib}") + endforeach(lib) +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Paddle DEFAULT_MSG Paddle_VERSION + Paddle_ROOT) diff --git a/CMakeLists.txt b/CMakeLists.txt index 354125dc01d..a7a6dea3aac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -141,6 +141,7 @@ option(BUILD_AZURE_KINECT "Build support for Azure Kinect sensor" OFF # ML library options option(BUILD_TENSORFLOW_OPS "Build ops for TensorFlow" OFF) option(BUILD_PYTORCH_OPS "Build ops for PyTorch" OFF) +option(BUILD_PADDLE_OPS "Build ops for Paddle" OFF) option(BUNDLE_OPEN3D_ML "Includes the Open3D-ML repo in the wheel" OFF) # Release build options @@ -292,6 +293,15 @@ endif() if(BUILD_SYCL_MODULE AND NOT GLIBCXX_USE_CXX11_ABI) message(FATAL_ERROR "BUILD_SYCL_MODULE=ON requires GLIBCXX_USE_CXX11_ABI=ON") endif() +if(BUILD_SYCL_MODULE AND BUILD_TENSORFLOW_OPS) + message(FATAL_ERROR "BUILD_SYCL_MODULE=ON requires BUILD_TENSORFLOW_OPS=OFF") +endif() +if(BUILD_SYCL_MODULE AND BUILD_PYTORCH_OPS) + message(FATAL_ERROR "BUILD_SYCL_MODULE=ON requires BUILD_PYTORCH_OPS=OFF") +endif() +if(BUILD_SYCL_MODULE AND BUILD_PADDLE_OPS) + message(FATAL_ERROR "BUILD_SYCL_MODULE=ON requires BUILD_PADDLE_OPS=OFF") +endif() if(BUILD_SYCL_MODULE AND BUILD_CUDA_MODULE) message(FATAL_ERROR "BUILD_SYCL_MODULE and BUILD_SYCL_MODULE cannot be on at the same time for now.") endif() diff --git a/cpp/open3d/ml/CMakeLists.txt b/cpp/open3d/ml/CMakeLists.txt index 35f0b65112a..be445a9f238 100644 --- a/cpp/open3d/ml/CMakeLists.txt +++ b/cpp/open3d/ml/CMakeLists.txt @@ -3,6 +3,10 @@ if (BUILD_TENSORFLOW_OPS AND WIN32) # see https://github.com/tensorflow/custom-op/issues/24 endif() +if (BUILD_PADDLE_OPS AND (WIN32 OR APPLE)) + message(FATAL_ERROR "Building Paddle ops on Windows or MacOS is currently not supported.") +endif() + if (BUILD_TENSORFLOW_OPS) add_subdirectory(tensorflow) endif() @@ -11,4 +15,9 @@ if (BUILD_PYTORCH_OPS) add_subdirectory(pytorch) endif() +if (BUILD_PADDLE_OPS) + add_subdirectory(paddle) +endif() + + add_subdirectory(contrib) diff --git a/cpp/open3d/ml/paddle/CMakeLists.txt b/cpp/open3d/ml/paddle/CMakeLists.txt new file mode 100644 index 00000000000..a6809665bde --- /dev/null +++ b/cpp/open3d/ml/paddle/CMakeLists.txt @@ -0,0 +1,172 @@ +if(BUILD_CUDA_MODULE) + message(STATUS "Building Paddle ops with CUDA") +else() + message(STATUS "Building Paddle ops") +endif() + +set(PRINT_ONCE ON) +find_package(Paddle REQUIRED) + +add_library(open3d_paddle_ops SHARED) + +target_sources(open3d_paddle_ops PRIVATE + PaddleHelper.cpp + misc/BuildSpatialHashTableOpKernel.cpp + misc/BuildSpatialHashTableOps.cpp + misc/FixedRadiusSearchOps.cpp + misc/FixedRadiusSearchOpKernel.cpp + misc/RadiusSearchOps.cpp + misc/RadiusSearchOpKernel.cpp + misc/InvertNeighborsListOps.cpp + misc/InvertNeighborsListOpKernel.cpp + misc/KnnSearchOps.cpp + misc/KnnSearchOpKernel.cpp + misc/RaggedToDenseOpKernel.cpp + misc/RaggedToDenseOps.cpp + misc/NmsOps.cpp + misc/VoxelizeOpKernel.cpp + misc/VoxelizeOps.cpp + misc/ReduceSubarraysSumOpKernel.cpp + misc/ReduceSubarraysSumOps.cpp + misc/VoxelPoolingOps.cpp + misc/VoxelPoolingOpKernel.cpp + misc/RoiPoolOps.cpp +) + +target_sources(open3d_paddle_ops PRIVATE + pointnet/BallQueryOps.cpp + pointnet/InterpolateOps.cpp + pointnet/SamplingOps.cpp +) + +target_sources(open3d_paddle_ops PRIVATE + continuous_conv/ContinuousConvOps.cpp + continuous_conv/ContinuousConvOpKernel.cpp + continuous_conv/ContinuousConvBackpropFilterOpKernel.cpp + continuous_conv/ContinuousConvTransposeOps.cpp + continuous_conv/ContinuousConvTransposeOpKernel.cpp + continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cpp +) + +target_sources(open3d_paddle_ops PRIVATE + sparse_conv/SparseConvBackpropFilterOpKernel.cpp + sparse_conv/SparseConvOpKernel.cpp + sparse_conv/SparseConvOps.cpp + sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cpp + sparse_conv/SparseConvTransposeOpKernel.cpp + sparse_conv/SparseConvTransposeOps.cpp +) + +target_sources(open3d_paddle_ops PRIVATE + ../contrib/Nms.cpp +) + +if (BUILD_CUDA_MODULE) + target_sources(open3d_paddle_ops PRIVATE + misc/BuildSpatialHashTableOpKernel.cu + misc/FixedRadiusSearchOpKernel.cu + misc/InvertNeighborsListOpKernel.cu + misc/RaggedToDenseOpKernel.cu + misc/ReduceSubarraysSumOpKernel.cu + misc/VoxelizeOpKernel.cu + ) + + target_sources(open3d_paddle_ops PRIVATE + pointnet/BallQueryKernel.cu + pointnet/InterpolateKernel.cu + pointnet/SamplingKernel.cu + ) + + target_sources(open3d_paddle_ops PRIVATE + continuous_conv/ContinuousConvOpKernel.cu + continuous_conv/ContinuousConvBackpropFilterOpKernel.cu + continuous_conv/ContinuousConvTransposeOpKernel.cu + continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cu + ) + target_sources(open3d_paddle_ops PRIVATE + sparse_conv/SparseConvBackpropFilterOpKernel.cu + sparse_conv/SparseConvOpKernel.cu + sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cu + sparse_conv/SparseConvTransposeOpKernel.cu + ) + target_sources(open3d_paddle_ops PRIVATE + ../contrib/BallQuery.cu + ../contrib/InterpolatePoints.cu + ../contrib/Nms.cu + ../contrib/RoiPoolKernel.cu + ) + + target_sources(open3d_paddle_ops PRIVATE + ../impl/continuous_conv/ContinuousConvCUDAKernels.cu + ../impl/sparse_conv/SparseConvCUDAKernels.cu + ) +endif() + +open3d_show_and_abort_on_warning(open3d_paddle_ops) +open3d_set_global_properties(open3d_paddle_ops) + +# Set output directory according to architecture (cpu/cuda) +get_target_property(PADDLE_OPS_DIR open3d_paddle_ops LIBRARY_OUTPUT_DIRECTORY) +set(PADDLE_OPS_ARCH_DIR + "${PADDLE_OPS_DIR}/$,cuda,cpu>") +set_target_properties(open3d_paddle_ops PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${PADDLE_OPS_ARCH_DIR}" + ARCHIVE_OUTPUT_DIRECTORY "${PADDLE_OPS_ARCH_DIR}") + +# Do not add "lib" prefix +set_target_properties(open3d_paddle_ops PROPERTIES PREFIX "") +set_target_properties(open3d_paddle_ops PROPERTIES DEBUG_POSTFIX "_debug") + +target_include_directories(open3d_paddle_ops SYSTEM PRIVATE + ${PROJECT_SOURCE_DIR}/cpp + ${PADDLE_INCLUDE_DIRS} +) + +target_link_libraries(open3d_paddle_ops PRIVATE + paddle + Open3D::Open3D + Open3D::3rdparty_eigen3 + Open3D::3rdparty_fmt + Open3D::3rdparty_nanoflann + TBB::tbb +) + +if (BUILD_CUDA_MODULE) + target_link_libraries(open3d_paddle_ops PRIVATE + Open3D::3rdparty_cutlass + ${PADDLE_LIBRARIES} + CUDA::cuda_driver + ) + + if (TARGET Open3D::3rdparty_cub) + target_link_libraries(open3d_paddle_ops PRIVATE + Open3D::3rdparty_cub + ) + endif() +endif() + +install(TARGETS open3d_paddle_ops EXPORT Open3DPaddleOps + LIBRARY DESTINATION ${Open3D_INSTALL_LIB_DIR} +) +install(EXPORT Open3DPaddleOps NAMESPACE ${PROJECT_NAME}:: DESTINATION ${Open3D_INSTALL_CMAKE_DIR}) + +if (BUILD_SHARED_LIBS AND UNIX) +file(CONFIGURE OUTPUT open3d_paddle_ops.pc.in + CONTENT [=[ +prefix=${pcfiledir}/../.. +libdir=${prefix}/lib +includedir=${prefix}/include/ + +Name: Open3D Paddle Ops +Description: @PROJECT_DESCRIPTION@ This library contains 3D ML Ops for use with Paddle. +URL: @PROJECT_HOMEPAGE_URL@ +Version: @PROJECT_VERSION@ +Requires: Open3D = @PROJECT_VERSION@ +Cflags: +Libs: -lopen3d_paddle_ops]=] @ONLY NEWLINE_STYLE LF) + file(GENERATE OUTPUT open3d_paddle_ops.pc INPUT + "${CMAKE_CURRENT_BINARY_DIR}/open3d_paddle_ops.pc.in" + TARGET open3d_paddle_ops) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/open3d_paddle_ops.pc" + DESTINATION "${Open3D_INSTALL_LIB_DIR}/pkgconfig") +endif() diff --git a/cpp/open3d/ml/paddle/PaddleHelper.cpp b/cpp/open3d/ml/paddle/PaddleHelper.cpp new file mode 100644 index 00000000000..f67baf2c497 --- /dev/null +++ b/cpp/open3d/ml/paddle/PaddleHelper.cpp @@ -0,0 +1,72 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- + +#include "PaddleHelper.h" + +paddle::Tensor InitializedEmptyTensor(const phi::DataType dtype, + const phi::IntArray& shape, + const phi::Place& place) { + switch (dtype) { + case phi::DataType::INT8: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::UINT8: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::INT16: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::FLOAT32: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::INT32: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::FLOAT64: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::INT64: + return InitializedEmptyTensor(shape, place); + break; + default: + PD_CHECK(false, + "Only support phi::DataType as `INT8`, `UINT8`, `INT16`, " + "`FLOAT32`, `FLOAT64`, " + "`INT32` and `INT64` but got %s.", + phi::DataTypeToString(dtype)); + } +} + +paddle::Tensor Arange(const int end, const paddle::Place& place) { + PD_CHECK(end > 0, "end:%d ,end must greater than 0", end); + auto start_tensor = paddle::zeros({1}, paddle::DataType::INT32, place); + auto end_tensor = paddle::experimental::full( + {1}, end, paddle::DataType::INT32, place); + auto step_tensor = + paddle::experimental::full({1}, 1, paddle::DataType::INT32, place); + return paddle::experimental::arange(start_tensor, end_tensor, step_tensor, + paddle::DataType::INT32, place); +} + +paddle::Tensor Transpose(const paddle::Tensor& t, int64_t dim0, int64_t dim1) { + int len = t.shape().size(); + dim0 = dim0 >= 0 ? dim0 : len + dim0; + dim1 = dim1 >= 0 ? dim1 : len + dim1; + PD_CHECK(dim0 >= 0 && dim0 < len, + "dim0 not in range" + "dim0:%d ,range:%d", + dim0, len); + PD_CHECK(dim1 >= 0 && dim1 < len, + "dim1 not in range" + "dim1:%d ,range:%d", + dim1, len); + std::vector transpose_perm(len); + std::iota(transpose_perm.begin(), transpose_perm.end(), 0); + transpose_perm[dim0] = dim1; + transpose_perm[dim1] = dim0; + return paddle::experimental::transpose(t, transpose_perm); +} diff --git a/cpp/open3d/ml/paddle/PaddleHelper.h b/cpp/open3d/ml/paddle/PaddleHelper.h new file mode 100644 index 00000000000..17f86e21e4d --- /dev/null +++ b/cpp/open3d/ml/paddle/PaddleHelper.h @@ -0,0 +1,292 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- + +#pragma once + +#include +#include + +#include +#include + +#include "open3d/ml/ShapeChecking.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/allocator.h" + +// Macros for checking tensor properties +#define CHECK_CUDA(x) \ + do { \ + PD_CHECK(x.is_gpu(), #x " must be a CUDA tensor"); \ + } while (0) + +// NOTE: The input Tensor will be preprocessed into a contiguous Tensor within +// the execution function of the custom operator, so CHECK_CONTIGUOUS will be +// always True as there is no need for an explicit conversion in Open3D. For +// reference, please see: +// https://github.com/PaddlePaddle/Paddle/blob/65126f558a5c0fbb0cd1aa0a42844a73632ff9e9/paddle/fluid/eager/custom_operator/custom_operator_utils.cc#L803-L810 +#define CHECK_CONTIGUOUS(x) \ + do { \ + } while (0) + +#define CHECK_TYPE(x, type) \ + do { \ + PD_CHECK(x.dtype() == type, #x " must have type " #type); \ + } while (0) + +#define CHECK_SAME_DEVICE_TYPE(...) \ + do { \ + if (!SameDeviceType({__VA_ARGS__})) { \ + PD_CHECK(false, \ + #__VA_ARGS__ \ + " must all have the same device type but got " + \ + TensorInfoStr({__VA_ARGS__})); \ + } \ + } while (0) + +#define CHECK_SAME_DTYPE(...) \ + do { \ + if (!SameDtype({__VA_ARGS__})) { \ + PD_CHECK(false, #__VA_ARGS__ \ + " must all have the same dtype but got " + \ + TensorInfoStr({__VA_ARGS__})); \ + } \ + } while (0) +// Conversion from standard types to paddle types +typedef std::remove_const::type + PaddleDtype_t; +template +inline PaddleDtype_t ToPaddleDtype() { + PD_CHECK(false, "Unsupported type"); +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::UINT8; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT8; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT16; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT32; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT64; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::FLOAT32; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::FLOAT64; +} + +// convenience function for comparing standard types with paddle types +template +inline bool ComparePaddleDtype(const TDtype& t) { + return ToPaddleDtype() == t; +} + +// convenience function to check if all tensors have the same device type +inline bool SameDeviceType(std::initializer_list tensors) { + if (tensors.size()) { + auto device_type = tensors.begin()->place().GetDeviceType(); + for (const auto& t : tensors) { + if (device_type != t.place().GetDeviceType()) { + return false; + } + } + } + return true; +} + +// convenience function to check if all tensors have the same dtype +inline bool SameDtype(std::initializer_list tensors) { + if (tensors.size()) { + auto dtype = tensors.begin()->dtype(); + for (const auto& t : tensors) { + if (dtype != t.dtype()) { + return false; + } + } + } + return true; +} + +inline std::string TensorInfoStr( + std::initializer_list tensors) { + std::stringstream sstr; + size_t count = 0; + for (const auto& t : tensors) { + sstr << "Tensor(" << t.size() << ", " << t.place() << ")"; + ++count; + if (count < tensors.size()) sstr << ", "; + } + return sstr.str(); +} + +// convenience function for creating a tensor for temp memory +inline paddle::Tensor CreateTempTensor(const int64_t size, + const paddle::Place& device, + void** ptr = nullptr) { + paddle::Tensor tensor = + paddle::empty({size}, ToPaddleDtype(), device); + if (ptr) { + *ptr = tensor.data(); + } + return tensor; +} + +inline std::vector GetShapeVector( + paddle::Tensor tensor) { + using namespace open3d::ml::op_util; + const auto old_shape = tensor.shape(); + std::vector shape; + for (auto i = 0u; i < old_shape.size(); ++i) { + shape.push_back(old_shape[i]); + } + return shape; +} + +template +std::tuple CheckShape(paddle::Tensor tensor, + TDimX&& dimex, + TArgs&&... args) { + return open3d::ml::op_util::CheckShape(GetShapeVector(tensor), + std::forward(dimex), + std::forward(args)...); +} + +// +// Macros for checking the shape of Tensors. +// Usage: +// { +// using namespace open3d::ml::op_util; +// Dim w("w"); +// Dim h("h"); +// CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10 +// // and assigns w and h based on +// // the shape of tensor1 +// +// CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim +// // of tensor2 matches the last dim +// // of tensor1. The first two dims +// // must match 10, 20. +// } +// +// +// See "../ShapeChecking.h" for more info and limitations. +// +#define CHECK_SHAPE(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#ifdef BUILD_CUDA_MODULE +static void cudaFreeWrapper(void* ptr) { + phi::gpuError_t result = cudaFree(ptr); + PADDLE_ENFORCE_GPU_SUCCESS(result); +} +#endif + +// NOTE: Hack to support empty tensor, like Tensor(shape=[0], []) +template +paddle::Tensor InitializedEmptyTensor(const phi::IntArray& shape, + const phi::Place& place) { + int64_t size = 1; + for (auto v : shape.GetData()) { + size *= v; + } + PD_CHECK(size == 0, "The numel of empty tensor is not equal to 0."); + + paddle::Deleter deleter; + T* ptr = nullptr; + if (phi::is_gpu_place(place)) { +#ifdef BUILD_CUDA_MODULE + phi::gpuError_t result = cudaMalloc(&ptr, sizeof(T) * 1); + PADDLE_ENFORCE_GPU_SUCCESS(result); + deleter = std::function(cudaFreeWrapper); +#else + PD_CHECK(false, + "InitializedEmptyTensor was not compiled with CUDA support"); +#endif + } else if (phi::is_cpu_place(place)) { + ptr = (T*)malloc(sizeof(T) * 1); + deleter = std::function(free); + } else { + PD_CHECK(false, "Not supported backend!"); + } + + // NOTE: In Paddle, the stride of an empty (0-size) tensor can be the same + // as its shape. + return paddle::from_blob(static_cast(ptr), shape, shape, + paddle::DataType(ToPaddleDtype()), + phi::DataLayout::NCHW, place, deleter); +} + +paddle::Tensor InitializedEmptyTensor(const phi::DataType dtype, + const phi::IntArray& shape, + const phi::Place& place); + +// return a array of [0 1 2 ... end-1] +paddle::Tensor Arange(const int end, const paddle::Place& place); + +// just like tensor.transpose(dim0,dim1) +paddle::Tensor Transpose(const paddle::Tensor& t, int64_t dim0, int64_t dim1); diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.cpp b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.cpp new file mode 100644 index 00000000000..8387331742c --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.cpp @@ -0,0 +1,72 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvBackpropFilter.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvBackpropFilterCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + CConvBackpropFilterCPU( + filter_backprop.data(), filter_dims, out_positions.shape()[0], + out_positions.data(), inp_positions.shape()[0], + inp_positions.data(), inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), out_features_gradient.data(), + interpolation, coordinate_mapping, align_corners, + individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void ContinuousConvBackpropFilterCPU( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, \ + const bool align_corners, \ + const open3d::ml::impl::CoordinateMapping coordinate_mapping, \ + const bool normalize, \ + const open3d::ml::impl::InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.cu b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.cu new file mode 100644 index 00000000000..016b1567fac --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.cu @@ -0,0 +1,109 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#include +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvBackpropFilter.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvBackpropFilterCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + CConvBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, out_positions.shape()[0], + out_positions.data(), inp_positions.shape()[0], + inp_positions.data(), inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), out_features_gradient.data(), + interpolation, coordinate_mapping, align_corners, + individual_extents, isotropic_extents, normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + CConvBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, out_positions.shape()[0], + out_positions.data(), inp_positions.shape()[0], + inp_positions.data(), inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), out_features_gradient.data(), + interpolation, coordinate_mapping, align_corners, + individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void \ + ContinuousConvBackpropFilterCUDA( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, \ + const bool align_corners, \ + const open3d::ml::impl::CoordinateMapping coordinate_mapping, \ + const bool normalize, \ + const open3d::ml::impl::InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, float, int32_t); diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.h b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.h new file mode 100644 index 00000000000..c92ff9bd2bc --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.h @@ -0,0 +1,54 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTypes.h" + +template +void ContinuousConvBackpropFilterCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); + +#ifdef BUILD_CUDA_MODULE +template +void ContinuousConvBackpropFilterCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); +#endif diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvHelper.h b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvHelper.h new file mode 100644 index 00000000000..f7628047382 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvHelper.h @@ -0,0 +1,56 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTypes.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +// +// helper functions for parsing arguments +// + +inline open3d::ml::impl::CoordinateMapping ParseCoordinateMappingStr( + const std::string& str) { + using open3d::ml::impl::CoordinateMapping; + CoordinateMapping coordinate_mapping = + CoordinateMapping::BALL_TO_CUBE_RADIAL; + if (str == "ball_to_cube_radial") { + coordinate_mapping = CoordinateMapping::BALL_TO_CUBE_RADIAL; + } else if (str == "ball_to_cube_volume_preserving") { + coordinate_mapping = CoordinateMapping::BALL_TO_CUBE_VOLUME_PRESERVING; + } else if (str == "identity") { + coordinate_mapping = CoordinateMapping::IDENTITY; + } else { + PD_CHECK(false, + "coordinate_mapping must be one of ('ball_to_cube_radial', " + "'ball_to_cube_volume_preserving', 'identity') but got " + + str); + } + return coordinate_mapping; +} + +inline open3d::ml::impl::InterpolationMode ParseInterpolationStr( + const std::string& str) { + using open3d::ml::impl::InterpolationMode; + InterpolationMode interpolation = InterpolationMode::LINEAR; + if (str == "linear") { + interpolation = InterpolationMode::LINEAR; + } else if (str == "linear_border") { + interpolation = InterpolationMode::LINEAR_BORDER; + } else if (str == "nearest_neighbor") { + interpolation = InterpolationMode::NEAREST_NEIGHBOR; + } else { + PD_CHECK(false, + "interpolation must be one of ('linear', " + "'linear_border', 'nearest_neighbor') but got " + + str); + } + return interpolation; +} diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.cpp b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.cpp new file mode 100644 index 00000000000..8fc0409e637 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.cpp @@ -0,0 +1,68 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConv.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvCPU(const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const CoordinateMapping coordinate_mapping, + const bool normalize, + const InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + CConvComputeFeaturesCPU( + out_features.data(), filter_dims, filters.data(), + out_positions.shape()[0], out_positions.data(), + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), interpolation, coordinate_mapping, + align_corners, individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void ContinuousConvCPU( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const bool align_corners, \ + const CoordinateMapping coordinate_mapping, const bool normalize, \ + const InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.cu b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.cu new file mode 100644 index 00000000000..d2a3f38ff73 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.cu @@ -0,0 +1,105 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConv.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvCUDA(const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const CoordinateMapping coordinate_mapping, + const bool normalize, + const InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + CConvComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + out_positions.shape()[0], out_positions.data(), + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), interpolation, coordinate_mapping, + align_corners, individual_extents, isotropic_extents, normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + CConvComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + out_positions.shape()[0], out_positions.data(), + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), interpolation, coordinate_mapping, + align_corners, individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void ContinuousConvCUDA( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const bool align_corners, \ + const CoordinateMapping coordinate_mapping, const bool normalize, \ + const InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.h b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.h new file mode 100644 index 00000000000..b26e628e7e6 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.h @@ -0,0 +1,53 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTypes.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void ContinuousConvCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); + +#ifdef BUILD_CUDA_MODULE +template +void ContinuousConvCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); +#endif diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOps.cpp b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOps.cpp new file mode 100644 index 00000000000..677086e6038 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvOps.cpp @@ -0,0 +1,239 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvBackpropFilterOpKernel.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvHelper.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.h" +#include "open3d/ml/paddle/misc/InvertNeighborsListOps.h" +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOps.h" + +using namespace open3d::ml::impl; + +std::vector ContinuousConvForward( + paddle::Tensor& filters, + paddle::Tensor& out_positions, + paddle::Tensor& extents, + paddle::Tensor& offset, + paddle::Tensor& inp_positions, + paddle::Tensor& inp_features, + paddle::Tensor& inp_importance, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const std::string& coordinate_mapping_str, + const bool normalize, + const std::string& interpolation_str, + const int64_t max_temp_mem_MB) { + CoordinateMapping coordinate_mapping = + ParseCoordinateMappingStr(coordinate_mapping_str); + + InterpolationMode interpolation = ParseInterpolationStr(interpolation_str); + + CHECK_TYPE(neighbors_row_splits, paddle::DataType::INT64); + CHECK_SAME_DTYPE(filters, out_positions, extents, offset, inp_positions, + inp_features, inp_importance, neighbors_importance); + CHECK_SAME_DEVICE_TYPE(filters, out_positions, inp_positions, inp_features, + inp_importance); + + // check input shapes + using namespace open3d::ml::op_util; + Dim kernel_depth("kernel_depth"); + Dim kernel_height("kernel_height"); + Dim kernel_width("kernel_width"); + Dim out_channels("out_channels"); + Dim in_channels("in_channels"); + Dim num_out_points("num_out_points"); + Dim num_inp_points("num_inp_points"); + Dim num_neighbors("nun_neighbors"); + + CHECK_SHAPE(filters, kernel_depth, kernel_height, kernel_width, in_channels, + out_channels); + CHECK_SHAPE(out_positions, num_out_points, 3); + CHECK_SHAPE(extents, num_out_points || 1, Dim(3) || 1); + CHECK_SHAPE(offset, 3); + CHECK_SHAPE(inp_positions, num_inp_points, 3); + CHECK_SHAPE(inp_features, num_inp_points, in_channels); + CHECK_SHAPE(inp_importance, num_inp_points || 0); + CHECK_SHAPE(neighbors_index, num_neighbors); + CHECK_SHAPE(neighbors_importance, num_neighbors || 0); + CHECK_SHAPE(neighbors_row_splits, num_out_points + 1); + + // make sure that these are on the same place as the filters, positions + // and feats + auto place = inp_features.place(); + offset = offset.copy_to(place, false); + extents = extents.copy_to(place, false); + neighbors_index = neighbors_index.copy_to(place, false); + neighbors_importance = neighbors_importance.copy_to(place, false); + neighbors_row_splits = neighbors_row_splits.copy_to(place, false); + + const auto& feat_dtype = filters.dtype(); + const auto& real_dtype = inp_positions.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + + paddle::Tensor out_features = paddle::empty( + {num_out_points.value(), out_channels.value()}, feat_dtype, place); +#define FN_PARAMETERS \ + filters, out_positions, extents, offset, inp_positions, inp_features, \ + inp_importance, neighbors_index, neighbors_importance, \ + neighbors_row_splits, align_corners, coordinate_mapping, \ + normalize, interpolation, max_temp_mem_MB, out_features + +#define CALL(feat_t, out_t, real_t, index_t, fn) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(real_dtype) && \ + ComparePaddleDtype(index_dtype)) { \ + fn(FN_PARAMETERS); \ + return {out_features}; \ + } + + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, float, int32_t, ::ContinuousConvCUDA) +#else + PD_CHECK(false, "ContinuousConv was not compiled with CUDA support"); +#endif + } else { + CALL(float, float, float, int32_t, ::ContinuousConvCPU) + } +#undef FN_PARAMETERS +#undef CALL + + PD_CHECK(false, "ContinuousConv does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index"); + return {paddle::Tensor()}; +} + +std::vector ContinuousConvBackward( + paddle::Tensor& filters, + paddle::Tensor& out_positions, + paddle::Tensor& extents, + paddle::Tensor& offset, + paddle::Tensor& inp_positions, + paddle::Tensor& inp_features, + paddle::Tensor& inp_importance, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& out_features_gradient, + const bool align_corners, + const std::string& coordinate_mapping_str, + const bool normalize, + const std::string& interpolation_str, + const int64_t max_temp_mem_MB) { + CoordinateMapping coordinate_mapping = + ParseCoordinateMappingStr(coordinate_mapping_str); + + InterpolationMode interpolation = ParseInterpolationStr(interpolation_str); + + auto place = inp_features.place(); + const auto& feat_dtype = filters.dtype(); + const auto& real_dtype = inp_positions.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + CHECK_SAME_DTYPE(out_features_gradient, inp_features, filters); + CHECK_SAME_DEVICE_TYPE(out_features_gradient, inp_features, filters); + + // output vars + paddle::Tensor filters_backprop; + paddle::Tensor inp_features_backprop; + +#define CALL(feat_t, out_t, real_t, index_t, fn_suffix) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(real_dtype) && \ + ComparePaddleDtype(index_dtype)) { \ + filters_backprop = paddle::empty(filters.shape(), real_dtype, place); \ + ContinuousConvBackpropFilter##fn_suffix( \ + filters, out_positions, extents, offset, inp_positions, \ + inp_features, inp_importance, neighbors_index, \ + neighbors_importance, neighbors_row_splits, \ + out_features_gradient, align_corners, coordinate_mapping, \ + normalize, interpolation, max_temp_mem_MB, filters_backprop); \ + \ + paddle::Tensor inv_neighbors_index, inv_neighbors_row_splits, \ + inv_neighbors_importance; \ + auto inv = InvertNeighborsList(neighbors_index, neighbors_row_splits, \ + neighbors_importance, \ + inp_positions.shape()[0]); \ + inv_neighbors_index = inv[0]; \ + inv_neighbors_row_splits = inv[1]; \ + inv_neighbors_importance = inv[2]; \ + auto neighbors_importance_sum = ReduceSubarraysSum( \ + neighbors_importance, neighbors_row_splits)[0]; \ + inp_features_backprop = \ + paddle::ones(inp_features.shape(), real_dtype, place); \ + auto filters_transposed = Transpose(filters, 3, 4).contiguous(); \ + \ + ContinuousConvTranspose##fn_suffix( \ + filters_transposed, inp_positions, inp_importance, extents, \ + offset, out_positions, out_features_gradient, neighbors_index, \ + neighbors_importance_sum, neighbors_row_splits, \ + inv_neighbors_index, inv_neighbors_importance, \ + inv_neighbors_row_splits, align_corners, coordinate_mapping, \ + normalize, interpolation, max_temp_mem_MB, \ + inp_features_backprop); \ + dispatch_success = true; \ + } + + bool dispatch_success = false; + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, float, int32_t, CUDA) +#else + PD_CHECK(false, + "ContinuousConv backward was not compiled " + "with CUDA support"); +#endif + } else { + CALL(float, float, float, int32_t, CPU) + } + PD_CHECK(dispatch_success, + "ContinuousConv backward does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index"); + + return {filters_backprop, inp_features_backprop}; +} + +std::vector ContinuousConvInferDtype( + paddle::DataType filters_dtype) { + return {filters_dtype}; +} + +PD_BUILD_OP(open3d_continuous_conv) + .Inputs({"filters", "out_positions", "extents", "offset", + "inp_positions", "inp_features", "inp_importance", + "neighbors_index", "neighbors_importance", + "neighbors_row_splits"}) + .Outputs({"out_features"}) + .Attrs({"align_corners:bool", "coordinate_mapping:std::string", + "normalize:bool", "interpolation:std::string", + "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(ContinuousConvForward)) + .SetInferDtypeFn(PD_INFER_DTYPE(ContinuousConvInferDtype)); + +PD_BUILD_GRAD_OP(open3d_continuous_conv) + .Inputs({"filters", "out_positions", "extents", "offset", + "inp_positions", "inp_features", "inp_importance", + "neighbors_index", "neighbors_importance", + "neighbors_row_splits", paddle::Grad("out_features")}) + .Outputs({paddle::Grad("filters"), paddle::Grad("inp_features")}) + .Attrs({"align_corners:bool", "coordinate_mapping:std::string", + "normalize:bool", "interpolation:std::string", + "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(ContinuousConvBackward)); diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cpp b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cpp new file mode 100644 index 00000000000..d977a013c51 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cpp @@ -0,0 +1,81 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTransposeBackpropFilter.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvTransposeBackpropFilterCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const CoordinateMapping coordinate_mapping, + const bool normalize, + const InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + CConvTransposeBackpropFilterCPU( + filter_backprop.data(), filter_dims, out_positions.shape()[0], + out_positions.data(), + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), out_features_gradient.data(), + interpolation, coordinate_mapping, align_corners, + individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void \ + ContinuousConvTransposeBackpropFilterCPU( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, \ + const bool align_corners, \ + const CoordinateMapping coordinate_mapping, const bool normalize, \ + const InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cu b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cu new file mode 100644 index 00000000000..e500e0bb421 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.cu @@ -0,0 +1,122 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#include +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTransposeBackpropFilter.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvTransposeBackpropFilterCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const CoordinateMapping coordinate_mapping, + const bool normalize, + const InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + CConvTransposeBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, out_positions.shape()[0], + out_positions.data(), + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), out_features_gradient.data(), + interpolation, coordinate_mapping, align_corners, + individual_extents, isotropic_extents, normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + CConvTransposeBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, out_positions.shape()[0], + out_positions.data(), + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), out_features_gradient.data(), + interpolation, coordinate_mapping, align_corners, + individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void \ + ContinuousConvTransposeBackpropFilterCUDA( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, \ + const bool align_corners, \ + const CoordinateMapping coordinate_mapping, const bool normalize, \ + const InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.h b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.h new file mode 100644 index 00000000000..e8b935c7365 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.h @@ -0,0 +1,59 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTypes.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void ContinuousConvTransposeBackpropFilterCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); + +#ifdef BUILD_CUDA_MODULE +template +void ContinuousConvTransposeBackpropFilterCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); +#endif diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.cpp b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.cpp new file mode 100644 index 00000000000..a1d84596636 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.cpp @@ -0,0 +1,80 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTranspose.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvTransposeCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const CoordinateMapping coordinate_mapping, + const bool normalize, + const InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + CConvTransposeComputeFeaturesCPU( + out_features.data(), filter_dims, filters.data(), + out_positions.shape()[0], out_positions.data(), + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), interpolation, coordinate_mapping, + align_corners, individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void ContinuousConvTransposeCPU( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_index, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const bool align_corners, \ + const CoordinateMapping coordinate_mapping, const bool normalize, \ + const InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.cu b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.cu new file mode 100644 index 00000000000..839b566b740 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.cu @@ -0,0 +1,120 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTranspose.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void ContinuousConvTransposeCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const CoordinateMapping coordinate_mapping, + const bool normalize, + const InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + const bool individual_extents = extents.shape()[0] > 1; + const bool isotropic_extents = extents.shape()[1] == 1; + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + CConvTransposeComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + out_positions.shape()[0], out_positions.data(), + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), interpolation, coordinate_mapping, + align_corners, individual_extents, isotropic_extents, normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + CConvTransposeComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + out_positions.shape()[0], out_positions.data(), + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_positions.shape()[0], inp_positions.data(), + inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), extents.data(), + offset.data(), interpolation, coordinate_mapping, + align_corners, individual_extents, isotropic_extents, normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void ContinuousConvTransposeCUDA( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_positions, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& extents, const paddle::Tensor& offset, \ + const paddle::Tensor& inp_positions, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_index, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const bool align_corners, \ + const CoordinateMapping coordinate_mapping, const bool normalize, \ + const InterpolationMode interpolation, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, float, int32_t) diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.h b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.h new file mode 100644 index 00000000000..df2bf86848a --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.h @@ -0,0 +1,59 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/impl/continuous_conv/ContinuousConvTypes.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void ContinuousConvTransposeCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); + +#ifdef BUILD_CUDA_MODULE +template +void ContinuousConvTransposeCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_positions, + const paddle::Tensor& out_importance, + const paddle::Tensor& extents, + const paddle::Tensor& offset, + const paddle::Tensor& inp_positions, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const open3d::ml::impl::CoordinateMapping coordinate_mapping, + const bool normalize, + const open3d::ml::impl::InterpolationMode interpolation, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); +#endif diff --git a/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOps.cpp b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOps.cpp new file mode 100644 index 00000000000..ca438b39b90 --- /dev/null +++ b/cpp/open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOps.cpp @@ -0,0 +1,258 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvHelper.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvOpKernel.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvTransposeBackpropFilterOpKernel.h" +#include "open3d/ml/paddle/continuous_conv/ContinuousConvTransposeOpKernel.h" +#include "open3d/ml/paddle/misc/InvertNeighborsListOps.h" +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOps.h" + +using namespace open3d::ml::impl; + +std::vector ContinuousConvTransposeForward( + paddle::Tensor& filters, + paddle::Tensor& out_positions, + paddle::Tensor& out_importance, + paddle::Tensor& extents, + paddle::Tensor& offset, + paddle::Tensor& inp_positions, + paddle::Tensor& inp_features, + paddle::Tensor& inp_neighbors_index, + paddle::Tensor& inp_neighbors_importance_sum, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + const bool align_corners, + const std::string& coordinate_mapping_str, + const bool normalize, + const std::string& interpolation_str, + const int64_t max_temp_mem_MB) { + CoordinateMapping coordinate_mapping = + ParseCoordinateMappingStr(coordinate_mapping_str); + + InterpolationMode interpolation = ParseInterpolationStr(interpolation_str); + + CHECK_TYPE(neighbors_row_splits, paddle::DataType::INT64); + CHECK_TYPE(inp_neighbors_row_splits, paddle::DataType::INT64); + CHECK_SAME_DTYPE(neighbors_index, inp_neighbors_index); + CHECK_SAME_DTYPE(filters, out_positions, extents, offset, inp_positions, + inp_features, out_importance, neighbors_importance); + CHECK_SAME_DEVICE_TYPE(filters, out_positions, inp_positions, inp_features, + out_importance); + + // check input shapes + using namespace open3d::ml::op_util; + Dim kernel_depth("kernel_depth"); + Dim kernel_height("kernel_height"); + Dim kernel_width("kernel_width"); + Dim out_channels("out_channels"); + Dim in_channels("in_channels"); + Dim num_out_points("num_out_points"); + Dim num_inp_points("num_inp_points"); + Dim num_neighbors("nun_neighbors"); + + CHECK_SHAPE(filters, kernel_depth, kernel_height, kernel_width, in_channels, + out_channels); + CHECK_SHAPE(out_positions, num_out_points, 3); + CHECK_SHAPE(inp_positions, num_inp_points, 3); + CHECK_SHAPE(extents, num_inp_points || 1, Dim(3) || 1); + CHECK_SHAPE(offset, 3); + CHECK_SHAPE(inp_features, num_inp_points, in_channels); + CHECK_SHAPE(out_importance, num_out_points || 0); + CHECK_SHAPE(inp_neighbors_index, num_neighbors); + CHECK_SHAPE(inp_neighbors_importance_sum, num_inp_points || 0); + CHECK_SHAPE(inp_neighbors_row_splits, num_inp_points + 1); + CHECK_SHAPE(neighbors_index, num_neighbors); + CHECK_SHAPE(neighbors_importance, num_neighbors || 0); + CHECK_SHAPE(neighbors_row_splits, num_out_points + 1); + + // make sure that these are on the same place as the filters, positions + // and feats + auto place = inp_features.place(); + offset = offset.copy_to(place, false); + extents = extents.copy_to(place, false); + neighbors_index = neighbors_index.copy_to(place, false); + neighbors_importance = neighbors_importance.copy_to(place, false); + neighbors_row_splits = neighbors_row_splits.copy_to(place, false); + inp_neighbors_index = inp_neighbors_index.copy_to(place, false); + inp_neighbors_importance_sum = + inp_neighbors_importance_sum.copy_to(place, false); + inp_neighbors_row_splits = inp_neighbors_row_splits.copy_to(place, false); + + const auto& feat_dtype = filters.dtype(); + const auto& real_dtype = inp_positions.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + + paddle::Tensor out_features = paddle::empty( + {num_out_points.value(), out_channels.value()}, real_dtype, place); +#define FN_PARAMETERS \ + filters, out_positions, out_importance, extents, offset, inp_positions, \ + inp_features, inp_neighbors_index, inp_neighbors_importance_sum, \ + inp_neighbors_row_splits, neighbors_index, neighbors_importance, \ + neighbors_row_splits, align_corners, coordinate_mapping, \ + normalize, interpolation, max_temp_mem_MB, out_features + +#define CALL(feat_t, out_t, real_t, index_t, fn) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(real_dtype) && \ + ComparePaddleDtype(index_dtype)) { \ + fn(FN_PARAMETERS); \ + return {out_features}; \ + } + + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, float, int32_t, ::ContinuousConvTransposeCUDA) +#else + PD_CHECK(false, + "ContinuousConvTranspose was not compiled with CUDA " + "support"); +#endif + } else { + CALL(float, float, float, int32_t, ::ContinuousConvTransposeCPU) + } +#undef FN_PARAMETERS +#undef CALL + + PD_CHECK(false, "ContinuousConv does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index"); + return {paddle::Tensor()}; +} + +std::vector ContinuousConvTransposeBackward( + paddle::Tensor& filters, + paddle::Tensor& out_positions, + paddle::Tensor& out_importance, + paddle::Tensor& extents, + paddle::Tensor& offset, + paddle::Tensor& inp_positions, + paddle::Tensor& inp_features, + paddle::Tensor& inp_neighbors_importance_sum, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& out_features_gradient, + const bool align_corners, + const std::string& coordinate_mapping_str, + const bool normalize, + const std::string& interpolation_str, + const int64_t max_temp_mem_MB) { + CoordinateMapping coordinate_mapping = + ParseCoordinateMappingStr(coordinate_mapping_str); + + InterpolationMode interpolation = ParseInterpolationStr(interpolation_str); + + auto place = inp_features.place(); + const auto& feat_dtype = filters.dtype(); + const auto& real_dtype = inp_features.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + CHECK_SAME_DTYPE(out_features_gradient, inp_features, filters); + CHECK_SAME_DEVICE_TYPE(out_features_gradient, inp_features, filters); + + // output vars + paddle::Tensor filters_backprop; + paddle::Tensor inp_features_backprop; + +#define CALL(feat_t, out_t, real_t, index_t, fn_suffix) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(real_dtype) && \ + ComparePaddleDtype(index_dtype)) { \ + filters_backprop = paddle::empty(filters.shape(), real_dtype, place); \ + ContinuousConvTransposeBackpropFilter##fn_suffix( \ + filters, out_positions, out_importance, extents, offset, \ + inp_positions, inp_features, inp_neighbors_importance_sum, \ + inp_neighbors_row_splits, neighbors_index, \ + neighbors_importance, neighbors_row_splits, \ + out_features_gradient, align_corners, coordinate_mapping, \ + normalize, interpolation, max_temp_mem_MB, filters_backprop); \ + \ + paddle::Tensor inv_neighbors_index, inv_neighbors_row_splits, \ + inv_neighbors_importance; \ + auto inv = InvertNeighborsList(neighbors_index, neighbors_row_splits, \ + neighbors_importance, \ + inp_positions.shape()[0]); \ + inv_neighbors_index = inv[0]; \ + inv_neighbors_row_splits = inv[1]; \ + inv_neighbors_importance = inv[2]; \ + InvertNeighborsList(neighbors_index, neighbors_row_splits, \ + neighbors_importance, inp_positions.shape()[0]); \ + inp_features_backprop = \ + paddle::ones(inp_features.shape(), real_dtype, place); \ + auto filters_transposed = Transpose(filters, 3, 4).contiguous(); \ + \ + ContinuousConv##fn_suffix( \ + filters_transposed, inp_positions, extents, offset, \ + out_positions, out_features_gradient, out_importance, \ + inv_neighbors_index, inv_neighbors_importance, \ + inp_neighbors_row_splits, align_corners, coordinate_mapping, \ + normalize, interpolation, max_temp_mem_MB, \ + inp_features_backprop); \ + dispatch_success = true; \ + } + + bool dispatch_success = false; + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, float, int32_t, CUDA) +#else + PD_CHECK(false, + "ContinuousConvTranspose backward was not compiled " + "with CUDA support"); +#endif + } else { + CALL(float, float, float, int32_t, CPU) + } + PD_CHECK(dispatch_success, + "ContinuousConvTranspose backward does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index"); + + return {filters_backprop, inp_features_backprop}; +} + +std::vector ContinuousConvTransposeInferDtype( + paddle::DataType inp_positions_dtype) { + return {inp_positions_dtype}; +} + +PD_BUILD_OP(open3d_continuous_conv_transpose) + .Inputs({"filters", "out_positions", "out_importance", "extents", + "offset", "inp_positions", "inp_features", + "inp_neighbors_index", "inp_neighbors_importance_sum", + "inp_neighbors_row_splits", "neighbors_index", + "neighbors_importance", "neighbors_row_splits"}) + .Outputs({"out_features"}) + .Attrs({"align_corners:bool", "coordinate_mapping:std::string", + "normalize:bool", "interpolation:std::string", + "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(ContinuousConvTransposeForward)) + .SetInferDtypeFn(PD_INFER_DTYPE(ContinuousConvTransposeInferDtype)); + +PD_BUILD_GRAD_OP(open3d_continuous_conv_transpose) + .Inputs({"filters", "out_positions", "out_importance", "extents", + "offset", "inp_positions", "inp_features", + "inp_neighbors_importance_sum", "inp_neighbors_row_splits", + "neighbors_index", "neighbors_importance", + "neighbors_row_splits", paddle::Grad("out_features")}) + .Outputs({paddle::Grad("filters"), paddle::Grad("inp_features")}) + .Attrs({"align_corners:bool", "coordinate_mapping:std::string", + "normalize:bool", "interpolation:std::string", + "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(ContinuousConvTransposeBackward)); diff --git a/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cpp b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cpp new file mode 100644 index 00000000000..bcb6d42a41e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cpp @@ -0,0 +1,34 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void BuildSpatialHashTableCPU(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits) { + open3d::core::nns::impl::BuildSpatialHashTableCPU( + points.shape()[0], points.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + hash_table_splits.data(), hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data()))); +} +#define INSTANTIATE(T) \ + template void BuildSpatialHashTableCPU( \ + const paddle::Tensor&, double, const paddle::Tensor&, \ + const std::vector&, paddle::Tensor&, paddle::Tensor&); + +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cu b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cu new file mode 100644 index 00000000000..061fcc68b1e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cu @@ -0,0 +1,59 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::core::nns; + +template +void BuildSpatialHashTableCUDA(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits) { + auto stream = points.stream(); + // -1 means current global place + auto cuda_place_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_place_props.textureAlignment; + + void* temp_ptr = nullptr; + size_t temp_size = 0; + + // determine temp_size + impl::BuildSpatialHashTableCUDA( + stream, temp_ptr, temp_size, texture_alignment, points.shape()[0], + points.data(), T(radius), points_row_splits.shape()[0], + points_row_splits.data(), hash_table_splits.data(), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data()))); + auto place = points.place(); + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually build the table + impl::BuildSpatialHashTableCUDA( + stream, temp_ptr, temp_size, texture_alignment, points.shape()[0], + points.data(), T(radius), points_row_splits.shape()[0], + points_row_splits.data(), hash_table_splits.data(), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data()))); +} + +#define INSTANTIATE(T) \ + template void BuildSpatialHashTableCUDA( \ + const paddle::Tensor&, double, const paddle::Tensor&, \ + const std::vector&, paddle::Tensor&, paddle::Tensor&); + +INSTANTIATE(float) diff --git a/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOps.cpp b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOps.cpp new file mode 100644 index 00000000000..a0f49e6a442 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOps.cpp @@ -0,0 +1,121 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void BuildSpatialHashTableCPU(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits); +#ifdef BUILD_CUDA_MODULE +template +void BuildSpatialHashTableCUDA(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits); +#endif + +std::vector BuildSpatialHashTable( + paddle::Tensor& points, + paddle::Tensor& points_row_splits, + double radius, + double hash_table_size_factor, + int64_t max_hash_table_size) { + points_row_splits = points_row_splits.copy_to(phi::CPUPlace(), false); + CHECK_TYPE(points_row_splits, paddle::DataType::INT64); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim batch_size("batch_size"); + + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(points_row_splits, batch_size + 1); + + const auto& point_type = points.dtype(); + + std::vector hash_table_splits(batch_size.value() + 1, 0); + for (int i = 0; i < batch_size.value(); ++i) { + int64_t num_points_i = points_row_splits.data()[i + 1] - + points_row_splits.data()[i]; + int64_t hash_table_size = std::min( + std::max(hash_table_size_factor * num_points_i, 1), + max_hash_table_size); + hash_table_splits[i + 1] = hash_table_splits[i] + hash_table_size; + } + + auto place = points.place(); + paddle::Tensor hash_table_index; + if (points.shape()[0] != 0) { + hash_table_index = + paddle::empty({points.shape()[0]}, + paddle::DataType(ToPaddleDtype()), place); + } else { + hash_table_index = InitializedEmptyTensor({0}, place); + } + paddle::Tensor hash_table_cell_splits = + paddle::empty({hash_table_splits.back() + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor out_hash_table_splits = paddle::empty( + {batch_size.value() + 1}, + paddle::DataType(ToPaddleDtype()), phi::CPUPlace()); + for (size_t i = 0; i < hash_table_splits.size(); ++i) { + out_hash_table_splits.data()[i] = hash_table_splits[i]; + } +#define FN_PARAMETERS \ + points, radius, points_row_splits, hash_table_splits, hash_table_index, \ + hash_table_cell_splits +#define CALL(type, fn) \ + if (ComparePaddleDtype(point_type)) { \ + fn(FN_PARAMETERS); \ + return {hash_table_index, hash_table_cell_splits, \ + out_hash_table_splits}; \ + } + if (points.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(float, BuildSpatialHashTableCUDA) +#else + PD_CHECK(false, + "BuildSpatialHashTable was not compiled with CUDA support"); +#endif + } else { + CALL(float, BuildSpatialHashTableCPU) + CALL(double, BuildSpatialHashTableCPU) + } + PD_CHECK(false, "BuildSpatialHashTable does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for " + "points"); + + return std::vector(); +} + +std::vector BuildSpatialHashTableInferDtype() { + auto dtype = paddle::DataType::INT32; + return {dtype, dtype, dtype}; +} + +PD_BUILD_OP(open3d_build_spatial_hash_table) + .Inputs({"points", "points_row_splits"}) + .Outputs({"hash_table_index", "hash_table_cell_splits", + "hash_table_splits"}) + .Attrs({"radius: double", "hash_table_size_factor: double", + "max_hash_table_size: int64_t"}) + .SetKernelFn(PD_KERNEL(BuildSpatialHashTable)) + .SetInferDtypeFn(PD_INFER_DTYPE(BuildSpatialHashTableInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cpp b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cpp new file mode 100644 index 00000000000..c48003d7891 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cpp @@ -0,0 +1,67 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void FixedRadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + NeighborSearchAllocator output_allocator(points.place()); + + impl::FixedRadiusSearchCPU( + neighbors_row_splits.data(), points.shape()[0], + points.data(), queries.shape()[0], queries.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + queries_row_splits.shape()[0], queries_row_splits.data(), + reinterpret_cast( + const_cast(hash_table_splits.data())), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data())), + metric, ignore_query_point, return_distances, output_allocator); + + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void FixedRadiusSearchCPU( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + double radius, const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, \ + const paddle::Tensor& hash_table_splits, \ + const paddle::Tensor& hash_table_index, \ + const paddle::Tensor& hash_table_cell_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cu b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cu new file mode 100644 index 00000000000..889ef81389f --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cu @@ -0,0 +1,94 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.cuh" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void FixedRadiusSearchCUDA(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + auto stream = points.stream(); + // -1 means current global place + auto cuda_place_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_place_props.textureAlignment; + + auto place = points.place(); + + NeighborSearchAllocator output_allocator(place); + void* temp_ptr = nullptr; + size_t temp_size = 0; + + // determine temp_size + impl::FixedRadiusSearchCUDA( + stream, temp_ptr, temp_size, texture_alignment, + neighbors_row_splits.data(), points.shape()[0], + points.data(), queries.shape()[0], queries.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + queries_row_splits.shape()[0], queries_row_splits.data(), + reinterpret_cast( + const_cast(hash_table_splits.data())), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data())), + metric, ignore_query_point, return_distances, output_allocator); + + auto temp_tensor = CreateTempTensor(temp_size, points.place(), &temp_ptr); + + // actually run the search + impl::FixedRadiusSearchCUDA( + stream, temp_ptr, temp_size, texture_alignment, + neighbors_row_splits.data(), points.shape()[0], + points.data(), queries.shape()[0], queries.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + queries_row_splits.shape()[0], queries_row_splits.data(), + reinterpret_cast( + const_cast(hash_table_splits.data())), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data())), + metric, ignore_query_point, return_distances, output_allocator); + + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void FixedRadiusSearchCUDA( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + double radius, const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, \ + const paddle::Tensor& hash_table_splits, \ + const paddle::Tensor& hash_table_index, \ + const paddle::Tensor& hash_table_cell_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOps.cpp b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOps.cpp new file mode 100644 index 00000000000..26feb76bd4e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOps.cpp @@ -0,0 +1,190 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/Dtype.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/utility/Helper.h" + +using namespace open3d::core::nns; + +template +void FixedRadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); +#ifdef BUILD_CUDA_MODULE +template +void FixedRadiusSearchCUDA(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); +#endif + +std::vector FixedRadiusSearch( + paddle::Tensor& points, + paddle::Tensor& queries, + paddle::Tensor& points_row_splits, + paddle::Tensor& queries_row_splits, + paddle::Tensor& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits, + double radius, + const std::string& index_dtype, + const std::string& metric_str, + const bool ignore_query_point, + const bool return_distances) { + Metric metric = L2; + if (metric_str == "L1") { + metric = L1; + } else if (metric_str == "L2") { + metric = L2; + } else if (metric_str == "Linf") { + metric = Linf; + } else { + PD_CHECK(false, + "metric must be one of (L1, L2, Linf) but got " + metric_str); + } + CHECK_TYPE(points_row_splits, paddle::DataType::INT64); + CHECK_TYPE(queries_row_splits, paddle::DataType::INT64); + CHECK_TYPE(hash_table_splits, paddle::DataType::INT32); + CHECK_TYPE(hash_table_index, paddle::DataType::INT32); + CHECK_TYPE(hash_table_cell_splits, paddle::DataType::INT32); + CHECK_SAME_DTYPE(points, queries); + CHECK_SAME_DEVICE_TYPE(points, queries); + // PD_CHECK(index_dtype == paddle::DataType::INT32 || index_dtype == + // paddle::DataType::INT64, + PD_CHECK(index_dtype == "int32" || index_dtype == "int64", + "index_dtype must be int32 or int64"); + // ensure that these are on the cpu + points_row_splits = points_row_splits.copy_to(paddle::CPUPlace(), false); + queries_row_splits = queries_row_splits.copy_to(paddle::CPUPlace(), false); + hash_table_splits = hash_table_splits.copy_to(paddle::CPUPlace(), false); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_queries("num_queries"); + Dim batch_size("batch_size"); + Dim num_cells("num_cells"); + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(hash_table_index, num_points); + CHECK_SHAPE(queries, num_queries, 3); + CHECK_SHAPE(points_row_splits, batch_size + 1); + CHECK_SHAPE(queries_row_splits, batch_size + 1); + CHECK_SHAPE(hash_table_splits, batch_size + 1); + CHECK_SHAPE(hash_table_cell_splits, num_cells + 1); + + const auto& point_type = points.dtype(); + + auto place = points.place(); + + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_row_splits = + paddle::empty({queries.shape()[0] + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_distance; + +#define FN_PARAMETERS \ + points, queries, radius, points_row_splits, queries_row_splits, \ + hash_table_splits, hash_table_index, hash_table_cell_splits, \ + metric, ignore_query_point, return_distances, neighbors_index, \ + neighbors_row_splits, neighbors_distance + + if (points.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + FixedRadiusSearchCUDA(FN_PARAMETERS); + } else { + FixedRadiusSearchCUDA(FN_PARAMETERS); + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } +#else + PD_CHECK(false, "FixedRadiusSearch was not compiled with CUDA support"); +#endif + } else { + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + FixedRadiusSearchCPU(FN_PARAMETERS); + } else { + FixedRadiusSearchCPU(FN_PARAMETERS); + } + } else { + if (index_dtype == "int32") { + FixedRadiusSearchCPU(FN_PARAMETERS); + } else { + FixedRadiusSearchCPU(FN_PARAMETERS); + } + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } + + // in torch the name is ToString, but paddle not have this function + PD_CHECK(false, "FixedRadiusSearch does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for points"); + return std::vector(); +} + +std::vector FixedRadiusSearchInferDtype( + const std::string& index_dtype) { + paddle::DataType dtype = index_dtype == "int32" ? paddle::DataType::INT32 + : paddle::DataType::INT64; + return {dtype, paddle::DataType::INT64, dtype}; +} + +std::vector> FixedRadiusSearchInferShape( + std::vector queries_shape, const bool return_distances) { + // this just a temp impl , all return is fake data + // TODO(woodman3): impl real data + int64_t neighbors_row_splits_shape = queries_shape[0] + 1; + int64_t neighbors_distance_shape = return_distances ? 1 : 0; + return {{neighbors_row_splits_shape}, + {neighbors_row_splits_shape}, + {neighbors_distance_shape}}; +} + +PD_BUILD_OP(open3d_fixed_radius_search) + .Inputs({"points", "queries", "points_row_splits", "queries_row_splits", + "hash_table_splits", "hash_table_index", + "hash_table_cell_splits"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_distance"}) + .Attrs({ + "radius: double", + "index_dtype:std::string", + "metric_str: std::string", + "ignore_query_point: bool", + "return_distances: bool", + }) + .SetKernelFn(PD_KERNEL(FixedRadiusSearch)) + .SetInferShapeFn(PD_INFER_SHAPE(FixedRadiusSearchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FixedRadiusSearchInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cpp b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cpp new file mode 100644 index 00000000000..597167f0cb6 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cpp @@ -0,0 +1,67 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h" + +#include "open3d/ml/impl/misc/InvertNeighborsList.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +std::vector InvertNeighborsListCPU( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes) { + paddle::Tensor neighbors_index = + paddle::empty(inp_neighbors_index.shape(), + paddle::DataType(ToPaddleDtype())); + paddle::Tensor neighbors_row_splits = paddle::empty( + {num_points + 1}, paddle::DataType(paddle::DataType::INT64)); + + paddle::Tensor neighbors_attributes = + paddle::empty_like(inp_neighbors_attributes); + + int num_attributes; + if (inp_neighbors_attributes.shape()[0] == 0) { + num_attributes = 0; + neighbors_attributes = + InitializedEmptyTensor(inp_neighbors_attributes.dtype(), + inp_neighbors_attributes.shape(), + inp_neighbors_attributes.place()); + + } else { + num_attributes = 1; + for (size_t i = 1; i < inp_neighbors_attributes.shape().size(); ++i) + num_attributes *= inp_neighbors_attributes.shape()[i]; + } + + open3d::ml::impl::InvertNeighborsListCPU( + inp_neighbors_index.data(), + num_attributes ? inp_neighbors_attributes.data() : nullptr, + num_attributes, inp_neighbors_row_splits.data(), + inp_neighbors_row_splits.shape()[0] - 1, + neighbors_index.data(), + num_attributes ? neighbors_attributes.data() : nullptr, + neighbors_index.shape()[0], neighbors_row_splits.data(), + neighbors_row_splits.shape()[0] - 1); + + return {neighbors_index, neighbors_row_splits, neighbors_attributes}; +} +#define INSTANTIATE(TIndex, TAttr) \ + template std::vector \ + InvertNeighborsListCPU(int64_t, const paddle::Tensor&, \ + const paddle::Tensor&, \ + const paddle::Tensor&); + +INSTANTIATE(int32_t, uint8_t) +INSTANTIATE(int32_t, int8_t) +INSTANTIATE(int32_t, int16_t) +INSTANTIATE(int32_t, int32_t) +INSTANTIATE(int32_t, int64_t) +INSTANTIATE(int32_t, float) +INSTANTIATE(int32_t, double) diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cu b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cu new file mode 100644 index 00000000000..dc4b13589e2 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cu @@ -0,0 +1,92 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/InvertNeighborsList.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +std::vector InvertNeighborsListCUDA( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes) { + auto place = inp_neighbors_index.place(); + paddle::Tensor neighbors_index = + paddle::empty(inp_neighbors_index.shape(), + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_row_splits = paddle::empty( + {num_points + 1}, paddle::DataType(paddle::DataType::INT64), place); + paddle::Tensor neighbors_attributes = + paddle::empty_like(inp_neighbors_attributes); + + // maybe this can use torch's impl way ? + auto stream = inp_neighbors_index.stream(); + // -1 means current global place + auto cuda_place_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_place_props.textureAlignment; + + int num_attributes; + if (inp_neighbors_attributes.shape()[0] == 0) { + std::cout << inp_neighbors_attributes.dtype() << std::endl; + num_attributes = 0; + neighbors_attributes = + InitializedEmptyTensor(inp_neighbors_attributes.dtype(), + inp_neighbors_attributes.shape(), + inp_neighbors_attributes.place()); + } else { + num_attributes = 1; + for (int i = 1; i < inp_neighbors_attributes.dims().size(); ++i) + num_attributes *= inp_neighbors_attributes.shape()[i]; + } + + void* temp_ptr = nullptr; + size_t temp_size = 0; + + // determine temp_size + open3d::ml::impl::InvertNeighborsListCUDA( + stream, temp_ptr, temp_size, texture_alignment, + inp_neighbors_index.data(), + num_attributes ? inp_neighbors_attributes.data() : nullptr, + num_attributes, inp_neighbors_row_splits.data(), + inp_neighbors_row_splits.shape()[0] - 1, + neighbors_index.data(), + num_attributes ? neighbors_attributes.data() : nullptr, + neighbors_index.shape()[0], + neighbors_row_splits.data(), // NOLINT + neighbors_row_splits.shape()[0] - 1); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually invert the list + open3d::ml::impl::InvertNeighborsListCUDA( + stream, temp_ptr, temp_size, texture_alignment, + inp_neighbors_index.data(), + num_attributes ? inp_neighbors_attributes.data() : nullptr, + num_attributes, inp_neighbors_row_splits.data(), + inp_neighbors_row_splits.shape()[0] - 1, + neighbors_index.data(), + num_attributes ? neighbors_attributes.data() : nullptr, + neighbors_index.shape()[0], + neighbors_row_splits.data(), // NOLINT + neighbors_row_splits.shape()[0] - 1); + + return {neighbors_index, neighbors_row_splits, neighbors_attributes}; +} +#define INSTANTIATE(TIndex, TAttr) \ + template std::vector \ + InvertNeighborsListCUDA(int64_t, const paddle::Tensor&, \ + const paddle::Tensor&, \ + const paddle::Tensor&); + +INSTANTIATE(int32_t, uint8_t) +INSTANTIATE(int32_t, int8_t) +INSTANTIATE(int32_t, int16_t) +INSTANTIATE(int32_t, int32_t) +INSTANTIATE(int32_t, int64_t) +INSTANTIATE(int32_t, float) +INSTANTIATE(int32_t, double) diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h new file mode 100644 index 00000000000..f97500abc55 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h @@ -0,0 +1,26 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +std::vector InvertNeighborsListCPU( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes); + +#ifdef BUILD_CUDA_MODULE +template +std::vector InvertNeighborsListCUDA( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes); +#endif diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.cpp b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.cpp new file mode 100644 index 00000000000..2fb3967d8fe --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.cpp @@ -0,0 +1,104 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/InvertNeighborsListOps.h" + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h" + +std::vector InvertNeighborsList( + paddle::Tensor& inp_neighbors_index, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& inp_neighbors_attributes, + int64_t num_points) { + CHECK_TYPE(inp_neighbors_row_splits, paddle::DataType::INT64); + + // check input shapes + { + using namespace open3d::ml::op_util; + Dim num_neighbors("num_neighbors"); + + CHECK_SHAPE(inp_neighbors_index, num_neighbors); + CHECK_SHAPE_IGNORE_LAST_DIMS(inp_neighbors_attributes, + num_neighbors || 0); + CHECK_SHAPE(inp_neighbors_row_splits, Dim()); + } + + const auto& index_type = inp_neighbors_index.dtype(); + const auto& attr_type = inp_neighbors_attributes.dtype(); + +#define FN_PARAMETERS \ + num_points, inp_neighbors_index, inp_neighbors_row_splits, \ + inp_neighbors_attributes + +#define CALL(idx_t, attr_t, fn) \ + if (ComparePaddleDtype(index_type) && \ + ComparePaddleDtype(attr_type)) { \ + return fn(FN_PARAMETERS); \ + } + + CHECK_SAME_DEVICE_TYPE(inp_neighbors_index, inp_neighbors_row_splits, + inp_neighbors_attributes); + if (inp_neighbors_index.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(int32_t, uint8_t, InvertNeighborsListCUDA) + CALL(int32_t, int8_t, InvertNeighborsListCUDA) + CALL(int32_t, int16_t, InvertNeighborsListCUDA) + CALL(int32_t, int32_t, InvertNeighborsListCUDA) + CALL(int32_t, int64_t, InvertNeighborsListCUDA) + CALL(int32_t, float, InvertNeighborsListCUDA) + CALL(int32_t, double, InvertNeighborsListCUDA) +#else + PD_CHECK(false, + "InvertNeighborsList was not compiled with CUDA support"); +#endif + } else { + CALL(int32_t, uint8_t, InvertNeighborsListCPU) + CALL(int32_t, int8_t, InvertNeighborsListCPU) + CALL(int32_t, int16_t, InvertNeighborsListCPU) + CALL(int32_t, int32_t, InvertNeighborsListCPU) + CALL(int32_t, int64_t, InvertNeighborsListCPU) + CALL(int32_t, float, InvertNeighborsListCPU) + CALL(int32_t, double, InvertNeighborsListCPU) + } + + PD_CHECK(false, + "InvertNeighborsList does not support " + + phi::DataTypeToString(inp_neighbors_index.dtype()) + + " as input for inp_neighbors_index and " + + phi::DataTypeToString(inp_neighbors_attributes.dtype()) + + " as input for inp_neighbors_attributes"); + return {}; +} + +std::vector InvertNeighborsListInferDtype( + const paddle::DataType inp_neighbors_attributes_dtype) { + return {paddle::DataType::INT32, paddle::DataType::INT64, + inp_neighbors_attributes_dtype}; +} + +std::vector> InvertNeighborsListInferShape( + int64_t num_points, + std::vector inp_neighbors_index_shape, + std::vector inp_neighbors_attributes_shape) { + return {inp_neighbors_index_shape, + {num_points + 1}, + inp_neighbors_attributes_shape}; +} +PD_BUILD_OP(open3d_invert_neighbors_list) + .Inputs({"inp_neighbors_index", "inp_neighbors_row_splits", + "inp_neighbors_attributes"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_attributes"}) + .Attrs({"num_points: int64_t"}) + .SetKernelFn(PD_KERNEL(InvertNeighborsList)) + .SetInferShapeFn(PD_INFER_SHAPE(InvertNeighborsListInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InvertNeighborsListInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.h b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.h new file mode 100644 index 00000000000..d9ecb757007 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.h @@ -0,0 +1,18 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +// this file seem not use +#pragma once + +#include "open3d/ml/paddle/PaddleHelper.h" + +std::vector InvertNeighborsList( + paddle::Tensor& inp_neighbors_index, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& inp_neighbors_attributes, + int64_t num_points); diff --git a/cpp/open3d/ml/paddle/misc/KnnSearchOpKernel.cpp b/cpp/open3d/ml/paddle/misc/KnnSearchOpKernel.cpp new file mode 100644 index 00000000000..ebc90a07a5e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/KnnSearchOpKernel.cpp @@ -0,0 +1,120 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/NanoFlannImpl.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void KnnSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const int64_t k, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + const int batch_size = points_row_splits.shape()[0] - 1; + + // run radius search for each batch item + std::vector> batch_output_allocators( + batch_size, NeighborSearchAllocator(points.place()) + + ); + int64_t last_neighbors_count = 0; + for (int i = 0; i < batch_size; ++i) { + const T* const points_i = + points.data() + 3 * points_row_splits.data()[i]; + const T* const queries_i = + queries.data() + 3 * queries_row_splits.data()[i]; + size_t num_points_i = points_row_splits.data()[i + 1] - + points_row_splits.data()[i]; + size_t num_queries_i = queries_row_splits.data()[i + 1] - + queries_row_splits.data()[i]; + + int64_t* neighbors_row_splits_i = neighbors_row_splits.data() + + queries_row_splits.data()[i]; + + std::unique_ptr holder = + impl::BuildKdTree(num_points_i, points_i, 3, metric); + + impl::KnnSearchCPU( + holder.get(), neighbors_row_splits_i, num_points_i, points_i, + num_queries_i, queries_i, 3, k, metric, ignore_query_point, + return_distances, batch_output_allocators[i]); + + if (i > 0) { + for (size_t j = 0; j <= num_queries_i; ++j) + neighbors_row_splits_i[j] += last_neighbors_count; + } + last_neighbors_count = neighbors_row_splits_i[num_queries_i]; + } + + if (batch_size == 1) { + // no need to combine just return the results from the first batch item + neighbors_index = batch_output_allocators[0].NeighborsIndex(); + neighbors_distance = batch_output_allocators[0].NeighborsDistance(); + return; + } + + NeighborSearchAllocator output_allocator(points.place()); + + // combine results + int64_t neighbors_index_size = 0; + int64_t neighbors_distance_size = 0; + for (const auto& a : batch_output_allocators) { + neighbors_index_size += a.NeighborsIndex().shape()[0]; + neighbors_distance_size += a.NeighborsDistance().shape()[0]; + } + TIndex* neighbors_index_data_ptr; + T* neighbors_distance_data_ptr; + output_allocator.AllocIndices(&neighbors_index_data_ptr, + neighbors_index_size); + output_allocator.AllocDistances(&neighbors_distance_data_ptr, + neighbors_distance_size); + + for (int i = 0; i < batch_size; ++i) { + auto& a = batch_output_allocators[i]; + if (a.NeighborsIndex().shape()[0]) { + for (int64_t j = 0; j < a.NeighborsIndex().shape()[0]; ++j) { + neighbors_index_data_ptr[0] = + a.IndicesPtr()[j] + + points_row_splits.data()[i]; + ++neighbors_index_data_ptr; + } + } + if (a.NeighborsDistance().shape()[0]) { + memcpy(neighbors_distance_data_ptr, a.DistancesPtr(), + a.NeighborsDistance().shape()[0] * sizeof(T)); + neighbors_distance_data_ptr += a.NeighborsDistance().shape()[0]; + } + } + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void KnnSearchCPU( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + const int64_t k, const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/KnnSearchOps.cpp b/cpp/open3d/ml/paddle/misc/KnnSearchOps.cpp new file mode 100644 index 00000000000..e1d9f143253 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/KnnSearchOps.cpp @@ -0,0 +1,140 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/core/Dtype.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/utility/Helper.h" +#include "paddle/extension.h" + +using namespace open3d::core::nns; + +template +void KnnSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const int64_t k, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); + +std::vector KnnSearch(paddle::Tensor& points, + paddle::Tensor& queries, + paddle::Tensor& points_row_splits, + paddle::Tensor& queries_row_splits, + const int64_t k, + const std::string& index_dtype, + const std::string& metric_str, + const bool ignore_query_point, + const bool return_distances) { + Metric metric = L2; + if (metric_str == "L1") { + metric = L1; + } else if (metric_str == "L2") { + metric = L2; + } else { + PD_CHECK(false, "metric must be one of (L1, L2) but got " + metric_str); + } + PD_CHECK(k > 0, "k must be greater than zero"); + CHECK_TYPE(points_row_splits, phi::DataType::INT64); + CHECK_TYPE(queries_row_splits, phi::DataType::INT64); + CHECK_SAME_DTYPE(points, queries); + CHECK_SAME_DEVICE_TYPE(points, queries); + PD_CHECK(index_dtype == "int32" || index_dtype == "int64", + "index_dtype must be int32 or int64"); + // ensure that these are on the cpu + points_row_splits = points_row_splits.copy_to(paddle::CPUPlace(), false); + queries_row_splits = queries_row_splits.copy_to(paddle::CPUPlace(), false); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_queries("num_queries"); + Dim batch_size("batch_size"); + Dim num_cells("num_cells"); + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(queries, num_queries, 3); + CHECK_SHAPE(points_row_splits, batch_size + 1); + CHECK_SHAPE(queries_row_splits, batch_size + 1); + + const auto& point_type = points.dtype(); + + auto place = points.place(); + + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_row_splits = + paddle::empty({queries.shape()[0] + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_distance; + +#define FN_PARAMETERS \ + points, queries, k, points_row_splits, queries_row_splits, metric, \ + ignore_query_point, return_distances, neighbors_index, \ + neighbors_row_splits, neighbors_distance + + if (points.is_gpu()) { + PD_CHECK(false, "KnnSearch does not support CUDA"); + } else { + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + KnnSearchCPU(FN_PARAMETERS); + } else { + KnnSearchCPU(FN_PARAMETERS); + } + } else { + if (index_dtype == "int32") { + KnnSearchCPU(FN_PARAMETERS); + } else { + KnnSearchCPU(FN_PARAMETERS); + } + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } + PD_CHECK(false, "KnnSearch does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for points"); + return std::vector(); +} + +std::vector KnnSearchInferDtype( + const std::string& index_dtype) { + paddle::DataType dtype = index_dtype == "int32" ? paddle::DataType::INT32 + : paddle::DataType::INT64; + return {dtype, paddle::DataType::INT64, dtype}; +} + +std::vector> KnnSearchInferShape( + std::vector queries_shape, const bool return_distances) { + int64_t neighbors_row_splits_shape = queries_shape[0] + 1; + int64_t neighbors_distance_shape = return_distances ? 1 : 0; + return {{neighbors_row_splits_shape}, + {neighbors_row_splits_shape}, + {neighbors_distance_shape}}; +} + +PD_BUILD_OP(open3d_knn_search) + .Inputs({"points", "queries", "points_row_splits", + "queries_row_splits"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_distance"}) + .Attrs({ + "k: int64_t", + "index_dtype:std::string", + "metric_str: std::string", + "ignore_query_point: bool", + "return_distances: bool", + }) + .SetKernelFn(PD_KERNEL(KnnSearch)) + .SetInferShapeFn(PD_INFER_SHAPE(KnnSearchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(KnnSearchInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/NeighborSearchAllocator.h b/cpp/open3d/ml/paddle/misc/NeighborSearchAllocator.h new file mode 100644 index 00000000000..05f631ba56e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/NeighborSearchAllocator.h @@ -0,0 +1,54 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/PaddleHelper.h" + +// These classes implement functors that can be passed to the neighbor search +// functions. + +template +class NeighborSearchAllocator { +public: + NeighborSearchAllocator(paddle::Place place) : place(place) {} + + void AllocIndices(TIndex** ptr, size_t num) { + if (num == 0) { + neighbors_index = InitializedEmptyTensor({0}, place); + } else { + neighbors_index = paddle::empty( + {int64_t(num)}, paddle::DataType(ToPaddleDtype()), + place); + } + *ptr = neighbors_index.data(); + } + + void AllocDistances(T** ptr, size_t num) { + if (num == 0) { + neighbors_distance = InitializedEmptyTensor({0}, place); + } else { + neighbors_distance = + paddle::empty({int64_t(num)}, + paddle::DataType(ToPaddleDtype()), place); + } + *ptr = neighbors_distance.data(); + } + + const TIndex* IndicesPtr() const { return neighbors_index.data(); } + + const T* DistancesPtr() const { return neighbors_distance.data(); } + + const paddle::Tensor& NeighborsIndex() const { return neighbors_index; } + const paddle::Tensor& NeighborsDistance() const { + return neighbors_distance; + } + +private: + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_distance; + paddle::Place place; +}; diff --git a/cpp/open3d/ml/paddle/misc/NmsOps.cpp b/cpp/open3d/ml/paddle/misc/NmsOps.cpp new file mode 100644 index 00000000000..50c7e2c5f97 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/NmsOps.cpp @@ -0,0 +1,62 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- + +#include + +#include "open3d/ml/contrib/Nms.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "paddle/extension.h" + +std::vector Nms(paddle::Tensor& boxes, + paddle::Tensor& scores, + double nms_overlap_thresh) { + CHECK_TYPE(boxes, phi::DataType::FLOAT32); + CHECK_TYPE(scores, phi::DataType::FLOAT32); + + std::vector keep_indices_blob; + if (boxes.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + keep_indices_blob = open3d::ml::contrib::NmsCUDAKernel( + boxes.data(), scores.data(), boxes.shape()[0], + nms_overlap_thresh); +#else + PD_CHECK(false, "Nms was not compiled with CUDA support"); + +#endif + } else { + keep_indices_blob = open3d::ml::contrib::NmsCPUKernel( + boxes.data(), scores.data(), boxes.shape()[0], + nms_overlap_thresh); + } + + paddle::IntArray out_shape( + {static_cast(keep_indices_blob.size())}); + paddle::IntArray out_strides({1}); + // NOTE: Not pass deleter because data will be free as vector destroy. + if (keep_indices_blob.data()) { + paddle::Tensor temp_keep_indices = paddle::from_blob( + keep_indices_blob.data(), out_shape, out_strides, + phi::DataType::INT64, phi::DataLayout::NCHW, phi::CPUPlace()); + paddle::Tensor keep_indices = temp_keep_indices.copy_to(boxes.place(), false); + + return {keep_indices}; + } else { + // keep indices is nullptr + return {InitializedEmptyTensor({0}, boxes.place())}; + } +} + +std::vector NmsInferDtype() { + return {paddle::DataType::INT64}; +} + +PD_BUILD_OP(open3d_nms) + .Inputs({"boxes", "scores"}) + .Outputs({"keep_indices"}) + .Attrs({"nms_overlap_thresh: double"}) + .SetKernelFn(PD_KERNEL(Nms)) + .SetInferDtypeFn(PD_INFER_DTYPE(NmsInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/RadiusSearchOpKernel.cpp b/cpp/open3d/ml/paddle/misc/RadiusSearchOpKernel.cpp new file mode 100644 index 00000000000..be25cb4c847 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RadiusSearchOpKernel.cpp @@ -0,0 +1,125 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/NanoFlannImpl.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void RadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const paddle::Tensor& radii, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + const bool normalize_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + const int batch_size = points_row_splits.shape()[0] - 1; + // run radius search for each batch item + std::vector> batch_output_allocators( + batch_size, NeighborSearchAllocator(points.place()) + + ); + int64_t last_neighbors_count = 0; + for (int i = 0; i < batch_size; ++i) { + const T* const points_i = + points.data() + 3 * points_row_splits.data()[i]; + const T* const queries_i = + queries.data() + 3 * queries_row_splits.data()[i]; + const T* const radius_i = + radii.data() + queries_row_splits.data()[i]; + size_t num_points_i = points_row_splits.data()[i + 1] - + points_row_splits.data()[i]; + size_t num_queries_i = queries_row_splits.data()[i + 1] - + queries_row_splits.data()[i]; + + int64_t* neighbors_row_splits_i = neighbors_row_splits.data() + + queries_row_splits.data()[i]; + + std::unique_ptr holder = + impl::BuildKdTree(num_points_i, points_i, 3, metric); + + impl::RadiusSearchCPU( + holder.get(), neighbors_row_splits_i, num_points_i, points_i, + num_queries_i, queries_i, 3, radius_i, metric, + ignore_query_point, return_distances, normalize_distances, + /* sort */ false, batch_output_allocators[i]); + + if (i > 0) { + for (size_t j = 0; j <= num_queries_i; ++j) + neighbors_row_splits_i[j] += last_neighbors_count; + } + last_neighbors_count = neighbors_row_splits_i[num_queries_i]; + } + + if (batch_size == 1) { + // no need to combine just return the results from the first batch + // item + neighbors_index = batch_output_allocators[0].NeighborsIndex(); + neighbors_distance = batch_output_allocators[0].NeighborsDistance(); + return; + } + + NeighborSearchAllocator output_allocator(points.place()); + + // combine results + int64_t neighbors_index_size = 0; + int64_t neighbors_distance_size = 0; + for (const auto& a : batch_output_allocators) { + neighbors_index_size += a.NeighborsIndex().shape()[0]; + neighbors_distance_size += a.NeighborsDistance().shape()[0]; + } + TIndex* neighbors_index_data_ptr; + T* neighbors_distance_data_ptr; + output_allocator.AllocIndices(&neighbors_index_data_ptr, + neighbors_index_size); + output_allocator.AllocDistances(&neighbors_distance_data_ptr, + neighbors_distance_size); + + for (int i = 0; i < batch_size; ++i) { + const auto& a = batch_output_allocators[i]; + if (a.NeighborsIndex().shape()[0]) { + for (int64_t j = 0; j < a.NeighborsIndex().shape()[0]; ++j) { + neighbors_index_data_ptr[0] = + a.IndicesPtr()[j] + + points_row_splits.data()[i]; + ++neighbors_index_data_ptr; + } + } + if (a.NeighborsDistance().shape()[0]) { + memcpy(neighbors_distance_data_ptr, a.DistancesPtr(), + a.NeighborsDistance().shape()[0] * sizeof(T)); + neighbors_distance_data_ptr += a.NeighborsDistance().shape()[0]; + } + } + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void RadiusSearchCPU( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + const paddle::Tensor& radii, \ + const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + const bool normalize_distances, paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/RadiusSearchOps.cpp b/cpp/open3d/ml/paddle/misc/RadiusSearchOps.cpp new file mode 100644 index 00000000000..4f48057422c --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RadiusSearchOps.cpp @@ -0,0 +1,141 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/core/Dtype.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/utility/Helper.h" + +using namespace open3d::core::nns; + +template +void RadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const paddle::Tensor& radii, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + const bool normalize_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); + +std::vector MultiRadiusSearch( + paddle::Tensor& points, + paddle::Tensor& queries, + paddle::Tensor& radii, + paddle::Tensor& points_row_splits, + paddle::Tensor& queries_row_splits, + const std::string& index_dtype, + const std::string& metric_str, + const bool ignore_query_point, + const bool return_distances, + const bool normalize_distances) { + Metric metric = L2; + if (metric_str == "L1") { + metric = L1; + } else if (metric_str == "L2") { + metric = L2; + } else { + PD_CHECK(false, "metric must be one of (L1, L2) but got " + metric_str); + } + CHECK_TYPE(points_row_splits, paddle::DataType::INT64); + CHECK_TYPE(queries_row_splits, paddle::DataType::INT64); + CHECK_SAME_DTYPE(points, queries, radii); + CHECK_SAME_DEVICE_TYPE(points, queries, radii); + PD_CHECK(index_dtype == "int32" || index_dtype == "int64", + "index_dtype must be int32 or int64"); + // ensure that these are on the cpu + points_row_splits = points_row_splits.copy_to(paddle::CPUPlace(), false); + queries_row_splits = queries_row_splits.copy_to(paddle::CPUPlace(), false); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_queries("num_queries"); + Dim batch_size("batch_size"); + Dim num_cells("num_cells"); + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(queries, num_queries, 3); + CHECK_SHAPE(radii, num_queries); + CHECK_SHAPE(points_row_splits, batch_size + 1); + CHECK_SHAPE(queries_row_splits, batch_size + 1); + + const auto& point_type = points.dtype(); + + auto place = points.place(); + + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_row_splits = + paddle::empty({queries.shape()[0] + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_distance; + +#define FN_PARAMETERS \ + points, queries, radii, points_row_splits, queries_row_splits, metric, \ + ignore_query_point, return_distances, normalize_distances, \ + neighbors_index, neighbors_row_splits, neighbors_distance + + if (points.is_gpu()) { + PD_CHECK(false, "MultiRadiusSearch does not support CUDA"); + } else { + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + RadiusSearchCPU(FN_PARAMETERS); + } else { + RadiusSearchCPU(FN_PARAMETERS); + } + } else { + if (index_dtype == "int32") { + RadiusSearchCPU(FN_PARAMETERS); + } else { + RadiusSearchCPU(FN_PARAMETERS); + } + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } + // same question of fixed_radius_search + PD_CHECK(false, "MultiRadiusSearch does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for points"); + return {neighbors_index, neighbors_row_splits, neighbors_distance}; +} + +std::vector MultiRadiusSearchInferDtype( + const std::string& index_dtype) { + paddle::DataType dtype = index_dtype == "int32" ? paddle::DataType::INT32 + : paddle::DataType::INT64; + return {dtype, paddle::DataType::INT64, dtype}; +} + +std::vector> MultiRadiusSearchInferShape( + std::vector queries_shape, const bool return_distances) { + // this just a temp impl , all return is fake data + // TODO(woodman3): impl real data + int64_t neighbors_row_splits_shape = queries_shape[0] + 1; + int64_t neighbors_distance_shape = return_distances ? 1 : 0; + return {{neighbors_row_splits_shape}, + {neighbors_row_splits_shape}, + {neighbors_distance_shape}}; +} + +PD_BUILD_OP(open3d_radius_search) + .Inputs({"points", "queries", "radii", "points_row_splits", + "queries_row_splits"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_distance"}) + .Attrs({"index_dtype: std::string", "metric_str: std::string", + "ignore_query_point: bool", "return_distances: bool", + "normalize_distances: bool"}) + .SetKernelFn(PD_KERNEL(MultiRadiusSearch)) + .SetInferShapeFn(PD_INFER_SHAPE(MultiRadiusSearchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MultiRadiusSearchInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cpp b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cpp new file mode 100644 index 00000000000..64fec24034a --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cpp @@ -0,0 +1,45 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/RaggedToDenseOpKernel.h" + +#include "open3d/ml/impl/misc/RaggedToDense.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +paddle::Tensor RaggedToDenseCPU(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value) { + auto out_shape = values.shape(); + out_shape.erase(out_shape.begin()); + out_shape.insert(out_shape.begin(), + {row_splits.shape()[0] - 1, out_col_size}); + paddle::Tensor out = + paddle::empty(out_shape, paddle::DataType(ToPaddleDtype())); + + open3d::ml::impl::RaggedToDenseCPU( + values.data(), row_splits.data(), row_splits.shape()[0], + out_col_size, default_value.data(), default_value.numel(), + out.data()); + + return out; +} + +#define INSTANTIATE(T) \ + template paddle::Tensor RaggedToDenseCPU( \ + const paddle::Tensor&, const paddle::Tensor&, const int64_t, \ + const paddle::Tensor&); + +INSTANTIATE(uint8_t) +INSTANTIATE(int8_t) +INSTANTIATE(int16_t) +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cu b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cu new file mode 100644 index 00000000000..ae3677bf035 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cu @@ -0,0 +1,48 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/RaggedToDense.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/RaggedToDenseOpKernel.h" +#include "paddle/extension.h" + +template +paddle::Tensor RaggedToDenseCUDA(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value) { + auto out_shape = values.shape(); + out_shape.erase(out_shape.begin()); + out_shape.insert(out_shape.begin(), + {row_splits.shape()[0] - 1, out_col_size}); + auto place = values.place(); + paddle::Tensor out = paddle::empty( + out_shape, paddle::DataType(ToPaddleDtype()), place); + + auto stream = values.stream(); + + open3d::ml::impl::RaggedToDenseCUDA( + stream, values.data(), row_splits.data(), + row_splits.shape()[0], out_col_size, default_value.data(), + default_value.numel(), out.data()); + + return out; +} + +#define INSTANTIATE(T) \ + template paddle::Tensor RaggedToDenseCUDA( \ + const paddle::Tensor&, const paddle::Tensor&, const int64_t, \ + const paddle::Tensor&); + +INSTANTIATE(uint8_t) +INSTANTIATE(int8_t) +INSTANTIATE(int16_t) +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.h b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.h new file mode 100644 index 00000000000..1834c710979 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.h @@ -0,0 +1,24 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include "paddle/extension.h" + +template +paddle::Tensor RaggedToDenseCPU(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value); + +#ifdef BUILD_CUDA_MODULE +template +paddle::Tensor RaggedToDenseCUDA(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value); +#endif diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOps.cpp b/cpp/open3d/ml/paddle/misc/RaggedToDenseOps.cpp new file mode 100644 index 00000000000..60ea9885dd2 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOps.cpp @@ -0,0 +1,111 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/RaggedToDenseOpKernel.h" +#include "paddle/extension.h" + +std::vector RaggedToDense(paddle::Tensor& values, + paddle::Tensor& row_splits, + paddle::Tensor& default_value, + const int64_t out_col_size) { + CHECK_TYPE(row_splits, phi::DataType::INT64); + CHECK_SAME_DTYPE(values, default_value); + + // check input shapes + { + using namespace open3d::ml::op_util; + Dim num_rows("num_rows"); + CHECK_SHAPE(row_splits, num_rows + 1); + if (default_value.shape().size()) { + Dim item_size("item_size"); + CHECK_SHAPE_COMBINE_LAST_DIMS(default_value, item_size); + CHECK_SHAPE_COMBINE_LAST_DIMS(values, Dim(), item_size); + auto value_shape = values.shape(); + + // check shape tail + std::vector item_shape(value_shape.begin() + 1, + value_shape.end()); + auto default_value_shape = default_value.shape(); + PD_CHECK(default_value_shape == item_shape, + "default_value " + + phi::DataTypeToString(default_value.dtype()) + + "has incompatible with the shape of items in " + "values" + + TensorInfoStr({values})); + } else // scalar default_value + { + Dim num_values("num_values"); + CHECK_SHAPE_COMBINE_LAST_DIMS(values, num_values); + } + } + + // make sure everything is on the same place as 'values' + auto place = values.place(); + row_splits = row_splits.copy_to(place, false); + default_value = default_value.copy_to(place, false); + + const auto& value_type = values.dtype(); + +#define CALL(value_t, fn) \ + if (ComparePaddleDtype(value_type)) { \ + return {fn(values, row_splits, out_col_size, default_value)}; \ + } + + if (values.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(uint8_t, RaggedToDenseCUDA) + CALL(int8_t, RaggedToDenseCUDA) + CALL(int16_t, RaggedToDenseCUDA) + CALL(int32_t, RaggedToDenseCUDA) + CALL(int64_t, RaggedToDenseCUDA) + CALL(float, RaggedToDenseCUDA) + CALL(double, RaggedToDenseCUDA) +#else + PD_CHECK(false, "RaggedToDense was not compiled with CUDA support"); +#endif + } else { + CALL(uint8_t, RaggedToDenseCPU) + CALL(int8_t, RaggedToDenseCPU) + CALL(int16_t, RaggedToDenseCPU) + CALL(int32_t, RaggedToDenseCPU) + CALL(int64_t, RaggedToDenseCPU) + CALL(float, RaggedToDenseCPU) + CALL(double, RaggedToDenseCPU) + } + PD_CHECK(false, "RaggedToDense does not support " + + phi::DataTypeToString(values.dtype()) + + " as input for values"); +} + +std::vector RaggedToDenseInferDtype( + const paddle::DataType values_dtype) { + return {values_dtype}; +} + +std::vector> RaggedToDenseInferShape( + std::vector values_shape, + std::vector row_splits_shape, + const int64_t out_col_size) { + auto out_shape = values_shape; + out_shape.erase(out_shape.begin()); + out_shape.insert(out_shape.begin(), + {row_splits_shape[0] - 1, out_col_size}); + return {out_shape}; +} + +PD_BUILD_OP(open3d_ragged_to_dense) + .Inputs({"values", "row_splits", "default_value"}) + .Attrs({"out_col_size: int64_t"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(RaggedToDense)) + .SetInferShapeFn(PD_INFER_SHAPE(RaggedToDenseInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RaggedToDenseInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.cpp b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.cpp new file mode 100644 index 00000000000..8e74c82bceb --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.cpp @@ -0,0 +1,33 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.h" + +#include "open3d/ml/impl/misc/ReduceSubarraysSum.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "paddle/extension.h" + +template +paddle::Tensor ReduceSubarraysSumCPU(const paddle::Tensor& values, + const paddle::Tensor& row_splits) { + paddle::Tensor sums = paddle::empty({row_splits.shape()[0] - 1}, + paddle::DataType(ToPaddleDtype())); + + open3d::ml::impl::ReduceSubarraysSumCPU( + values.data(), values.shape()[0], row_splits.data(), + row_splits.shape()[0] - 1, sums.data()); + return sums; +} +#define INSTANTIATE(T) \ + template paddle::Tensor ReduceSubarraysSumCPU(const paddle::Tensor&, \ + const paddle::Tensor&); + +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.cu b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.cu new file mode 100644 index 00000000000..9acedbf04f8 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.cu @@ -0,0 +1,36 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/ReduceSubarraysSum.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.h" +#include "paddle/extension.h" + +template +paddle::Tensor ReduceSubarraysSumCUDA(const paddle::Tensor& values, + const paddle::Tensor& row_splits) { + auto place = values.place(); + paddle::Tensor sums = + paddle::empty({row_splits.shape()[0] - 1}, + paddle::DataType(ToPaddleDtype()), place); + + auto stream = values.stream(); + open3d::ml::impl::ReduceSubarraysSumCUDA( + stream, values.data(), values.shape()[0], + row_splits.data(), row_splits.shape()[0] - 1, + sums.data()); + return sums; +} +#define INSTANTIATE(T) \ + template paddle::Tensor ReduceSubarraysSumCUDA(const paddle::Tensor&, \ + const paddle::Tensor&); + +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.h b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.h new file mode 100644 index 00000000000..a66b260a6b5 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.h @@ -0,0 +1,20 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include "paddle/extension.h" + +template +paddle::Tensor ReduceSubarraysSumCPU(const paddle::Tensor& values, + const paddle::Tensor& row_splits); + +#ifdef BUILD_CUDA_MODULE +template +paddle::Tensor ReduceSubarraysSumCUDA(const paddle::Tensor& values, + const paddle::Tensor& row_splits); +#endif diff --git a/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOps.cpp b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOps.cpp new file mode 100644 index 00000000000..c55db9217ba --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOps.cpp @@ -0,0 +1,71 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOps.h" + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOpKernel.h" +#include "paddle/extension.h" + +std::vector ReduceSubarraysSum(paddle::Tensor& values, + paddle::Tensor& row_splits) { + CHECK_TYPE(row_splits, phi::DataType::INT64); + + const auto& attr_type = values.dtype(); + + // special treatment for empty values vector + if (values.shape()[0] == 0) { + return {InitializedEmptyTensor(values.dtype(), values.shape(), + values.place())}; + } + +#define CALL(attr_t, fn) \ + if (ComparePaddleDtype(attr_type)) { \ + return {fn(values, row_splits)}; \ + } + + CHECK_SAME_DEVICE_TYPE(values, row_splits); + + if (values.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(int32_t, ReduceSubarraysSumCUDA) + CALL(int64_t, ReduceSubarraysSumCUDA) + CALL(float, ReduceSubarraysSumCUDA) + CALL(double, ReduceSubarraysSumCUDA) +#else + PD_CHECK(false, + "ReduceSubarraysSum was not compiled with CUDA support"); +#endif + } else { + CALL(int32_t, ReduceSubarraysSumCPU) + CALL(int64_t, ReduceSubarraysSumCPU) + CALL(float, ReduceSubarraysSumCPU) + CALL(double, ReduceSubarraysSumCPU) + } + return {paddle::Tensor()}; +} + +std::vector ReduceSubarraysSumInferDtype( + const paddle::DataType values_dtype) { + return {values_dtype}; +} + +std::vector> ReduceSubarraysSumInferShape( + std::vector values_shape) { + return {values_shape}; +} + +PD_BUILD_OP(open3d_reduce_subarrays_sum) + .Inputs({"values", "row_splits"}) + .Outputs({"sums"}) + .SetKernelFn(PD_KERNEL(ReduceSubarraysSum)) + .SetInferShapeFn(PD_INFER_SHAPE(ReduceSubarraysSumInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ReduceSubarraysSumInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOps.h b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOps.h new file mode 100644 index 00000000000..a3336a1011c --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/ReduceSubarraysSumOps.h @@ -0,0 +1,13 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include "paddle/extension.h" + +std::vector ReduceSubarraysSum(paddle::Tensor& values, + paddle::Tensor& row_splits); diff --git a/cpp/open3d/ml/paddle/misc/RoiPoolOps.cpp b/cpp/open3d/ml/paddle/misc/RoiPoolOps.cpp new file mode 100644 index 00000000000..52550054512 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RoiPoolOps.cpp @@ -0,0 +1,96 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on PointRCNN Library (MIT License): +// https://github.com/sshaoshuai/PointRCNN +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include "open3d/ml/contrib/RoiPoolKernel.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +#ifdef BUILD_CUDA_MODULE + +std::vector RoiPool(paddle::Tensor &xyz, + paddle::Tensor &boxes3d, + paddle::Tensor &pts_feature, + const int64_t sampled_pts_num) { + int batch_size = xyz.shape()[0]; + int pts_num = xyz.shape()[1]; + int boxes_num = boxes3d.shape()[1]; + int feature_in_len = pts_feature.shape()[2]; + + auto place = xyz.place(); + paddle::Tensor features = paddle::full( + {batch_size, boxes_num, sampled_pts_num, 3 + feature_in_len}, 0.0f, + paddle::DataType(ToPaddleDtype()), place); + + paddle::Tensor empty_flag = + paddle::full({batch_size, boxes_num}, 0.0f, + paddle::DataType(ToPaddleDtype()), place); + + const float *xyz_data = xyz.data(); + const float *boxes3d_data = boxes3d.data(); + const float *pts_feature_data = pts_feature.data(); + float *pooled_features_data = features.data(); + int *pooled_empty_flag_data = empty_flag.data(); + + open3d::ml::contrib::roipool3dLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + xyz_data, boxes3d_data, pts_feature_data, pooled_features_data, + pooled_empty_flag_data); + + return {features, empty_flag}; +} + +std::vector RoiPoolInferDtype() { + return {paddle::DataType::FLOAT32, paddle::DataType::INT32}; +} + +std::vector> RoiPoolInferShape( + std::vector xyz_shape, + std::vector boxes3d_shape, + std::vector pts_feature_shape, + const int64_t sampled_pts_num) { + std::vector features_shape{xyz_shape[0], boxes3d_shape[1], + sampled_pts_num, + 3 + pts_feature_shape[2]}; + return {features_shape, {xyz_shape[0], boxes3d_shape[1]}}; +} + +PD_BUILD_OP(open3d_roi_pool) + .Inputs({"xyz", "boxes3d", "pts_feature"}) + .Outputs({"features", "empty_flag"}) + .Attrs({ + "sampled_pts_num: int64_t", + }) + .SetKernelFn(PD_KERNEL(RoiPool)) + .SetInferShapeFn(PD_INFER_SHAPE(RoiPoolInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RoiPoolInferDtype)); + +#endif diff --git a/cpp/open3d/ml/paddle/misc/VoxelPoolingOpKernel.cpp b/cpp/open3d/ml/paddle/misc/VoxelPoolingOpKernel.cpp new file mode 100644 index 00000000000..a70e05281d7 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/VoxelPoolingOpKernel.cpp @@ -0,0 +1,118 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/VoxelPooling.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +namespace { + +template +class OutputAllocator { +public: + explicit OutputAllocator(paddle::Place place) : place(place) {} + + void AllocPooledPositions(TReal** ptr, size_t num) { + if (num != 0) { + positions = paddle::empty({int64_t(num), 3}, ToPaddleDtype(), + place); + } else { + positions = InitializedEmptyTensor(ToPaddleDtype(), {0, 3}, + place); + } + *ptr = positions.data(); + } + + void AllocPooledFeatures(TFeat** ptr, size_t num, size_t channels) { + if (num != 0) { + features = paddle::empty({int64_t(num), int64_t(channels)}, + ToPaddleDtype(), place); + } else { + features = InitializedEmptyTensor(ToPaddleDtype(), + {0, int64_t(channels)}, place); + } + *ptr = features.data(); + } + + const paddle::Tensor& PooledPositions() const { return positions; } + const paddle::Tensor& PooledFeatures() const { return features; } + +private: + paddle::Tensor positions; + paddle::Tensor features; + paddle::Place place; +}; + +} // namespace + +template +std::vector VoxelPoolingCPU(const paddle::Tensor& positions, + const paddle::Tensor& features, + const double voxel_size, + const AccumulationFn position_fn, + const AccumulationFn feature_fn, + const bool debug) { + OutputAllocator output_allocator(positions.place()); + + if (debug) { + std::string err; + PD_CHECK(CheckVoxelSize(err, positions.shape()[0], + positions.data(), TReal(voxel_size)), + err); + } + + VoxelPooling(positions.shape()[0], positions.data(), + features.shape()[1], features.data(), + voxel_size, output_allocator, position_fn, + feature_fn); + + return {output_allocator.PooledPositions(), + output_allocator.PooledFeatures()}; +} +#define INSTANTIATE(TReal, TFeat) \ + template std::vector VoxelPoolingCPU( \ + const paddle::Tensor&, const paddle::Tensor&, const double, \ + const AccumulationFn, const AccumulationFn, const bool); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(float, float) +INSTANTIATE(float, double) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) +INSTANTIATE(double, float) +INSTANTIATE(double, double) +#undef INSTANTIATE + +template +void VoxelPoolingGradCPU(paddle::Tensor& features_backprop, + const paddle::Tensor& positions, + const paddle::Tensor& features, + const paddle::Tensor& pooled_positions, + const paddle::Tensor& pooled_features_gradient, + const double voxel_size, + const AccumulationFn position_fn, + const AccumulationFn feature_fn) { + VoxelPoolingBackprop( + features_backprop.data(), positions.shape()[0], + positions.data(), features.shape()[1], + features.data(), pooled_positions.shape()[0], + pooled_positions.data(), + pooled_features_gradient.data(), TReal(voxel_size), + position_fn, feature_fn); +} +#define INSTANTIATE(TReal, TFeat) \ + template void VoxelPoolingGradCPU( \ + paddle::Tensor&, const paddle::Tensor&, const paddle::Tensor&, \ + const paddle::Tensor&, const paddle::Tensor&, const double, \ + const AccumulationFn, const AccumulationFn); +INSTANTIATE(float, float) +INSTANTIATE(float, double) +INSTANTIATE(double, float) +INSTANTIATE(double, double) diff --git a/cpp/open3d/ml/paddle/misc/VoxelPoolingOps.cpp b/cpp/open3d/ml/paddle/misc/VoxelPoolingOps.cpp new file mode 100644 index 00000000000..d2c4f8c127f --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/VoxelPoolingOps.cpp @@ -0,0 +1,211 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/misc/VoxelPooling.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +std::vector VoxelPoolingCPU(const paddle::Tensor& positions, + const paddle::Tensor& features, + const double voxel_size, + const AccumulationFn position_fn, + const AccumulationFn feature_fn, + const bool debug); + +template +void VoxelPoolingGradCPU(paddle::Tensor& features_backprop, + const paddle::Tensor& positions, + const paddle::Tensor& features, + const paddle::Tensor& pooled_positions, + const paddle::Tensor& pooled_features_gradient, + const double voxel_size, + const AccumulationFn position_fn, + const AccumulationFn feature_fn); + +std::vector VoxelPoolingForward( + paddle::Tensor& positions, + paddle::Tensor& features, + const double voxel_size, + const std::string& position_fn_str, + const std::string& feature_fn_str, + const bool debug) { + AccumulationFn position_fn = AVERAGE; + if (position_fn_str == "average") { + position_fn = AVERAGE; + } else if (position_fn_str == "nearest_neighbor") { + position_fn = NEAREST_NEIGHBOR; + } else if (position_fn_str == "center") { + position_fn = CENTER; + } else { + PD_CHECK(false, + "position_fn must be one of ('average', " + "'nearest_neighbor', 'center') but got " + + position_fn_str); + } + AccumulationFn feature_fn = AVERAGE; + if (feature_fn_str == "average") { + feature_fn = AVERAGE; + } else if (feature_fn_str == "nearest_neighbor") { + feature_fn = NEAREST_NEIGHBOR; + } else if (feature_fn_str == "max") { + feature_fn = MAX; + } else { + PD_CHECK(false, + "feature_fn must be one of ('average', " + "'nearest_neighbor', 'max') but got " + + feature_fn_str); + } + + // check input shapes + { + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_channels("num_channels"); + + CHECK_SHAPE(positions, num_points, 3); + CHECK_SHAPE_COMBINE_LAST_DIMS(features, num_points, num_channels); + } + + // ctx->saved_data["position_fn_str"] = position_fn_str; + // ctx->saved_data["feature_fn_str"] = feature_fn_str; + // ctx->saved_data["voxel_size"] = voxel_size; + + const auto& positions_type = positions.dtype(); + const auto& features_type = features.dtype(); + +#define FN_PARAMETERS \ + positions, features, voxel_size, position_fn, feature_fn, debug + +#define CALL(real_t, feat_t, fn) \ + if (ComparePaddleDtype(positions_type) && \ + ComparePaddleDtype(features_type)) { \ + return fn(FN_PARAMETERS); \ + } + + CHECK_SAME_DEVICE_TYPE(positions, features); + if (positions.is_gpu()) { + PD_CHECK(false, "VoxelPooling does not support CUDA"); + } else { + CALL(float, float, VoxelPoolingCPU) + CALL(float, int32_t, VoxelPoolingCPU) + CALL(float, int64_t, VoxelPoolingCPU) + CALL(float, double, VoxelPoolingCPU) + CALL(double, float, VoxelPoolingCPU) + CALL(double, int32_t, VoxelPoolingCPU) + CALL(double, int64_t, VoxelPoolingCPU) + CALL(double, double, VoxelPoolingCPU) + } +#undef FN_PARAMETERS +#undef CALL + + PD_CHECK(false, "VoxelPooling does not support " + + phi::DataTypeToString(positions.dtype()) + + " as input for positions and " + + phi::DataTypeToString(features.dtype()) + + " as input for features"); + return {paddle::Tensor(), paddle::Tensor()}; +} + +std::vector VoxelPoolingBackward( + paddle::Tensor& positions, + paddle::Tensor& features, + paddle::Tensor& pooled_positions, + paddle::Tensor& pooled_features_gradient, + const double voxel_size, + const std::string& position_fn_str, + const std::string& feature_fn_str) { + AccumulationFn position_fn = AVERAGE; + if (position_fn_str == "average") { + position_fn = AVERAGE; + } else if (position_fn_str == "nearest_neighbor") { + position_fn = NEAREST_NEIGHBOR; + } else if (position_fn_str == "center") { + position_fn = CENTER; + } else { + PD_CHECK(false, + "position_fn must be one of ('average', " + "'nearest_neighbor', 'center') but got " + + position_fn_str); + } + AccumulationFn feature_fn = AVERAGE; + if (feature_fn_str == "average") { + feature_fn = AVERAGE; + } else if (feature_fn_str == "nearest_neighbor") { + feature_fn = NEAREST_NEIGHBOR; + } else if (feature_fn_str == "max") { + feature_fn = MAX; + } else { + PD_CHECK(false, + "feature_fn must be one of ('average', " + "'nearest_neighbor', 'max') but got " + + feature_fn_str); + } + + // auto pooled_positions = saved_vars[2]; + + paddle::Tensor features_backprop = + paddle::empty(features.shape(), features.dtype()); + + const auto& positions_type = positions.dtype(); + const auto& features_type = features.dtype(); + +#define FN_PARAMETERS \ + features_backprop, positions, features, pooled_positions, \ + pooled_features_gradient, voxel_size, position_fn, feature_fn + +#define CALL(real_t, feat_t, fn) \ + if (ComparePaddleDtype(positions_type) && \ + ComparePaddleDtype(features_type)) { \ + fn(FN_PARAMETERS); \ + return {features_backprop}; \ + } + + CHECK_SAME_DEVICE_TYPE(positions, features); + if (positions.is_gpu()) { + PD_CHECK(false, "VoxelPooling backward does not support CUDA"); + } else { + CALL(float, float, VoxelPoolingGradCPU) + CALL(float, double, VoxelPoolingGradCPU) + CALL(double, float, VoxelPoolingGradCPU) + CALL(double, double, VoxelPoolingGradCPU) + PD_CHECK(false, "VoxelPooling backward does not support " + + phi::DataTypeToString(positions.dtype()) + + " as input for positions and " + + phi::DataTypeToString(features.dtype()) + + " as input for features"); + } +#undef FN_PARAMETERS +#undef CALL + + return {}; +} + +std::vector VoxelPoolingInferDtype( + paddle::DataType positions_dtype, paddle::DataType features_dtype) { + return {positions_dtype, features_dtype}; +} + +PD_BUILD_OP(open3d_voxel_pooling) + .Inputs({"positions", "features"}) + .Outputs({"pooled_positions", "pooled_features"}) + .Attrs({"voxel_size:double", "position_fn:std::string", + "feature_fn:std::string", "debug:bool"}) + .SetKernelFn(PD_KERNEL(VoxelPoolingForward)) + .SetInferDtypeFn(PD_INFER_DTYPE(VoxelPoolingInferDtype)); + +PD_BUILD_GRAD_OP(open3d_voxel_pooling) + .Inputs({"positions", "features", "pooled_positions", + paddle::Grad("pooled_features")}) + .Outputs({paddle::Grad("features")}) + .Attrs({"voxel_size:double", "position_fn:std::string", + "feature_fn:std::string"}) + .SetKernelFn(PD_KERNEL(VoxelPoolingBackward)); diff --git a/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.cpp b/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.cpp new file mode 100644 index 00000000000..67f349bee1a --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.cpp @@ -0,0 +1,73 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/VoxelizeOpKernel.h" + +#include "open3d/ml/impl/misc/Voxelize.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "paddle/extension.h" + +using namespace open3d::ml::impl; + +template +void VoxelizeCPU(const paddle::Tensor& points, + const paddle::Tensor& row_splits, + const paddle::Tensor& voxel_size, + const paddle::Tensor& points_range_min, + const paddle::Tensor& points_range_max, + const int64_t max_points_per_voxel, + const int64_t max_voxels, + paddle::Tensor& voxel_coords, + paddle::Tensor& voxel_point_indices, + paddle::Tensor& voxel_point_row_splits, + paddle::Tensor& voxel_batch_splits) { + VoxelizeOutputAllocator output_allocator(points.place()); + + switch (points.shape()[1]) { +#define CASE(NDIM) \ + case NDIM: \ + VoxelizeCPU(points.shape()[0], points.data(), \ + row_splits.shape()[0] - 1, \ + row_splits.data(), voxel_size.data(), \ + points_range_min.data(), \ + points_range_max.data(), max_points_per_voxel, \ + max_voxels, output_allocator); \ + break; + CASE(1) + CASE(2) + CASE(3) + CASE(4) + CASE(5) + CASE(6) + CASE(7) + CASE(8) + default: + break; // will be handled by the generic paddle function + +#undef CASE + } + + voxel_coords = output_allocator.VoxelCoords(); + voxel_point_indices = output_allocator.VoxelPointIndices(); + voxel_point_row_splits = output_allocator.VoxelPointRowSplits(); + voxel_batch_splits = output_allocator.VoxelBatchSplits(); +} + +#define INSTANTIATE(T) \ + template void VoxelizeCPU( \ + const paddle::Tensor& points, const paddle::Tensor& row_splits, \ + const paddle::Tensor& voxel_size, \ + const paddle::Tensor& points_range_min, \ + const paddle::Tensor& points_range_max, \ + const int64_t max_points_per_voxel, const int64_t max_voxels, \ + paddle::Tensor& voxel_coords, paddle::Tensor& voxel_point_indices, \ + paddle::Tensor& voxel_point_row_splits, \ + paddle::Tensor& voxel_batch_splits); + +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.cu b/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.cu new file mode 100644 index 00000000000..712d78ada1b --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.cu @@ -0,0 +1,91 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/Voxelize.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/VoxelizeOpKernel.h" +#include "paddle/extension.h" + +using namespace open3d::ml::impl; + +template +void VoxelizeCUDA(const paddle::Tensor& points, + const paddle::Tensor& row_splits, + const paddle::Tensor& voxel_size, + const paddle::Tensor& points_range_min, + const paddle::Tensor& points_range_max, + const int64_t max_points_per_voxel, + const int64_t max_voxels, + paddle::Tensor& voxel_coords, + paddle::Tensor& voxel_point_indices, + paddle::Tensor& voxel_point_row_splits, + paddle::Tensor& voxel_batch_splits) { + auto stream = points.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + VoxelizeOutputAllocator output_allocator(points.place()); + + switch (points.shape()[1]) { +#define CASE(NDIM) \ + case NDIM: { \ + void* temp_ptr = nullptr; \ + size_t temp_size = 0; \ + VoxelizeCUDA( \ + stream, temp_ptr, temp_size, texture_alignment, \ + points.shape()[0], points.data(), \ + row_splits.shape()[0] - 1, row_splits.data(), \ + voxel_size.data(), points_range_min.data(), \ + points_range_max.data(), max_points_per_voxel, max_voxels, \ + output_allocator); \ + \ + auto temp_tensor = \ + CreateTempTensor(temp_size, points.place(), &temp_ptr); \ + \ + VoxelizeCUDA( \ + stream, temp_ptr, temp_size, texture_alignment, \ + points.shape()[0], points.data(), \ + row_splits.shape()[0] - 1, row_splits.data(), \ + voxel_size.data(), points_range_min.data(), \ + points_range_max.data(), max_points_per_voxel, max_voxels, \ + output_allocator); \ + } break; + CASE(1) + CASE(2) + CASE(3) + CASE(4) + CASE(5) + CASE(6) + CASE(7) + CASE(8) + default: + break; // will be handled by the generic paddle function + +#undef CASE + } + + voxel_coords = output_allocator.VoxelCoords(); + voxel_point_indices = output_allocator.VoxelPointIndices(); + voxel_point_row_splits = output_allocator.VoxelPointRowSplits(); + voxel_batch_splits = output_allocator.VoxelBatchSplits(); +} + +#define INSTANTIATE(T) \ + template void VoxelizeCUDA( \ + const paddle::Tensor& points, const paddle::Tensor& row_splits, \ + const paddle::Tensor& voxel_size, \ + const paddle::Tensor& points_range_min, \ + const paddle::Tensor& points_range_max, \ + const int64_t max_points_per_voxel, const int64_t max_voxels, \ + paddle::Tensor& voxel_coords, paddle::Tensor& voxel_point_indices, \ + paddle::Tensor& voxel_point_row_splits, \ + paddle::Tensor& voxel_batch_splits); + +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.h b/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.h new file mode 100644 index 00000000000..f9f4e3a0aa7 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/VoxelizeOpKernel.h @@ -0,0 +1,98 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "paddle/extension.h" + +template +void VoxelizeCPU(const paddle::Tensor& points, + const paddle::Tensor& row_splits, + const paddle::Tensor& voxel_size, + const paddle::Tensor& points_range_min, + const paddle::Tensor& points_range_max, + const int64_t max_points_per_voxel, + const int64_t max_voxels, + paddle::Tensor& voxel_coords, + paddle::Tensor& voxel_point_indices, + paddle::Tensor& voxel_point_row_splits, + paddle::Tensor& voxel_batch_splits); + +#ifdef BUILD_CUDA_MODULE +template +void VoxelizeCUDA(const paddle::Tensor& points, + const paddle::Tensor& row_splits, + const paddle::Tensor& voxel_size, + const paddle::Tensor& points_range_min, + const paddle::Tensor& points_range_max, + const int64_t max_points_per_voxel, + const int64_t max_voxels, + paddle::Tensor& voxel_coords, + paddle::Tensor& voxel_point_indices, + paddle::Tensor& voxel_point_row_splits, + paddle::Tensor& voxel_batch_splits); +#endif + +class VoxelizeOutputAllocator { +public: + VoxelizeOutputAllocator(paddle::Place place) : place(place) {} + + void AllocVoxelCoords(int32_t** ptr, int64_t rows, int64_t cols) { + if (rows * cols == 0) { + voxel_coords = InitializedEmptyTensor({rows, cols}, place); + } else { + voxel_coords = paddle::empty( + {rows, cols}, paddle::DataType(ToPaddleDtype()), + place); + } + *ptr = voxel_coords.data(); + } + + void AllocVoxelPointIndices(int64_t** ptr, int64_t num) { + if (num == 0) { + voxel_point_indices = InitializedEmptyTensor({num}, place); + } else { + voxel_point_indices = paddle::empty( + {num}, paddle::DataType(ToPaddleDtype()), place); + } + *ptr = voxel_point_indices.data(); + } + + void AllocVoxelPointRowSplits(int64_t** ptr, int64_t num) { + voxel_point_row_splits = paddle::empty( + {num}, paddle::DataType(ToPaddleDtype()), place); + *ptr = voxel_point_row_splits.data(); + } + + void AllocVoxelBatchSplits(int64_t** ptr, int64_t num) { + voxel_batch_splits = paddle::empty( + {num}, paddle::DataType(ToPaddleDtype()), place); + *ptr = voxel_batch_splits.data(); + } + + const paddle::Tensor& VoxelCoords() const { return voxel_coords; } + const paddle::Tensor& VoxelPointIndices() const { + return voxel_point_indices; + } + const paddle::Tensor& VoxelPointRowSplits() const { + return voxel_point_row_splits; + } + const paddle::Tensor& VoxelBatchSplits() const { + return voxel_batch_splits; + } + +private: + paddle::Tensor voxel_coords; + paddle::Tensor voxel_point_indices; + paddle::Tensor voxel_point_row_splits; + paddle::Tensor voxel_batch_splits; + paddle::Place place; +}; diff --git a/cpp/open3d/ml/paddle/misc/VoxelizeOps.cpp b/cpp/open3d/ml/paddle/misc/VoxelizeOps.cpp new file mode 100644 index 00000000000..100d829be42 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/VoxelizeOps.cpp @@ -0,0 +1,96 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/VoxelizeOpKernel.h" +#include "paddle/extension.h" + +std::vector Voxelize(paddle::Tensor& points, + paddle::Tensor& row_splits, + paddle::Tensor& voxel_size, + paddle::Tensor& points_range_min, + paddle::Tensor& points_range_max, + const int64_t max_points_per_voxel, + const int64_t max_voxels) { + CHECK_TYPE(row_splits, phi::DataType::INT64); + + auto cpu_place = paddle::CPUPlace(); + // make sure that these tensors are on the cpu + voxel_size = voxel_size.copy_to(cpu_place, false); + points_range_min = points_range_min.copy_to(cpu_place, false); + points_range_max = points_range_max.copy_to(cpu_place, false); + + CHECK_SAME_DTYPE(points, voxel_size, points_range_min, points_range_max); + + // check input shapes + { + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim ndim("ndim"); + CHECK_SHAPE(points, num_points, ndim); + CHECK_SHAPE(voxel_size, ndim); + CHECK_SHAPE(points_range_min, ndim); + CHECK_SHAPE(points_range_max, ndim); + PD_CHECK(0 < ndim.value() && ndim.value() < 9, + "the number of dimensions must be in [1,..,8]"); + } + + const auto& points_dtype = points.dtype(); + + // output tensors + paddle::Tensor voxel_coords, voxel_point_indices, voxel_point_row_splits, + voxel_batch_splits; + +#define CALL(point_t, fn) \ + if (ComparePaddleDtype(points_dtype)) { \ + fn(points, row_splits, voxel_size, points_range_min, \ + points_range_max, max_points_per_voxel, max_voxels, \ + voxel_coords, voxel_point_indices, voxel_point_row_splits, \ + voxel_batch_splits); \ + return {voxel_coords, voxel_point_indices, voxel_point_row_splits, \ + voxel_batch_splits}; \ + } + + if (points.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(float, VoxelizeCUDA) + CALL(double, VoxelizeCUDA) +#else + PD_CHECK(false, "Voxelize was not compiled with CUDA support"); +#endif + } else { + CALL(float, VoxelizeCPU) + CALL(double, VoxelizeCPU) + } + + PD_CHECK(false, "Voxelize does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for values"); + + return std::vector(); +} + +std::vector VoxelizeInferDtype() { + return {paddle::DataType::INT32, paddle::DataType::INT64, + paddle::DataType::INT64, paddle::DataType::INT64}; +} + +PD_BUILD_OP(open3d_voxelize) + .Inputs({"points", "row_splits", "voxel_size", "points_range_min", + "points_range_max"}) + .Outputs({"voxel_coords", "voxel_point_indices", + "voxel_point_row_splits", "voxel_batch_splits"}) + .Attrs({ + "max_points_per_voxel: int64_t", + "max_voxels: int64_t", + }) + .SetKernelFn(PD_KERNEL(Voxelize)) + .SetInferDtypeFn(PD_INFER_DTYPE(VoxelizeInferDtype)); \ No newline at end of file diff --git a/cpp/open3d/ml/paddle/pointnet/BallQueryKernel.cu b/cpp/open3d/ml/paddle/pointnet/BallQueryKernel.cu new file mode 100644 index 00000000000..604b4d1bac3 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/BallQueryKernel.cu @@ -0,0 +1,74 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include +#include + +#include "open3d/ml/contrib/BallQuery.cuh" +#include "open3d/ml/contrib/cuda_utils.h" +#include "open3d/ml/paddle/pointnet/BallQueryKernel.h" + +using namespace open3d::ml::contrib; + +void ball_query_launcher(int b, + int n, + int m, + float radius, + int nsample, + const float *new_xyz, + const float *xyz, + int *idx, + uint64_t stream_id) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + cudaStream_t stream = reinterpret_cast(stream_id); + + cudaError_t err; + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + ball_query_kernel<<>>(b, n, m, radius, nsample, + new_xyz, xyz, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/paddle/pointnet/BallQueryKernel.h b/cpp/open3d/ml/paddle/pointnet/BallQueryKernel.h new file mode 100644 index 00000000000..e27fa600639 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/BallQueryKernel.h @@ -0,0 +1,44 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void ball_query_launcher(int b, + int n, + int m, + float radius, + int nsample, + const float *xyz, + const float *new_xyz, + int *idx, + uint64_t stream_id); diff --git a/cpp/open3d/ml/paddle/pointnet/BallQueryOps.cpp b/cpp/open3d/ml/paddle/pointnet/BallQueryOps.cpp new file mode 100644 index 00000000000..0554d779dfd --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/BallQueryOps.cpp @@ -0,0 +1,87 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/pointnet/BallQueryKernel.h" +#include "paddle/extension.h" + +#ifdef BUILD_CUDA_MODULE + +std::vector BallQuery(paddle::Tensor &xyz, + paddle::Tensor ¢er, + double radius, + const int64_t nsample) { + int batch_size = xyz.shape()[0]; + int pts_num = xyz.shape()[1]; + int ball_num = center.shape()[1]; + + auto place = xyz.place(); + paddle::Tensor out = + paddle::full({batch_size, ball_num, nsample}, 0.0f, + paddle::DataType(ToPaddleDtype()), place); + + const float *center_data = center.data(); + const float *xyz_data = xyz.data(); + int *idx = out.data(); + + ball_query_launcher(batch_size, pts_num, ball_num, radius, nsample, + center_data, xyz_data, idx, + reinterpret_cast(xyz.stream())); + return {out}; +} + +std::vector BallQueryInferDtype() { + return {paddle::DataType::FLOAT32}; +} + +std::vector> BallQueryInferShape( + std::vector xyz_shape, + std::vector center_shape, + const int64_t nsample) { + return {{xyz_shape[0], xyz_shape[1], center_shape[1]}}; +} + +PD_BUILD_OP(open3d_ball_query) + .Inputs({"xyz", "center"}) + .Outputs({"out"}) + .Attrs({ + "radius: double", + "nsample: int64_t", + }) + .SetKernelFn(PD_KERNEL(BallQuery)) + .SetInferShapeFn(PD_INFER_SHAPE(BallQueryInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(BallQueryInferDtype)); + +#endif diff --git a/cpp/open3d/ml/paddle/pointnet/InterpolateKernel.cu b/cpp/open3d/ml/paddle/pointnet/InterpolateKernel.cu new file mode 100644 index 00000000000..eb6d7d60de6 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/InterpolateKernel.cu @@ -0,0 +1,139 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include +#include + +#include + +#include "open3d/ml/contrib/InterpolatePoints.cuh" +#include "open3d/ml/contrib/cuda_utils.h" +#include "open3d/ml/paddle/pointnet/InterpolateKernel.h" + +using namespace open3d::ml::contrib; + +void three_nn_launcher(int b, + int n, + int m, + const float *unknown, + const float *known, + float *dist2, + int *idx, + uint64_t stream_id) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + cudaError_t err; + + cudaStream_t stream = reinterpret_cast(stream_id); + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + three_nn_kernel<<>>(b, n, m, unknown, known, + dist2, idx); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +void three_interpolate_launcher(int b, + int c, + int m, + int n, + const float *points, + const int *idx, + const float *weight, + float *out, + uint64_t stream_id) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + cudaError_t err; + + cudaStream_t stream = reinterpret_cast(stream_id); + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_kernel<<>>(b, c, m, n, points, + idx, weight, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +void three_interpolate_grad_launcher(int b, + int c, + int n, + int m, + const float *grad_out, + const int *idx, + const float *weight, + float *grad_points, + uint64_t stream_id) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + cudaError_t err; + + cudaStream_t stream = reinterpret_cast(stream_id); + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_grad_kernel<<>>( + b, c, n, m, grad_out, idx, weight, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/paddle/pointnet/InterpolateKernel.h b/cpp/open3d/ml/paddle/pointnet/InterpolateKernel.h new file mode 100644 index 00000000000..bb9e6a639a4 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/InterpolateKernel.h @@ -0,0 +1,63 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void three_nn_launcher(int b, + int n, + int m, + const float *unknown, + const float *known, + float *dist2, + int *idx, + uint64_t stream_id); + +void three_interpolate_launcher(int b, + int c, + int m, + int n, + const float *points, + const int *idx, + const float *weight, + float *out, + uint64_t stream_id); + +void three_interpolate_grad_launcher(int b, + int c, + int n, + int m, + const float *grad_out, + const int *idx, + const float *weight, + float *grad_points, + uint64_t stream_id); diff --git a/cpp/open3d/ml/paddle/pointnet/InterpolateOps.cpp b/cpp/open3d/ml/paddle/pointnet/InterpolateOps.cpp new file mode 100644 index 00000000000..e812dedbba0 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/InterpolateOps.cpp @@ -0,0 +1,181 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/pointnet/InterpolateKernel.h" +#include "paddle/extension.h" + +#ifdef BUILD_CUDA_MODULE +std::vector ThreeNN(paddle::Tensor &query_pts, + paddle::Tensor &data_pts) { + int batch_size = query_pts.shape()[0]; + int pts_num_out = query_pts.shape()[1]; + int pts_num_in = data_pts.shape()[1]; + + auto place = data_pts.place(); + paddle::Tensor out_idx = + paddle::full({batch_size, pts_num_out, 3}, 0, + paddle::DataType(ToPaddleDtype()), place); + + paddle::Tensor out_dist2 = + paddle::zeros({batch_size, pts_num_out, 3}, + paddle::DataType(ToPaddleDtype()), place); + + const float *pts_out = query_pts.data(); + const float *pts_in = data_pts.data(); + float *dist2 = out_dist2.data(); + int *idx = out_idx.data(); + + three_nn_launcher(batch_size, pts_num_out, pts_num_in, pts_out, pts_in, + dist2, idx, + reinterpret_cast(query_pts.stream())); + + return {out_dist2, out_idx}; +} + +std::vector ThreeNNInferDtype() { + return {paddle::DataType::INT32, paddle::DataType::FLOAT32}; +} + +std::vector> ThreeNNInferShape( + std::vector query_pts_shape, + std::vector data_pts_shape) { + std::vector shape{query_pts_shape[0], query_pts_shape[1], 3}; + return {shape, shape}; +} + +PD_BUILD_OP(open3d_three_nn) + .Inputs({"query_pts", "data_pts"}) + .Outputs({"dist", "idx"}) + .Attrs({}) + .SetKernelFn(PD_KERNEL(ThreeNN)) + .SetInferShapeFn(PD_INFER_SHAPE(ThreeNNInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ThreeNNInferDtype)); + +std::vector ThreeInterpolate(paddle::Tensor &points, + paddle::Tensor &idx, + paddle::Tensor &weights) { + int batch_size = points.shape()[0]; + int C = points.shape()[1]; + int M = points.shape()[2]; + int N = idx.shape()[1]; + + auto place = points.place(); + paddle::Tensor out = + paddle::full({batch_size, C, N}, 0.0f, + paddle::DataType(ToPaddleDtype()), place); + + const float *points_data = points.data(); + const float *weights_data = weights.data(); + const int *idx_data = idx.data(); + float *out_data = out.data(); + + three_interpolate_launcher(batch_size, C, M, N, points_data, idx_data, + weights_data, out_data, + reinterpret_cast(points.stream())); + + return {out}; +} + +std::vector ThreeInterpolateInferDtype() { + return {paddle::DataType::FLOAT32}; +} + +std::vector> ThreeInterpolateInferShape( + std::vector points_shape, std::vector idx_shape) { + return {{points_shape[0], points_shape[1], idx_shape[1]}}; +} + +PD_BUILD_OP(open3d_three_interpolate) + .Inputs({ + "points", + "idx", + "weights", + }) + .Outputs({"out"}) + .Attrs({}) + .SetKernelFn(PD_KERNEL(ThreeInterpolate)) + .SetInferShapeFn(PD_INFER_SHAPE(ThreeInterpolateInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ThreeInterpolateInferDtype)); + +std::vector ThreeInterpolateGrad(paddle::Tensor &grad_out, + paddle::Tensor &idx, + paddle::Tensor &weights, + const int64_t M) { + int batch_size = grad_out.shape()[0]; + int C = grad_out.shape()[1]; + int N = grad_out.shape()[2]; + + auto place = grad_out.place(); + paddle::Tensor out = + paddle::full({batch_size, C, M}, 0.0f, + paddle::DataType(ToPaddleDtype()), place); + + const float *grad_out_data = grad_out.data(); + const float *weights_data = weights.data(); + const int *idx_data = idx.data(); + + float *out_data = out.data(); + + three_interpolate_grad_launcher( + batch_size, C, N, M, grad_out_data, idx_data, weights_data, + out_data, reinterpret_cast(grad_out.stream())); + + return {out}; +} + +std::vector ThreeInterpolateGradInferDtype() { + return {paddle::DataType::FLOAT32}; +} + +std::vector> ThreeInterpolateGradInferShape( + std::vector grad_out_shape) { + return {{grad_out_shape[0], grad_out_shape[1], grad_out_shape[2]}}; +} + +PD_BUILD_OP(open3d_three_interpolate_grad) + .Inputs({"grad_out", "idx", "weights"}) + .Outputs({"out"}) + .Attrs({"M: int64_t"}) + .SetKernelFn(PD_KERNEL(ThreeInterpolateGrad)) + .SetInferShapeFn(PD_INFER_SHAPE(ThreeInterpolateGradInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ThreeInterpolateGradInferDtype)); + +#endif diff --git a/cpp/open3d/ml/paddle/pointnet/SamplingKernel.cu b/cpp/open3d/ml/paddle/pointnet/SamplingKernel.cu new file mode 100644 index 00000000000..60b66838f75 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/SamplingKernel.cu @@ -0,0 +1,118 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include + +#include + +#include "open3d/ml/contrib/PointSampling.cuh" +#include "open3d/ml/contrib/cuda_utils.h" +#include "open3d/ml/paddle/pointnet/SamplingKernel.h" + +using namespace open3d::ml::contrib; + +void furthest_point_sampling_launcher(int b, + int n, + int m, + const float *dataset, + float *temp, + int *idxs, + uint64_t stream_id) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + cudaError_t err; + + cudaStream_t stream = reinterpret_cast(stream_id); + + unsigned int n_threads = OptNumThreads(n); + + switch (n_threads) { + case 1024: + furthest_point_sampling_kernel<1024> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 512: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 256: + furthest_point_sampling_kernel<256> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 128: + furthest_point_sampling_kernel<128> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 64: + furthest_point_sampling_kernel<64> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 32: + furthest_point_sampling_kernel<32> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 16: + furthest_point_sampling_kernel<16> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 8: + furthest_point_sampling_kernel<8> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 4: + furthest_point_sampling_kernel<4> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 2: + furthest_point_sampling_kernel<2> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 1: + furthest_point_sampling_kernel<1> + <<>>(b, n, m, dataset, temp, idxs); + break; + default: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + } + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/paddle/pointnet/SamplingKernel.h b/cpp/open3d/ml/paddle/pointnet/SamplingKernel.h new file mode 100644 index 00000000000..5baa9e0be62 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/SamplingKernel.h @@ -0,0 +1,42 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void furthest_point_sampling_launcher(int b, + int n, + int m, + const float *dataset, + float *temp, + int *idxs, + uint64_t stream_id); diff --git a/cpp/open3d/ml/paddle/pointnet/SamplingOps.cpp b/cpp/open3d/ml/paddle/pointnet/SamplingOps.cpp new file mode 100644 index 00000000000..a6d944c5151 --- /dev/null +++ b/cpp/open3d/ml/paddle/pointnet/SamplingOps.cpp @@ -0,0 +1,85 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyPaddle +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/pointnet/SamplingKernel.h" +#include "paddle/extension.h" + +#ifdef BUILD_CUDA_MODULE + +std::vector FurthestPointSampling(paddle::Tensor &points, + const int64_t sample_size) { + int batch_size = points.shape()[0]; + int pts_size = points.shape()[1]; + + auto place = points.place(); + paddle::Tensor out = + paddle::full({batch_size, sample_size}, 0, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor temp = + paddle::full({batch_size, pts_size}, 1e10, + paddle::DataType(ToPaddleDtype()), place); + + const float *points_data = points.data(); + float *temp_data = temp.data(); + int *out_data = out.data(); + + furthest_point_sampling_launcher( + batch_size, pts_size, sample_size, points_data, temp_data, out_data, + reinterpret_cast(points.stream())); + + return {out}; +} + +std::vector FurthestPointSamplingInferDtype() { + return {paddle::DataType::INT32}; +} + +std::vector> FurthestPointSamplingInferShape( + std::vector points_shape, const int64_t sample_size) { + return {{points_shape[0], sample_size}}; +} + +PD_BUILD_OP(open3d_furthest_point_sampling) + .Inputs({"points"}) + .Outputs({"out"}) + .Attrs({ + "sample_size: int64_t", + }) + .SetKernelFn(PD_KERNEL(FurthestPointSampling)) + .SetInferShapeFn(PD_INFER_SHAPE(FurthestPointSamplingInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FurthestPointSamplingInferDtype)); + +#endif diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.cpp b/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.cpp new file mode 100644 index 00000000000..bbfcf91ad1b --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.cpp @@ -0,0 +1,56 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/sparse_conv/SparseConvBackpropFilter.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvBackpropFilterCPU(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + SparseConvBackpropFilterCPU( + filter_backprop.data(), filter_dims, + neighbors_row_splits.shape()[0] - 1, inp_features.shape()[0], + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), + out_features_gradient.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void \ + SparseConvBackpropFilterCPU( \ + const paddle::Tensor& filters, const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.cu b/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.cu new file mode 100644 index 00000000000..ab2eb18a7de --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.cu @@ -0,0 +1,94 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include +#include + +#include "open3d/ml/impl/sparse_conv/SparseConvBackpropFilter.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvBackpropFilterCUDA(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + SparseConvBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, + neighbors_row_splits.shape()[0] - 1, inp_features.shape()[0], + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), + out_features_gradient.data(), normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + SparseConvBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, + neighbors_row_splits.shape()[0] - 1, inp_features.shape()[0], + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), + out_features_gradient.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void \ + SparseConvBackpropFilterCUDA( \ + const paddle::Tensor& filters, const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.h b/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.h new file mode 100644 index 00000000000..0f166227eb4 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.h @@ -0,0 +1,40 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void SparseConvBackpropFilterCPU(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); + +#ifdef BUILD_CUDA_MODULE +template +void SparseConvBackpropFilterCUDA(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); +#endif diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.cpp b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.cpp new file mode 100644 index 00000000000..06f9850361d --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.cpp @@ -0,0 +1,52 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/sparse_conv/SparseConv.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvCPU(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + SparseConvComputeFeaturesCPU( + out_features.data(), filter_dims, filters.data(), + neighbors_row_splits.shape()[0] - 1, inp_features.shape()[0], + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void SparseConvCPU( \ + const paddle::Tensor& filters, const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.cu b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.cu new file mode 100644 index 00000000000..545969f54c1 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.cu @@ -0,0 +1,87 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include "open3d/ml/impl/sparse_conv/SparseConv.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvCUDA(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + SparseConvComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + neighbors_row_splits.shape()[0] - 1, inp_features.shape()[0], + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + SparseConvComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + neighbors_row_splits.shape()[0] - 1, inp_features.shape()[0], + inp_features.data(), + inp_importance.shape()[0] ? inp_importance.data() : nullptr, + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TReal, TIndex) \ + template void SparseConvCUDA( \ + const paddle::Tensor& filters, const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_importance, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.h b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.h new file mode 100644 index 00000000000..95578e74dec --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOpKernel.h @@ -0,0 +1,38 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void SparseConvCPU(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); + +#ifdef BUILD_CUDA_MODULE +template +void SparseConvCUDA(const paddle::Tensor& filters, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_importance, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); +#endif diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvOps.cpp b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOps.cpp new file mode 100644 index 00000000000..55d01bc871d --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvOps.cpp @@ -0,0 +1,212 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/InvertNeighborsListOps.h" +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOps.h" +#include "open3d/ml/paddle/sparse_conv/SparseConvBackpropFilterOpKernel.h" +#include "open3d/ml/paddle/sparse_conv/SparseConvOpKernel.h" +#include "open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.h" + +std::vector SparseConvForward( + paddle::Tensor& filters, + paddle::Tensor& inp_features, + paddle::Tensor& inp_importance, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_kernel_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB) { + CHECK_TYPE(neighbors_row_splits, paddle::DataType::INT64); + CHECK_SAME_DTYPE(filters, inp_features, inp_importance, + neighbors_importance); + CHECK_SAME_DEVICE_TYPE(filters, inp_features, inp_importance); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_kernel_elements("num_kernel_elements"); + Dim in_channels("in_channels"); + Dim out_channels("out_channels"); + Dim num_out_points("num_out_points"); + Dim num_inp_points("num_inp_points"); + Dim num_neighbors("nun_neighbors"); + + CHECK_SHAPE_COMBINE_FIRST_DIMS(filters, num_kernel_elements, in_channels, + out_channels); + CHECK_SHAPE(inp_features, num_inp_points, in_channels); + CHECK_SHAPE(inp_importance, num_inp_points || 0); + CHECK_SHAPE(neighbors_index, num_neighbors); + CHECK_SHAPE(neighbors_kernel_index, num_neighbors); + CHECK_SHAPE(neighbors_importance, num_neighbors || 0); + CHECK_SHAPE(neighbors_row_splits, num_out_points + 1); + + // make sure that these are on the same place as the filters and feats + auto place = inp_features.place(); + neighbors_index = neighbors_index.copy_to(place, false); + neighbors_kernel_index = neighbors_kernel_index.copy_to(place, false); + neighbors_importance = neighbors_importance.copy_to(place, false); + neighbors_row_splits = neighbors_row_splits.copy_to(place, false); + + const auto& feat_dtype = filters.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + const auto& kernel_index_dtype = neighbors_kernel_index.dtype(); + + paddle::Tensor out_features = paddle::empty( + {num_out_points.value(), out_channels.value()}, feat_dtype, place); +#define FN_PARAMETERS \ + filters, inp_features, inp_importance, neighbors_index, \ + neighbors_kernel_index, neighbors_importance, \ + neighbors_row_splits, normalize, max_temp_mem_MB, out_features + +#define CALL(feat_t, out_t, index_t, kernel_index_t, fn) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(index_dtype) && \ + ComparePaddleDtype(kernel_index_dtype)) { \ + fn(FN_PARAMETERS); \ + return {out_features}; \ + } + + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, int32_t, uint8_t, ::SparseConvCUDA) +#else + PD_CHECK(false, "SparseConv was not compiled with CUDA support"); +#endif + } else { + CALL(float, float, int32_t, uint8_t, ::SparseConvCPU) + } +#undef FN_PARAMETERS +#undef CALL + + PD_CHECK(false, + "SparseConv does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features, and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index, and " + + phi::DataTypeToString(neighbors_kernel_index.dtype()) + + " as input for neighbors_kernel_indexcgcgcc"); + return {paddle::Tensor()}; +} + +std::vector SparseConvBackward( + paddle::Tensor& filters, + paddle::Tensor& inp_features, + paddle::Tensor& inp_importance, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_kernel_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB) { + auto place = inp_features.place(); + const auto& feat_dtype = filters.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + const auto& kernel_index_dtype = neighbors_kernel_index.dtype(); + CHECK_SAME_DTYPE(out_features_gradient, inp_features, filters); + CHECK_SAME_DEVICE_TYPE(out_features_gradient, inp_features, filters); + + // output vars + paddle::Tensor filters_backprop; + paddle::Tensor inp_features_backprop; + +#define CALL(feat_t, out_t, index_t, kernel_index_t, fn_suffix) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(index_dtype) && \ + ComparePaddleDtype(kernel_index_dtype)) { \ + filters_backprop = paddle::empty(filters.shape(), feat_dtype, place); \ + SparseConvBackpropFilter##fn_suffix( \ + filters, inp_features, inp_importance, neighbors_index, \ + neighbors_kernel_index, neighbors_importance, \ + neighbors_row_splits, out_features_gradient, normalize, \ + max_temp_mem_MB, filters_backprop); \ + \ + paddle::Tensor inv_neighbors_index, inv_neighbors_row_splits, \ + inv_neighbors_importance, inv_arange; \ + paddle::Tensor arange = Arange(neighbors_index.shape()[0], place); \ + auto inv = InvertNeighborsList(neighbors_index, neighbors_row_splits, \ + arange, inp_features.shape()[0]); \ + inv_neighbors_index = inv[0]; \ + inv_neighbors_row_splits = inv[1]; \ + inv_arange = inv[2]; \ + paddle::Tensor inv_neighbors_kernel_index = \ + paddle::experimental::gather(neighbors_kernel_index, \ + inv_arange); \ + if (neighbors_importance.shape()[0] > 0) { \ + inv_neighbors_importance = paddle::experimental::gather( \ + neighbors_importance, inv_arange); \ + } else { \ + inv_neighbors_importance = paddle::empty({0}, feat_dtype, place); \ + } \ + \ + auto neighbors_importance_sum = ReduceSubarraysSum( \ + neighbors_importance, neighbors_row_splits)[0]; \ + inp_features_backprop = \ + paddle::ones(inp_features.shape(), feat_dtype, place); \ + auto filters_transposed = Transpose(filters, -1, -2).contiguous(); \ + \ + SparseConvTranspose##fn_suffix( \ + filters_transposed, inp_importance, out_features_gradient, \ + neighbors_importance_sum, neighbors_row_splits, \ + inv_neighbors_index, inv_neighbors_kernel_index, \ + inv_neighbors_importance, inv_neighbors_row_splits, normalize, \ + max_temp_mem_MB, inp_features_backprop); \ + dispatch_success = true; \ + } + + bool dispatch_success = false; + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, int32_t, uint8_t, CUDA) +#else + PD_CHECK(false, + "SparseConv backward was not compiled " + "with CUDA support"); +#endif + } else { + CALL(float, float, int32_t, uint8_t, CPU) + } + PD_CHECK(dispatch_success, + "SparseConv backward does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features, and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index, and " + + phi::DataTypeToString(neighbors_kernel_index.dtype()) + + " as input for neighbors_kernel_index"); + + return {filters_backprop, inp_features_backprop}; +} + +std::vector SparseConvInferDtype( + paddle::DataType filters_dtype) { + return {filters_dtype}; +} + +PD_BUILD_OP(open3d_sparse_conv) + .Inputs({"filters", "inp_features", "inp_importance", "neighbors_index", + "neighbors_kernel_index", "neighbors_importance", + "neighbors_row_splits"}) + .Outputs({"out_features"}) + .Attrs({"normalize:bool", "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(SparseConvForward)) + .SetInferDtypeFn(PD_INFER_DTYPE(SparseConvInferDtype)); + +PD_BUILD_GRAD_OP(open3d_sparse_conv) + .Inputs({"filters", "inp_features", "inp_importance", "neighbors_index", + "neighbors_kernel_index", "neighbors_importance", + "neighbors_row_splits", paddle::Grad("out_features")}) + .Outputs({paddle::Grad("filters"), paddle::Grad("inp_features")}) + .Attrs({"normalize:bool", "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(SparseConvBackward)); \ No newline at end of file diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cpp b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cpp new file mode 100644 index 00000000000..5c063d23ad9 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cpp @@ -0,0 +1,66 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/sparse_conv/SparseConvTransposeBackpropFilter.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvTransposeBackpropFilterCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + SparseConvTransposeBackpropFilterCPU( + filter_backprop.data(), filter_dims, + neighbors_row_splits.shape()[0] - 1, + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_features.shape()[0], inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), + out_features_gradient.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void \ + SparseConvTransposeBackpropFilterCPU( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cu b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cu new file mode 100644 index 00000000000..b2bed16dc21 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.cu @@ -0,0 +1,106 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include "open3d/ml/impl/sparse_conv/SparseConvTransposeBackpropFilter.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvTransposeBackpropFilterCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + SparseConvTransposeBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, + neighbors_row_splits.shape()[0] - 1, + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_features.shape()[0], inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), + out_features_gradient.data(), normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + SparseConvTransposeBackpropFilterCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + filter_backprop.data(), filter_dims, + neighbors_row_splits.shape()[0] - 1, + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_features.shape()[0], inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), + out_features_gradient.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void \ + SparseConvTransposeBackpropFilterCUDA( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, \ + const paddle::Tensor& out_features_gradient, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& filter_backprop); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.h b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.h new file mode 100644 index 00000000000..d9c8eae2d2a --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.h @@ -0,0 +1,46 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void SparseConvTransposeBackpropFilterCPU( + const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); + +#ifdef BUILD_CUDA_MODULE +template +void SparseConvTransposeBackpropFilterCUDA( + const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& filter_backprop); +#endif diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.cpp b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.cpp new file mode 100644 index 00000000000..a517ed5af53 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.cpp @@ -0,0 +1,62 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/impl/sparse_conv/SparseConvTranspose.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvTransposeCPU(const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + SparseConvTransposeComputeFeaturesCPU( + out_features.data(), filter_dims, filters.data(), + neighbors_row_splits.shape()[0] - 1, + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_features.shape()[0], inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void SparseConvTransposeCPU( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.cu b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.cu new file mode 100644 index 00000000000..5df7ddc7520 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.cu @@ -0,0 +1,100 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include "open3d/ml/impl/sparse_conv/SparseConvTranspose.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::ml::impl; + +template +void SparseConvTransposeCUDA(const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features) { + std::vector filter_dims; + for (auto d : filters.shape()) { + filter_dims.push_back(static_cast(d)); + } + + auto stream = filters.stream(); + // -1 means current global place + auto cuda_device_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_device_props.textureAlignment; + + auto place = filters.place(); + + void* temp_ptr = nullptr; + size_t temp_size = 0; + size_t max_temp_size = 0; + + // determine temp_size + SparseConvTransposeComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + neighbors_row_splits.shape()[0] - 1, + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_features.shape()[0], inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), normalize); + + temp_size = std::max( + std::min(static_cast(max_temp_mem_MB) * 1024 * 1024, + max_temp_size), + temp_size); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually run the operation + SparseConvTransposeComputeFeaturesCUDA( + stream, temp_ptr, temp_size, max_temp_size, texture_alignment, + out_features.data(), filter_dims, filters.data(), + neighbors_row_splits.shape()[0] - 1, + out_importance.shape()[0] ? out_importance.data() : nullptr, + inp_features.shape()[0], inp_features.data(), + inp_neighbors_importance_sum.shape()[0] + ? inp_neighbors_importance_sum.data() + : nullptr, + inp_neighbors_row_splits.data(), + neighbors_index.shape()[0], neighbors_index.data(), + neighbors_kernel_index.data(), + neighbors_importance.shape()[0] ? neighbors_importance.data() + : nullptr, + neighbors_row_splits.data(), normalize); +} +#define INSTANTIATE(TFeat, TOut, TIndex, TKernelIndex) \ + template void SparseConvTransposeCUDA( \ + const paddle::Tensor& filters, \ + const paddle::Tensor& out_importance, \ + const paddle::Tensor& inp_features, \ + const paddle::Tensor& inp_neighbors_importance_sum, \ + const paddle::Tensor& inp_neighbors_row_splits, \ + const paddle::Tensor& neighbors_index, \ + const paddle::Tensor& neighbors_kernel_index, \ + const paddle::Tensor& neighbors_importance, \ + const paddle::Tensor& neighbors_row_splits, const bool normalize, \ + const int64_t max_temp_mem_MB, paddle::Tensor& out_features); + +INSTANTIATE(float, float, int32_t, uint8_t) diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.h b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.h new file mode 100644 index 00000000000..51507320107 --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.h @@ -0,0 +1,42 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void SparseConvTransposeCPU(const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); + +#ifdef BUILD_CUDA_MODULE +template +void SparseConvTransposeCUDA(const paddle::Tensor& filters, + const paddle::Tensor& out_importance, + const paddle::Tensor& inp_features, + const paddle::Tensor& inp_neighbors_importance_sum, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& neighbors_index, + const paddle::Tensor& neighbors_kernel_index, + const paddle::Tensor& neighbors_importance, + const paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB, + paddle::Tensor& out_features); +#endif diff --git a/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOps.cpp b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOps.cpp new file mode 100644 index 00000000000..490b7b9a76c --- /dev/null +++ b/cpp/open3d/ml/paddle/sparse_conv/SparseConvTransposeOps.cpp @@ -0,0 +1,224 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/InvertNeighborsListOps.h" +#include "open3d/ml/paddle/misc/ReduceSubarraysSumOps.h" +#include "open3d/ml/paddle/sparse_conv/SparseConvOpKernel.h" +#include "open3d/ml/paddle/sparse_conv/SparseConvTransposeBackpropFilterOpKernel.h" +#include "open3d/ml/paddle/sparse_conv/SparseConvTransposeOpKernel.h" + +std::vector SparseConvTransposeForward( + paddle::Tensor& filters, + paddle::Tensor& out_importance, + paddle::Tensor& inp_features, + paddle::Tensor& inp_neighbors_index, + paddle::Tensor& inp_neighbors_importance_sum, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_kernel_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + const bool normalize, + const int64_t max_temp_mem_MB) { + CHECK_TYPE(neighbors_row_splits, paddle::DataType::INT64); + CHECK_TYPE(inp_neighbors_row_splits, paddle::DataType::INT64); + CHECK_SAME_DTYPE(neighbors_index, inp_neighbors_index); + CHECK_SAME_DTYPE(filters, inp_features, out_importance, + neighbors_importance); + CHECK_SAME_DEVICE_TYPE(filters, inp_features, out_importance); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_out("num_out"); + Dim num_inp("num_inp"); + Dim num_kernel_elements("num_kernel_elements"); + Dim in_channels("in_channels"); + Dim out_channels("out_channels"); + Dim num_neighbors("num_neighbors"); + + CHECK_SHAPE_COMBINE_FIRST_DIMS(filters, num_kernel_elements, in_channels, + out_channels); + CHECK_SHAPE(neighbors_row_splits, num_out + 1); + CHECK_SHAPE(out_importance, 0 || num_out); + CHECK_SHAPE(inp_features, num_inp, in_channels); + CHECK_SHAPE(inp_neighbors_index, num_neighbors); + CHECK_SHAPE(inp_neighbors_importance_sum, 0 || num_inp); + CHECK_SHAPE(inp_neighbors_row_splits, num_inp + 1); + CHECK_SHAPE(neighbors_index, num_neighbors); + CHECK_SHAPE(neighbors_kernel_index, num_neighbors); + CHECK_SHAPE(neighbors_importance, 0 || num_neighbors); + + // make sure that these are on the same place as the filters and feats + auto place = inp_features.place(); + neighbors_index = neighbors_index.copy_to(place, false); + neighbors_kernel_index = neighbors_kernel_index.copy_to(place, false); + neighbors_importance = neighbors_importance.copy_to(place, false); + neighbors_row_splits = neighbors_row_splits.copy_to(place, false); + inp_neighbors_index = inp_neighbors_index.copy_to(place, false); + inp_neighbors_importance_sum = + inp_neighbors_importance_sum.copy_to(place, false); + inp_neighbors_row_splits = inp_neighbors_row_splits.copy_to(place, false); + + const auto& feat_dtype = filters.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + const auto& kernel_index_dtype = neighbors_kernel_index.dtype(); + + paddle::Tensor out_features = paddle::empty( + {num_out.value(), out_channels.value()}, feat_dtype, place); +#define FN_PARAMETERS \ + filters, out_importance, inp_features, inp_neighbors_importance_sum, \ + inp_neighbors_row_splits, neighbors_index, neighbors_kernel_index, \ + neighbors_importance, neighbors_row_splits, normalize, \ + max_temp_mem_MB, out_features + +#define CALL(feat_t, out_t, index_t, kernel_index_t, fn) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(index_dtype) && \ + ComparePaddleDtype(kernel_index_dtype)) { \ + fn(FN_PARAMETERS); \ + return {out_features}; \ + } + + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, int32_t, uint8_t, ::SparseConvTransposeCUDA) +#else + PD_CHECK(false, + "SparseConvTranspose was not compiled with CUDA " + "support"); +#endif + } else { + CALL(float, float, int32_t, uint8_t, ::SparseConvTransposeCPU) + } +#undef FN_PARAMETERS +#undef CALL + + PD_CHECK(false, "SparseConv does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index"); + return {paddle::Tensor()}; +} + +std::vector SparseConvTransposeBackward( + paddle::Tensor& filters, + paddle::Tensor& out_importance, + paddle::Tensor& inp_features, + paddle::Tensor& inp_neighbors_importance_sum, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_kernel_index, + paddle::Tensor& neighbors_importance, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& out_features_gradient, + const bool normalize, + const int64_t max_temp_mem_MB) { + auto place = inp_features.place(); + const auto& feat_dtype = filters.dtype(); + const auto& index_dtype = neighbors_index.dtype(); + const auto& kernel_index_dtype = neighbors_kernel_index.dtype(); + CHECK_SAME_DTYPE(out_features_gradient, inp_features, filters); + CHECK_SAME_DEVICE_TYPE(out_features_gradient, inp_features, filters); + + // output vars + paddle::Tensor filters_backprop; + paddle::Tensor inp_features_backprop; + +#define CALL(feat_t, out_t, index_t, kernel_index_t, fn_suffix) \ + if (ComparePaddleDtype(feat_dtype) && \ + ComparePaddleDtype(index_dtype) && \ + ComparePaddleDtype(kernel_index_dtype)) { \ + filters_backprop = paddle::empty(filters.shape(), feat_dtype, place); \ + SparseConvTransposeBackpropFilter##fn_suffix( \ + filters, out_importance, inp_features, \ + inp_neighbors_importance_sum, inp_neighbors_row_splits, \ + neighbors_index, neighbors_kernel_index, neighbors_importance, \ + neighbors_row_splits, out_features_gradient, normalize, \ + max_temp_mem_MB, filters_backprop); \ + \ + paddle::Tensor inv_neighbors_index, _inv_neighbors_row_splits, \ + inv_neighbors_importance, inv_arange; \ + paddle::Tensor arange = Arange(neighbors_index.shape()[0], place); \ + auto inv = InvertNeighborsList(neighbors_index, neighbors_row_splits, \ + arange, inp_features.shape()[0]); \ + inv_neighbors_index = inv[0]; \ + _inv_neighbors_row_splits = inv[1]; \ + inv_arange = inv[2]; \ + paddle::Tensor inv_neighbors_kernel_index = \ + paddle::experimental::gather(neighbors_kernel_index, \ + inv_arange); \ + if (neighbors_importance.shape()[0] > 0) { \ + inv_neighbors_importance = paddle::experimental::gather( \ + neighbors_importance, inv_arange); \ + } else { \ + inv_neighbors_importance = paddle::empty({0}, feat_dtype, place); \ + } \ + inp_features_backprop = \ + paddle::ones(inp_features.shape(), feat_dtype, place); \ + auto filters_transposed = Transpose(filters, -1, -2).contiguous(); \ + \ + SparseConv##fn_suffix( \ + filters_transposed, out_features_gradient, out_importance, \ + inv_neighbors_index, inv_neighbors_kernel_index, \ + inv_neighbors_importance, inp_neighbors_row_splits, normalize, \ + max_temp_mem_MB, inp_features_backprop); \ + dispatch_success = true; \ + } + + bool dispatch_success = false; + if (inp_features.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + CALL(float, float, int32_t, uint8_t, CUDA) +#else + PD_CHECK(false, + "SparseConvTranspose backward was not compiled " + "with CUDA support"); +#endif + } else { + CALL(float, float, int32_t, uint8_t, CPU) + } + PD_CHECK(dispatch_success, + "SparseConvTranspose backward does not support " + + phi::DataTypeToString(inp_features.dtype()) + + " as input for inp_features and " + + phi::DataTypeToString(neighbors_index.dtype()) + + " as input for neighbors_index"); + + return {filters_backprop, inp_features_backprop}; +} + +std::vector SparseConvTransposeInferDtype( + paddle::DataType inp_positions_dtype) { + return {inp_positions_dtype}; +} + +PD_BUILD_OP(open3d_sparse_conv_transpose) + .Inputs({"filters", "out_importance", "inp_features", + "inp_neighbors_index", "inp_neighbors_importance_sum", + "inp_neighbors_row_splits", "neighbors_index", + "neighbors_kernel_index", "neighbors_importance", + "neighbors_row_splits"}) + .Outputs({"out_features"}) + .Attrs({"normalize:bool", "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(SparseConvTransposeForward)) + .SetInferDtypeFn(PD_INFER_DTYPE(SparseConvTransposeInferDtype)); + +PD_BUILD_GRAD_OP(open3d_sparse_conv_transpose) + .Inputs({"filters", "out_importance", "inp_features", + "inp_neighbors_importance_sum", "inp_neighbors_row_splits", + "neighbors_index", "neighbors_kernel_index", + "neighbors_importance", "neighbors_row_splits", + paddle::Grad("out_features")}) + .Outputs({paddle::Grad("filters"), paddle::Grad("inp_features")}) + .Attrs({"normalize:bool", "max_temp_mem_MB:int64_t"}) + .SetKernelFn(PD_KERNEL(SparseConvTransposeBackward)); \ No newline at end of file diff --git a/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h b/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h index 01bd9060b9e..bd40b18b37a 100644 --- a/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h +++ b/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h @@ -703,7 +703,7 @@ void RayCastCPU index_t block_buf_idx = cache.Check(key[0], key[1], key[2]); if (block_buf_idx < 0) { auto iter = hashmap_impl.find(key); - if (iter == hashmap_impl.end()) return -1; + if (iter == hashmap_impl.cend()) return -1; block_buf_idx = iter->second; cache.Update(key[0], key[1], key[2], block_buf_idx); } @@ -730,7 +730,7 @@ void RayCastCPU index_t block_buf_idx = cache.Check(x_b, y_b, z_b); if (block_buf_idx < 0) { auto iter = hashmap_impl.find(key); - if (iter == hashmap_impl.end()) return -1; + if (iter == hashmap_impl.cend()) return -1; block_buf_idx = iter->second; cache.Update(x_b, y_b, z_b, block_buf_idx); } @@ -929,7 +929,7 @@ void RayCastCPU index_t block_buf_idx = cache.Check(x_b, y_b, z_b); if (block_buf_idx < 0) { auto iter = hashmap_impl.find(key); - if (iter == hashmap_impl.end()) return; + if (iter == hashmap_impl.cend()) return; block_buf_idx = iter->second; cache.Update(x_b, y_b, z_b, block_buf_idx); } diff --git a/cpp/pybind/CMakeLists.txt b/cpp/pybind/CMakeLists.txt index 5fdce155f18..65d80cdbb28 100644 --- a/cpp/pybind/CMakeLists.txt +++ b/cpp/pybind/CMakeLists.txt @@ -184,6 +184,25 @@ if (BUILD_PYTORCH_OPS) OUTPUT_VARIABLE Pytorch_VERSION) endif() +# add additional optional compiled modules +if (BUILD_PADDLE_OPS) + list( APPEND COMPILED_MODULE_PATH_LIST $ ) + add_custom_command( OUTPUT "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/ops.py" "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/return_types.py" + COMMAND ${Python3_EXECUTABLE} generate_paddle_ops_wrapper.py + --input_return_types_py_in "${PYTHON_PACKAGE_SRC_DIR}/open3d/ml/paddle/python/return_types.py.in" + --output_dir "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/" + --lib $ + DEPENDS open3d_paddle_ops + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Generating python ops.py and return_types.py" ) + + list(APPEND GENERATED_OUTPUTS "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/ops.py" "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/return_types.py") + # find paddle to get some info for the _build_config.py + set(PRINT_ONCE OFF) + find_package(Paddle) +endif() + + if (BUNDLE_OPEN3D_ML) find_path( OPEN3D_ML_ROOT @@ -229,6 +248,7 @@ add_custom_target(python-package -DBUILD_JUPYTER_EXTENSION=${BUILD_JUPYTER_EXTENSION} -DBUILD_TENSORFLOW_OPS=${BUILD_TENSORFLOW_OPS} -DBUILD_PYTORCH_OPS=${BUILD_PYTORCH_OPS} + -DBUILD_PADDLE_OPS=${BUILD_PADDLE_OPS} -DBUNDLE_OPEN3D_ML=${BUNDLE_OPEN3D_ML} -DOPEN3D_ML_ROOT=${OPEN3D_ML_ROOT} -DBUILD_GUI=${BUILD_GUI} diff --git a/cpp/pybind/_build_config.py.in b/cpp/pybind/_build_config.py.in index 6c32224de02..7ea14135ff2 100644 --- a/cpp/pybind/_build_config.py.in +++ b/cpp/pybind/_build_config.py.in @@ -1,6 +1,7 @@ _build_config = { "BUILD_TENSORFLOW_OPS" : $,True,False>, "BUILD_PYTORCH_OPS" : $,True,False>, + "BUILD_PADDLE_OPS" : $,True,False>, "BUILD_CUDA_MODULE" : $,True,False>, "BUILD_SYCL_MODULE" : $,True,False>, "BUILD_AZURE_KINECT" : $,True,False>, @@ -16,5 +17,6 @@ _build_config = { "CUDA_GENCODES" : "@CUDA_GENCODES@", "Tensorflow_VERSION" : "@Tensorflow_VERSION@", "Pytorch_VERSION" : "@Pytorch_VERSION@", + "Paddle_VERSION" : "@Paddle_VERSION@", "WITH_OPENMP" : $,True,False> } diff --git a/cpp/pybind/generate_paddle_ops_wrapper.py b/cpp/pybind/generate_paddle_ops_wrapper.py new file mode 100644 index 00000000000..40038e044e7 --- /dev/null +++ b/cpp/pybind/generate_paddle_ops_wrapper.py @@ -0,0 +1,197 @@ +import argparse +import textwrap +import sys +import os +from yapf.yapflib.yapf_api import FormatFile + + +from paddle.utils.cpp_extension.extension_utils import ( + load_op_meta_info_and_register_op, + _get_api_inputs_str, + _gen_output_content +) + + +def remove_op_name_prefix(op_name): + PADDLE_OPS_PREFIX = "open3d_" + + assert op_name.startswith(PADDLE_OPS_PREFIX), "Paddle operators should be start with `open3d_`." + func_name = op_name[len(PADDLE_OPS_PREFIX):] + + return func_name + + +def custom_api_header(): + HEADER = textwrap.dedent( + """ + # ---------------------------------------------------------------------------- + # - Open3D: www.open3d.org - + # ---------------------------------------------------------------------------- + # The MIT License (MIT) + # + # Copyright (c) 2018-2024 www.open3d.org + # + # Permission is hereby granted, free of charge, to any person obtaining a copy + # of this software and associated documentation files (the "Software"), to deal + # in the Software without restriction, including without limitation the rights + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + # copies of the Software, and to permit persons to whom the Software is + # furnished to do so, subject to the following conditions: + # + # The above copyright notice and this permission notice shall be included in + # all copies or substantial portions of the Software. + # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + # IN THE SOFTWARE. + # ---------------------------------------------------------------------------- + + # This file is machine generated. Do not modify. + from paddle import _C_ops + from paddle.framework import in_dynamic_or_pir_mode + from paddle.base.layer_helper import LayerHelper + from . import return_types + """ + ).lstrip() + + return HEADER + + +def custom_api_footer(custom_ops): + FOOTER = textwrap.dedent( + """ + __all__ = [ + {export_func_name_strs} + ] + """ + ).lstrip() + + export_func_name_strs = "" + for op_name in custom_ops: + export_func_name_strs += f"'{remove_op_name_prefix(op_name)}', \n" + + return FOOTER.format( + export_func_name_strs = export_func_name_strs + ) + + +def custom_api_content(op_name): + ( + params_list, + ins_map, + attrs_map, + outs_list, + in_names, + _, + out_names, + inplace_reverse_idx, + ) = _get_api_inputs_str(op_name) + dynamic_content, static_content = _gen_output_content( + op_name, + in_names, + out_names, + ins_map, + attrs_map, + outs_list, + inplace_reverse_idx, + ) + API_TEMPLATE = textwrap.dedent( + """ + def {func_name}({params_list}): + # The output variable's dtype use default value 'float32', + # and the actual dtype of output variable will be inferred in runtime. + if in_dynamic_or_pir_mode(): + outs = _C_ops._run_custom_op("{op_name}", {params_list}) + {dynamic_content} + else: + {static_content} + """ + ).lstrip() + + # NOTE: Hack return express to wrapper multi return value by return_types + if len(out_names) > 1: + RETURN_NAMEDTUPLE_TEMPLATE = textwrap.dedent("""return return_types.{op_name}(*res)""").lstrip() + REPLACED_RETURN_TEMPLATE = textwrap.dedent("""return res[0] if len(res)==1 else res""").lstrip() + dynamic_content = dynamic_content.replace(REPLACED_RETURN_TEMPLATE, RETURN_NAMEDTUPLE_TEMPLATE.format(op_name=op_name)) + static_content = static_content.replace(REPLACED_RETURN_TEMPLATE, RETURN_NAMEDTUPLE_TEMPLATE.format(op_name=op_name)) + + func_name = remove_op_name_prefix(op_name) + + # generate python api file + api_content = API_TEMPLATE.format( + func_name=func_name, + op_name=op_name, + params_list=params_list, + dynamic_content=dynamic_content, + static_content=static_content, + ) + + NAMEDTUPLE_TEMPLATE= textwrap.dedent("""{op_name} = _namedtuple('{op_name}', '{out_names}')""").lstrip() + out_names = ' '.join([out_name for out_name in out_names]) + api_namedtuple = NAMEDTUPLE_TEMPLATE.format( + op_name=op_name, out_names=out_names) + + + return api_content, api_namedtuple + + +def main(): + parser = argparse.ArgumentParser( + description="Creates the ops.py and return_types.py files") + parser.add_argument("--input_return_types_py_in", + type=str, + required=True, + help="input file with header") + parser.add_argument("--lib", + type=str, + required=True, + help="path to open3d_paddle_ops.so") + parser.add_argument("--output_dir", + type=str, + required=True, + help="output directory") + args = parser.parse_args() + + generated_fuction_strs = "" + generated_namedtuple_strs = "" + custom_ops = load_op_meta_info_and_register_op(args.lib) + for _custom_op in custom_ops: + generated_fuction_str, generated_namedtuple_str = custom_api_content(_custom_op) + generated_fuction_strs += generated_fuction_str + "\n" + generated_namedtuple_strs += generated_namedtuple_str + "\n" + + CUSTOM_API_TEMPLATE = textwrap.dedent(""" + {custom_api_header} + + {custom_api_content} + + {custom_api_footer} + """).lstrip() + generated_ops_strs = CUSTOM_API_TEMPLATE.format( + custom_api_header = custom_api_header(), + custom_api_content = generated_fuction_strs, + custom_api_footer = custom_api_footer(custom_ops) + ) + + os.makedirs(args.output_dir, exist_ok=True) + output_ops_py_path = os.path.join(args.output_dir, 'ops.py') + with open(output_ops_py_path,'w') as f: + f.write(generated_ops_strs) + FormatFile(output_ops_py_path, in_place=True) + + output_return_types_py_path = os.path.join(args.output_dir, + 'return_types.py') + with open(args.input_return_types_py_in, 'r') as f: + input_header = f.read() + with open(output_return_types_py_path, 'w') as f: + f.write(input_header + generated_namedtuple_strs) + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) \ No newline at end of file diff --git a/cpp/pybind/make_python_package.cmake b/cpp/pybind/make_python_package.cmake index aa4171414b0..204b12a5427 100644 --- a/cpp/pybind/make_python_package.cmake +++ b/cpp/pybind/make_python_package.cmake @@ -65,7 +65,7 @@ configure_file("${PYTHON_PACKAGE_SRC_DIR}/../cpp/open3d/visualization/webrtc_ser file(COPY "${PYTHON_COMPILED_MODULE_DIR}/_build_config.py" DESTINATION "${PYTHON_PACKAGE_DST_DIR}/open3d/") -if (BUILD_TENSORFLOW_OPS OR BUILD_PYTORCH_OPS) +if (BUILD_TENSORFLOW_OPS OR BUILD_PYTORCH_OPS OR BUILD_PADDLE_OPS) # copy generated files file(COPY "${PYTHON_PACKAGE_DST_DIR}/../ml" DESTINATION "${PYTHON_PACKAGE_DST_DIR}/open3d/" ) diff --git a/docker/Dockerfile.paddle b/docker/Dockerfile.paddle new file mode 100644 index 00000000000..1a4d71334d6 --- /dev/null +++ b/docker/Dockerfile.paddle @@ -0,0 +1,83 @@ +# FROM must be called before other ARGS except for ARG BASE_IMAGE +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +# Customizable build arguments from cuda.yml +ARG DEVELOPER_BUILD +ARG CMAKE_VERSION +ARG PYTHON_VERSION +ARG BUILD_PADDLE_OPS + +# Forward all ARG to ENV +# ci_utils.sh requires these environment variables +ENV DEVELOPER_BUILD=${DEVELOPER_BUILD} +ENV CMAKE_VERSION=${CMAKE_VERSION} +ENV PYTHON_VERSION=${PYTHON_VERSION} +ENV BUILD_PYTORCH_OPS=OFF +ENV BUILD_TENSORFLOW_OPS=OFF +ENV BUILD_PADDLE_OPS=${BUILD_PADDLE_OPS} + +# Prevent interactive inputs when installing packages +ENV DEBIAN_FRONTEND=noninteractive +ENV TZ=America/Los_Angeles +ENV SUDO=command + +# Miniconda requires bash as the default shell. +SHELL ["/bin/bash", "-c"] + +# Dependencies: basic +RUN apt-get update && apt-get install -y \ + git \ + wget \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Dependencies: cmake +RUN CMAKE_VERSION_NUMBERS=$(echo "${CMAKE_VERSION}" | cut -d"-" -f2) \ + && wget -q https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION_NUMBERS}/${CMAKE_VERSION}.tar.gz \ + && tar -xf ${CMAKE_VERSION}.tar.gz \ + && cp -ar ${CMAKE_VERSION} ${HOME} +ENV PATH=${HOME}/${CMAKE_VERSION}/bin:${PATH} + +# Dependencies: gcc +RUN rm /usr/bin/gcc && rm /usr/bin/g++ \ + && mv /usr/bin/gcc.bak /usr/bin/gcc \ + && mv /usr/bin/g++.bak /usr/bin/g++ +ENV PATH=/usr/bin:${PATH} + +# Miniconda +ENV PATH="/root/miniconda3/bin:${PATH}" +RUN wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && bash Miniconda3-latest-Linux-x86_64.sh -b \ + && rm Miniconda3-latest-Linux-x86_64.sh \ + && conda --version +ENV PATH="/root/miniconda3/envs/open3d/bin:${PATH}" +RUN conda create -y -n open3d python=${PYTHON_VERSION} \ + && source activate open3d +RUN which python \ + && python --version + +# Open3D C++ dependencies +# Done before copying the full Open3D directory for better Docker caching +COPY ./util/install_deps_ubuntu.sh /root/Open3D/util/ +RUN /root/Open3D/util/install_deps_ubuntu.sh assume-yes \ + && rm -rf /var/lib/apt/lists/* + +# Open3D Python dependencies +COPY ./util/ci_utils.sh /root/Open3D/util/ +COPY ./python/requirements.txt /root/Open3D/python/ +RUN source /root/Open3D/util/ci_utils.sh \ + && install_python_dependencies + +# Open3D repo +# Always keep /root/Open3D as the WORKDIR +COPY . /root/Open3D +WORKDIR /root/Open3D + +# Build python wheel +RUN export NPROC=$(nproc) \ + && export BUILD_SHARED_LIBS=OFF \ + && source /root/Open3D/util/ci_utils.sh \ + && build_pip_package build_azure_kinect + +RUN echo "Docker build done." diff --git a/docker/README.md b/docker/README.md index cd3a8198579..752d28ba439 100644 --- a/docker/README.md +++ b/docker/README.md @@ -65,4 +65,4 @@ cd docker ./docker_test.sh openblas-amd64-py38-dev ``` -See `./docker_build.sh` and `./docker_test.sh` for all available options. +See `./docker_build.sh` and `./docker_test.sh` for all available options. \ No newline at end of file diff --git a/docker/docker_build.sh b/docker/docker_build.sh index 5ff0a1a0942..b4afaddde2f 100755 --- a/docker/docker_build.sh +++ b/docker/docker_build.sh @@ -27,65 +27,62 @@ OPTION: openblas-amd64-py39-dev : OpenBLAS AMD64 3.9 wheel, developer mode openblas-amd64-py310-dev : OpenBLAS AMD64 3.10 wheel, developer mode openblas-amd64-py311-dev : OpenBLAS AMD64 3.11 wheel, developer mode - openblas-amd64-py312-dev : OpenBLAS AMD64 3.12 wheel, developer mode openblas-amd64-py38 : OpenBLAS AMD64 3.8 wheel, release mode openblas-amd64-py39 : OpenBLAS AMD64 3.9 wheel, release mode openblas-amd64-py310 : OpenBLAS AMD64 3.10 wheel, release mode openblas-amd64-py311 : OpenBLAS AMD64 3.11 wheel, release mode - openblas-amd64-py312 : OpenBLAS AMD64 3.12 wheel, release mode # OpenBLAS ARM64 (Dockerfile.openblas) openblas-arm64-py38-dev : OpenBLAS ARM64 3.8 wheel, developer mode openblas-arm64-py39-dev : OpenBLAS ARM64 3.9 wheel, developer mode openblas-arm64-py310-dev : OpenBLAS ARM64 3.10 wheel, developer mode openblas-arm64-py311-dev : OpenBLAS ARM64 3.11 wheel, developer mode - openblas-arm64-py312-dev : OpenBLAS ARM64 3.12 wheel, developer mode openblas-arm64-py38 : OpenBLAS ARM64 3.8 wheel, release mode openblas-arm64-py39 : OpenBLAS ARM64 3.9 wheel, release mode openblas-arm64-py310 : OpenBLAS ARM64 3.10 wheel, release mode openblas-arm64-py311 : OpenBLAS ARM64 3.11 wheel, release mode - openblas-arm64-py312 : OpenBLAS ARM64 3.12 wheel, release mode # Ubuntu CPU CI (Dockerfile.ci) - cpu-static : Ubuntu CPU static - cpu-shared : Ubuntu CPU shared (cxx11_abi) - cpu-shared-release : Ubuntu CPU shared (cxx11_abi), release mode - cpu-shared-ml : Ubuntu CPU shared with ML (pre_cxx11_abi) - cpu-shared-ml-release : Ubuntu CPU shared with ML (pre_cxx11_abi), release mode + cpu-static : Ubuntu CPU static + cpu-shared : Ubuntu CPU shared (cxx11_abi) + cpu-shared-release : Ubuntu CPU shared (cxx11_abi), release mode + cpu-shared-ml : Ubuntu CPU shared with ML (pre_cxx11_abi) + cpu-shared-ml-release : Ubuntu CPU shared with ML (pre_cxx11_abi), release mode # Sycl CPU CI (Dockerfile.ci) - sycl-shared : SYCL (oneAPI) with shared lib - sycl-static : SYCL (oneAPI) with static lib + sycl-shared : SYCL (oneAPI) with shared lib + sycl-static : SYCL (oneAPI) with static lib # ML CIs (Dockerfile.ci) - 2-focal : CUDA CI, 2-bionic, developer mode - 3-ml-shared-focal-release : CUDA CI, 3-ml-shared-bionic (pre_cxx11_abi), release mode - 3-ml-shared-focal : CUDA CI, 3-ml-shared-bionic (pre_cxx11_abi), developer mode - 4-shared-focal : CUDA CI, 4-shared-bionic (cxx11_abi), developer mode - 4-shared-focal-release : CUDA CI, 4-shared-bionic (cxx11_abi), release mode - 5-ml-jammy : CUDA CI, 5-ml-focal, developer mode + 2-bionic : CUDA CI, 2-bionic, developer mode + 3-ml-shared-bionic-release : CUDA CI, 3-ml-shared-bionic (pre_cxx11_abi), release mode + 3-ml-shared-bionic : CUDA CI, 3-ml-shared-bionic (pre_cxx11_abi), developer mode + 4-shared-bionic : CUDA CI, 4-shared-bionic (cxx11_abi), developer mode + 4-shared-bionic-release : CUDA CI, 4-shared-bionic (cxx11_abi), release mode + 5-ml-focal : CUDA CI, 5-ml-focal, developer mode # CUDA wheels (Dockerfile.wheel) cuda_wheel_py38_dev : CUDA Python 3.8 wheel, developer mode cuda_wheel_py39_dev : CUDA Python 3.9 wheel, developer mode cuda_wheel_py310_dev : CUDA Python 3.10 wheel, developer mode cuda_wheel_py311_dev : CUDA Python 3.11 wheel, developer mode - cuda_wheel_py312_dev : CUDA Python 3.12 wheel, developer mode cuda_wheel_py38 : CUDA Python 3.8 wheel, release mode cuda_wheel_py39 : CUDA Python 3.9 wheel, release mode cuda_wheel_py310 : CUDA Python 3.10 wheel, release mode cuda_wheel_py311 : CUDA Python 3.11 wheel, release mode - cuda_wheel_py312 : CUDA Python 3.12 wheel, release mode + + # Paddle wheels (Dockerfile.paddle) + paddle_cuda_wheel_py310_dev : CUDA Python 3.10 wheel, developer mode + paddle_cuda_wheel_py310 : CUDA Python 3.10 wheel, release mode " HOST_OPEN3D_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. >/dev/null 2>&1 && pwd)" # Shared variables CCACHE_VERSION=4.3 -CMAKE_VERSION=cmake-3.29.2-linux-x86_64 -CMAKE_VERSION_AARCH64=cmake-3.24.4-linux-aarch64 -CUDA_VERSION=12.1.0-cudnn8 -CUDA_VERSION_LATEST=12.1.0-cudnn8 +CMAKE_VERSION=cmake-3.20.6-linux-x86_64 +CMAKE_VERSION_AARCH64=cmake-3.20.6-linux-aarch64 +CUDA_VERSION=11.7.1-cudnn8 print_usage_and_exit_docker_build() { echo "$__usage_docker_build" @@ -230,6 +227,45 @@ cuda_wheel_build() { && chown $(id -u):$(id -g) /opt/mount/${CCACHE_TAR_NAME}.tar.xz" } +paddle_cuda_wheel_build() { + BASE_IMAGE=registry.baidubce.com/paddlepaddle/paddle:3.0.0b1-gpu-${PADDLE_CUDA_VERSION} + BUILD_PADDLE_OPS=ON + + options="$(echo "$@" | tr ' ' '|')" + echo "[cuda_wheel_build()] options: ${options}" + if [[ "py310" =~ ^($options)$ ]]; then + PYTHON_VERSION=3.10 + else + echo "Invalid python version." + print_usage_and_exit_docker_build + fi + if [[ "dev" =~ ^($options)$ ]]; then + DEVELOPER_BUILD=ON + else + DEVELOPER_BUILD=OFF + fi + echo "[paddle_cuda_wheel_build()] PYTHON_VERSION: ${PYTHON_VERSION}" + echo "[paddle_cuda_wheel_build()] DEVELOPER_BUILD: ${DEVELOPER_BUILD}" + echo "[paddle_cuda_wheel_build()] BUILD_PADDLE_OPS=${BUILD_PADDLE_OPS:?'env var must be set.'}" + + pushd "${HOST_OPEN3D_ROOT}" + docker build \ + --progress plain \ + --build-arg BASE_IMAGE="${BASE_IMAGE}" \ + --build-arg DEVELOPER_BUILD="${DEVELOPER_BUILD}" \ + --build-arg CMAKE_VERSION="${CMAKE_VERSION}" \ + --build-arg PYTHON_VERSION="${PYTHON_VERSION}" \ + --build-arg BUILD_PADDLE_OPS="${BUILD_PADDLE_OPS}" \ + -t open3d-ci:wheel \ + -f docker/Dockerfile.paddle . + popd + + python_package_dir=/root/Open3D/build/lib/python_package + docker run -v "${PWD}:/opt/mount" --rm open3d-ci:wheel \ + bash -c "cp ${python_package_dir}/pip_package/open3d*.whl /opt/mount \ + && chown $(id -u):$(id -g) /opt/mount/open3d*.whl" +} + ci_build() { echo "[ci_build()] DOCKER_TAG=${DOCKER_TAG}" echo "[ci_build()] BASE_IMAGE=${BASE_IMAGE}" @@ -636,9 +672,6 @@ function main() { cuda_wheel_py311) cuda_wheel_build py311 ;; - cuda_wheel_py312) - cuda_wheel_build py312 - ;; # ML CIs 2-focal) diff --git a/docs/documented_modules.txt b/docs/documented_modules.txt index 62748dfcd91..a521b7b1e66 100644 --- a/docs/documented_modules.txt +++ b/docs/documented_modules.txt @@ -45,6 +45,10 @@ open3d.ml.torch.modules open3d.ml.torch.modules.losses open3d.ml.torch.modules.metrics open3d.ml.torch.pipelines +open3d.ml.paddle +open3d.ml.paddle.layers +open3d.ml.paddle.ops +open3d.ml.paddle.classes open3d.pipelines open3d.pipelines.color_map open3d.pipelines.integration diff --git a/docs/make_docs.py b/docs/make_docs.py index 7f691afd81c..ffa58548b3a 100644 --- a/docs/make_docs.py +++ b/docs/make_docs.py @@ -103,6 +103,8 @@ def _try_import_module(self, full_module_name): import open3d.ml.tf if open3d._build_config['BUILD_PYTORCH_OPS']: import open3d.ml.torch + if open3d._build_config['BUILD_PADDLE_OPS']: + import open3d.ml.paddle try: # Try to import directly. This will work for pure python submodules @@ -139,7 +141,8 @@ def _generate_class_doc(self, full_module_name, class_name, output_path): out_string += "\n :members:" out_string += "\n :undoc-members:" if not (full_module_name.startswith("open3d.ml.tf") or - full_module_name.startswith("open3d.ml.torch")): + full_module_name.startswith("open3d.ml.torch") or + full_module_name.startswith("open3d.ml.paddle")): out_string += "\n :inherited-members:" out_string += "\n" diff --git a/docs/python_api_in/open3d.ml.paddle.layers.rst b/docs/python_api_in/open3d.ml.paddle.layers.rst new file mode 100644 index 00000000000..4bd5af4a953 --- /dev/null +++ b/docs/python_api_in/open3d.ml.paddle.layers.rst @@ -0,0 +1,29 @@ +open3d.ml.paddle.layers +---------------------- + +.. currentmodule:: open3d.ml.paddle.layers + +.. automodule:: open3d.ml.paddle.layers + +**Classes** + +.. autosummary:: + + ContinuousConv + FixedRadiusSearch + KNNSearch + RadiusSearch + SparseConv + SparseConvTranspose + VoxelPooling + +.. toctree:: + :hidden: + + ContinuousConv + FixedRadiusSearch + KNNSearch + RadiusSearch + SparseConv + SparseConvTranspose + VoxelPooling diff --git a/docs/python_api_in/open3d.ml.paddle.ops.rst b/docs/python_api_in/open3d.ml.paddle.ops.rst new file mode 100644 index 00000000000..0a027448cd1 --- /dev/null +++ b/docs/python_api_in/open3d.ml.paddle.ops.rst @@ -0,0 +1,38 @@ +open3d.ml.paddle.ops +------------------- + +.. currentmodule:: open3d.ml.paddle.ops + +.. automodule:: open3d.ml.paddle.ops + +**Functions** + +.. autosummary:: + + build_spatial_hash_table + continuous_conv + continuous_conv_transpose + fixed_radius_search + invert_neighbors_list + knn_search + nms + radius_search + reduce_subarrays_sum + voxel_pooling + voxelize + +.. toctree:: + :hidden: + + build_spatial_hash_table + continuous_conv + continuous_conv_transpose + fixed_radius_search + invert_neighbors_list + knn_search + nms + radius_search + reduce_subarrays_sum + voxel_pooling + voxelize + diff --git a/python/open3d/ml/paddle/__init__.py b/python/open3d/ml/paddle/__init__.py new file mode 100644 index 00000000000..260429d1d20 --- /dev/null +++ b/python/open3d/ml/paddle/__init__.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +from packaging.version import parse as _verp +import paddle as _paddle +from open3d import _build_config + +if not _build_config["Paddle_VERSION"]: + raise Exception('Open3D was not built with Paddle support!') +_o3d_paddle_version = _verp(_build_config["Paddle_VERSION"]) +# Check match with Paddle version, any patch level is OK +if _verp(_paddle.__version__).release[:2] != _o3d_paddle_version.release[:2]: + match_paddle_ver = '.'.join( + str(v) for v in _o3d_paddle_version.release[:2] + ('*',)) + raise Exception('Version mismatch: Open3D needs Paddle version {}, but ' + 'version {} is installed!'.format(match_paddle_ver, + _paddle.__version__)) + +_loaded = False +try: + from . import ops + _loaded = True +except Exception as e: + raise e + +from . import layers +from . import classes + +# put contrib at the same level +from open3d.ml import contrib diff --git a/python/open3d/ml/paddle/classes/__init__.py b/python/open3d/ml/paddle/classes/__init__.py new file mode 100644 index 00000000000..a321fa69f87 --- /dev/null +++ b/python/open3d/ml/paddle/classes/__init__.py @@ -0,0 +1,22 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +"""Paddle specific machine learning classes.""" +import paddle + +from .ragged_tensor import RaggedTensor + +DTYPE_MAP = { + paddle.bool: 'bool', + paddle.float16: 'float16', + paddle.float32: 'float32', + paddle.float64: 'float64', + paddle.int8: 'int8', + paddle.int16: 'int16', + paddle.int32: 'int32', + paddle.int64: 'int64', + paddle.bfloat16: 'uint16', +} diff --git a/python/open3d/ml/paddle/classes/ragged_tensor.py b/python/open3d/ml/paddle/classes/ragged_tensor.py new file mode 100644 index 00000000000..ecc6ad5b443 --- /dev/null +++ b/python/open3d/ml/paddle/classes/ragged_tensor.py @@ -0,0 +1,199 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +import paddle +import numpy as np + +__all__ = ['RaggedTensor'] + + +class RaggedTensor: + + def __init__(self, values, row_splits, internal=False): + if not internal: + raise ValueError( + "RaggedTensor constructor is private, please use one of the factory method instead(e.g. RaggedTensor.from_row_splits())" + ) + self._values = values + self._row_splits = row_splits + + @classmethod + def _from_row_splits(cls, values, row_splits, validate=True): + if row_splits.dtype != paddle.int64: + raise ValueError("row_splits must have type paddle.int64") + + values = values.contiguous() + row_splits = row_splits.contiguous() + + if validate: + if len(row_splits.shape) != 1: + raise ValueError("row_splits must be of rank 1") + if row_splits[0] != 0: + raise ValueError( + f"Arguments to from_row_splits do not form a valid RaggedTensor. Expect row_splits[0] == 0 but received row_splits[0] == {row_splits[0]}." + ) + for i in range(0, row_splits.shape[0] - 1): + if row_splits[i] > row_splits[i + 1]: + raise ValueError( + "row_splits must be monotonically increasing") + + row_splits = row_splits.to(values.place) + + return values, row_splits + + @classmethod + def from_row_splits(cls, values, row_splits, validate=True, copy=True): + + if isinstance(values, list): + values = paddle.to_tensor(values, dtype=paddle.float64) + elif isinstance(values, np.ndarray): + values = paddle.to_tensor(values) + elif isinstance(values, paddle.Tensor) and copy: + values = values.clone() + + if isinstance(row_splits, list): + row_splits = paddle.to_tensor(row_splits, dtype=paddle.int64) + elif isinstance(row_splits, np.ndarray): + row_splits = paddle.to_tensor(row_splits) + elif isinstance(row_splits, paddle.Tensor) and copy: + row_splits = row_splits.clone() + + values, row_splits = cls._from_row_splits(values, row_splits, validate) + + return cls(values, row_splits, internal=True) + + @property + def values(self): + """The concatenated rows for this ragged tensor.""" + return self._values + + @property + def row_splits(self): + """The row-split indices for this ragged tensor's `values`.""" + return self._row_splits + + @property + def dtype(self): + """The `DType` of values in this ragged tensor.""" + return self._values.dtype + + @property + def device(self): + """The device of values in this ragged tensor.""" + return self._values.place + + @property + def shape(self): + """The statically known shape of this ragged tensor.""" + return [ + len(self._row_splits.shape[0] - 1), None, *self._values.shape[1:] + ] + + @property + def requires_grad(self): + """Read/writeble `requires_grad` for values.""" + return not self._values.stop_gradient + + @requires_grad.setter + def requires_grad(self, value): + # NOTE: stop_gradient=True means not requires grad + self._values.stop_gradient = not value + + def clone(self): + """Returns a clone of object.""" + return self.__class__(self._values.clone(), self._row_splits.clone(), + True) + + def to_list(self): + """Returns a list of tensors""" + return [tensor for tensor in self._values] + + def __getitem__(self, idx): + return self._values.slice([ + 0, + ], [ + self._row_splits[idx], + ], [ + self._row_splits[idx + 1], + ]) + + def __repr__(self): + return f"RaggedTensor(values={self._values}, row_splits={self._row_splits})" + + def __len__(self): + return len(self._row_splits.shape[0] - 1) + + def __add__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values + self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values, row_splits, True) + + def __iadd__(self, other): + paddle.assign(self._values + self.__convert_to_tensor(other), + self._values) + return self + + def __sub__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values - self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __isub__(self, other): + paddle.assign(self._values - self.__convert_to_tensor(other), + self._values) + return self + + def __mul__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values * self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __imul__(self, other): + paddle.assign(self._values * self.__convert_to_tensor(other), + self._values) + return self + + def __truediv__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values / self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __itruediv__(self, other): + paddle.assign(self._values / self.__convert_to_tensor(other), + self._values) + return self + + def __floordiv__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values // self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __ifloordiv__(self, other): + paddle.assign(self._values // self.__convert_to_tensor(other), + self._values) + return self + + def __convert_to_tensor(self, value): + """Converts scalar/tensor/RaggedTensor to paddle.Tensor""" + if isinstance(value, RaggedTensor): + if self._row_splits.shape != value.row_splits.shape or paddle.any( + self._row_splits != value.row_splits).item(): + raise ValueError( + f"Incompatible shape : {self._row_splits} and {value.row_splits}" + ) + return value.values + elif isinstance(value, paddle.Tensor): + return value + elif isinstance(value, (int, float, bool)): + return paddle.to_tensor([value], dtype=type(value)) + else: + raise ValueError(f"Unknown type : {type(value)}") diff --git a/python/open3d/ml/paddle/layers/__init__.py b/python/open3d/ml/paddle/layers/__init__.py new file mode 100644 index 00000000000..368c90f5e43 --- /dev/null +++ b/python/open3d/ml/paddle/layers/__init__.py @@ -0,0 +1,14 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +"""High level layer API for building networks. + +This module contains layers for processing 3D data. +All layers subclass paddle.nn.Layer +""" +from ..python.layers.neighbor_search import * +from ..python.layers.convolutions import * +from ..python.layers.voxel_pooling import * diff --git a/python/open3d/ml/paddle/ops/__init__.py b/python/open3d/ml/paddle/ops/__init__.py new file mode 100644 index 00000000000..2e9336fd76e --- /dev/null +++ b/python/open3d/ml/paddle/ops/__init__.py @@ -0,0 +1,79 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +"""Functional API with operators. + +These are the building blocks for the layers. See The layer API for an easy to +use high level interface. +""" +import os as _os +import sys as _sys +import types as _types +import importlib as _importlib +import importlib.abc as _importlib_abc +import importlib.util as _importlib_util +import paddle as _paddle +from open3d import _build_config + +from ..python.ops import * + +_lib_path = [] +# allow overriding the path to the op library with an env var. +if 'OPEN3D_PADDLE_OP_LIB' in _os.environ: + _lib_path.append(_os.environ['OPEN3D_PADDLE_OP_LIB']) + +_this_dir = _os.path.dirname(__file__) +_package_root = _os.path.join(_this_dir, '..', '..', '..') +_lib_ext = {'linux': '.so', 'darwin': '.dylib', 'win32': '.dll'}[_sys.platform] +_lib_suffix = '_debug' if _build_config['CMAKE_BUILD_TYPE'] == 'Debug' else '' +_lib_arch = ('cpu',) +if _build_config["BUILD_CUDA_MODULE"] and _paddle.device.cuda.device_count( +) >= 1: + if _paddle.version.cuda() == _build_config["CUDA_VERSION"]: + _lib_arch = ('cuda', 'cpu') + else: + print("Warning: Open3D was built with CUDA {} but" + "Paddle was built with CUDA {}. Falling back to CPU for now." + "Otherwise, install Paddle with CUDA {}.".format( + _build_config["CUDA_VERSION"], _paddle.version.cuda(), + _build_config["CUDA_VERSION"])) +_lib_path.extend([ + _os.path.join(_package_root, la, + 'open3d_paddle_ops' + _lib_suffix + _lib_ext) + for la in _lib_arch +]) + +_loaded_lib = False +_loaded_except = None +for _lp in _lib_path: + try: + _load_lib_path = _lp + # load custom op shared library with abs path + _custom_ops = _paddle.utils.cpp_extension.load_op_meta_info_and_register_op( + _load_lib_path) + _loaded_lib = True + + break + + except Exception as e: + _loaded_except = e + + if not _os.path.isfile(_lp): + print('The op library at "{}" was not found. Make sure that ' + 'BUILD_PADDLE_OPS was enabled.'.format( + _os.path.realpath(_lp))) + +if not _loaded_lib: + raise _loaded_except + +try: + _spec = _importlib_util.spec_from_file_location(__name__, _load_lib_path) + assert _spec is not None + _mod = _importlib_util.module_from_spec(_spec) + assert isinstance(_spec.loader, _importlib_abc.Loader) + _spec.loader.exec_module(_mod) +except ImportError: + _mod = _types.ModuleType(__name__) diff --git a/python/open3d/ml/paddle/python/layers/convolutions.py b/python/open3d/ml/paddle/python/layers/convolutions.py new file mode 100644 index 00000000000..59eade8aba8 --- /dev/null +++ b/python/open3d/ml/paddle/python/layers/convolutions.py @@ -0,0 +1,742 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +from ...python import ops +from ....paddle import classes +from .neighbor_search import FixedRadiusSearch, RadiusSearch +import paddle +from paddle import create_parameter +import numpy as np + +__all__ = ['ContinuousConv', 'SparseConv', 'SparseConvTranspose'] + + +class ContinuousConv(paddle.nn.Layer): + r"""Continuous Convolution. + + This convolution supports continuous input and output point positions. + This layer implements the convolution defined in + + *B. Ummenhofer and V. Koltun, Lagrangian Fluid Simulation with Continuous Convolutions, ICLR 2020.* + + The convolution at position :math:`\mathbf x` is defined as + + .. math:: + (f*g)(\mathbf x) = \frac{1}{\psi(\mathbf x)} \sum_{i \in \mathcal N(\mathbf x, R)} a(\mathbf x_i, \mathbf x)\; f_i\; g(\Lambda(\mathbf x_i - \mathbf x)). + + With :math:`f` as the input feature function and :math:`g` as the filter function. + The input points are :math:`\mathbf x_i` and the input features are :math:`f_i`. + The normalization :math:`\frac{1}{\psi(\mathbf x)}` can be turned on with the **normalize** parameter. + The per neighbor value :math:`a(\mathbf x_i, \mathbf x)` can be used to implement window functions; see parameter **window_function**. + The function :math:`\Lambda` for looking up filter values is defined by the parameters **coordinate_mapping** and **interpolation**. + + Example: + This shows a minimal example of how to use the layer:: + + import paddle + import open3d.ml.paddle as ml3d + + inp_positions = paddle.randn([20, 3]) + inp_features = paddle.randn([20, 8]) + out_positions = paddle.randn([10, 3]) + + conv = ml3d.layers.ContinuousConv(in_channels=8, filters=16, kernel_size=[3,3,3]) + out_features = conv(inp_features, inp_positions, out_positions, extents=2.0) + + + Arguments: + in_channels: The number of input channels. + + filters: The number of filters/output channels. + + kernel_size: The spatial resolution of the filter, e.g. [3,3,3]. + + activation: The activation function to use. None means no activation. + + use_bias: If True adds an additive bias vector. + + kernel_initializer: Initializer for the kernel weights. + + bias_initializer: Initializer for the bias vector. + + align_corners: If true then the voxel centers of the outer voxels of the + filter array are mapped to the boundary of the filter shape. + If false then the boundary of the filter array is mapped to the + boundary of the filter shape. + + coordinate_mapping: The mapping that is applied to the input coordinates. + One of 'ball_to_cube_radial', 'ball_to_cube_volume_preserving', + 'identity'. + + * 'ball_to_cube_radial' uses radial stretching to map a sphere to + a cube. + * 'ball_to_cube_volume_preserving' is using a more expensive volume + preserving mapping to map a sphere to a cube. + * 'identity' no mapping is applied to the coordinates. + + interpolation: One of 'linear', 'linear_border', 'nearest_neighbor'. + * 'linear' is trilinear interpolation with coordinate clamping. + * 'linear_border' uses a zero border if outside the range. + * 'nearest_neighbor' uses the nearest neighbor instead of interpolation. + + normalize: If true then the result is normalized either by the number of + points (neighbors_importance is null) or by the sum of the respective + values in neighbors_importance. + + radius_search_ignore_query_points: If true the points that coincide with the + center of the search window will be ignored. This excludes the query point + if 'queries' and 'points' are the same point cloud. + + radius_search_metric: Either L1, L2 or Linf. Default is L2 + + offset: A single 3D vector used in the filter coordinate computation. + The shape is [3]. + + window_function: Optional radial window function to steer the importance of + points based on their distance to the center. The input to the function + is a 1D tensor of distances (squared distances if radius_search_metric is + 'L2'). The output must be a tensor of the same shape. Example:: + + def window_fn(r_sqr): + return paddle.clamp((1-r_sqr)**3, 0, 1) + + use_dense_layer_for_center: If True a linear dense layer is used to + process the input features for each point. The result is added to the + result of the convolution before adding the bias. This option is + useful when using even kernel sizes that have no center element and + input and output point sets are the same and + 'radius_search_ignore_query_points' has been set to True. + + dense_kernel_initializer: Initializer for the kernel weights of the + linear layer used for the center if 'use_dense_layer_for_center' + is True. + """ + + def __init__(self, + in_channels, + filters, + kernel_size, + activation=None, + use_bias=True, + kernel_initializer=paddle.nn.initializer.Uniform(-0.05, 0.05), + bias_initializer=paddle.nn.initializer.Constant(), + align_corners=True, + coordinate_mapping='ball_to_cube_radial', + interpolation='linear', + normalize=True, + radius_search_ignore_query_points=False, + radius_search_metric='L2', + offset=None, + window_function=None, + use_dense_layer_for_center=False, + dense_kernel_initializer=paddle.nn.initializer.XavierUniform(), + **kwargs): + super().__init__() + + self.in_channels = in_channels + self.filters = filters + self.kernel_size = kernel_size + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.align_corners = align_corners + self.coordinate_mapping = coordinate_mapping + self.interpolation = interpolation + self.normalize = normalize + self.radius_search_ignore_query_points = radius_search_ignore_query_points + self.radius_search_metric = radius_search_metric + self.dense_kernel_initializer = dense_kernel_initializer + + if offset is None: + offset = paddle.zeros(shape=(3,), dtype=paddle.float32) + self.register_buffer('offset', offset) + + self.window_function = window_function + + self.fixed_radius_search = FixedRadiusSearch( + metric=self.radius_search_metric, + ignore_query_point=self.radius_search_ignore_query_points, + return_distances=not self.window_function is None) + + self.radius_search = RadiusSearch( + metric=self.radius_search_metric, + ignore_query_point=self.radius_search_ignore_query_points, + return_distances=not self.window_function is None, + normalize_distances=not self.window_function is None) + + self.use_dense_layer_for_center = use_dense_layer_for_center + if self.use_dense_layer_for_center: + self.dense = paddle.nn.Linear(self.in_channels, + self.filters, + bias=False) + self.dense_kernel_initializer(self.dense.weight) + + kernel_shape = (*self.kernel_size, self.in_channels, self.filters) + # self.kernel = paddle.nn.Parameter(data=paddle.Tensor(*kernel_shape), + # requires_grad=True) + self.kernel = create_parameter( + kernel_shape, + dtype=paddle.float32, + default_initializer=self.kernel_initializer) + + if self.use_bias: + self.bias = create_parameter((self.filters,), + dtype=paddle.float32, + is_bias=True, + default_initializer=bias_initializer) + + def forward(self, + inp_features, + inp_positions, + out_positions, + extents, + inp_importance=None, + fixed_radius_search_hash_table=None, + user_neighbors_index=None, + user_neighbors_row_splits=None, + user_neighbors_importance=None): + """This function computes the output features. + + Arguments: + inp_features: A 2D tensor which stores a feature vector for each input + point. + + inp_positions: A 2D tensor with the 3D point positions of each input + point. The coordinates for each point is a vector with format [x,y,z]. + + out_positions: A 2D tensor with the 3D point positions of each output + point. The coordinates for each point is a vector with format [x,y,z]. + + extents: The extent defines the spatial size of the filter for each + output point. + For 'ball to cube' coordinate mappings the extent defines the + bounding box of the ball. + The shape of the tensor is either [1] or [num output points]. + + inp_importance: Optional scalar importance value for each input point. + + fixed_radius_search_hash_table: A precomputed hash table generated with + build_spatial_hash_table(). + This input can be used to explicitly force the reuse of a hash table in + special cases and is usually not needed. + Note that the hash table must have been generated with the same 'points' + array. Note that this parameter is only used if 'extents' is a scalar. + + user_neighbors_index: This parameter together with 'user_neighbors_row_splits' + and 'user_neighbors_importance' allows to override the automatic neighbor + search. This is the list of neighbor indices for each output point. + This is a nested list for which the start and end of each sublist is + defined by 'user_neighbors_row_splits'. + + user_neighbors_row_splits: Defines the start and end of each neighbors + list in 'user_neighbors_index'. + + user_neighbors_importance: Defines a scalar importance value for each + element in 'user_neighbors_index'. + + + Returns: A tensor of shape [num output points, filters] with the output + features. + """ + offset = self.offset + if isinstance(extents, (float, int)): + extents = paddle.to_tensor(extents, dtype=inp_positions.dtype) + + if inp_importance is None: + inp_importance = paddle.empty( + (0,), dtype=paddle.float32).to(self.kernel.place) + + + if not user_neighbors_index is None and not user_neighbors_row_splits is None: + + if user_neighbors_importance is None: + neighbors_importance = paddle.empty( + (0,), dtype=paddle.float32).to(self.kernel.place) + else: + neighbors_importance = user_neighbors_importance + + neighbors_index = user_neighbors_index + neighbors_row_splits = user_neighbors_row_splits + + else: + if len(extents.shape) == 0: + radius = 0.5 * extents + self.nns = self.fixed_radius_search( + inp_positions, + queries=out_positions, + radius=radius, + hash_table=fixed_radius_search_hash_table) + + elif len(extents.shape) == 1: + radii = 0.5 * extents + self.nns = self.radius_search(inp_positions, + queries=out_positions, + radii=radii) + + else: + raise ValueError("extents rank must be 0 or 1") + + if self.window_function is None: + neighbors_importance = paddle.empty((0,), dtype=paddle.float32) + else: + if self.radius_search_metric == 'L2': + neighbors_distance_normalized = self.nns.neighbors_distance / ( + radius * radius) + else: # L1 + neighbors_distance_normalized = self.nns.neighbors_distance / radius + neighbors_importance = self.window_function( + neighbors_distance_normalized) + + neighbors_index = self.nns.neighbors_index + neighbors_row_splits = self.nns.neighbors_row_splits + + # for stats and debugging + num_pairs = neighbors_index.shape[0] + self._avg_neighbors = num_pairs / out_positions.shape[0] + + extents_rank2 = extents + while len(extents_rank2.shape) < 2: + extents_rank2 = paddle.unsqueeze(extents_rank2, axis=-1) + + self._conv_values = { + 'filters': self.kernel, + 'out_positions': out_positions, + 'extents': extents_rank2, + 'offset': offset, + 'inp_positions': inp_positions, + 'inp_features': inp_features, + 'inp_importance': inp_importance, + 'neighbors_index': neighbors_index, + 'neighbors_row_splits': neighbors_row_splits, + 'neighbors_importance': neighbors_importance, + 'align_corners': self.align_corners, + 'coordinate_mapping': self.coordinate_mapping, + 'interpolation': self.interpolation, + 'normalize': self.normalize, + 'max_temp_mem_mb': 64 + } + + out_features = ops.continuous_conv(**self._conv_values) + + self._conv_output = out_features + + if self.use_dense_layer_for_center: + self._dense_output = self.dense(inp_features) + out_features = out_features + self._dense_output + + if self.use_bias: + out_features += self.bias + if not self.activation is None: + out_features = self.activation(out_features) + + return out_features + + +class SparseConv(paddle.nn.Layer): + """Sparse Convolution. + + This layer computes a convolution which is only evaluated at the specified output positions. + The layer assumes that input and output points lie on a regular grid. + + Example: + This shows a minimal example of how to use the layer:: + + import paddle + import open3d.ml.paddle as ml3d + + # +0.5 to move the points to the voxel center + inp_positions = paddle.randint(0, 10, [20,3]).to(paddle.float32) + 0.5 + inp_features = paddle.randn([20,8]) + out_positions = paddle.randint(0, 10, [20,3]).to(paddle.float32) + 0.5 + + conv = ml3d.layers.SparseConv(in_channels=8, filters=16, kernel_size=[3,3,3]) + out_features = conv(inp_features, inp_positions, out_positions, voxel_size=1.0) + + + Arguments: + in_channels: The number of input channels. + + filters: The number of filters/output channels. + + kernel_size: The spatial resolution of the filter, e.g. [3,3,3]. + + activation: The activation function to use. None means no activation. + + use_bias: If True adds an additive bias vector. + + kernel_initializer: Initializer for the kernel weights. + + bias_initializer: Initializer for the bias vector. + + normalize: If true then the result is normalized by the number of input points. + + offset: A single 3D vector used in the filter coordinate computation. + The shape is [3]. This can be used to control how the filters are + centered. It will be set automatically for kernels with even sizes. + """ + + def __init__(self, + in_channels, + filters, + kernel_size, + activation=None, + use_bias=True, + kernel_initializer=paddle.nn.initializer.Uniform(-0.05, 0.05), + bias_initializer=paddle.nn.initializer.Constant(), + normalize=False, + offset=None, + **kwargs): + super().__init__() + + self.in_channels = in_channels + self.filters = filters + self.kernel_size = kernel_size + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.normalize = normalize + + if not (np.asarray(kernel_size) == kernel_size[0]).all(): + raise ValueError("Only cubic kernel sizes are supported.") + + if offset is None: + if kernel_size[0] % 2: + offset = paddle.zeros(shape=(3,), dtype=paddle.float32) + else: + offset = paddle.full((3,), -0.5, dtype=paddle.float32) + self.register_buffer('offset', offset) + + self.fixed_radius_search = FixedRadiusSearch(metric='Linf', + ignore_query_point=False, + return_distances=False) + + kernel_shape = (*self.kernel_size, self.in_channels, self.filters) + self.kernel = create_parameter( + kernel_shape, + dtype=paddle.float32, + default_initializer=self.kernel_initializer) + + if self.use_bias: + self.bias = create_parameter((self.filters,), + dtype=paddle.float32, + is_bias=True, + default_initializer=bias_initializer) + + def forward(self, + inp_features, + inp_positions, + out_positions, + voxel_size, + inp_importance=None, + fixed_radius_search_hash_table=None): + """This function computes the output features. + + Arguments: + inp_features: A 2D tensor which stores a feature vector for each input + point. + + inp_positions: A 2D tensor with the 3D point positions of each input + point. The coordinates for each point is a vector with format [x,y,z]. + + out_positions: A 2D tensor with the 3D point positions of each output + point. The coordinates for each point is a vector with format [x,y,z]. + + voxel_size: A scalar float that defines the edge length of a voxel. + + inp_importance: Optional scalar importance value for each input point. + + fixed_radius_search_hash_table: A precomputed hash table generated with + build_spatial_hash_table(). This input can be used to explicitly force the + reuse of a hash table in special cases and is usually not needed. + Note that the hash table must have been generated with the same 'points' + array. Note that this parameter is only used if 'extents' is a scalar. + + Returns: A tensor of shape [num output points, filters] with the output + features. + """ + if isinstance(inp_features, classes.RaggedTensor): + if not (isinstance(inp_positions, classes.RaggedTensor) and + isinstance(out_positions, classes.RaggedTensor)): + raise ValueError( + "All of inp_positions, inp_features and out_positions must be paddle.Tensor, or ml3d.classes.RaggedTensor" + ) + + offset = self.offset + if isinstance(voxel_size, (float, int)): + voxel_size = paddle.to_tensor( + voxel_size, dtype=inp_positions.dtype).to(self.kernel.place) + if len(voxel_size.shape) != 0: + raise Exception("voxel_size must be a scalar") + + if inp_importance is None: + inp_importance = paddle.empty( + (0,), dtype=paddle.float32).to(self.kernel.place) + + hash_table_size_factor = 1 / 64 + self.nns = self.fixed_radius_search( + inp_positions, + queries=out_positions - offset * voxel_size, + radius=self.kernel_size[0] * voxel_size * 0.51, + hash_table_size_factor=hash_table_size_factor, + hash_table=fixed_radius_search_hash_table) + + out_positions_split = None + if isinstance(inp_positions, classes.RaggedTensor): + inp_positions = inp_positions.values + inp_features = inp_features.values + out_positions_split = out_positions.row_splits + out_positions = out_positions.values + + # for stats and debugging + num_pairs = self.nns.neighbors_index.shape[0] + self._avg_neighbors = num_pairs / out_positions.shape[0] + + extents_rank2 = paddle.full([1, 1], voxel_size * self.kernel_size[0]) + + self._conv_values = { + 'filters': self.kernel, + 'out_positions': out_positions, + 'extents': extents_rank2, + 'offset': offset, + 'inp_positions': inp_positions, + 'inp_features': inp_features, + 'inp_importance': inp_importance, + 'neighbors_index': self.nns.neighbors_index, + 'neighbors_importance': paddle.empty((0,), dtype=paddle.float32), + 'neighbors_row_splits': self.nns.neighbors_row_splits, + 'align_corners': False, + 'coordinate_mapping': 'identity', + 'interpolation': 'nearest_neighbor', + 'normalize': self.normalize, + 'max_temp_mem_mb': 64 + } + + out_features = ops.continuous_conv(**self._conv_values) + + self._conv_output = out_features + + if self.use_bias: + out_features += self.bias + if self.activation: + out_features = self.activation(out_features) + + if out_positions_split is not None: + out_features = classes.RaggedTensor.from_row_splits( + out_features, out_positions_split, validate=False, copy=False) + + return out_features + + +class SparseConvTranspose(paddle.nn.Layer): + """Sparse Transposed Convolution. + + This layer computes a transposed convolution which is only evaluated at the specified output positions. + The layer assumes that input and output points lie on a regular grid. + + Example: + This shows a minimal example of how to use the layer:: + + import paddle + import open3d.ml.paddle as ml3d + + # +0.5 to move the points to the voxel center + inp_positions = paddle.randint(0, 10, [20,3]).to(paddle.float32) + 0.5 + inp_features = paddle.randn([20,8]) + out_positions = paddle.randint(0, 10, [20,3]).to(paddle.float32) + 0.5 + + conv = ml3d.layers.SparseConv(in_channels=8, filters=16, kernel_size=[3,3,3]) + out_features = conv(inp_features, inp_positions, out_positions, voxel_size=1.0) + + + Arguments: + in_channels: The number of input channels. + + filters: The number of filters/output channels. + + kernel_size: The spatial resolution of the filter, e.g. [3,3,3]. + + activation: The activation function to use. None means no activation. + + use_bias: If True adds an additive bias vector. + + kernel_initializer: Initializer for the kernel weights. + + bias_initializer: Initializer for the bias vector. + + normalize: If true then the input features will be normalized with the number of + output points. + + offset: A single 3D vector used in the filter coordinate computation. + The shape is [3]. This can be used to control how the filters are + centered. It will be set automatically for kernels with even sizes. + """ + + def __init__(self, + in_channels, + filters, + kernel_size, + activation=None, + use_bias=True, + kernel_initializer=paddle.nn.initializer.Uniform(-0.05, 0.05), + bias_initializer=paddle.nn.initializer.Constant(), + normalize=False, + offset=None, + **kwargs): + super().__init__() + + self.in_channels = in_channels + self.filters = filters + self.kernel_size = kernel_size + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.normalize = normalize + + if not (np.asarray(kernel_size) == kernel_size[0]).all(): + raise ValueError("Only cubic kernel sizes are supported.") + + if offset is None: + if kernel_size[0] % 2: + offset = paddle.zeros(shape=(3,), dtype=paddle.float32) + else: + offset = paddle.full((3,), -0.5, dtype=paddle.float32) + self.register_buffer('offset', offset) + + self.fixed_radius_search = FixedRadiusSearch(metric='Linf', + ignore_query_point=False, + return_distances=False) + + kernel_shape = (*self.kernel_size, self.in_channels, self.filters) + self.kernel = create_parameter( + kernel_shape, + dtype=paddle.float32, + default_initializer=self.kernel_initializer) + + if self.use_bias: + self.bias = create_parameter((self.filters,), + dtype=paddle.float32, + is_bias=True, + default_initializer=bias_initializer) + + def forward(self, + inp_features, + inp_positions, + out_positions, + voxel_size, + out_importance=None, + fixed_radius_search_hash_table=None): + """This function computes the output features. + + Arguments: + inp_features: A 2D tensor which stores a feature vector for each input + point. + + inp_positions: A 2D tensor with the 3D point positions of each input + point. The coordinates for each point is a vector with format [x,y,z]. + + out_positions: A 2D tensor with the 3D point positions of each output + point. The coordinates for each point is a vector with format [x,y,z]. + + voxel_size: A scalar float that defines the edge length of a voxel. + + out_importance: Optional scalar importance value for each output point. + + fixed_radius_search_hash_table: A precomputed hash table generated with + build_spatial_hash_table(). This input can be used to explicitly force the + reuse of a hash table in special cases and is usually not needed. + Note that the hash table must have been generated with the same 'points' + array. Note that this parameter is only used if 'extents' is a scalar. + + Returns: A tensor of shape [num output points, filters] with the output + features. + """ + if isinstance(inp_features, classes.RaggedTensor): + if not (isinstance(inp_positions, classes.RaggedTensor) and + isinstance(out_positions, classes.RaggedTensor)): + raise ValueError( + "All of inp_positions, inp_features and out_positions must be paddle.Tensor, or ml3d.classes.RaggedTensor" + ) + + offset = self.offset + if isinstance(voxel_size, (float, int)): + voxel_size = paddle.to_tensor( + voxel_size, dtype=inp_positions.dtype).to(self.kernel.place) + if len(voxel_size.shape) != 0: + raise Exception("voxel_size must be a scalar") + + if out_importance is None: + out_importance = paddle.empty( + (0,), dtype=paddle.float32).to(self.kernel.place) + + empty_vec = paddle.empty((0,), + dtype=paddle.float32).to(self.kernel.place) + + hash_table_size_factor = 1 / 64 + self.nns_inp = self.fixed_radius_search( + out_positions, + queries=inp_positions - offset * voxel_size, + radius=self.kernel_size[0] * voxel_size * 0.51, + hash_table_size_factor=hash_table_size_factor, + hash_table=fixed_radius_search_hash_table) + + out_positions_split = None + if isinstance(inp_positions, classes.RaggedTensor): + inp_positions = inp_positions.values + inp_features = inp_features.values + out_positions_split = out_positions.row_splits + out_positions = out_positions.values + + num_out = out_positions.shape[0] + + neighbors_index, neighbors_row_splits, _ = ops.invert_neighbors_list( + self.nns_inp.neighbors_index, self.nns_inp.neighbors_row_splits, + empty_vec, num_out) + + # for stats and debugging + num_pairs = neighbors_index.shape[0] + self._avg_neighbors = num_pairs / out_positions.shape[0] + + extents_rank2 = paddle.full([1, 1], voxel_size * self.kernel_size[0]) + + self._conv_values = { + 'filters': self.kernel, + 'out_positions': out_positions, + 'extents': extents_rank2, + 'offset': offset, + 'inp_positions': inp_positions, + 'inp_features': inp_features, + 'out_importance': out_importance, + 'inp_neighbors_index': self.nns_inp.neighbors_index, + 'inp_neighbors_importance_sum': empty_vec, + 'inp_neighbors_row_splits': self.nns_inp.neighbors_row_splits, + 'neighbors_index': neighbors_index, + 'neighbors_importance': empty_vec, + 'neighbors_row_splits': neighbors_row_splits, + 'align_corners': False, + 'coordinate_mapping': 'identity', + 'interpolation': 'nearest_neighbor', + 'normalize': self.normalize, + 'max_temp_mem_mb': 64, + } + + out_features = ops.continuous_conv_transpose(**self._conv_values) + + self._conv_output = out_features + + if self.use_bias: + out_features += self.bias + if self.activation: + out_features = self.activation(out_features) + + if out_positions_split is not None: + out_features = classes.RaggedTensor.from_row_splits( + out_features, out_positions_split, validate=False, copy=False) + + return out_features diff --git a/python/open3d/ml/paddle/python/layers/neighbor_search.py b/python/open3d/ml/paddle/python/layers/neighbor_search.py new file mode 100644 index 00000000000..2600cf03b54 --- /dev/null +++ b/python/open3d/ml/paddle/python/layers/neighbor_search.py @@ -0,0 +1,376 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +from ...python import ops +from ....paddle import classes +from ...classes import DTYPE_MAP +import paddle + +__all__ = ['FixedRadiusSearch', 'RadiusSearch', 'KNNSearch'] + + +class FixedRadiusSearch(paddle.nn.Layer): + """Fixed radius search for 3D point clouds. + + This layer computes the neighbors for a fixed radius on a point cloud. + + Example: + This example shows a neighbor search that returns the indices to the + found neighbors and the distances.:: + + import paddle + import open3d.ml.paddle as ml3d + + points = paddle.randn([20, 3]) + queries = paddle.randn([10, 3]) + radius = 0.8 + + nsearch = ml3d.layers.FixedRadiusSearch(return_distances=True) + ans = nsearch(points, queries, radius) + # returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance + + + Arguments: + metric: Either L1, L2 or Linf. Default is L2. + + ignore_query_point: If True the points that coincide with the center of + the search window will be ignored. This excludes the query point if + 'queries' and 'points' are the same point cloud. + + return_distances: If True the distances for each neighbor will be returned. + If False a zero length Tensor will be returned instead. + """ + + def __init__(self, + metric='L2', + ignore_query_point=False, + return_distances=False, + max_hash_table_size=32 * 2**20, + index_dtype=paddle.int32, + **kwargs): + super().__init__() + self.metric = metric + self.ignore_query_point = ignore_query_point + self.return_distances = return_distances + self.max_hash_table_size = max_hash_table_size + assert index_dtype in [paddle.int32, paddle.int64] + self.index_dtype = DTYPE_MAP[index_dtype] + + def forward(self, + points, + queries, + radius, + points_row_splits=None, + queries_row_splits=None, + hash_table_size_factor=1 / 64, + hash_table=None): + """This function computes the neighbors within a fixed radius for each query point. + + Arguments: + + points: The 3D positions of the input points. It can be a RaggedTensor. + + queries: The 3D positions of the query points. It can be a RaggedTensor. + + radius: A scalar with the neighborhood radius + + points_row_splits: Optional 1D vector with the row splits information + if points is batched. This vector is [0, num_points] if there is + only 1 batch item. + + queries_row_splits: Optional 1D vector with the row splits information + if queries is batched. This vector is [0, num_queries] if there is + only 1 batch item. + + hash_table_size_factor: Scalar. The size of the hash table as fraction + of points. + + hash_table: A precomputed hash table generated with build_spatial_hash_table(). + This input can be used to explicitly force the reuse of a hash table in special + cases and is usually not needed. + Note that the hash table must have been generated with the same 'points' array. + + Returns: + 3 Tensors in the following order + + neighbors_index + The compact list of indices of the neighbors. The corresponding query point + can be inferred from the 'neighbor_count_row_splits' vector. + + neighbors_row_splits + The exclusive prefix sum of the neighbor count for the query points including + the total neighbor count as the last element. The size of this array is the + number of queries + 1. + + neighbors_distance + Stores the distance to each neighbor if 'return_distances' is True. + Note that the distances are squared if metric is L2. + This is a zero length Tensor if 'return_distances' is False. + """ + + if isinstance(points, classes.RaggedTensor): + points_row_splits = points.row_splits + points = points.values + if isinstance(queries, classes.RaggedTensor): + queries_row_splits = queries.row_splits + queries = queries.values + + if points_row_splits is None: + points_row_splits = paddle.to_tensor([0, points.shape[0]], + dtype="int64") + if queries_row_splits is None: + queries_row_splits = paddle.to_tensor([0, queries.shape[0]], + dtype="int64") + + if hash_table is None: + table = ops.build_spatial_hash_table( + max_hash_table_size=self.max_hash_table_size, + points=points, + radius=radius, + points_row_splits=points_row_splits, + hash_table_size_factor=hash_table_size_factor) + else: + table = hash_table + + result = ops.fixed_radius_search( + ignore_query_point=self.ignore_query_point, + return_distances=self.return_distances, + metric_str=self.metric, + points=points, + queries=queries, + radius=radius, + points_row_splits=points_row_splits, + queries_row_splits=queries_row_splits, + hash_table_splits=table.hash_table_splits, + hash_table_index=table.hash_table_index, + hash_table_cell_splits=table.hash_table_cell_splits, + index_dtype=self.index_dtype) + + return result + + +class RadiusSearch(paddle.nn.Layer): + """Radius search for 3D point clouds. + + This layer computes the neighbors for each query point with each query + having an individual radius. + + Example: + This example shows a neighbor search that returns the indices to the + found neighbors and the distances.:: + + import paddle + import open3d.ml.paddle as ml3d + + points = paddle.randn([20, 3]) + queries = paddle.randn([10, 3]) + radii = paddle.randn([10]) + 1.0 + + nsearch = ml3d.layers.RadiusSearch(return_distances=True) + ans = nsearch(points, queries, radii) + # returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance + + + Arguments: + metric: Either L1, L2 or Linf. Default is L2. + + ignore_query_point: If True the points that coincide with the center of the + search window will be ignored. This excludes the query point if 'queries' + and 'points' are the same point cloud. + + return_distances: If True the distances for each neighbor will be returned. + If False a zero length Tensor will be returned instead. + + normalize_distances: If True the returned distances will be normalized with + the radii. + """ + + def __init__(self, + metric='L2', + ignore_query_point=False, + return_distances=False, + normalize_distances=False, + index_dtype=paddle.int32, + **kwargs): + super().__init__() + self.metric = metric + self.ignore_query_point = ignore_query_point + self.return_distances = return_distances + self.normalize_distances = normalize_distances + assert index_dtype in [paddle.int32, paddle.int64] + self.index_dtype = DTYPE_MAP[index_dtype] + + def forward(self, + points, + queries, + radii, + points_row_splits=None, + queries_row_splits=None): + """This function computes the neighbors within a radius for each query point. + + Arguments: + + points: The 3D positions of the input points. + + queries: The 3D positions of the query points. + + radii: A radius for each query point. + + points_row_splits: Optional 1D vector with the row splits information + if points is batched. This vector is [0, num_points] if there is + only 1 batch item. + + queries_row_splits: Optional 1D vector with the row splits information + if queries is batched. This vector is [0, num_queries] if there is + only 1 batch item. + + Returns: + 3 Tensors in the following order + + neighbors_index + The compact list of indices of the neighbors. The corresponding query point + can be inferred from the 'neighbor_count_row_splits' vector. + + neighbors_row_splits + The exclusive prefix sum of the neighbor count for the query points including + the total neighbor count as the last element. The size of this array is the + number of queries + 1. + + neighbors_distance + Stores the distance to each neighbor if 'return_distances' is True. + Note that the distances are squared if metric is L2. + This is a zero length Tensor if 'return_distances' is False. + """ + if points_row_splits is None: + points_row_splits = paddle.to_tensor([0, points.shape[0]], + dtype="int64") + if queries_row_splits is None: + queries_row_splits = paddle.to_tensor([0, queries.shape[0]], + dtype="int64") + + result = ops.radius_search(ignore_query_point=self.ignore_query_point, + return_distances=self.return_distances, + normalize_distances=self.normalize_distances, + metric_str=self.metric, + points=points, + queries=queries, + radii=radii, + points_row_splits=points_row_splits, + queries_row_splits=queries_row_splits, + index_dtype=self.index_dtype) + + return result + + +class KNNSearch(paddle.nn.Layer): + """KNN search for 3D point clouds. + + This layer computes the k nearest neighbors for each query point. + + Example: + This example shows a neighbor search that returns the indices to the + found neighbors and the distances.:: + + import paddle + import open3d.ml.paddle as ml3d + + points = paddle.randn([20, 3]) + queries = paddle.randn([10, 3]) + k = 8 + + nsearch = ml3d.layers.KNNSearch(return_distances=True) + ans = nsearch(points, queries, k) + # returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance + # Since there are more than k points and we do not ignore any points we can + # reshape the output to [num_queries, k] with + neighbors_index = ans.neighbors_index.reshape(10,k) + neighbors_distance = ans.neighbors_distance.reshape(10,k) + + + Arguments: + metric: Either L1, L2 or Linf. Default is L2. + + ignore_query_point: If True the points that coincide with the center of the + search window will be ignored. This excludes the query point if 'queries' + and 'points' are the same point cloud. + + return_distances: If True the distances for each neighbor will be returned. + If False a zero length Tensor will be returned instead. + """ + + def __init__(self, + metric='L2', + ignore_query_point=False, + return_distances=False, + index_dtype=paddle.int32, + **kwargs): + super().__init__() + self.metric = metric + self.ignore_query_point = ignore_query_point + self.return_distances = return_distances + assert index_dtype in [paddle.int32, paddle.int64] + self.index_dtype = DTYPE_MAP[index_dtype] + + def forward(self, + points, + queries, + k, + points_row_splits=None, + queries_row_splits=None): + """This function computes the k nearest neighbors for each query point. + + Arguments: + points: The 3D positions of the input points. *This argument must be + given as a positional argument!* + + queries: The 3D positions of the query points. + + k: The number of nearest neighbors to search. + + points_row_splits: Optional 1D vector with the row splits information + if points is batched. + This vector is [0, num_points] if there is only 1 batch item. + + queries_row_splits: Optional 1D vector with the row splits information + if queries is batched. + This vector is [0, num_queries] if there is only 1 batch item. + + Returns: 3 Tensors in the following order + + neighbors_index + The compact list of indices of the neighbors. The corresponding query point + can be inferred from the 'neighbor_count_row_splits' vector. + + neighbors_row_splits + The exclusive prefix sum of the neighbor count for the query points including + the total neighbor count as the last element. The size of this array is the + number of queries + 1. + + neighbors_distance + Stores the distance to each neighbor if 'return_distances' is True. + Note that the distances are squared if metric is L2. + This is a zero length Tensor if 'return_distances' is False. + """ + + if points_row_splits is None: + points_row_splits = paddle.to_tensor([0, points.shape[0]], + dtype=paddle.int64) + if queries_row_splits is None: + queries_row_splits = paddle.to_tensor([0, queries.shape[0]], + dtype=paddle.int64) + + result = ops.knn_search(ignore_query_point=self.ignore_query_point, + return_distances=self.return_distances, + metric_str=self.metric, + points=points, + queries=queries, + k=k, + points_row_splits=points_row_splits, + queries_row_splits=queries_row_splits, + index_dtype=self.index_dtype) + + return result diff --git a/python/open3d/ml/paddle/python/layers/voxel_pooling.py b/python/open3d/ml/paddle/python/layers/voxel_pooling.py new file mode 100644 index 00000000000..8cfbf953ce8 --- /dev/null +++ b/python/open3d/ml/paddle/python/layers/voxel_pooling.py @@ -0,0 +1,101 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +from ...python import ops +import paddle + +__all__ = ['VoxelPooling'] + + +class VoxelPooling(paddle.nn.Layer): + """Voxel pooling for 3D point clouds. + + Spatial pooling for point clouds by combining points that fall into the same voxel bin. + + The voxel grid used for pooling is always aligned to the origin (0,0,0) to + simplify building voxel grid hierarchies. The order of the returned voxels is + not defined as can be seen in the following example:: + + import paddle + import open3d.ml.paddle as ml3d + + positions = paddle.to_tensor([ + [0.1, 0.1, 0.1], + [0.5, 0.5, 0.5], + [1.7, 1.7, 1.7], + [1.8, 1.8, 1.8], + [0.3, 2.4, 1.4]]) + + features = paddle.to_tensor([ + [1.0, 2.0], + [1.1, 2.3], + [4.2, 0.1], + [1.3, 3.4], + [2.3, 1.9]]) + + voxel_pooling = ml3d.layers.VoxelPooling(position_fn='center', feature_fn='max') + + ans = voxel_pooling(positions, features, 1.0) + + # returns the voxel centers in + # ans.pooled_positions = [[0.5, 2.5, 1.5], + # [1.5, 1.5, 1.5], + # [0.5, 0.5, 0.5]] + # + # and the max pooled features for each voxel in + # ans.pooled_features = [[2.3, 1.9], + # [4.2, 3.4], + # [1.1, 2.3]] + + Arguments: + position_fn: Defines how the new point positions will be computed. + The options are + * "average" computes the center of gravity for the points within one voxel. + * "nearest_neighbor" selects the point closest to the voxel center. + * "center" uses the voxel center for the position of the generated point. + + feature_fn: Defines how the pooled features will be computed. + The options are + * "average" computes the average feature vector. + * "nearest_neighbor" selects the feature vector of the point closest to the voxel center. + * "max" uses the maximum feature among all points within the voxel. + """ + + def __init__(self, position_fn='center', feature_fn='max', **kwargs): + super().__init__() + self.position_fn = position_fn + self.feature_fn = feature_fn + + def forward(self, positions, features, voxel_size): + """This function computes the pooled positions and features. + + Arguments: + positions: The point positions with shape [N,3] with N as the number of points. + *This argument must be given as a positional argument!* + + features: The feature vector with shape [N,channels]. + + voxel_size: The voxel size. + + Returns: + 2 Tensors in the following order: + + pooled_positions + The output point positions with shape [M,3] and M <= N. + + pooled_features: + The output point features with shape [M,channels] and M <= N. + """ + if isinstance(voxel_size, (float, int)): + voxel_size = paddle.to_tensor(voxel_size, dtype=positions.dtype) + result = ops.voxel_pooling(positions, + features, + voxel_size, + position_fn=self.position_fn, + feature_fn=self.feature_fn, + debug=False) + return result diff --git a/python/open3d/ml/paddle/python/return_types.py.in b/python/open3d/ml/paddle/python/return_types.py.in new file mode 100644 index 00000000000..a73a0765ed9 --- /dev/null +++ b/python/open3d/ml/paddle/python/return_types.py.in @@ -0,0 +1,29 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2018-2024 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +# This file is machine generated. Do not modify. +from collections import namedtuple as _namedtuple + diff --git a/python/test/ml_ops/mltest.py b/python/test/ml_ops/mltest.py index 0f25f5d46f6..123d5ac51b3 100644 --- a/python/test/ml_ops/mltest.py +++ b/python/test/ml_ops/mltest.py @@ -17,7 +17,8 @@ # skip all tests if the ml ops were not built default_marks = [ pytest.mark.skipif(not (o3d._build_config['BUILD_TENSORFLOW_OPS'] or - o3d._build_config['BUILD_PYTORCH_OPS']), + o3d._build_config['BUILD_PYTORCH_OPS'] or + o3d._build_config['BUILD_PADDLE_OPS']), reason='ml ops not built'), ] @@ -62,9 +63,24 @@ except ImportError: pass +try: + paddle = importlib.import_module('paddle') + ml3d_ops = importlib.import_module('open3d.ml.paddle.ops') + ml3d_layers = importlib.import_module('open3d.ml.paddle.layers') + ml3d_classes = importlib.import_module('open3d.ml.paddle.classes') + _ml_modules['paddle'] = MLModules(paddle, ml3d_ops, ml3d_layers, + ml3d_classes, 'cpu', 'cpu', False) + if paddle.device.is_compiled_with_cuda( + ) and o3d._build_config['BUILD_CUDA_MODULE']: + _ml_modules['paddle_cuda'] = MLModules(paddle, ml3d_ops, ml3d_layers, + ml3d_classes, 'cuda', 'cpu', + True) +except ImportError: + pass + def is_gpu_device_name(name): - return name in ('GPU:0', 'cuda') + return name in ('GPU:0', 'cuda', 'gpu:0', 'gpu') def to_numpy(tensor): @@ -75,6 +91,8 @@ def to_numpy(tensor): if tensor.device.type == 'cuda': tensor = tensor.cpu() + return tensor.numpy() + elif 'paddle' in _ml_modules and isinstance(tensor, paddle.Tensor): return tensor.numpy() else: return tensor.numpy() @@ -88,6 +106,25 @@ def to_torch(x, device): return x +def to_paddle(x, device): + """Converts x such that it can be used as input to a paddle op.""" + if isinstance(x, np.ndarray): + return paddle.to_tensor( + x, place='gpu') if device == 'cuda' else paddle.to_tensor( + x, place='cpu') + else: + return x + + +def paddle_cmp_device(x, device): + if device == 'cuda' and x.place.is_gpu_place(): + return True + elif device == 'cpu' and x.place.is_cpu_place(): + return True + else: + return False + + def run_op(ml, device_name, check_device, fn, *args, **kwargs): """Runs an op using an ml framework""" if ml.module.__name__ == 'tensorflow': @@ -126,7 +163,27 @@ def run_op(ml, device_name, check_device, fn, *args, **kwargs): x, torch.Tensor) and device_name == x.device.type: tensor_on_device = True assert tensor_on_device + elif ml.module.__name__ == 'paddle': + _args = [to_paddle(x, device_name) for x in args] + _kwargs = { + k.lower(): to_paddle(v, device_name) for k, v in kwargs.items() + } + ans = fn(*_args, **_kwargs) + + if check_device: + # not all returned tensor have to use the device. + # check if there is at least one tensor using device memory + tensor_on_device = False + if isinstance(ans, paddle.Tensor): + if paddle_cmp_device(ans, device_name): + tensor_on_device = True + else: + for x in ans: + if isinstance(x, paddle.Tensor) and paddle_cmp_device( + x, device_name): + tensor_on_device = True + assert tensor_on_device else: raise ValueError('unsupported ml framework {}'.format(ml.module)) @@ -189,6 +246,30 @@ def run_op_grad(ml, device_name, check_device, fn, x, y_attr_name, torch.Tensor) and device_name == dy_dx.device.type: tensor_on_device = True assert tensor_on_device + elif ml.module.__name__ == 'paddle': + x_var = to_paddle(x, device_name) + x_var.stop_gradient = False + _args = [x_var if a is x else to_paddle(a, device_name) for a in args] + _kwargs = { + k.lower(): x_var if a is x else to_paddle(a, device_name) + for k, a in kwargs.items() + } + + ans = fn(*_args, **_kwargs) + if y_attr_name: + y = getattr(ans, y_attr_name) + else: + y = ans + y.backward(to_paddle(backprop_values, device_name)) + dy_dx = x_var.grad + + if check_device: + # check if the gradient is using device memory + tensor_on_device = False + if isinstance(dy_dx, paddle.Tensor) and paddle_cmp_device( + dy_dx, device_name): + tensor_on_device = True + assert tensor_on_device else: raise ValueError('unsupported ml framework {}'.format(ml.module)) @@ -213,6 +294,8 @@ def set_seed(self, seed): self.module.random.set_seed(seed) elif self.module.__name__ == 'torch': self.module.manual_seed(seed) + elif self.module.__name__ == 'paddle': + self.module.seed(seed) else: raise Exception('Unsupported ml framework') @@ -221,6 +304,11 @@ def set_deterministic(self, deterministic): pass elif self.module.__name__ == 'torch': self.module.set_deterministic(deterministic) + elif self.module.__name__ == 'paddle': + paddle.set_flags({ + "FLAGS_cudnn_deterministic": "1", + "FLAGS_cpu_deterministic": "1" + }) else: raise Exception('Unsupported ml framework') @@ -235,6 +323,9 @@ def random_uniform(self, size, dtype, minval=0, maxval=1): elif self.module.__name__ == 'torch': ans = self.module.empty(size=size, dtype=dtype) return ans.uniform_(minval, maxval) + elif self.module.__name__ == 'paddle': + ans = self.module.empty(shape=size, dtype=dtype) + return ans.uniform_(minval, maxval) else: raise Exception('Unsupported ml framework') @@ -245,6 +336,8 @@ def empty(self, shape, dtype): return self.module.zeros(shape=shape, dtype=dtype) elif self.module.__name__ == 'torch': return self.module.empty(size=shape, dtype=dtype) + elif self.module.__name__ == 'paddle': + return self.module.empty(shape=shape, dtype=dtype) else: raise Exception('Unsupported ml framework') @@ -255,6 +348,8 @@ def zeros(self, shape, dtype): return self.module.zeros(shape=shape, dtype=dtype) elif self.module.__name__ == 'torch': return self.module.zeros(size=shape, dtype=dtype) + elif self.module.__name__ == 'paddle': + return self.module.zeros(shape=shape, dtype=dtype) else: raise Exception('Unsupported ml framework') @@ -272,6 +367,13 @@ def zeros(self, shape, dtype): ml_tf_only=pytest.mark.parametrize('ml', [ v for k, v in _ml_modules.items() if v.module.__name__ == 'tensorflow' ]), + ml_paddle_only=pytest.mark.parametrize( + 'ml', + [v for k, v in _ml_modules.items() if v.module.__name__ == 'paddle']), + ml_torch_and_paddle_only=pytest.mark.parametrize('ml', [ + v for k, v in _ml_modules.items() + if v.module.__name__ == 'paddle' or v.module.__name__ == 'torch' + ]), ) diff --git a/python/test/ml_ops/test_cconv.py b/python/test/ml_ops/test_cconv.py index 8f07d7f9b5e..61aa4c9879a 100644 --- a/python/test/ml_ops/test_cconv.py +++ b/python/test/ml_ops/test_cconv.py @@ -183,6 +183,7 @@ def test_cconv_gradient(ml, dtype, filter_size, out_channels, in_channels, 'coordinate_mapping': coordinate_mapping, 'normalize': with_normalization, 'interpolation': interpolation, + 'max_temp_mem_MB': 64 } filters = np.random.random(size=(*filter_size, in_channels, @@ -223,9 +224,14 @@ def test_cconv_gradient(ml, dtype, filter_size, out_channels, in_channels, neighbors_importance_sum = np.empty((0,), dtype=dtype) inverted_neighbors_index, inverted_neighbors_row_splits, inverted_neighbors_importance = mltest.run_op( - ml, ml.device, False, ml.ops.invert_neighbors_list, - inp_positions.shape[0], neighbors_index, neighbors_row_splits, - neighbors_importance) + ml, + ml.device, + False, + ml.ops.invert_neighbors_list, + num_points=inp_positions.shape[0], + inp_neighbors_index=neighbors_index, + inp_neighbors_row_splits=neighbors_row_splits, + inp_neighbors_attributes=neighbors_importance) # print(neighbors_row_splits, inverted_neighbors_row_splits) # print(neighbors_index, inverted_neighbors_index) diff --git a/python/test/ml_ops/test_fixed_radius_search.py b/python/test/ml_ops/test_fixed_radius_search.py index b0a8f720129..c3da38894ae 100644 --- a/python/test/ml_ops/test_fixed_radius_search.py +++ b/python/test/ml_ops/test_fixed_radius_search.py @@ -14,6 +14,8 @@ import torch if o3d._build_config['BUILD_TENSORFLOW_OPS']: import tensorflow as tf +if o3d._build_config['BUILD_PADDLE_OPS']: + import paddle # skip all tests if the ml ops were not built pytestmark = mltest.default_marks @@ -62,6 +64,11 @@ def test_fixed_radius_search(dtype, ml, num_points_queries, radius, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') @@ -209,6 +216,11 @@ def test_fixed_radius_search_batches(dtype, ml, batch_size, radius, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') diff --git a/python/test/ml_ops/test_general_sparseconv.py b/python/test/ml_ops/test_general_sparseconv.py index deae230826d..19e5f2c3326 100644 --- a/python/test/ml_ops/test_general_sparseconv.py +++ b/python/test/ml_ops/test_general_sparseconv.py @@ -39,9 +39,7 @@ def test_sparseconv_gradient(ml, dtype, kernel_size, out_channels, in_channels, rng = np.random.RandomState(123) - conv_attrs = { - 'normalize': with_normalization, - } + conv_attrs = {'normalize': with_normalization, 'max_temp_mem_MB': 64} filters = rng.random(size=(kernel_size, in_channels, out_channels)).astype(dtype) @@ -75,8 +73,14 @@ def test_sparseconv_gradient(ml, dtype, kernel_size, out_channels, in_channels, arange = np.arange(neighbors_index.shape[0]) inv_neighbors_index, inv_neighbors_row_splits, inv_arange = mltest.run_op( - ml, ml.device, False, ml.ops.invert_neighbors_list, num_inp, - neighbors_index, neighbors_row_splits, arange) + ml, + ml.device, + False, + ml.ops.invert_neighbors_list, + num_points=num_inp, + inp_neighbors_index=neighbors_index, + inp_neighbors_row_splits=neighbors_row_splits, + inp_neighbors_attributes=arange) inv_neighbors_kernel_index = neighbors_kernel_index[inv_arange] if with_neighbors_importance: diff --git a/python/test/ml_ops/test_knn_search.py b/python/test/ml_ops/test_knn_search.py index 38d0e11d644..8deb5d453f5 100644 --- a/python/test/ml_ops/test_knn_search.py +++ b/python/test/ml_ops/test_knn_search.py @@ -14,6 +14,8 @@ import torch if o3d._build_config['BUILD_TENSORFLOW_OPS']: import tensorflow as tf +if o3d._build_config['BUILD_PADDLE_OPS']: + import paddle # skip all tests if the ml ops were not built pytestmark = mltest.default_marks @@ -59,6 +61,11 @@ def test_knn_search(dtype, ml, num_points_queries, metric, ignore_query_point, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') diff --git a/python/test/ml_ops/test_radius_search.py b/python/test/ml_ops/test_radius_search.py index c23590c5596..a9d7a6cce1c 100644 --- a/python/test/ml_ops/test_radius_search.py +++ b/python/test/ml_ops/test_radius_search.py @@ -14,6 +14,8 @@ import torch if o3d._build_config['BUILD_TENSORFLOW_OPS']: import tensorflow as tf +if o3d._build_config['BUILD_PADDLE_OPS']: + import paddle # skip all tests if the tf ops were not built and disable warnings caused by # tensorflow @@ -58,6 +60,11 @@ def test_radius_search(dtype, ml, num_points_queries, metric, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') diff --git a/python/test/ml_ops/test_ragged_tensor_paddle.py b/python/test/ml_ops/test_ragged_tensor_paddle.py new file mode 100644 index 00000000000..eaafd72dd03 --- /dev/null +++ b/python/test/ml_ops/test_ragged_tensor_paddle.py @@ -0,0 +1,209 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +# noqa # pylint: disable=unused-import +import open3d as o3d +import numpy as np +import pytest +import mltest +import paddle + +# skip all tests if the tf ops were not built and disable warnings caused by +# tensorflow +pytestmark = mltest.default_marks + +# the supported dtypes for the values +dtypes = pytest.mark.parametrize('dtype', + [np.int32, np.int64, np.float32, np.float64]) + +# this class is only available for torch + + +@dtypes +@mltest.parametrize.ml_paddle_only +def test_creation(dtype, ml): + values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype) + row_splits = np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.int64) + + # From numpy arrays + r_tensor = ml.classes.RaggedTensor.from_row_splits(values, row_splits) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + # From List + r_tensor = ml.classes.RaggedTensor.from_row_splits(list(values), + list(row_splits)) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + # Incompatible tensors. + # Non zero first element. + row_splits = np.array([1, 2, 4, 4, 5, 12, 13], dtype=np.int64) + + context = np.testing.assert_raises(ValueError) + + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + # Rank > 1. + row_splits = np.array([[0, 2, 4, 4, 5, 12, 13]], dtype=np.int64) + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + # Not increasing monotonically. + row_splits = np.array([[0, 2, 4, 6, 5, 12, 13]], dtype=np.int64) + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + # Wrong dtype. + row_splits = np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.float32) + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + +# test with more dimensions +@dtypes +@mltest.parametrize.ml_paddle_only +def test_creation_more_dims(dtype, ml): + values = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], + [7, 7], [8, 8], [9, 9], [10, 10], [11, 11], [12, 12]], + dtype=dtype) + row_splits = np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.int64) + + # From numpy arrays + r_tensor = ml.classes.RaggedTensor.from_row_splits(values, row_splits) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + # From List + r_tensor = ml.classes.RaggedTensor.from_row_splits(list(values), + list(row_splits)) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + +@mltest.parametrize.ml_paddle_only +def test_backprop(ml): + # Create 3 different RaggedTensors and torch.tensor + t_1 = paddle.randn([10, 3]) + t_1.stop_gradient = False + + t_2 = paddle.randn([10, 3]) + t_2.stop_gradient = False + + t_3 = paddle.randn([10, 3]) + t_3.stop_gradient = False + + row_splits = paddle.to_tensor([0, 4, 6, 6, 8, 10]) + + r_1 = ml.classes.RaggedTensor.from_row_splits(t_1.detach().numpy(), + row_splits) + r_1.requires_grad = True + r_2 = ml.classes.RaggedTensor.from_row_splits(t_2.detach().numpy(), + row_splits) + r_2.requires_grad = True + r_3 = ml.classes.RaggedTensor.from_row_splits(t_3.detach().numpy(), + row_splits) + r_3.requires_grad = True + + r_ans = (r_1 + r_2) * r_3 + t_ans = (t_1 + t_2) * t_3 + + np.testing.assert_equal(mltest.to_numpy(t_ans), + mltest.to_numpy(r_ans.values)) + + # Compute gradients + t_ans.sum().backward() + r_ans.values.sum().backward() + + np.testing.assert_equal(mltest.to_numpy(t_1.grad), + mltest.to_numpy(r_1.values.grad)) + + +@dtypes +@mltest.parametrize.ml_paddle_only +def test_binary_ew_ops(dtype, ml): + # Binary Ops. + device = 'gpu' if ml.device == 'cuda' else 'cpu' + + t_1 = paddle.to_tensor( + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + dtype=dtype)).to(device) + t_2 = paddle.to_tensor( + np.array([2, 3, 6, 3, 11, 3, 43, 12, 8, 15, 12, 87, 45], + dtype=dtype)).to(device) + + row_splits = paddle.to_tensor( + np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.int64)).to(device) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + b = ml.classes.RaggedTensor.from_row_splits(t_2, row_splits) + + np.testing.assert_equal( + (a + b).values.cpu().numpy(), + np.array([2, 4, 8, 6, 15, 8, 49, 19, 16, 24, 22, 98, 57])) + np.testing.assert_equal( + (a - b).values.cpu().numpy(), + np.array([-2, -2, -4, 0, -7, 2, -37, -5, 0, -6, -2, -76, -33])) + np.testing.assert_equal( + (a * b).values.cpu().numpy(), + np.array([0, 3, 12, 9, 44, 15, 258, 84, 64, 135, 120, 957, 540])) + np.testing.assert_equal((a / b).values.cpu().numpy(), + (t_1 / t_2).cpu().numpy()) + np.testing.assert_equal((a // b).values.cpu().numpy(), + np.array([0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0])) + + # Assignment Ops. + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a += b + np.testing.assert_equal( + a.values.cpu().numpy(), + np.array([2, 4, 8, 6, 15, 8, 49, 19, 16, 24, 22, 98, 57])) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a -= b + np.testing.assert_equal( + a.values.cpu().numpy(), + np.array([-2, -2, -4, 0, -7, 2, -37, -5, 0, -6, -2, -76, -33])) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a *= b + np.testing.assert_equal( + a.values.cpu().numpy(), + np.array([0, 3, 12, 9, 44, 15, 258, 84, 64, 135, 120, 957, 540])) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a //= b + np.testing.assert_equal(a.values.cpu().numpy(), + np.array([0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0])) + + # Failure cases with incompatible shape. + # Different row_splits. + row_splits = [0, 4, 5, 13] + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + row_splits = [0, 4, 6, 13] + b = ml.classes.RaggedTensor.from_row_splits(t_2, row_splits) + + with np.testing.assert_raises(ValueError): + a + b # noqa # pylint: disable=pointless-statement + with np.testing.assert_raises(ValueError): + a += b # noqa # pylint: disable=pointless-statement + + # Different length + row_splits = [0, 4, 5, 13] + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + row_splits = [0, 4, 13] + b = ml.classes.RaggedTensor.from_row_splits(t_2, row_splits) + + with np.testing.assert_raises(ValueError): + a + b # noqa # pylint: disable=pointless-statement + with np.testing.assert_raises(ValueError): + a += b diff --git a/python/test/ml_ops/test_ragged_to_dense.py b/python/test/ml_ops/test_ragged_to_dense.py index 9121faa1b46..289320032ff 100644 --- a/python/test/ml_ops/test_ragged_to_dense.py +++ b/python/test/ml_ops/test_ragged_to_dense.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: MIT # ---------------------------------------------------------------------------- +# noqa # pylint: disable=unused-import import open3d as o3d import numpy as np import pytest @@ -22,7 +23,7 @@ @dtypes -@mltest.parametrize.ml_torch_only +@mltest.parametrize.ml_torch_and_paddle_only def test_ragged_to_dense(dtype, ml): values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype) @@ -30,8 +31,14 @@ def test_ragged_to_dense(dtype, ml): out_col_size = 4 default_value = np.array(-1, dtype=dtype) - ans = mltest.run_op(ml, ml.device, True, ml.ops.ragged_to_dense, values, - row_splits, out_col_size, default_value) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.ragged_to_dense, + values=values, + row_splits=row_splits, + out_col_size=out_col_size, + default_value=default_value) expected = np.full((row_splits.shape[0] - 1, out_col_size), default_value) for i in range(row_splits.shape[0] - 1): @@ -44,7 +51,7 @@ def test_ragged_to_dense(dtype, ml): # test with more dimensions @dtypes -@mltest.parametrize.ml_torch_only +@mltest.parametrize.ml_torch_and_paddle_only def test_ragged_to_dense_more_dims(dtype, ml): values = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], @@ -54,8 +61,14 @@ def test_ragged_to_dense_more_dims(dtype, ml): out_col_size = 4 default_value = np.array([-1, -1], dtype=dtype) - ans = mltest.run_op(ml, ml.device, True, ml.ops.ragged_to_dense, values, - row_splits, out_col_size, default_value) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.ragged_to_dense, + values=values, + row_splits=row_splits, + out_col_size=out_col_size, + default_value=default_value) expected = np.full(( row_splits.shape[0] - 1, @@ -71,7 +84,7 @@ def test_ragged_to_dense_more_dims(dtype, ml): # test with larger random data @dtypes -@mltest.parametrize.ml_torch_only +@mltest.parametrize.ml_torch_and_paddle_only @pytest.mark.parametrize('seed', [123, 456]) def test_ragged_to_dense_random(dtype, ml, seed): @@ -87,8 +100,14 @@ def test_ragged_to_dense_random(dtype, ml, seed): default_value = np.array(-1, dtype=dtype) - ans = mltest.run_op(ml, ml.device, True, ml.ops.ragged_to_dense, values, - row_splits, out_col_size, default_value) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.ragged_to_dense, + values=values, + row_splits=row_splits, + out_col_size=out_col_size, + default_value=default_value) expected = np.full((row_splits.shape[0] - 1, out_col_size), default_value) for i in range(row_splits.shape[0] - 1): diff --git a/python/test/ml_ops/test_sparseconv.py b/python/test/ml_ops/test_sparseconv.py index 5ed8881e9cb..c2ec8c23d5a 100644 --- a/python/test/ml_ops/test_sparseconv.py +++ b/python/test/ml_ops/test_sparseconv.py @@ -76,6 +76,11 @@ def kernel_initializer(a): def bias_initializer(a): a.data = torch.from_numpy(bias) + elif ml.module.__name__ == 'paddle': + paddle = ml.module + + kernel_initializer = paddle.nn.initializer.Assign(filters) + bias_initializer = paddle.nn.initializer.Assign(bias) else: raise Exception('Unsupported ml framework {}'.format( ml.module.__name__)) @@ -88,6 +93,8 @@ def bias_initializer(a): bias_initializer=bias_initializer) if ml.module.__name__ == 'torch': sparse_conv.to(ml.device) + elif ml.module.__name__ == 'paddle': + sparse_conv.to('gpu' if ml.device == 'cuda' else 'cpu') y = mltest.run_op(ml, ml.device, @@ -333,6 +340,11 @@ def kernel_initializer(a): def bias_initializer(a): a.data = torch.from_numpy(bias) + elif ml.module.__name__ == 'paddle': + paddle = ml.module + + kernel_initializer = paddle.nn.initializer.Assign(filters) + bias_initializer = paddle.nn.initializer.Assign(bias) else: raise Exception('Unsupported ml framework {}'.format( ml.module.__name__)) @@ -347,6 +359,8 @@ def bias_initializer(a): if ml.module.__name__ == 'torch': sparse_conv_transpose.to(ml.device) + elif ml.module.__name__ == 'paddle': + sparse_conv_transpose.to('gpu' if ml.device == 'cuda' else 'cpu') y = mltest.run_op(ml, ml.device, diff --git a/python/test/ml_ops/test_voxel_pooling.py b/python/test/ml_ops/test_voxel_pooling.py index 7fb85366571..3af251382cd 100644 --- a/python/test/ml_ops/test_voxel_pooling.py +++ b/python/test/ml_ops/test_voxel_pooling.py @@ -57,8 +57,16 @@ def test_voxel_pooling(ml, pos_dtype, feat_dtype, position_fn, feature_fn): # yapf: enable voxel_size = 1 - ans = mltest.run_op(ml, ml.device, True, ml.ops.voxel_pooling, points, - features, voxel_size, position_fn, feature_fn) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.voxel_pooling, + positions=points, + features=features, + voxel_size=voxel_size, + position_fn=position_fn, + feature_fn=feature_fn, + debug=False) if position_fn == 'average': expected_positions = np.stack( @@ -117,8 +125,16 @@ def test_voxel_pooling_empty_point_set(ml, pos_dtype, feat_dtype, position_fn, features = np.zeros(shape=[0, 5], dtype=feat_dtype) voxel_size = 1 - ans = mltest.run_op(ml, ml.device, True, ml.ops.voxel_pooling, points, - features, voxel_size, position_fn, feature_fn) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.voxel_pooling, + positions=points, + features=features, + voxel_size=voxel_size, + position_fn=position_fn, + feature_fn=feature_fn, + debug=False) np.testing.assert_array_equal(points, ans.pooled_positions) np.testing.assert_array_equal(features, ans.pooled_features) @@ -154,16 +170,32 @@ def test_voxel_pooling_grad(ml, pos_dtype, feat_dtype, position_fn, feature_fn, voxel_size = 0.25 def fn(features): - ans = mltest.run_op(ml, ml.device, True, ml.ops.voxel_pooling, - positions, features, voxel_size, position_fn, - feature_fn) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.voxel_pooling, + positions=positions, + features=features, + voxel_size=voxel_size, + position_fn=position_fn, + feature_fn=feature_fn, + debug=False) return ans.pooled_features def fn_grad(features_bp, features): - return mltest.run_op_grad(ml, ml.device, True, ml.ops.voxel_pooling, - features, 'pooled_features', features_bp, - positions, features, voxel_size, position_fn, - feature_fn) + return mltest.run_op_grad(ml, + ml.device, + True, + ml.ops.voxel_pooling, + features, + 'pooled_features', + features_bp, + positions=positions, + features=features, + voxel_size=voxel_size, + position_fn=position_fn, + feature_fn=feature_fn, + debug=False) gradient_OK = check_gradients(features, fn, fn_grad, epsilon=1) assert gradient_OK diff --git a/util/ci_utils.sh b/util/ci_utils.sh index 0d798ca9139..1e4817aabd2 100644 --- a/util/ci_utils.sh +++ b/util/ci_utils.sh @@ -20,6 +20,7 @@ if [ -z "${BUILD_CUDA_MODULE:+x}" ]; then fi BUILD_TENSORFLOW_OPS=${BUILD_TENSORFLOW_OPS:-ON} BUILD_PYTORCH_OPS=${BUILD_PYTORCH_OPS:-ON} +BUILD_PADDLE_OPS=${BUILD_PADDLE_OPS:-ON} LOW_MEM_USAGE=${LOW_MEM_USAGE:-OFF} BUILD_SYCL_MODULE=${BUILD_SYCL_MODULE:-OFF} @@ -47,6 +48,8 @@ install_python_dependencies() { python -m pip install -U -r "${OPEN3D_SOURCE_ROOT}/python/requirements_test.txt" fi if [[ "with-cuda" =~ ^($options)$ ]]; then + PADDLE_GLNX=paddlepaddle-gpu + PADDLE_GLNX_PIP_INDEX=https://www.paddlepaddle.org.cn/packages/nightly/cu118/ TF_ARCH_NAME=tensorflow TF_ARCH_DISABLE_NAME=tensorflow-cpu CUDA_VER=$(nvcc --version | grep "release " | cut -c33-37 | sed 's|[^0-9]||g') # e.g.: 117, 118, 121, ... @@ -57,6 +60,8 @@ install_python_dependencies() { TF_ARCH_NAME=tensorflow TF_ARCH_DISABLE_NAME=tensorflow else + PADDLE_GLNX=paddlepaddle + PADDLE_GLNX_PIP_INDEX=https://www.paddlepaddle.org.cn/packages/nightly/cpu/ TF_ARCH_NAME=tensorflow-cpu TF_ARCH_DISABLE_NAME=tensorflow fi @@ -88,7 +93,16 @@ install_python_dependencies() { exit 1 fi fi - if [ "$BUILD_TENSORFLOW_OPS" == "ON" ] || [ "$BUILD_PYTORCH_OPS" == "ON" ]; then + if [ "$BUILD_PADDLE_OPS" == "ON" ]; then # ML/requirements-torch.txt + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + python -m pip uninstall paddlepaddle paddlepaddle-gpu -y + python -m pip install --pre "$PADDLE_GLNX" -i "$PADDLE_GLNX_PIP_INDEX" + else + echo "unknown OS $OSTYPE" + exit 1 + fi + fi + if [ "$BUILD_TENSORFLOW_OPS" == "ON" ] || [ "$BUILD_PYTORCH_OPS" == "ON" ] || [ "$BUILD_PADDLE_OPS" == "ON" ]; then python -m pip install -U -c "${OPEN3D_SOURCE_ROOT}/python/requirements_build.txt" yapf # Fix Protobuf compatibility issue # https://stackoverflow.com/a/72493690/1255535 @@ -121,6 +135,7 @@ build_all() { -DGLIBCXX_USE_CXX11_ABI=OFF -DBUILD_TENSORFLOW_OPS="$BUILD_TENSORFLOW_OPS" -DBUILD_PYTORCH_OPS="$BUILD_PYTORCH_OPS" + -DBUILD_PADDLE_OPS="$BUILD_ADDLE_OPS" -DCMAKE_INSTALL_PREFIX="$OPEN3D_INSTALL_DIR" -DBUILD_UNIT_TESTS=ON -DBUILD_BENCHMARKS=ON @@ -178,6 +193,8 @@ build_pip_package() { CXX11_ABI=$(python -c "import tensorflow as tf; print('ON' if tf.__cxx11_abi_flag__ else 'OFF')") elif [ "$BUILD_PYTORCH_OPS" == "ON" ]; then CXX11_ABI=$(python -c "import torch; print('ON' if torch._C._GLIBCXX_USE_CXX11_ABI else 'OFF')") + elif [ "$BUILD_PADDLE_OPS" == "ON" ]; then + CXX11_ABI="ON" fi echo Building with GLIBCXX_USE_CXX11_ABI="$CXX11_ABI" set -u @@ -194,6 +211,7 @@ build_pip_package() { "-DGLIBCXX_USE_CXX11_ABI=$CXX11_ABI" "-DBUILD_TENSORFLOW_OPS=$BUILD_TENSORFLOW_OPS" "-DBUILD_PYTORCH_OPS=$BUILD_PYTORCH_OPS" + "-DBUILD_PADDLE_OPS=$BUILD_PADDLE_OPS" "-DBUILD_FILAMENT_FROM_SOURCE=$BUILD_FILAMENT_FROM_SOURCE" "-DBUILD_JUPYTER_EXTENSION=$BUILD_JUPYTER_EXTENSION" "-DCMAKE_INSTALL_PREFIX=$OPEN3D_INSTALL_DIR" @@ -211,7 +229,11 @@ build_pip_package() { make VERBOSE=1 -j"$NPROC" pip-package mv lib/python_package/pip_package/open3d*.whl . # save CPU wheel - if [ "$BUILD_CUDA_MODULE" == ON ]; then + if [ "$BUILD_CUDA_MODULE" == "ON" ]; then + if [ "$BUILD_PADDLE_OPS" == "ON" ]; then + export LD_LIBRARY_PATH=/usr/local/cuda/compat/:${LD_LIBRARY_PATH} + fi + echo echo Installing CUDA versions of TensorFlow and PyTorch... install_python_dependencies with-cuda purge-cache @@ -267,6 +289,10 @@ test_wheel() { python -W default -c \ "import open3d.ml.torch; print('PyTorch Ops library loaded:', open3d.ml.torch._loaded)" fi + if [ "$BUILD_PADDLE_OPS" == ON ]; then + python -c \ + "import open3d.ml.paddle; print('Paddle Ops library loaded:', open3d.ml.paddle._loaded)" + fi if python -c "import sys, open3d; sys.exit(not open3d._build_config['BUILD_TENSORFLOW_OPS'])"; then BUILD_TENSORFLOW_OPS=ON python -m pip install -r "$OPEN3D_ML_ROOT/requirements-tensorflow.txt" @@ -289,10 +315,17 @@ run_python_tests() { python -m pip install -U -r "$OPEN3D_SOURCE_ROOT/python/requirements_test.txt" echo Add --randomly-seed=SEED to the test command to reproduce test order. pytest_args=("$OPEN3D_SOURCE_ROOT"/python/test/) - if [ "$BUILD_PYTORCH_OPS" == "OFF" ] && [ "$BUILD_TENSORFLOW_OPS" == "OFF" ]; then + if [ "$BUILD_PYTORCH_OPS" == "OFF" ] && [ "$BUILD_TENSORFLOW_OPS" == "OFF" ] && [ "$BUILD_PADDLE_OPS" == "OFF" ]; then echo Testing ML Ops disabled pytest_args+=(--ignore "$OPEN3D_SOURCE_ROOT"/python/test/ml_ops/) fi + + if [ "$BUILD_PADDLE_OPS" == "OFF" ]; then + pytest_args+=(--ignore "$OPEN3D_SOURCE_ROOT"/python/test/ml_ops/test_ragged_tensor_paddle.py) + elif [ "$BUILD_PYTORCH_OPS" == "OFF" ]; then + pytest_args+=(--ignore "$OPEN3D_SOURCE_ROOT"/python/test/ml_ops/test_ragged_tensor.py) + fi + python -m pytest "${pytest_args[@]}" deactivate open3d_test.venv # argument prevents unbound variable error rm -rf open3d_test.venv # cleanup for testing the next wheel