Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix thrust::raw_reference_cast for tuple_of_iterator_references and simplify thrust::generate #3970

Merged
merged 10 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions thrust/testing/generate_const_iterators.cu

This file was deleted.

48 changes: 48 additions & 0 deletions thrust/testing/raw_reference_cast.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include <thrust/detail/raw_reference_cast.h>
#include <thrust/device_vector.h>
#include <thrust/iterator/zip_iterator.h>

#include <unittest/unittest.h>

void TestRawReferenceCast()
{
using ::cuda::std::is_same_v;

{
[[maybe_unused]] int i = 42;
[[maybe_unused]] const int ci = 42;
static_assert(is_same_v<decltype(thrust::raw_reference_cast(i)), int&>);
static_assert(is_same_v<decltype(thrust::raw_reference_cast(ci)), const int&>);
}
{
[[maybe_unused]] thrust::host_vector<int> vec(1);
static_assert(is_same_v<decltype(thrust::raw_reference_cast(*vec.begin())), int&>);
static_assert(is_same_v<decltype(thrust::raw_reference_cast(*vec.cbegin())), const int&>);

[[maybe_unused]] auto zip = thrust::make_zip_iterator(vec.begin(), vec.begin());
static_assert(
is_same_v<decltype(thrust::raw_reference_cast(*zip)), thrust::detail::tuple_of_iterator_references<int&, int&>>);
Comment on lines +22 to +24
Copy link
Contributor Author

@bernhardmgruber bernhardmgruber Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thrust::raw_reference_cast(*zip)) would return const thrust::detail::tuple_of_iterator_references<int&, int&>& (const and ref qualified) before this PR, if vec is a host vector. It worked correctly for a device vector.


[[maybe_unused]] auto zip2 = thrust::make_zip_iterator(zip, zip);
static_assert(
is_same_v<decltype(thrust::raw_reference_cast(*zip2)),
thrust::detail::tuple_of_iterator_references<thrust::detail::tuple_of_iterator_references<int&, int&>,
thrust::detail::tuple_of_iterator_references<int&, int&>>>);
}
{
[[maybe_unused]] thrust::device_vector<int> vec(1);
static_assert(is_same_v<decltype(thrust::raw_reference_cast(*vec.begin())), int&>);
static_assert(is_same_v<decltype(thrust::raw_reference_cast(*vec.cbegin())), const int&>);

[[maybe_unused]] auto zip = thrust::make_zip_iterator(vec.begin(), vec.begin());
static_assert(
is_same_v<decltype(thrust::raw_reference_cast(*zip)), thrust::detail::tuple_of_iterator_references<int&, int&>>);

[[maybe_unused]] auto zip2 = thrust::make_zip_iterator(zip, zip);
static_assert(
is_same_v<decltype(thrust::raw_reference_cast(*zip2)),
thrust::detail::tuple_of_iterator_references<thrust::detail::tuple_of_iterator_references<int&, int&>,
thrust::detail::tuple_of_iterator_references<int&, int&>>>);
}
}
DECLARE_UNITTEST(TestRawReferenceCast);
115 changes: 0 additions & 115 deletions thrust/testing/unittest/runtime_static_assert.h

This file was deleted.

10 changes: 0 additions & 10 deletions thrust/testing/unittest_static_assert.cmake

This file was deleted.

33 changes: 0 additions & 33 deletions thrust/testing/unittest_static_assert.cu

This file was deleted.

73 changes: 0 additions & 73 deletions thrust/thrust/detail/internal_functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,79 +112,6 @@ struct tuple_not_binary_predicate
mutable Predicate pred;
};

template <typename Generator>
struct host_generate_functor
{
using result_type = void;

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE host_generate_functor(Generator g)
: gen(g)
{}

// operator() does not take an lvalue reference because some iterators
// produce temporary proxy references when dereferenced. for example,
// consider the temporary tuple of references produced by zip_iterator.
// such temporaries cannot bind to an lvalue reference.
//
// to WAR this, accept a const reference (which is bindable to a temporary),
// and const_cast in the implementation.
//
// XXX change to an rvalue reference upon c++0x (which either a named variable
// or temporary can bind to)
template <typename T>
_CCCL_HOST void operator()(const T& x)
{
// we have to be naughty and const_cast this to get it to work
T& lvalue = const_cast<T&>(x);

// this assigns correctly whether x is a true reference or proxy
lvalue = gen();
}

Generator gen;
};

template <typename Generator>
struct device_generate_functor
{
using result_type = void;

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE device_generate_functor(Generator g)
: gen(g)
{}

// operator() does not take an lvalue reference because some iterators
// produce temporary proxy references when dereferenced. for example,
// consider the temporary tuple of references produced by zip_iterator.
// such temporaries cannot bind to an lvalue reference.
//
// to WAR this, accept a const reference (which is bindable to a temporary),
// and const_cast in the implementation.
//
// XXX change to an rvalue reference upon c++0x (which either a named variable
// or temporary can bind to)
template <typename T>
_CCCL_HOST_DEVICE void operator()(const T& x)
{
// we have to be naughty and const_cast this to get it to work
T& lvalue = const_cast<T&>(x);

// this assigns correctly whether x is a true reference or proxy
lvalue = gen();
}

Generator gen;
};

template <typename System, typename Generator>
struct generate_functor
: thrust::detail::eval_if<::cuda::std::is_convertible<System, thrust::host_system_tag>::value,
thrust::detail::identity_<host_generate_functor<Generator>>,
thrust::detail::identity_<device_generate_functor<Generator>>>
{};

template <typename T>
struct is_non_const_reference
: ::cuda::std::_And<thrust::detail::not_<::cuda::std::is_const<T>>,
Expand Down
Loading
Loading