From 7ce62f969c94598c64be7d4da3a0c66a9d1d183c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 25 Feb 2025 21:35:25 -0800 Subject: [PATCH 1/7] Expose kmeans to python --- cpp/CMakeLists.txt | 1 + cpp/include/cuvs/cluster/kmeans.h | 201 ++++++++++++++++++++++ cpp/include/cuvs/cluster/kmeans.hpp | 47 ++++- cpp/src/cluster/kmeans.cuh | 46 +++++ cpp/src/cluster/kmeans_c.cpp | 236 ++++++++++++++++++++++++++ python/cuvs/CMakeLists.txt | 1 + python/cuvs/cuvs/tests/test_kmeans.py | 70 ++++++++ 7 files changed, 598 insertions(+), 4 deletions(-) create mode 100644 cpp/include/cuvs/cluster/kmeans.h create mode 100644 cpp/src/cluster/kmeans_c.cpp create mode 100644 python/cuvs/cuvs/tests/test_kmeans.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 869854847..c8f867bb4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -688,6 +688,7 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENAB add_library( cuvs_c SHARED src/core/c_api.cpp + src/cluster/kmeans_c.cpp src/neighbors/brute_force_c.cpp src/neighbors/ivf_flat_c.cpp src/neighbors/ivf_pq_c.cpp diff --git a/cpp/include/cuvs/cluster/kmeans.h b/cpp/include/cuvs/cluster/kmeans.h new file mode 100644 index 000000000..2719f963e --- /dev/null +++ b/cpp/include/cuvs/cluster/kmeans.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +enum cuvsKMeansInitMethod { + /** + * Sample the centroids using the kmeans++ strategy + */ + KMeansPlusPlus, + + /** + * Sample the centroids uniformly at random + */ + Random, + + /** + * User provides the array of initial centroids + */ + Array +}; + +/** + * @brief Hyper-parameters for the kmeans algorithm + */ +struct cuvsKMeansParams { + cuvsDistanceType metric; + + /** + * The number of clusters to form as well as the number of centroids to generate (default:8). + */ + int n_clusters; + + /** + * Method for initialization, defaults to k-means++: + * - cuvsKMeansInitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm + * to select the initial cluster centers. + * - cuvsKMeansInitMethod::Random (random): Choose 'n_clusters' observations (rows) at + * random from the input data for the initial centroids. + * - cuvsKMeansInitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + */ + cuvsKMeansInitMethod init; + + /** + * Maximum number of iterations of the k-means algorithm for a single run. + */ + int max_iter; + + /** + * Relative tolerance with regards to inertia to declare convergence. + */ + double tol; + + /** + * Number of instance k-means algorithm will be run with different seeds. + */ + int n_init; + + /** + * Oversampling factor for use in the k-means|| algorithm + */ + double oversampling_factor; + + /** + * batch_samples and batch_centroids are used to tile 1NN computation which is + * useful to optimize/control the memory footprint + * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + * then don't tile the centroids + */ + int batch_samples; + + /** + * if 0 then batch_centroids = n_clusters + */ + int batch_centroids; + + bool inertia_check; + + // TODO: handle balanced kmeans +}; + +typedef struct cuvsKMeansParams* cuvsKMeansParams_t; + +/** + * @brief Allocate Scalar Quantizer params, and populate with default values + * + * @param[in] params cuvsKMeansParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params); + +/** + * @brief De-allocate Scalar Quantizer params + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params); + +/** + * @brief Find clusters with k-means algorithm. + * + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +cuvsError_t cuvsKMeansFit(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter); + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +cuvsError_t cuvsKMeansPredict(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia); + +/** + * @brief Compute cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * + */ +cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, + DLManagedTensor* X, + DLManagedTensor* centroids, + double* cost); +#ifdef __cplusplus +} +#endif diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 64ac813ab..7c6af27e1 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -104,10 +104,12 @@ struct params : base_params { */ double oversampling_factor = 2.0; - // batch_samples and batch_centroids are used to tile 1NN computation which is - // useful to optimize/control the memory footprint - // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 - // then don't tile the centroids + /** + * batch_samples and batch_centroids are used to tile 1NN computation which is + * useful to optimize/control the memory footprint + * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + * then don't tile the centroids + */ int batch_samples = 1 << 15; /** @@ -1089,6 +1091,43 @@ void transform(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_matrix_view X_new); + +/** + * @brief Compute cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * + */ +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost); + +/** + * @brief Compute cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * + */ +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost); + /** * @} */ diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 5e6d756cc..4115e4abe 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -465,6 +465,52 @@ void min_cluster_distance(raft::resources const& handle, workspace); } +template +void cluster_cost(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost) +{ + auto stream = raft::resource::get_cuda_stream(handle); + + auto n_clusters = centroids.extent(0); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + + rmm::device_uvector workspace(n_samples * sizeof(IndexT), stream); + + rmm::device_uvector x_norms(n_samples, stream); + rmm::device_uvector centroid_norms(n_clusters, stream); + raft::linalg::rowNorm( + x_norms.data(), X.data_handle(), n_features, n_samples, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm( + centroid_norms.data(), centroids, n_features, n_clusters, raft::linalg::L2Norm, true, stream); + + rmm::device_uvector min_cluster_distance(n_samples, stream); + rmm::device_uvector l2_norm_or_distance_buffer(0, stream); + + auto metric = cuvs::distance::DistanceType::L2Expanded; + + cuvs::cluster::kmeans::min_cluster_distance(handle, + X, + centroids, + min_cluster_distance, + x_norms, + l2_norm_or_distance_buffer, + metric, + n_samples, + n_clusters, + workspace); + + rmm::device_scalar device_cost(0, stream); + cuvs::cluster::kmeans::cluster_cost(handle, + min_cluster_distance.view(), + workspace, + raft::make_device_scalar_view(device_cost.data()), + raft::add_op{}); + raft::update_host(cost.data(), device_cost.data(), 1, stream); +} + /** * @brief Calculates a pair for every sample in input 'X' where key is an * index of one of the 'centroids' (index of the nearest centroid) and 'value' diff --git a/cpp/src/cluster/kmeans_c.cpp b/cpp/src/cluster/kmeans_c.cpp new file mode 100644 index 000000000..243c3db87 --- /dev/null +++ b/cpp/src/cluster/kmeans_c.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) +{ + auto kmeans_params = cuvs::cluster::kmeans::params(); + kmeans_params.metric = params.metric; + kmeans_params.init = static_cast(params.init); + kmeans_params.n_clusters = params.n_clusters; + kmeans_params.max_iter = params.max_iter; + kmeans_params.tol = params.tol; + kmeans_params.oversampling_factor = params.oversampling_factor; + kmeans_params.batch_samples = params.batch_samples; + kmeans_params.batch_centroids = params.batch_centroids; + kmeans_params.inertia_check = params.inertia_check; + return kmeans_params; +} + +template +void _fit(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + double* inertia, + int* n_iter) +{ + auto X = X_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + + auto kmeans_params = convert_params(params); + + T inertia_temp; + IdxT n_iter_temp; + + if (cuvs::core::is_dlpack_device_compatible(X)) { + using const_mdspan_type = raft::device_matrix_view; + using mdspan_type = raft::device_matrix_view; + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + sample_weight = + cuvs::core::from_dlpack>(sample_weight_tensor); + } + + cuvs::cluster::kmeans::fit(*res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + sample_weight, + cuvs::core::from_dlpack(centroids_tensor), + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + } else { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + *inertia = inertia_temp; + *n_iter = n_iter_temp; +} + +template +void _predict(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + DLManagedTensor* labels_tensor, + bool normalize_weight, + double* inertia) +{ + auto X = X_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + + auto kmeans_params = convert_params(params); + T inertia_temp; + + if (cuvs::core::is_dlpack_device_compatible(X)) { + using labels_mdspan_type = raft::device_vector_view; + using const_mdspan_type = raft::device_matrix_view; + using mdspan_type = raft::device_matrix_view; + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + sample_weight = + cuvs::core::from_dlpack>(sample_weight_tensor); + } + + cuvs::cluster::kmeans::predict(*res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + sample_weight, + cuvs::core::from_dlpack(centroids_tensor), + cuvs::core::from_dlpack(labels_tensor), + normalize_weight, + raft::make_host_scalar_view(&inertia_temp)); + } else { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + *inertia = inertia_temp; +} + +template +void _cluster_cost(cuvsResources_t res, + DLManagedTensor* X_tensor, + DLManagedTensor* centroids_tensor, + double* cost) +{ + auto X = X_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + + T cost_temp; + + if (cuvs::core::is_dlpack_device_compatible(X)) { + using mdspan_type = raft::device_matrix_view; + + cuvs::cluster::kmeans::cluster_cost(*res_ptr, + cuvs::core::from_dlpack(X_tensor), + cuvs::core::from_dlpack(centroids_tensor), + raft::make_host_scalar_view(&cost_temp)); + } else { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + *cost = cost_temp; +} +} // namespace + +extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) +{ + return cuvs::core::translate_exceptions([=] { + cuvs::cluster::kmeans::params cpp_params; + *params = new cuvsKMeansParams{.metric = cpp_params.metric, + .n_clusters = cpp_params.n_clusters, + .init = static_cast(cpp_params.init), + .max_iter = cpp_params.max_iter, + .tol = cpp_params.tol, + .oversampling_factor = cpp_params.oversampling_factor, + .batch_samples = cpp_params.batch_samples, + .inertia_check = cpp_params.inertia_check}; + }); +} + +extern "C" cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params) +{ + return cuvs::core::translate_exceptions([=] { delete params; }); +} + +extern "C" cuvsError_t cuvsKMeansFit(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsKMeansPredict(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _predict(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _predict( + res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, + DLManagedTensor* X, + DLManagedTensor* centroids, + double* cost) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _cluster_cost(res, X, centroids, cost); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _cluster_cost(res, X, centroids, cost); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} diff --git a/python/cuvs/CMakeLists.txt b/python/cuvs/CMakeLists.txt index 93946cfdb..91cf0d503 100644 --- a/python/cuvs/CMakeLists.txt +++ b/python/cuvs/CMakeLists.txt @@ -57,6 +57,7 @@ target_include_directories(cuvs::cuvs INTERFACE "$= 1 + assert np.allclose(cluster_cost(X, centroids), inertia, rtol=1e-6) + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [4, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = X[:n_clusters] + centroids_device = device_ndarray(centroids) + + inertia = cluster_cost(X_device, centroids_device) + + # compute the nearest centroid to each sample + distances = pairwise_distance( + X_device, centroids_device, metric="sqeuclidean" + ).copy_to_host() + cluster_ids = np.argmin(distances, axis=1) + + cluster_distances = np.take_along_axis( + distances, cluster_ids[:, None], axis=1 + ) + + # need reduced tolerance for float32 + tol = 1e-3 if dtype == np.float32 else 1e-6 + assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) From 5e85c0c4cf1854de0353aa8a1ac163495f465800 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 26 Feb 2025 10:06:16 -0800 Subject: [PATCH 2/7] add missing cython bindings --- python/cuvs/cuvs/cluster/CMakeLists.txt | 15 + python/cuvs/cuvs/cluster/__init__.pxd | 0 python/cuvs/cuvs/cluster/__init__.py | 18 ++ .../cuvs/cuvs/cluster/kmeans/CMakeLists.txt | 23 ++ python/cuvs/cuvs/cluster/kmeans/__init__.pxd | 0 python/cuvs/cuvs/cluster/kmeans/__init__.py | 18 ++ python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 65 +++++ python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 264 ++++++++++++++++++ python/cuvs/cuvs/tests/test_kmeans.py | 2 +- 9 files changed, 404 insertions(+), 1 deletion(-) create mode 100644 python/cuvs/cuvs/cluster/CMakeLists.txt create mode 100644 python/cuvs/cuvs/cluster/__init__.pxd create mode 100644 python/cuvs/cuvs/cluster/__init__.py create mode 100644 python/cuvs/cuvs/cluster/kmeans/CMakeLists.txt create mode 100644 python/cuvs/cuvs/cluster/kmeans/__init__.pxd create mode 100644 python/cuvs/cuvs/cluster/kmeans/__init__.py create mode 100644 python/cuvs/cuvs/cluster/kmeans/kmeans.pxd create mode 100644 python/cuvs/cuvs/cluster/kmeans/kmeans.pyx diff --git a/python/cuvs/cuvs/cluster/CMakeLists.txt b/python/cuvs/cuvs/cluster/CMakeLists.txt new file mode 100644 index 000000000..8f9b7e3b7 --- /dev/null +++ b/python/cuvs/cuvs/cluster/CMakeLists.txt @@ -0,0 +1,15 @@ +# ============================================================================= +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +add_subdirectory(kmeans) diff --git a/python/cuvs/cuvs/cluster/__init__.pxd b/python/cuvs/cuvs/cluster/__init__.pxd new file mode 100644 index 000000000..e69de29bb diff --git a/python/cuvs/cuvs/cluster/__init__.py b/python/cuvs/cuvs/cluster/__init__.py new file mode 100644 index 000000000..2d4c4c0d3 --- /dev/null +++ b/python/cuvs/cuvs/cluster/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cuvs.cluster import kmeans + +__all__ = ["kmeans"] diff --git a/python/cuvs/cuvs/cluster/kmeans/CMakeLists.txt b/python/cuvs/cuvs/cluster/kmeans/CMakeLists.txt new file mode 100644 index 000000000..ae2a91dbc --- /dev/null +++ b/python/cuvs/cuvs/cluster/kmeans/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources kmeans.pyx) +set(linked_libraries cuvs::cuvs cuvs::c_api) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX cluster_kmeans_ +) diff --git a/python/cuvs/cuvs/cluster/kmeans/__init__.pxd b/python/cuvs/cuvs/cluster/kmeans/__init__.pxd new file mode 100644 index 000000000..e69de29bb diff --git a/python/cuvs/cuvs/cluster/kmeans/__init__.py b/python/cuvs/cuvs/cluster/kmeans/__init__.py new file mode 100644 index 000000000..87a01f59d --- /dev/null +++ b/python/cuvs/cuvs/cluster/kmeans/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .kmeans import KMeansParams, cluster_cost, fit + +__all__ = ["KMeansParams", "cluster_cost", "fit"] diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd new file mode 100644 index 000000000..aa76c4f6a --- /dev/null +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -0,0 +1,65 @@ +# +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t +from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor +from cuvs.distance_type cimport cuvsDistanceType + + +cdef extern from "cuvs/cluster/kmeans.h" nogil: + ctypedef enum cuvsKMeansInitMethod: + KMeansPlusPlus + Random + Array + + ctypedef struct cuvsKMeansParams: + cuvsDistanceType metric, + int n_clusters, + cuvsKMeansInitMethod init, + int max_iter, + double tol, + int n_init, + double oversampling_factor, + int batch_samples, + int batch_centroids, + bool inertia_check, + + ctypedef cuvsKMeansParams* cuvsKMeansParams_t + + cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* index) + + cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t index) + + cuvsError_t cuvsKMeansFit(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor * centroids, + double * inertia, + int * n_iter) except + + + cuvsError_t cuvsKMeansPredict(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor * centroids, + DLManagedTensor * labels, + bool normalize_weight, + double * inertia) diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx new file mode 100644 index 000000000..1e18bd74c --- /dev/null +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -0,0 +1,264 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +from collections import namedtuple + +import numpy as np + +cimport cuvs.common.cydlpack + +from cuvs.common.resources import auto_sync_resources + +from cython.operator cimport dereference as deref +from libcpp cimport bool, cast +from libcpp.string cimport string + +from cuvs.common cimport cydlpack +from cuvs.distance_type cimport cuvsDistanceType + +from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray +from pylibraft.common.cai_wrapper import wrap_array +from pylibraft.common.interruptible import cuda_interruptible + +from cuvs.distance import DISTANCE_NAMES, DISTANCE_TYPES +from cuvs.neighbors.common import _check_input_array + +from libc.stdint cimport ( + int8_t, + int64_t, + uint8_t, + uint32_t, + uint64_t, + uintptr_t, +) + +from cuvs.common.exceptions import check_cuvs + + +cdef class KMeansParams: + """ + Hyper-parameters for the kmeans algorithm + + Parameters + ---------- + metric : str + String denoting the metric type. + n_clusters : int + The number of clusters to form as well as the number of centroids + to generate + max_iter : int + Maximum number of iterations of the k-means algorithm for a single run + tol : float + Relative tolerance with regards to inertia to declare convergence. + n_init : int + Number of instance k-means algorithm will be run with different seeds + oversampling_factor : double + Oversampling factor for use in the k-means|| algorithm + """ + + cdef cuvsKMeansParams* params + + def __cinit__(self): + cuvsKMeansParamsCreate(&self.params) + + def __dealloc__(self): + check_cuvs(cuvsKMeansParamsDestroy(self.params)) + + # TODO: initMethod + def __init__(self, *, + metric=None, + n_clusters=None, + max_iter=None, + tol=None, + n_init=None, + oversampling_factor=None): + if metric is not None: + self.params.metric = DISTANCE_TYPES[metric] + if n_clusters is not None: + self.params.n_clusters = n_clusters + if max_iter is not None: + self.params.max_iter = max_iter + if tol is not None: + self.params.tol = tol + if n_init is not None: + self.params.n_init = n_init + if oversampling_factor is not None: + self.params.oversampling_factor = oversampling_factor + + @property + def metric(self): + return DISTANCE_NAMES[self.params.metric] + + @property + def n_clusters(self): + return self.params.n_clusters + + @property + def max_iter(self): + return self.params.max_iter + + @property + def tol(self): + return self.params.tol + + @property + def n_init(self): + return self.params.n_init + + @property + def oversampling_factor(self): + return self.params.oversampling_factor + + +FitOutput = namedtuple("FitOutput", "centroids inertia n_iter") + + +@auto_sync_resources +@auto_convert_output +def fit( + KMeansParams params, X, centroids=None, sample_weights=None, resources=None +): + """ + Find clusters with the k-means algorithm + + Parameters + ---------- + + params : KMeansParams + Parameters to use to fit KMeans model + X : Input CUDA array interface compliant matrix shape (m, k) + centroids : Optional writable CUDA array interface compliant matrix + shape (n_clusters, k) + sample_weights : Optional input CUDA array interface compliant matrix shape + (n_clusters, 1) default: None + {resources_docstring} + + Returns + ------- + centroids : raft.device_ndarray + The computed centroids for each cluster + inertia : float + Sum of squared distances of samples to their closest cluster center + n_iter : int + The number of iterations used to fit the model + + Examples + -------- + + >>> import cupy as cp + >>> + >>> from cuvs.cluster.kmeans import fit, KMeansParams + >>> + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_clusters = 3 + >>> + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> params = KMeansParams(n_clusters=n_clusters) + >>> centroids, inertia, n_iter = fit(params, X) + """ + + x_ai = wrap_array(X) + _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) + cdef cydlpack.DLManagedTensor* sample_weight_dlpack = NULL + + cdef cuvsResources_t res = resources.get_c_obj() + + cdef double inertia = 0 + cdef int n_iter = 0 + + if centroids is None: + centroids = device_ndarray.empty((params.n_clusters, x_ai.shape[1]), + dtype=x_ai.dtype) + + centroids_ai = wrap_array(centroids) + cdef cydlpack.DLManagedTensor * centroids_dlpack = \ + cydlpack.dlpack_c(centroids_ai) + + if sample_weights is not None: + sample_weight_dlpack = cydlpack.dlpack_c(wrap_array(sample_weights)) + + with cuda_interruptible(): + check_cuvs(cuvsKMeansFit( + res, + params.params, + x_dlpack, + sample_weight_dlpack, + centroids_dlpack, + &inertia, + &n_iter)) + + return FitOutput(centroids, inertia, n_iter) + + +@auto_sync_resources +@auto_convert_output +def cluster_cost(X, centroids, resources=None): + """ + Compute cluster cost given an input matrix and existing centroids + + Parameters + ---------- + X : Input CUDA array interface compliant matrix shape (m, k) + centroids : Input CUDA array interface compliant matrix shape + (n_clusters, k) + {resources_docstring} + + Examples + -------- + + >>> import cupy as cp + >>> + >>> from cuvs.cluster.kmeans import cluster_cost + >>> + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_clusters = 3 + >>> + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> centroids = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + + >>> inertia = cluster_cost(X, centroids) + """ + + x_ai = wrap_array(X) + _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) + + centroids_ai = wrap_array(centroids) + _check_input_array(centroids_ai, [np.dtype('float32'), + np.dtype('float64')]) + cdef cydlpack.DLManagedTensor* centroids_dlpack = \ + cydlpack.dlpack_c(centroids_ai) + + cdef double inertia = 0 + + with cuda_interruptible(): + check_cuvs(cuvsKMeansClusterCost( + res, + x_dlpack, + centroids_dlpack, + &inertia)) + + return inertia diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index b694125db..616bb56b3 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From bf0cbd24332e0d68c66a9f6d3a704a9bd1f7ec0c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 27 Feb 2025 12:14:21 -0800 Subject: [PATCH 3/7] fix cluster_cost --- cpp/CMakeLists.txt | 1 + cpp/include/cuvs/cluster/kmeans.h | 4 +- cpp/src/cluster/kmeans.cuh | 45 +++++++++++++--------- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 5 +++ python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 1 + python/cuvs/cuvs/tests/test_kmeans.py | 3 +- 6 files changed, 38 insertions(+), 21 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c8f867bb4..c6b534ca9 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -313,6 +313,7 @@ if(BUILD_SHARED_LIBS) add_library( cuvs_objs OBJECT src/cluster/kmeans_balanced_fit_float.cu + src/cluster/kmeans_cluster_cost.cu src/cluster/kmeans_fit_mg_float.cu src/cluster/kmeans_fit_mg_double.cu src/cluster/kmeans_fit_double.cu diff --git a/cpp/include/cuvs/cluster/kmeans.h b/cpp/include/cuvs/cluster/kmeans.h index 2719f963e..2d87cc7eb 100644 --- a/cpp/include/cuvs/cluster/kmeans.h +++ b/cpp/include/cuvs/cluster/kmeans.h @@ -104,7 +104,7 @@ struct cuvsKMeansParams { typedef struct cuvsKMeansParams* cuvsKMeansParams_t; /** - * @brief Allocate Scalar Quantizer params, and populate with default values + * @brief Allocate KMeans params, and populate with default values * * @param[in] params cuvsKMeansParams_t to allocate * @return cuvsError_t @@ -112,7 +112,7 @@ typedef struct cuvsKMeansParams* cuvsKMeansParams_t; cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params); /** - * @brief De-allocate Scalar Quantizer params + * @brief De-allocate KMeans params * * @param[in] params * @return cuvsError_t diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 4115e4abe..d78c6ed0b 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -483,32 +483,41 @@ void cluster_cost(raft::resources const& handle, rmm::device_uvector centroid_norms(n_clusters, stream); raft::linalg::rowNorm( x_norms.data(), X.data_handle(), n_features, n_samples, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm( - centroid_norms.data(), centroids, n_features, n_clusters, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(centroid_norms.data(), + centroids.data_handle(), + n_features, + n_clusters, + raft::linalg::L2Norm, + true, + stream); rmm::device_uvector min_cluster_distance(n_samples, stream); rmm::device_uvector l2_norm_or_distance_buffer(0, stream); auto metric = cuvs::distance::DistanceType::L2Expanded; - cuvs::cluster::kmeans::min_cluster_distance(handle, - X, - centroids, - min_cluster_distance, - x_norms, - l2_norm_or_distance_buffer, - metric, - n_samples, - n_clusters, - workspace); + cuvs::cluster::kmeans::min_cluster_distance( + handle, + X, + raft::make_device_matrix_view( + const_cast(centroids.data_handle()), n_clusters, n_features), + raft::make_device_vector_view(min_cluster_distance.data(), n_samples), + raft::make_device_vector_view(x_norms.data(), n_samples), + l2_norm_or_distance_buffer, + metric, + n_samples, + n_clusters, + workspace); rmm::device_scalar device_cost(0, stream); - cuvs::cluster::kmeans::cluster_cost(handle, - min_cluster_distance.view(), - workspace, - raft::make_device_scalar_view(device_cost.data()), - raft::add_op{}); - raft::update_host(cost.data(), device_cost.data(), 1, stream); + + cuvs::cluster::kmeans::cluster_cost( + handle, + raft::make_device_vector_view(min_cluster_distance.data(), n_samples), + workspace, + raft::make_device_scalar_view(device_cost.data()), + raft::add_op{}); + raft::update_host(cost.data_handle(), device_cost.data(), 1, stream); } /** diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index aa76c4f6a..f53dc9905 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -63,3 +63,8 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: DLManagedTensor * labels, bool normalize_weight, double * inertia) + + cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, + DLManagedTensor* X, + DLManagedTensor* centroids, + double* cost) diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 1e18bd74c..f34c317ab 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -253,6 +253,7 @@ def cluster_cost(X, centroids, resources=None): cydlpack.dlpack_c(centroids_ai) cdef double inertia = 0 + cdef cuvsResources_t res = resources.get_c_obj() with cuda_interruptible(): check_cuvs(cuvsKMeansClusterCost( diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 616bb56b3..7c465a30d 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -21,6 +21,7 @@ from cuvs.distance import pairwise_distance +@pytest.mark.parametrize("n_rows", [100]) @pytest.mark.parametrize("n_cols", [5, 25]) @pytest.mark.parametrize("n_clusters", [5, 15]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -33,7 +34,7 @@ def test_kmeans_fit(n_rows, n_cols, n_clusters, dtype): # compute the inertia, before fitting centroids original_inertia = cluster_cost(X, centroids) - params = KMeansParams(n_clusters=n_clusters, seed=42) + params = KMeansParams(n_clusters=n_clusters) # fit the centroids, make sure inertia has gone down centroids, inertia, n_iter = fit(params, X, centroids) From 599caac88e2ed2554a5fa81302da043f2544b3fa Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 27 Feb 2025 12:45:09 -0800 Subject: [PATCH 4/7] Expose CAGRA persistent mode to python --- cpp/include/cuvs/neighbors/cagra.h | 24 ++++++++++++++++++++ cpp/src/neighbors/cagra_c.cpp | 26 ++++++++++++++-------- python/cuvs/cuvs/neighbors/cagra/cagra.pxd | 3 +++ python/cuvs/cuvs/neighbors/cagra/cagra.pyx | 18 ++++++++++++++- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index f3dc0d5c2..7d04fab2b 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -256,6 +256,30 @@ struct cuvsCagraSearchParams { uint32_t num_random_samplings; /** Bit mask used for initial random seed node selection. */ uint64_t rand_xor_mask; + + /** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */ + bool persistent = false; + /** Persistent kernel: time in seconds before the kernel stops if no requests received. */ + float persistent_lifetime = 2; + /** + * Set the fraction of maximum grid size used by persistent kernel. + * Value 1.0 means the kernel grid size is maximum possible for the selected device. + * The value must be greater than 0.0 and not greater than 1.0. + * + * One may need to run other kernels alongside this persistent kernel. This parameter can + * be used to reduce the grid size of the persistent kernel to leave a few SMs idle. + * Note: running any other work on GPU alongside with the persistent kernel makes the setup + * fragile. + * - Running another kernel in another thread usually works, but no progress guaranteed + * - Any CUDA allocations block the context (this issue may be obscured by using pools) + * - Memory copies to not-pinned host memory may block the context + * + * Even when we know there are no other kernels working at the same time, setting + * kDeviceUsage to 1.0 surprisingly sometimes hurts performance. Proceed with care. + * If you suspect this is an issue, you can reduce this number to ~0.9 without a significant + * impact on the throughput. + */ + float persistent_device_usage; }; typedef struct cuvsCagraSearchParams* cuvsCagraSearchParams_t; diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index 9b86072ef..5b42a2f47 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -152,10 +152,13 @@ void _search(cuvsResources_t res, search_params.min_iterations = params.min_iterations; search_params.thread_block_size = params.thread_block_size; search_params.hashmap_mode = static_cast(params.hashmap_mode); - search_params.hashmap_min_bitlen = params.hashmap_min_bitlen; - search_params.hashmap_max_fill_rate = params.hashmap_max_fill_rate; - search_params.num_random_samplings = params.num_random_samplings; - search_params.rand_xor_mask = params.rand_xor_mask; + search_params.hashmap_min_bitlen = params.hashmap_min_bitlen; + search_params.hashmap_max_fill_rate = params.hashmap_max_fill_rate; + search_params.num_random_samplings = params.num_random_samplings; + search_params.rand_xor_mask = params.rand_xor_mask; + search_params.persistent = params.persistent; + search_params.persistent_lifetime = params.persistent_lifetime; + search_params.persistent_device_usage = params.persistent_device_usage; using queries_mdspan_type = raft::device_matrix_view; using neighbors_mdspan_type = raft::device_matrix_view; @@ -395,11 +398,16 @@ extern "C" cuvsError_t cuvsCagraExtendParamsDestroy(cuvsCagraExtendParams_t para extern "C" cuvsError_t cuvsCagraSearchParamsCreate(cuvsCagraSearchParams_t* params) { return cuvs::core::translate_exceptions([=] { - *params = new cuvsCagraSearchParams{.itopk_size = 64, - .search_width = 1, - .hashmap_max_fill_rate = 0.5, - .num_random_samplings = 1, - .rand_xor_mask = 0x128394}; + *params = new cuvsCagraSearchParams{ + .itopk_size = 64, + .search_width = 1, + .hashmap_max_fill_rate = 0.5, + .num_random_samplings = 1, + .rand_xor_mask = 0x128394, + .persistent = false, + .persistent_lifetime = 2, + .persistent_device_usage = 1.0, + }; }); } diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd index 41d74dbc7..f6435a5b2 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd @@ -84,6 +84,9 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: float hashmap_max_fill_rate uint32_t num_random_samplings uint64_t rand_xor_mask + bool persistent + float persistent_lifetime + float persistent_device_usage ctypedef struct cuvsCagraIndex: uintptr_t addr diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index 56a7c061b..c31df724f 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -375,6 +375,13 @@ cdef class SearchParams: more. rand_xor_mask: int, default = 0x128394 Bit mask used for initial random seed node selection. + persistent: bool, default = false + Whether to use the persistent version of the kernel + persistent_lifetime: float + Persistent kernel: time in seconds before the kernel stops if no + requests are received. + persistent_device_usage : float + Sets the fraction of maximum grid size used by persistent kernel. """ cdef cuvsCagraSearchParams params @@ -392,7 +399,11 @@ cdef class SearchParams: hashmap_min_bitlen=0, hashmap_max_fill_rate=0.5, num_random_samplings=1, - rand_xor_mask=0x128394): + rand_xor_mask=0x128394, + persistent=False, + persistent_lifetime=None, + persistent_device_usage=None + ): self.params.max_queries = max_queries self.params.itopk_size = itopk_size self.params.max_iterations = max_iterations @@ -424,6 +435,11 @@ cdef class SearchParams: self.params.hashmap_max_fill_rate = hashmap_max_fill_rate self.params.num_random_samplings = num_random_samplings self.params.rand_xor_mask = rand_xor_mask + self.params.persistent = persistent + if persistent_lifetime is not None: + self.params.persistent_lifetime = persistent_lifetime + if persistent_device_usage is not None: + self.params.persistent_device_usage = persistent_device_usage def __repr__(self): attr_str = [attr + "=" + str(getattr(self, attr)) From d1a28038b0d9572f4f922a0184dece89b0b5600a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 27 Feb 2025 14:19:55 -0800 Subject: [PATCH 5/7] add missing file --- cpp/src/cluster/kmeans_cluster_cost.cu | 37 ++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 cpp/src/cluster/kmeans_cluster_cost.cu diff --git a/cpp/src/cluster/kmeans_cluster_cost.cu b/cpp/src/cluster/kmeans_cluster_cost.cu new file mode 100644 index 000000000..6fe7e4d44 --- /dev/null +++ b/cpp/src/cluster/kmeans_cluster_cost.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kmeans.cuh" +// #include +#include + +namespace cuvs::cluster::kmeans { +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost) +{ + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost); +} + +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost) +{ + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost); +} +} // namespace cuvs::cluster::kmeans From 8be33733f4f097cdb7dcc1cba8fb06d3cad5fa5f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 27 Feb 2025 14:21:22 -0800 Subject: [PATCH 6/7] . --- cpp/src/cluster/kmeans_cluster_cost.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/cluster/kmeans_cluster_cost.cu b/cpp/src/cluster/kmeans_cluster_cost.cu index 6fe7e4d44..d55569255 100644 --- a/cpp/src/cluster/kmeans_cluster_cost.cu +++ b/cpp/src/cluster/kmeans_cluster_cost.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #include "kmeans.cuh" -// #include +#include #include namespace cuvs::cluster::kmeans { From 13b21d355640fb4c44147d7655c9fc41e4a7dd1f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 27 Feb 2025 15:47:23 -0800 Subject: [PATCH 7/7] fix c-api for cagra persistent --- cpp/include/cuvs/neighbors/cagra.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index 7d04fab2b..64ac15c97 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -258,9 +258,9 @@ struct cuvsCagraSearchParams { uint64_t rand_xor_mask; /** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */ - bool persistent = false; + bool persistent; /** Persistent kernel: time in seconds before the kernel stops if no requests received. */ - float persistent_lifetime = 2; + float persistent_lifetime; /** * Set the fraction of maximum grid size used by persistent kernel. * Value 1.0 means the kernel grid size is maximum possible for the selected device.