Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop cub::TransformInputIterator in favor of thrust::transform_iterator #2588

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

#include <cuda.h>
#include <thrust/fill.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>

#include <algorithm>
Expand Down Expand Up @@ -443,13 +444,12 @@ void kmeans_fit_main(raft::resources const& handle,
params.batch_centroids,
workspace);

// Using TransformInputIteratorT to dereference an array of
// Using thrust::transform_iterator to dereference an array of
// raft::KeyValuePair and converting them to just return the Key to be used
// in reduce_rows_by_key prims
detail::KeyValueIndexOp<IndexT, DataT> conversion_op;
cub::TransformInputIterator<IndexT,
detail::KeyValueIndexOp<IndexT, DataT>,
raft::KeyValuePair<IndexT, DataT>*>
thrust::transform_iterator<detail::KeyValueIndexOp<IndexT, DataT>,
raft::KeyValuePair<IndexT, DataT>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

update_centroids(handle,
Expand Down
7 changes: 5 additions & 2 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <rmm/resource_ref.hpp>

#include <thrust/gather.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>

#include <limits>
Expand Down Expand Up @@ -288,7 +289,8 @@ void calc_centers_and_sizes(const raft::resources& handle,
dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters);
} else {
// todo(lsugy): use iterator from KV output of fusedL2NN
cub::TransformInputIterator<MathT, MappingOpT, const T*> mapping_itr(dataset, mapping_op);
thrust::transform_iterator<MappingOpT, const T*, thrust::use_default, MathT> mapping_itr(
dataset, mapping_op);
raft::linalg::reduce_rows_by_key(
mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters);
}
Expand Down Expand Up @@ -894,7 +896,8 @@ auto build_fine_clusters(const raft::resources& handle,
"Number of fine clusters must be non-zero for a non-empty mesocluster");
}

cub::TransformInputIterator<MathT, MappingOpT, const T*> mapping_itr(dataset_mptr, mapping_op);
thrust::transform_iterator<MappingOpT, const T*, thrust::use_default, MathT> mapping_itr(
dataset_mptr, mapping_op);
raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream);
if (params.metric == raft::distance::DistanceType::L2Expanded ||
params.metric == raft::distance::DistanceType::L2SqrtExpanded) {
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <cuda.h>
#include <thrust/fill.h>
#include <thrust/for_each.h>
#include <thrust/iterator/transform_iterator.h>

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -199,8 +200,8 @@ void computeClusterCost(raft::resources const& handle,
{
cudaStream_t stream = resource::get_cuda_stream(handle);

cub::TransformInputIterator<OutputT, MainOpT, InputT*> itr(minClusterDistance.data_handle(),
main_op);
thrust::transform_iterator<MainOpT, InputT*, thrust::use_default, OutputT> itr(
minClusterDistance.data_handle(), main_op);

size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr,
Expand Down Expand Up @@ -641,13 +642,12 @@ void countSamplesInCluster(raft::resources const& handle,
params.batch_centroids,
workspace);

// Using TransformInputIteratorT to dereference an array of raft::KeyValuePair
// Using thrust::transform_iterator to dereference an array of raft::KeyValuePair
// and converting them to just return the Key to be used in reduce_rows_by_key
// prims
detail::KeyValueIndexOp<IndexT, DataT> conversion_op;
cub::TransformInputIterator<IndexT,
detail::KeyValueIndexOp<IndexT, DataT>,
raft::KeyValuePair<IndexT, DataT>*>
thrust::transform_iterator<detail::KeyValueIndexOp<IndexT, DataT>,
raft::KeyValuePair<IndexT, DataT>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

// count # of samples in each cluster
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

#include <cuda_fp16.h>
#include <thrust/extrema.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/scan.h>

#include <memory>
Expand Down Expand Up @@ -180,8 +181,8 @@ void select_residuals(raft::resources const& handle,
rmm::device_uvector<float> tmp(size_t(n_rows) * size_t(dim), stream, device_memory);
// Note: the number of rows of the input dataset isn't actually n_rows, but matrix::gather doesn't
// need to know it, any strictly positive number would work.
cub::TransformInputIterator<float, utils::mapping<float>, const T*> mapping_itr(
dataset, utils::mapping<float>{});
thrust::transform_iterator<utils::mapping<float>, const T*> mapping_itr(dataset,
utils::mapping<float>{});
raft::matrix::gather(mapping_itr, (IdxT)dim, n_rows, row_ids, n_rows, tmp.data(), stream);

raft::matrix::linewise_op(handle,
Expand Down
Loading