diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 4203f0969b..82659fd202 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -43,6 +43,7 @@ #include #include +#include #include #include @@ -443,13 +444,12 @@ void kmeans_fit_main(raft::resources const& handle, params.batch_centroids, workspace); - // Using TransformInputIteratorT to dereference an array of + // Using thrust::transform_iterator to dereference an array of // raft::KeyValuePair and converting them to just return the Key to be used // in reduce_rows_by_key prims detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> + thrust::transform_iterator, + raft::KeyValuePair*> itr(minClusterAndDistance.data_handle(), conversion_op); update_centroids(handle, diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 7e78f4b5a0..92a2f6cde1 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -47,6 +47,7 @@ #include #include +#include #include #include @@ -288,7 +289,8 @@ void calc_centers_and_sizes(const raft::resources& handle, dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); } else { // todo(lsugy): use iterator from KV output of fusedL2NN - cub::TransformInputIterator mapping_itr(dataset, mapping_op); + thrust::transform_iterator mapping_itr( + dataset, mapping_op); raft::linalg::reduce_rows_by_key( mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); } @@ -894,7 +896,8 @@ auto build_fine_clusters(const raft::resources& handle, "Number of fine clusters must be non-zero for a non-empty mesocluster"); } - cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); + thrust::transform_iterator mapping_itr( + dataset_mptr, mapping_op); raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); if (params.metric == raft::distance::DistanceType::L2Expanded || params.metric == raft::distance::DistanceType::L2SqrtExpanded) { diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 8263aa4615..e8fbe7cd73 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -43,6 +43,7 @@ #include #include #include +#include #include #include @@ -199,8 +200,8 @@ void computeClusterCost(raft::resources const& handle, { cudaStream_t stream = resource::get_cuda_stream(handle); - cub::TransformInputIterator itr(minClusterDistance.data_handle(), - main_op); + thrust::transform_iterator itr( + minClusterDistance.data_handle(), main_op); size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, @@ -641,13 +642,12 @@ void countSamplesInCluster(raft::resources const& handle, params.batch_centroids, workspace); - // Using TransformInputIteratorT to dereference an array of raft::KeyValuePair + // Using thrust::transform_iterator to dereference an array of raft::KeyValuePair // and converting them to just return the Key to be used in reduce_rows_by_key // prims detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> + thrust::transform_iterator, + raft::KeyValuePair*> itr(minClusterAndDistance.data_handle(), conversion_op); // count # of samples in each cluster diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 24574642ef..a9306e808b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -53,6 +53,7 @@ #include #include +#include #include #include @@ -180,8 +181,8 @@ void select_residuals(raft::resources const& handle, rmm::device_uvector tmp(size_t(n_rows) * size_t(dim), stream, device_memory); // Note: the number of rows of the input dataset isn't actually n_rows, but matrix::gather doesn't // need to know it, any strictly positive number would work. - cub::TransformInputIterator, const T*> mapping_itr( - dataset, utils::mapping{}); + thrust::transform_iterator, const T*> mapping_itr(dataset, + utils::mapping{}); raft::matrix::gather(mapping_itr, (IdxT)dim, n_rows, row_ids, n_rows, tmp.data(), stream); raft::matrix::linewise_op(handle,