Skip to content

Commit

Permalink
pr feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan committed Dec 12, 2023
1 parent 80f42ca commit cf48096
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions apis/python/src/tiledb/vector_search/flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
+ self.index_version
].uri
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
self.expected_query_columns = schema.shape[0]
self.expected_query_dimensions = schema.shape[0]
if self.base_size == -1:
self.size = schema.domain.dim(1).domain[1] + 1
else:
Expand Down Expand Up @@ -76,8 +76,8 @@ def __init__(
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
)

def get_expected_query_columns(self):
return self.expected_query_columns
def get_expected_query_dimensions(self):
return self.expected_query_dimensions

def query_internal(
self,
Expand Down
8 changes: 4 additions & 4 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def query(self, queries: np.ndarray, k, **kwargs):
if queries.ndim != 1 and queries.ndim != 2:
raise TypeError(f"Expected queries to have either 1 or 2 dimensions (i.e. [...] or [[...], [...]]), but it had {queries.ndim} dimensions")

query_columns = queries.shape[0] if queries.ndim == 1 else queries.shape[1]
if query_columns != self.get_expected_query_columns():
raise TypeError(f"A query in queries has {query_columns} columns, but the indexed data had {self.expected_query_columns} columns")
query_dimensions = queries.shape[0] if queries.ndim == 1 else queries.shape[1]
if query_dimensions != self.get_expected_query_dimensions():
raise TypeError(f"A query in queries has {query_dimensions} dimensions, but the indexed data had {self.expected_query_dimensions} dimensions")

with tiledb.scope_ctx(ctx_or_config=self.config):
if not tiledb.array_exists(self.updates_array_uri):
Expand Down Expand Up @@ -260,7 +260,7 @@ def read_additions(
else:
return None, None, updated_ids

def get_expected_query_columns(self):
def get_expected_query_dimensions(self):
raise NotImplementedError

def query_internal(self, queries: np.ndarray, k, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions apis/python/src/tiledb/vector_search/ivf_flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.memory_budget = memory_budget

schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
self.expected_query_columns = schema.shape[0]
self.expected_query_dimensions = schema.shape[0]

self.dtype = self.group.meta.get("dtype", None)
if self.dtype is None:
Expand Down Expand Up @@ -122,8 +122,8 @@ def __init__(
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
)

def get_expected_query_columns(self):
return self.expected_query_columns
def get_expected_query_dimensions(self):
return self.expected_query_dimensions

def query_internal(
self,
Expand Down
10 changes: 5 additions & 5 deletions apis/python/test/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,19 @@ def test_index_with_incorrect_dimensions(tmp_path):
indexes = [flat_index, ivf_flat_index]
for index_type in indexes:
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
index = index_type.create(uri=uri, dimensions=1, vector_type=np.dtype(np.uint8))
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))

# Wrong number of dimensions will raise a TypeError.
with pytest.raises(TypeError):
index.query(np.array(1, dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([[[1]]], dtype=np.float32), k=3)
index.query(np.array([[[1, 1, 1]]], dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([[[[1]]]], dtype=np.float32), k=3)
index.query(np.array([[[[1, 1, 1]]]], dtype=np.float32), k=3)

# Okay otherwise.
index.query(np.array([1], dtype=np.float32), k=3)
index.query(np.array([[1]], dtype=np.float32), k=3)
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
index.query(np.array([[1, 1, 1]], dtype=np.float32), k=3)

def test_index_with_incorrect_num_of_query_columns_simple(tmp_path):
siftsmall_uri = "test/data/siftsmall/siftsmall_base.fvecs"
Expand Down

0 comments on commit cf48096

Please sign in to comment.