-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
1,382 additions
and
1,329 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,57 @@ | ||
#include <tiledb/tiledb> | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "linalg.h" | ||
#include "flat_query.h" | ||
#include "ivf_index.h" | ||
#include "ivf_query.h" | ||
#include "flat_query.h" | ||
#include "linalg.h" | ||
|
||
namespace py = pybind11; | ||
using Ctx = tiledb::Context; | ||
|
||
namespace { | ||
|
||
template <typename T, typename shuffled_ids_type = uint64_t> | ||
static void declare_kmeans(py::module& m, const std::string& suffix) { | ||
template<typename T, typename shuffled_ids_type = uint64_t> | ||
static void declare_kmeans(py::module &m, const std::string &suffix) { | ||
m.def(("kmeans_fit_" + suffix).c_str(), | ||
[](size_t n_clusters, | ||
std::string init, | ||
size_t max_iter, | ||
bool verbose, | ||
size_t n_init, | ||
const ColMajorMatrix<T>& sample_vectors, | ||
const ColMajorMatrix<T> &sample_vectors, | ||
std::optional<double> tol, | ||
std::optional<unsigned int> seed, | ||
std::optional<size_t> nthreads) { | ||
// TODO: support verbose and n_init | ||
std::ignore = verbose; | ||
std::ignore = n_init; | ||
kmeans_init init_val; | ||
if (init == "k-means++") { | ||
init_val = kmeans_init::kmeanspp; | ||
} else if (init == "random") { | ||
init_val = kmeans_init::random; | ||
} else { | ||
throw std::invalid_argument("Invalid init method"); | ||
} | ||
kmeans_index<T> idx(sample_vectors.num_rows(), n_clusters, max_iter, tol.value_or(0.0001), nthreads, seed); | ||
idx.train(sample_vectors, init_val); | ||
return std::move(idx.get_centroids()); | ||
}); | ||
// TODO: support verbose and n_init | ||
std::ignore = verbose; | ||
std::ignore = n_init; | ||
kmeans_init init_val; | ||
if (init == "k-means++") { | ||
init_val = kmeans_init::kmeanspp; | ||
} else if (init == "random") { | ||
init_val = kmeans_init::random; | ||
} else { | ||
throw std::invalid_argument("Invalid init method"); | ||
} | ||
kmeans_index<T> idx(sample_vectors.num_rows(), n_clusters, max_iter, tol.value_or(0.0001), nthreads, seed); | ||
idx.train(sample_vectors, init_val); | ||
return std::move(idx.get_centroids()); | ||
}); | ||
|
||
m.def(("kmeans_predict_" + suffix).c_str(), | ||
[](const ColMajorMatrix<T>& centroids, | ||
const ColMajorMatrix<T>& sample_vectors) { | ||
return kmeans_index<T>::predict(centroids, sample_vectors); | ||
}); | ||
[](const ColMajorMatrix<T> ¢roids, | ||
const ColMajorMatrix<T> &sample_vectors) { | ||
return kmeans_index<T>::predict(centroids, sample_vectors); | ||
}); | ||
} | ||
|
||
} // anonymous namespace | ||
}// anonymous namespace | ||
|
||
|
||
void init_kmeans(py::module_& m) { | ||
void init_kmeans(py::module_ &m) { | ||
declare_kmeans<float>(m, "f32"); | ||
} |
Oops, something went wrong.