Skip to content

Commit

Permalink
Sync C++ kmeans branch with main and fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ihnorton committed Nov 14, 2023
1 parent e668ec5 commit 52fa45d
Show file tree
Hide file tree
Showing 28 changed files with 1,886 additions and 781 deletions.
9 changes: 4 additions & 5 deletions apis/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[project]
name = "tiledb-vector-search"
version = "0.0.14"
#dynamic = ["version"]
dynamic = ["version"]
description = "TileDB Vector Search Python client"
license = { text = "MIT" }
readme = "README.md"
Expand All @@ -19,8 +18,8 @@ classifiers = [
]

dependencies = [
"tiledb-cloud>=0.10.5",
"tiledb>=0.15.2",
"tiledb-cloud>=0.11",
"tiledb>=0.23.1",
"typing-extensions", # for tiledb-cloud indirect, x-ref https://github.com/TileDB-Inc/TileDB-Cloud-Py/pull/428
"scikit-learn",
]
Expand All @@ -47,7 +46,7 @@ zip-safe = false

[tool.setuptools_scm]
root = "../.."
#write_to = "apis/python/src/tiledb/vector_search/version.py"
write_to = "apis/python/src/tiledb/vector_search/version.py"

[tool.ruff]
extend-select = ["I"]
Expand Down
4 changes: 2 additions & 2 deletions apis/python/requirements-py.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy==1.24.3
tiledb-cloud==0.10.5
tiledb==0.21.3
tiledb-cloud==0.10.24
tiledb==0.23.1
32 changes: 14 additions & 18 deletions apis/python/src/tiledb/vector_search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
# Re-import mode from cloud.dag
from tiledb.cloud.dag.mode import Mode

from . import utils
from .index import Index
from .ivf_flat_index import IVFFlatIndex
from .flat_index import FlatIndex
from .index import Index
from .ingestion import ingest
from .storage_formats import storage_formats, STORAGE_VERSION
from .module import load_as_array
from .module import load_as_matrix
from .module import (
query_vq_heap,
query_vq_nth,
ivf_query,
ivf_query_ram,
validate_top_k,
array_to_matrix,
ivf_index,
ivf_index_tdb,
partition_ivf_index,
)
from .ivf_flat_index import IVFFlatIndex
from .module import (array_to_matrix, ivf_index, ivf_index_tdb, ivf_query,
ivf_query_ram, load_as_array, load_as_matrix,
partition_ivf_index, query_vq_heap, query_vq_nth,
validate_top_k)
from .storage_formats import STORAGE_VERSION, storage_formats

# Re-import mode from cloud.dag
from tiledb.cloud.dag.mode import Mode
try:
from tiledb.vector_search.version import version as __version__
except ImportError:
__version__ = "0.0.0.local"

__all__ = [
"FlatIndex",
Expand Down
160 changes: 132 additions & 28 deletions apis/python/src/tiledb/vector_search/flat_index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import json
from typing import Any, Mapping

import numpy as np

from tiledb.vector_search import index
from tiledb.vector_search.module import *
from tiledb.vector_search.storage_formats import storage_formats
from tiledb.vector_search.index import Index
from typing import Any, Mapping
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
storage_formats)

MAX_INT32 = np.iinfo(np.dtype("int32")).max
TILE_SIZE_BYTES = 128000000 # 128MB
INDEX_TYPE = "FLAT"

class FlatIndex(Index):

class FlatIndex(index.Index):
"""
Open a flat index
Expand All @@ -22,36 +29,51 @@ def __init__(
self,
uri: str,
config: Optional[Mapping[str, Any]] = None,
timestamp=None,
**kwargs,
):
super().__init__(uri=uri, config=config)
self.index_type = "FLAT"
super().__init__(uri=uri, config=config, timestamp=timestamp)
self.index_type = INDEX_TYPE
self._index = None
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version].uri
schema = tiledb.ArraySchema.load(
self.db_uri, ctx=tiledb.Ctx(self.config)
)
self.size = schema.domain.dim(1).domain[1]+1
self._db = load_as_matrix(
self.db_uri,
ctx=self.ctx,
config=config,
)
self.db_uri = self.group[
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
+ self.index_version
].uri
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
if self.base_size == -1:
self.size = schema.domain.dim(1).domain[1] + 1
else:
self.size = self.base_size

# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
# that the external_ids were the position of the vector in the array.
if storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version in self.group:
self.dtype = np.dtype(self.group.meta.get("dtype", None))
if (
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
in self.group
):
self.ids_uri = self.group[
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
+ self.index_version
].uri
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
else:
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))

dtype = self.group.meta.get("dtype", None)
if dtype is None:
self.dtype = self._db.dtype
else:
self.dtype = np.dtype(dtype)
self.ids_uri = ""
if self.size > 0:
self._db = load_as_matrix(
self.db_uri,
ctx=self.ctx,
config=config,
size=self.size,
timestamp=self.base_array_timestamp,
)
if self.dtype is None:
self.dtype = self._db.dtype
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
# that the external_ids were the position of the vector in the array.
if self.ids_uri == "":
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
else:
self._ids = read_vector_u64(
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
)

def query_internal(
self,
Expand All @@ -74,10 +96,92 @@ def query_internal(
# TODO:
# - typecheck queries
# - add all the options and query strategies
if self.size == 0:
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
(queries.shape[0], k), index.MAX_UINT64
)

assert queries.dtype == np.float32

queries_m = array_to_matrix(np.transpose(queries))
d, i = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)

return np.transpose(np.array(d)), np.transpose(np.array(i))


def create(
uri: str,
dimensions: int,
vector_type: np.dtype,
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
**kwargs,
) -> FlatIndex:
index.create_metadata(
uri=uri,
dimensions=dimensions,
vector_type=vector_type,
index_type=INDEX_TYPE,
group_exists=group_exists,
config=config,
)
with tiledb.scope_ctx(ctx_or_config=config):
group = tiledb.Group(uri, "w")
tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
ids_array_name = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
parts_array_name = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
ids_uri = f"{uri}/{ids_array_name}"
parts_uri = f"{uri}/{parts_array_name}"

ids_array_rows_dim = tiledb.Dim(
name="rows",
domain=(0, MAX_INT32),
tile=tile_size,
dtype=np.dtype(np.int32),
)
ids_array_dom = tiledb.Domain(ids_array_rows_dim)
ids_attr = tiledb.Attr(
name="values",
dtype=np.dtype(np.uint64),
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
)
ids_schema = tiledb.ArraySchema(
domain=ids_array_dom,
sparse=False,
attrs=[ids_attr],
cell_order="col-major",
tile_order="col-major",
)
tiledb.Array.create(ids_uri, ids_schema)
group.add(ids_uri, name=ids_array_name)

parts_array_rows_dim = tiledb.Dim(
name="rows",
domain=(0, dimensions - 1),
tile=dimensions,
dtype=np.dtype(np.int32),
)
parts_array_cols_dim = tiledb.Dim(
name="cols",
domain=(0, MAX_INT32),
tile=tile_size,
dtype=np.dtype(np.int32),
)
parts_array_dom = tiledb.Domain(parts_array_rows_dim, parts_array_cols_dim)
parts_attr = tiledb.Attr(
name="values",
dtype=vector_type,
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
)
parts_schema = tiledb.ArraySchema(
domain=parts_array_dom,
sparse=False,
attrs=[parts_attr],
cell_order="col-major",
tile_order="col-major",
)
tiledb.Array.create(parts_uri, parts_schema)
group.add(parts_uri, name=parts_array_name)

group.close()
return FlatIndex(uri=uri, config=config)
Loading

0 comments on commit 52fa45d

Please sign in to comment.