diff --git a/README.md b/README.md index f25ad89b0b..ee5b4a8009 100755 --- a/README.md +++ b/README.md @@ -369,7 +369,7 @@ If citing the k-selection routines, please consider the following bibtex: isbn = {9798400701092}, publisher = {Association for Computing Machinery}, address = {New York, NY, USA}, - location = {Denver, CO, USA} + location = {Denver, CO, USA}, series = {SC '23} } ``` diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index e4e3ea3512..93faf9dd19 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -423,24 +423,39 @@ void optimize(raft::resources const& res, const auto num_full = host_stats.data_handle()[1]; // Create pruned kNN graph - uint32_t max_detour = 0; -#pragma omp parallel for reduction(max : max_detour) +#pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { - if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } + // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable + // count of the neighbors while increasing the target detourable count from zero. + uint64_t pk = 0; + uint32_t num_detour = 0; + while (pk < output_graph_degree) { + uint32_t next_num_detour = std::numeric_limits::max(); for (uint64_t k = 0; k < input_graph_degree; k++) { - if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; } + const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)]; + // Find the detourable count to check in the next iteration + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } + + // Store the neighbor index if its detourable count is equal to `num_detour`. + if (num_detour_k != num_detour) { continue; } output_graph_ptr[pk + (output_graph_degree * i)] = input_graph_ptr[k + (input_graph_degree * i)]; pk += 1; if (pk >= output_graph_degree) break; } if (pk >= output_graph_degree) break; + + assert(next_num_detour != std::numeric_limits::max()); + num_detour = next_num_detour; } - assert(pk == output_graph_degree); + RAFT_EXPECTS(pk == output_graph_degree, + "Couldn't find the output_graph_degree (%u) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + static_cast(i)); } - // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); const double time_prune_end = cur_time(); RAFT_LOG_DEBUG(