Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ k-means implementation #130

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
24f5bb6
Default the shuffled ID and index types of `kmeans_index` to `size_t`.
teo-tsirpanis Aug 8, 2023
8fe7850
Enable the k-means initialization tests.
teo-tsirpanis Aug 9, 2023
23bf690
Support specifying the seed when creating a `kmeans_index`.
teo-tsirpanis Aug 9, 2023
ce8745f
Avoid randomly choosing the same centroid many times.
teo-tsirpanis Aug 10, 2023
0674917
Apply some fixes to the superbuild CMake file from Core.
teo-tsirpanis Aug 23, 2023
960e036
Add default values for tolerance and number of threads in `kmeans_ind…
teo-tsirpanis Aug 23, 2023
c37f20c
Start writing the Python kmeans APIs in a separate file.
teo-tsirpanis Aug 24, 2023
87983b4
Set internal linkage to some utility functions.
teo-tsirpanis Aug 25, 2023
8e33ac4
Fix more duplicate symbol errors.
teo-tsirpanis Aug 25, 2023
43ef100
Add a kmeans predict function.
teo-tsirpanis Aug 28, 2023
115b8f2
Train the kmeans index in the Python wrapper.
teo-tsirpanis Aug 29, 2023
c773e2a
Use kmeans_fit in the ingestion code instead of sklearn.
teo-tsirpanis Aug 29, 2023
455ca20
Fix compile errors and a warning.
teo-tsirpanis Aug 29, 2023
fcf88f3
More refactorings and use `array_to_matrix`.
teo-tsirpanis Aug 29, 2023
f35f100
Fix errors in the ingestion.
teo-tsirpanis Aug 30, 2023
239a753
Improve a test and diagnostic output.
teo-tsirpanis Aug 30, 2023
66de269
Always use floats to train kmeans.
teo-tsirpanis Aug 30, 2023
fc5c0cf
Add more parameters to `kmeans_fit`.
teo-tsirpanis Aug 31, 2023
2879be9
Add a test that compares the results of sklearn's and our own kmeans …
teo-tsirpanis Sep 4, 2023
94643ce
Use kmeans_predict instead of sklearn. This removes the sklearn depen…
teo-tsirpanis Sep 4, 2023
a8f1679
Merge remote-tracking branch 'origin/main' into teo/kmeans
teo-tsirpanis Sep 4, 2023
bd9e702
Merge branch 'main' into teo/kmeans
teo-tsirpanis Sep 14, 2023
45f2852
Use common options across sklearn's and our kmeans implementations.
teo-tsirpanis Sep 14, 2023
b307de5
Rename `kmeans++` to `k-means++` to match sklearn.
teo-tsirpanis Sep 14, 2023
584d548
Assert that the score of the our kmeans implementation is smaller tha…
teo-tsirpanis Sep 14, 2023
a7da424
fix transposed args in kmeans.cc -- add unit test [skip ci]
lums658 Sep 14, 2023
8527303
Test both kmeans++ and random initialization.
teo-tsirpanis Sep 15, 2023
6575791
Fix formatting and delete commented code.
teo-tsirpanis Sep 15, 2023
34ddcb5
Make the kmeans test more deterministic.
teo-tsirpanis Sep 15, 2023
8769d04
Add back the asserts.
teo-tsirpanis Sep 15, 2023
ef38b0b
Add an opt-in switch to use sklearn's kmeans implementation.
teo-tsirpanis Sep 15, 2023
697c481
Parameterize min heap with comparison function [skip ci]
lums658 Sep 17, 2023
e5a797a
Debug zero cluster fix [skip ci]
lums658 Sep 17, 2023
d085f66
Uncomment debug statements [skip ci]
lums658 Sep 17, 2023
0f73373
Merge branch 'main' into teo/kmeans
teo-tsirpanis Oct 4, 2023
6b07f17
Initial partition-equalization
lums658 Oct 5, 2023
9867c90
Updates for kmeans and kmeans++
lums658 Oct 5, 2023
ff71e87
Small update
lums658 Oct 6, 2023
9176a54
clean up warnings, clang-format
lums658 Oct 6, 2023
e5e3690
Add documentation, update unit tests
lums658 Oct 6, 2023
3d3c712
Post merge from origin
lums658 Oct 6, 2023
cddc05a
Revert "Add documentation, update unit tests"
lums658 Oct 6, 2023
682ac00
Revert "Post merge from origin"
lums658 Oct 6, 2023
03694b2
Revert "Revert "Add documentation, update unit tests""
lums658 Oct 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches:
- release-*
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string
tags:
- '*'
pull_request:
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN conda config --prepend channels conda-forge
# Install mamba for faster installations
RUN conda install mamba

RUN mamba install -y -c tiledb 'tiledb>=2.17,<2.18' tiledb-py cmake pybind11 pytest c-compiler cxx-compiler ninja openblas-devel "pip>22"
RUN mamba install -y -c tiledb 'tiledb>=2.16,<2.17' tiledb-py cmake pybind11 pytest c-compiler cxx-compiler ninja openblas-devel "pip>22"

COPY . TileDB-Vector-Search/

Expand Down
5 changes: 4 additions & 1 deletion apis/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ find_package(pybind11 CONFIG REQUIRED)

set(VSPY_TARGET_NAME _tiledbvspy)

python_add_library(${VSPY_TARGET_NAME} MODULE "src/tiledb/vector_search/module.cc" WITH_SOABI)
python_add_library(${VSPY_TARGET_NAME} MODULE
"src/tiledb/vector_search/module.cc"
"src/tiledb/vector_search/kmeans.cc"
WITH_SOABI)

target_link_libraries(${VSPY_TARGET_NAME}
PRIVATE
Expand Down
2 changes: 1 addition & 1 deletion apis/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tiledb-vector-search"
version = "0.0.14"
version = "0.0.10"
#dynamic = ["version"]
description = "TileDB Vector Search Python client"
license = { text = "MIT" }
Expand Down
136 changes: 49 additions & 87 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import concurrent.futures as futures
import os
import numpy as np
import sys

Expand All @@ -22,7 +20,6 @@ class Index:
config: Optional[Mapping[str, Any]]
config dictionary, defaults to None
"""

def __init__(
self,
uri: str,
Expand All @@ -39,28 +36,16 @@ def __init__(
self.storage_version = self.group.meta.get("storage_version", "0.1")
self.update_arrays_uri = None
self.index_version = self.group.meta.get("index_version", "")
self.thread_executor = futures.ThreadPoolExecutor()


def query(self, queries: np.ndarray, k, **kwargs):
updated_ids = set(self.read_updated_ids())
retrieval_k = k
if len(updated_ids) > 0:
retrieval_k = 2*k
internal_results_d, internal_results_i = self.query_internal(queries, retrieval_k, **kwargs)
if self.update_arrays_uri is None:
return self.query_internal(queries, k, **kwargs)

# Query with updates
# Perform the queries in parallel
retrieval_k = 2 * k
kwargs["nthreads"] = int(os.cpu_count() / 2)
future = self.thread_executor.submit(
Index.query_additions,
queries,
k,
self.dtype,
self.update_arrays_uri,
int(os.cpu_count() / 2),
)
internal_results_d, internal_results_i = self.query_internal(
queries, retrieval_k, **kwargs
)
addition_results_d, addition_results_i, updated_ids = future.result()
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]

# Filter updated vectors
query_id = 0
Expand All @@ -70,137 +55,119 @@ def query(self, queries: np.ndarray, k, **kwargs):
if res in updated_ids:
internal_results_d[query_id, res_id] = MAX_FLOAT_32
internal_results_i[query_id, res_id] = MAX_UINT64
if (
internal_results_d[query_id, res_id] == 0
and internal_results_i[query_id, res_id] == 0
):
internal_results_d[query_id, res_id] = MAX_FLOAT_32
internal_results_i[query_id, res_id] = MAX_UINT64
res_id += 1
query_id += 1
sort_index = np.argsort(internal_results_d, axis=1)
internal_results_d = np.take_along_axis(internal_results_d, sort_index, axis=1)
internal_results_i = np.take_along_axis(internal_results_i, sort_index, axis=1)

# Merge update results
addition_results_d, addition_results_i = self.query_additions(queries, k)
if addition_results_d is None:
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]

query_id = 0
for query in addition_results_d:
res_id = 0
for res in query:
if (
addition_results_d[query_id, res_id] == 0
and addition_results_i[query_id, res_id] == 0
):
if addition_results_d[query_id, res_id] == 0 and addition_results_i[query_id, res_id] == 0:
addition_results_d[query_id, res_id] = MAX_FLOAT_32
addition_results_i[query_id, res_id] = MAX_UINT64
res_id += 1
query_id += 1


results_d = np.hstack((internal_results_d, addition_results_d))
results_i = np.hstack((internal_results_i, addition_results_i))
sort_index = np.argsort(results_d, axis=1)
results_d = np.take_along_axis(results_d, sort_index, axis=1)
results_i = np.take_along_axis(results_i, sort_index, axis=1)
return results_d[:, 0:k], results_i[:, 0:k]

@staticmethod
def query_additions(
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
):
def query_internal(self, queries: np.ndarray, k, **kwargs):
raise NotImplementedError

def query_additions(self, queries: np.ndarray, k):
assert queries.dtype == np.float32
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
update_arrays_uri
)
additions_vectors, additions_external_ids = self.read_additions()
if additions_vectors is None:
return None, None, updated_ids

return None, None
queries_m = array_to_matrix(np.transpose(queries))
d, i = query_vq_heap_pyarray(
array_to_matrix(np.transpose(additions_vectors).astype(dtype)),
array_to_matrix(np.transpose(additions_vectors).astype(self.dtype)),
queries_m,
StdVector_u64(additions_external_ids),
k,
nthreads,
)
return np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids

@staticmethod
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
if update_arrays_uri is None:
return None, None, np.array([], np.uint64)
updates_array = tiledb.open(update_arrays_uri, mode="r")
q = updates_array.query(attrs=("vector",), coords=True)
data = q[:]
updates_array.close()
updated_ids = data["external_id"]
additions_filter = [len(item) > 0 for item in data["vector"]]
if len(data["external_id"][additions_filter]) > 0:
return (
np.vstack(data["vector"][additions_filter]),
data["external_id"][additions_filter],
updated_ids
)
else:
return None, None, updated_ids

def query_internal(self, queries: np.ndarray, k, **kwargs):
raise NotImplementedError
8)
return np.transpose(np.array(d)), np.transpose(np.array(i))

def update(self, vector: np.array, external_id: np.uint64):
updates_array = self.open_updates_array()
vectors = np.empty((1), dtype="O")
vectors = np.empty((1), dtype='O')
vectors[0] = vector
updates_array[external_id] = {"vector": vectors}
updates_array[external_id] = {'vector': vectors}
updates_array.close()
self.consolidate_update_fragments()

def update_batch(self, vectors: np.ndarray, external_ids: np.array):
updates_array = self.open_updates_array()
updates_array[external_ids] = {"vector": vectors}
updates_array[external_ids] = {'vector': vectors}
updates_array.close()
self.consolidate_update_fragments()

def delete(self, external_id: np.uint64):
updates_array = self.open_updates_array()
deletes = np.empty((1), dtype="O")
deletes = np.empty((1), dtype='O')
deletes[0] = np.array([], dtype=self.dtype)
updates_array[external_id] = {"vector": deletes}
updates_array[external_id] = {'vector': deletes}
updates_array.close()
self.consolidate_update_fragments()

def delete_batch(self, external_ids: np.array):
updates_array = self.open_updates_array()
deletes = np.empty((len(external_ids)), dtype="O")
deletes = np.empty((len(external_ids)), dtype='O')
for i in range(len(external_ids)):
deletes[i] = np.array([], dtype=self.dtype)
updates_array[external_ids] = {"vector": deletes}
updates_array[external_ids] = {'vector': deletes}
updates_array.close()
self.consolidate_update_fragments()

def consolidate_update_fragments(self):
fragments_info = tiledb.array_fragments(self.update_arrays_uri)
if len(fragments_info) > 10:
if(len(fragments_info) > 10):
tiledb.consolidate(self.update_arrays_uri)
tiledb.vacuum(self.update_arrays_uri)

def get_updates_uri(self):
return self.update_arrays_uri

def read_additions(self) -> (np.ndarray, np.array):
if self.update_arrays_uri is None:
return None, None
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
q = updates_array.query(attrs=('vector',), coords=True)
data = q[:]
additions_filter = [len(item) > 0 for item in data["vector"]]
if len(data["external_id"][additions_filter]) > 0:
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
else:
return None, None
def read_updated_ids(self) -> np.array:
if self.update_arrays_uri is None:
return np.array([], np.uint64)
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
q = updates_array.query(attrs=('vector',), coords=True)
data = q[:]
return data["external_id"]

def open_updates_array(self):
if self.update_arrays_uri is None:
updates_array_name = storage_formats[self.storage_version][
"UPDATES_ARRAY_NAME"
]
updates_array_name = storage_formats[self.storage_version]["UPDATES_ARRAY_NAME"]
updates_array_uri = f"{self.group.uri}/{updates_array_name}"
if tiledb.array_exists(updates_array_uri):
raise RuntimeError(f"Array {updates_array_uri} already exists.")
external_id_dim = tiledb.Dim(
name="external_id",
domain=(0, MAX_UINT64 - 1),
dtype=np.dtype(np.uint64),
name="external_id", domain=(0, MAX_UINT64-1), dtype=np.dtype(np.uint64)
)
dom = tiledb.Domain(external_id_dim)
vector_attr = tiledb.Attr(name="vector", dtype=self.dtype, var=True)
Expand All @@ -221,18 +188,13 @@ def open_updates_array(self):

def consolidate_updates(self):
from tiledb.vector_search.ingestion import ingest

new_index = ingest(
index_type=self.index_type,
index_uri=self.uri,
size=self.size,
source_uri=self.db_uri,
external_ids_uri=self.ids_uri,
updates_uri=self.update_arrays_uri,
updates_uri=self.update_arrays_uri
)
tiledb.Array.delete_array(self.update_arrays_uri)
self.group.close()
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
self.group.remove(self.update_arrays_uri)
self.group.close()
return new_index
Loading