Skip to content

Commit

Permalink
Enforce queries has two dimensions (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan authored Dec 14, 2023
1 parent b8c68d3 commit 15f83d8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 49 deletions.
4 changes: 2 additions & 2 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def __init__(
self.thread_executor = futures.ThreadPoolExecutor()

def query(self, queries: np.ndarray, k, **kwargs):
if queries.ndim != 1 and queries.ndim != 2:
raise TypeError(f"Expected queries to have either 1 or 2 dimensions (i.e. [...] or [[...], [...]]), but it had {queries.ndim} dimensions")
if queries.ndim != 2:
raise TypeError(f"Expected queries to have 2 dimensions (i.e. [[...], etc.]), but it had {queries.ndim} dimensions")

query_dimensions = queries.shape[0] if queries.ndim == 1 else queries.shape[1]
if query_dimensions != self.get_dimensions():
Expand Down
37 changes: 2 additions & 35 deletions apis/python/test/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,14 @@ def test_index_with_incorrect_dimensions(tmp_path):
# Wrong number of dimensions will raise a TypeError.
with pytest.raises(TypeError):
index.query(np.array(1, dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([[[1, 1, 1]]], dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([[[[1, 1, 1]]]], dtype=np.float32), k=3)

# Okay otherwise.
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
index.query(np.array([[1, 1, 1]], dtype=np.float32), k=3)

def test_index_with_incorrect_num_of_query_columns_simple(tmp_path):
Expand Down Expand Up @@ -156,37 +157,3 @@ def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
else:
with pytest.raises(TypeError):
index.query(query, k=1)

# TODO(paris): This will throw with the following error. Fix and re-enable, then remove
# test_index_with_incorrect_num_of_query_columns_in_single_vector_query:
# def array_to_matrix(array: np.ndarray):
# if array.dtype == np.float32:
# > return pyarray_copyto_matrix_f32(array)
# E RuntimeError: Number of dimensions must be two
# Here we test with a query which is just a vector, i.e. [1, 2, 3].
# query = query[0]
# if num_columns_for_query == num_columns:
# index.query(query, k=1)
# else:
# with pytest.raises(TypeError):
# index.query(query, k=1)

def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_path):
# Tests that we raise a TypeError if the number of columns in the query is not the same as the
# number of columns in the indexed data, specifically for a single vector query.
# i.e. queries = [1, 2, 3] instead of queries = [[1, 2, 3], [4, 5, 6]].
indexes = [flat_index, ivf_flat_index]
for index_type in indexes:
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))

# Wrong number of columns will raise a TypeError.
with pytest.raises(TypeError):
index.query(np.array([1], dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([1, 1], dtype=np.float32), k=3)
with pytest.raises(TypeError):
index.query(np.array([1, 1, 1, 1], dtype=np.float32), k=3)

# Okay otherwise.
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
12 changes: 0 additions & 12 deletions apis/python/test/test_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# Test single query vector handling
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY

index_ram = IVFFlatIndex(uri=index_uri)
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
Expand Down Expand Up @@ -242,10 +238,6 @@ def test_ivf_flat_ingestion_numpy(tmp_path):
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# Test single query vector handling
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY

index_ram = IVFFlatIndex(uri=index_uri)
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
Expand Down Expand Up @@ -286,10 +278,6 @@ def test_ivf_flat_ingestion_multiple_workers(tmp_path):
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# Test single query vector handling
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY

index_ram = IVFFlatIndex(uri=index_uri)
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
Expand Down

0 comments on commit 15f83d8

Please sign in to comment.