Skip to content

Commit

Permalink
Fix IVF PQ when used by object index
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan committed Jul 17, 2024
1 parent 2e5d929 commit 41d76c0
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 212 deletions.
3 changes: 3 additions & 0 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
The consolidation process is used to avoid query latency degradation as more updates
are added to the index. It triggers a base index re-indexing, merging the non-consolidated
updates and the rest of the base vectors.
TODO(sc-51202): This throws with a unintuitive error message if update()/delete()/etc. has
not been called.
Parameters
----------
Expand Down
98 changes: 53 additions & 45 deletions apis/python/test/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUpClass(cls):
cls.flat_index_uri = f"{test_path}/test_cloud_flat_index"
cls.vamana_index_uri = f"{test_path}/test_cloud_vamana_index"
cls.ivf_flat_index_uri = f"{test_path}/test_cloud_ivf_flat_index"
cls.ivf_pq_index_uri = f"{test_path}/test_cloud_ivf_pq_index"
cls.ivf_flat_random_sampling_index_uri = (
f"{test_path}/test_cloud_ivf_flat_random_sampling_index"
)
Expand All @@ -38,6 +39,7 @@ def tearDownClass(cls):
vs.Index.delete_index(uri=cls.flat_index_uri, config=tiledb.cloud.Config())
vs.Index.delete_index(uri=cls.vamana_index_uri, config=tiledb.cloud.Config())
vs.Index.delete_index(uri=cls.ivf_flat_index_uri, config=tiledb.cloud.Config())
vs.Index.delete_index(uri=cls.ivf_pq_index_uri, config=tiledb.cloud.Config())
vs.Index.delete_index(
uri=cls.ivf_flat_random_sampling_index_uri, config=tiledb.cloud.Config()
)
Expand All @@ -63,6 +65,7 @@ def run_cloud_test(self, index_uri, index_type, index_class):
input_vectors_per_work_item=5000,
config=tiledb.cloud.Config().dict(),
mode=Mode.BATCH,
num_subspaces=siftsmall_dimensions / 2,
)
tiledb_index_uri = groups.info(index_uri).tiledb_uri

Expand Down Expand Up @@ -156,53 +159,58 @@ def run_cloud_test(self, index_uri, index_type, index_class):
_, result_i = index.query(queries, k=k, nprobe=nprobe)
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY

def test_cloud_flat(self):
self.run_cloud_test(CloudTests.flat_index_uri, "FLAT", vs.flat_index.FlatIndex)
# def test_cloud_flat(self):
# self.run_cloud_test(CloudTests.flat_index_uri, "FLAT", vs.flat_index.FlatIndex)

def test_cloud_vamana(self):
self.run_cloud_test(
CloudTests.vamana_index_uri, "VAMANA", vs.vamana_index.VamanaIndex
)

def test_cloud_ivf_flat(self):
self.run_cloud_test(
CloudTests.ivf_flat_index_uri, "IVF_FLAT", vs.ivf_flat_index.IVFFlatIndex
)

def test_cloud_ivf_flat_random_sampling(self):
# NOTE(paris): This was also tested with the following (and also with mode=Mode.BATCH):
# source_uri = "tiledb://TileDB-Inc/ann_sift1b_raw_vectors_col_major"
# training_sample_size = 1000000
source_uri = "tiledb://TileDB-Inc/sift_10k"
queries_uri = siftsmall_query_file
gt_uri = siftsmall_groundtruth_file
index_uri = CloudTests.ivf_flat_random_sampling_index_uri
k = 100
nqueries = 100
nprobe = 20
max_sampling_tasks = 13
training_sample_size = 1234

queries = load_fvecs(queries_uri)
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)
# def test_cloud_vamana(self):
# self.run_cloud_test(
# CloudTests.vamana_index_uri, "VAMANA", vs.vamana_index.VamanaIndex
# )

index = vs.ingest(
index_type="IVF_FLAT",
index_uri=index_uri,
source_uri=source_uri,
training_sampling_policy=vs.ingestion.TrainingSamplingPolicy.RANDOM,
training_sample_size=training_sample_size,
max_sampling_tasks=max_sampling_tasks,
config=tiledb.cloud.Config().dict(),
mode=Mode.BATCH,
)
# def test_cloud_ivf_flat(self):
# self.run_cloud_test(
# CloudTests.ivf_flat_index_uri, "IVF_FLAT", vs.ivf_flat_index.IVFFlatIndex
# )

check_training_input_vectors(
index_uri=index_uri,
expected_training_sample_size=training_sample_size,
expected_dimensions=queries.shape[1],
config=tiledb.cloud.Config().dict(),
def test_cloud_ivf_pq(self):
self.run_cloud_test(
CloudTests.ivf_flat_index_uri, "IVF_PQ", vs.ivf_pq_index.IVFPQIndex
)

_, result_i = index.query(queries, k=k, nprobe=nprobe)
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
# def test_cloud_ivf_flat_random_sampling(self):
# # NOTE(paris): This was also tested with the following (and also with mode=Mode.BATCH):
# # source_uri = "tiledb://TileDB-Inc/ann_sift1b_raw_vectors_col_major"
# # training_sample_size = 1000000
# source_uri = "tiledb://TileDB-Inc/sift_10k"
# queries_uri = siftsmall_query_file
# gt_uri = siftsmall_groundtruth_file
# index_uri = CloudTests.ivf_flat_random_sampling_index_uri
# k = 100
# nqueries = 100
# nprobe = 20
# max_sampling_tasks = 13
# training_sample_size = 1234

# queries = load_fvecs(queries_uri)
# gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)

# index = vs.ingest(
# index_type="IVF_FLAT",
# index_uri=index_uri,
# source_uri=source_uri,
# training_sampling_policy=vs.ingestion.TrainingSamplingPolicy.RANDOM,
# training_sample_size=training_sample_size,
# max_sampling_tasks=max_sampling_tasks,
# config=tiledb.cloud.Config().dict(),
# mode=Mode.BATCH,
# )

# check_training_input_vectors(
# index_uri=index_uri,
# expected_training_sample_size=training_sample_size,
# expected_dimensions=queries.shape[1],
# config=tiledb.cloud.Config().dict(),
# )

# _, result_i = index.query(queries, k=k, nprobe=nprobe)
# assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
Loading

0 comments on commit 41d76c0

Please sign in to comment.