Skip to content

Commit

Permalink
Change rmm::exec_policy to take async_resource_ref (#1449)
Browse files Browse the repository at this point in the history
Authors:
  - Michael Schellenberger Costa (https://github.com/miscco)

Approvers:
  - Mark Harris (https://github.com/harrism)

URL: #1449
  • Loading branch information
miscco authored Feb 2, 2024
1 parent b85f482 commit beb6c36
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions include/rmm/exec_policy.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/thrust_allocator_adaptor.hpp>
#include <rmm/resource_ref.hpp>

#include <rmm/detail/thrust_namespace.h>
#include <thrust/system/cuda/execution_policy.h>
Expand All @@ -39,7 +40,7 @@ namespace rmm {
* @brief Synchronous execution policy for allocations using thrust
*/
using thrust_exec_policy_t =
thrust::detail::execute_with_allocator<rmm::mr::thrust_allocator<char>,
thrust::detail::execute_with_allocator<mr::thrust_allocator<char>,
thrust::cuda_cub::execute_on_stream_base>;

/**
Expand All @@ -54,10 +55,10 @@ class exec_policy : public thrust_exec_policy_t {
* @param stream The stream on which to allocate temporary memory
* @param mr The resource to use for allocating temporary memory
*/
explicit exec_policy(cuda_stream_view stream = cuda_stream_default,
rmm::mr::device_memory_resource* mr = mr::get_current_device_resource())
explicit exec_policy(cuda_stream_view stream = cuda_stream_default,
device_async_resource_ref mr = mr::get_current_device_resource())
: thrust_exec_policy_t(
thrust::cuda::par(rmm::mr::thrust_allocator<char>(stream, mr)).on(stream.value()))
thrust::cuda::par(mr::thrust_allocator<char>(stream, mr)).on(stream.value()))
{
}
};
Expand All @@ -68,7 +69,7 @@ class exec_policy : public thrust_exec_policy_t {
* @brief Asynchronous execution policy for allocations using thrust
*/
using thrust_exec_policy_nosync_t =
thrust::detail::execute_with_allocator<rmm::mr::thrust_allocator<char>,
thrust::detail::execute_with_allocator<mr::thrust_allocator<char>,
thrust::cuda_cub::execute_on_stream_nosync_base>;
/**
* @brief Helper class usable as a Thrust CUDA execution policy
Expand All @@ -78,11 +79,10 @@ using thrust_exec_policy_nosync_t =
*/
class exec_policy_nosync : public thrust_exec_policy_nosync_t {
public:
explicit exec_policy_nosync(
cuda_stream_view stream = cuda_stream_default,
rmm::mr::device_memory_resource* mr = mr::get_current_device_resource())
explicit exec_policy_nosync(cuda_stream_view stream = cuda_stream_default,
device_async_resource_ref mr = mr::get_current_device_resource())
: thrust_exec_policy_nosync_t(
thrust::cuda::par_nosync(rmm::mr::thrust_allocator<char>(stream, mr)).on(stream.value()))
thrust::cuda::par_nosync(mr::thrust_allocator<char>(stream, mr)).on(stream.value()))
{
}
};
Expand Down

0 comments on commit beb6c36

Please sign in to comment.