Skip to content

Commit

Permalink
Merge pull request #156 from TileDB-Inc/npapa/fix-ingestion
Browse files Browse the repository at this point in the history
Fix corner cases for ingestion
  • Loading branch information
NikolaosPapailiou authored Nov 29, 2023
2 parents bcfdaa1 + a331d3b commit 2ad1ae4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 22 deletions.
79 changes: 57 additions & 22 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def ingest(
training_sample_size: int = -1,
workers: int = -1,
input_vectors_per_work_item: int = -1,
max_tasks_per_stage: int= -1,
storage_version: str = STORAGE_VERSION,
verbose: bool = False,
trace_id: Optional[str] = None,
Expand Down Expand Up @@ -83,6 +84,9 @@ def ingest(
input_vectors_per_work_item: int = -1
number of vectors per ingestion work item,
if not provided, is auto-configured
max_tasks_per_stage: int = -1
Max number of tasks per execution stage of ingestion,
if not provided, is auto-configured
storage_version: str
Vector index storage format version.
verbose: bool
Expand Down Expand Up @@ -348,7 +352,7 @@ def create_arrays(
index_type: str,
size: int,
dimensions: int,
input_vectors_work_tasks: int,
input_vectors_work_items: int,
vector_type: np.dtype,
logger: logging.Logger,
) -> None:
Expand Down Expand Up @@ -475,10 +479,7 @@ def create_arrays(
partial_write_array_group.add(
partial_write_array_parts_uri, name=PARTS_ARRAY_NAME
)
partial_write_arrays = input_vectors_work_tasks
if updates_uri is not None:
partial_write_arrays += 1
for part in range(partial_write_arrays):
for part in range(input_vectors_work_items):
part_index_uri = partial_write_array_index_uri + "/" + str(part)
if not tiledb.array_exists(part_index_uri):
logger.debug(f"Creating part array {part_index_uri}")
Expand All @@ -505,6 +506,33 @@ def create_arrays(
logger.debug(index_schema)
tiledb.Array.create(part_index_uri, index_schema)
partial_write_array_index_group.add(part_index_uri, name=str(part))
if updates_uri is not None:
part_index_uri = partial_write_array_index_uri + "/additions"
if not tiledb.array_exists(part_index_uri):
logger.debug(f"Creating part array {part_index_uri}")
index_array_rows_dim = tiledb.Dim(
name="rows",
domain=(0, partitions),
tile=partitions,
dtype=np.dtype(np.int32),
)
index_array_dom = tiledb.Domain(index_array_rows_dim)
index_attr = tiledb.Attr(
name="values",
dtype=np.dtype(np.uint64),
filters=DEFAULT_ATTR_FILTERS,
)
index_schema = tiledb.ArraySchema(
domain=index_array_dom,
sparse=False,
attrs=[index_attr],
capacity=partitions,
cell_order="col-major",
tile_order="col-major",
)
logger.debug(index_schema)
tiledb.Array.create(part_index_uri, index_schema)
partial_write_array_index_group.add(part_index_uri, name="additions")
partial_write_array_group.close()
partial_write_array_index_group.close()

Expand Down Expand Up @@ -1098,7 +1126,7 @@ def ingest_vectors_udf(
part_name = str(part) + "-" + str(part_end)

partial_write_array_index_uri = partial_write_array_index_group[
str(int(start / batch))
str(int(part / batch))
].uri
logger.debug("Input vectors start_pos: %d, end_pos: %d", part, part_end)
updated_ids = read_updated_ids(
Expand Down Expand Up @@ -1203,7 +1231,6 @@ def ingest_additions_udf(
updates_uri: str,
vector_type: np.dtype,
write_offset: int,
task_id: int,
threads: int,
config: Optional[Mapping[str, Any]] = None,
verbose: bool = False,
Expand All @@ -1228,7 +1255,7 @@ def ingest_additions_udf(
partial_write_array_index_dir_uri
)
partial_write_array_index_uri = partial_write_array_index_group[
str(task_id)
"additions"
].uri
additions_vectors, additions_external_ids = read_additions(
updates_uri=updates_uri,
Expand Down Expand Up @@ -1678,7 +1705,6 @@ def create_ingestion_dag(
updates_uri=updates_uri,
vector_type=vector_type,
write_offset=size,
task_id=task_id,
threads=threads,
config=config,
verbose=verbose,
Expand Down Expand Up @@ -1744,15 +1770,22 @@ def consolidate_and_vacuum(
tiledb.vacuum(parts_uri, config=conf)
tiledb.consolidate(ids_uri, config=conf)
tiledb.vacuum(ids_uri, config=conf)
group.close()

# TODO remove temp data for tiledb URIs
if not index_group_uri.startswith("tiledb://"):
vfs = tiledb.VFS(config)
partial_write_array_dir_uri = (
index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
)
if vfs.is_dir(partial_write_array_dir_uri):
vfs.remove_dir(partial_write_array_dir_uri)
group = tiledb.Group(index_group_uri, "r")
if PARTIAL_WRITE_ARRAY_DIR in group:
group.close()
group = tiledb.Group(index_group_uri, "w")
group.remove(PARTIAL_WRITE_ARRAY_DIR)
vfs = tiledb.VFS(config)
partial_write_array_dir_uri = (
index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
)
if vfs.is_dir(partial_write_array_dir_uri):
vfs.remove_dir(partial_write_array_dir_uri)
group.close()

# --------------------------------------------------------------------
# End internal function definitions
Expand Down Expand Up @@ -1852,11 +1885,13 @@ def consolidate_and_vacuum(
input_vectors_work_items = int(math.ceil(size / input_vectors_per_work_item))
input_vectors_work_tasks = input_vectors_work_items
input_vectors_work_items_per_worker = 1
if input_vectors_work_tasks > MAX_TASKS_PER_STAGE:
if max_tasks_per_stage == -1:
max_tasks_per_stage = MAX_TASKS_PER_STAGE
if input_vectors_work_tasks > max_tasks_per_stage:
input_vectors_work_items_per_worker = int(
math.ceil(input_vectors_work_items / MAX_TASKS_PER_STAGE)
math.ceil(input_vectors_work_items / max_tasks_per_stage)
)
input_vectors_work_tasks = MAX_TASKS_PER_STAGE
input_vectors_work_tasks = max_tasks_per_stage
logger.debug("input_vectors_per_work_item %d", input_vectors_per_work_item)
logger.debug("input_vectors_work_items %d", input_vectors_work_items)
logger.debug("input_vectors_work_tasks %d", input_vectors_work_tasks)
Expand All @@ -1875,11 +1910,11 @@ def consolidate_and_vacuum(
)
table_partitions_work_tasks = table_partitions_work_items
table_partitions_work_items_per_worker = 1
if table_partitions_work_tasks > MAX_TASKS_PER_STAGE:
if table_partitions_work_tasks > max_tasks_per_stage:
table_partitions_work_items_per_worker = int(
math.ceil(table_partitions_work_items / MAX_TASKS_PER_STAGE)
math.ceil(table_partitions_work_items / max_tasks_per_stage)
)
table_partitions_work_tasks = MAX_TASKS_PER_STAGE
table_partitions_work_tasks = max_tasks_per_stage
logger.debug(
"table_partitions_per_work_item %d", table_partitions_per_work_item
)
Expand All @@ -1897,7 +1932,7 @@ def consolidate_and_vacuum(
index_type=index_type,
size=size,
dimensions=dimensions,
input_vectors_work_tasks=input_vectors_work_tasks,
input_vectors_work_items=input_vectors_work_items,
vector_type=vector_type,
logger=logger,
)
Expand Down
45 changes: 45 additions & 0 deletions apis/python/test/test_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,51 @@ def test_ivf_flat_ingestion_numpy(tmp_path):
assert accuracy(result, gt_i) > MINIMUM_ACCURACY


def test_ivf_flat_ingestion_multiple_workers(tmp_path):
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
gt_uri = "test/data/siftsmall/siftsmall_groundtruth.ivecs"
index_uri = os.path.join(tmp_path, "array")
k = 100
partitions = 100
nqueries = 100
nprobe = 20

query_vectors = load_fvecs(queries_uri)
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)

index = ingest(
index_type="IVF_FLAT",
index_uri=index_uri,
source_uri=source_uri,
partitions=partitions,
input_vectors_per_work_item=421,
max_tasks_per_stage=4,
)
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# Test single query vector handling
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY

index_ram = IVFFlatIndex(uri=index_uri)
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

_, result = index_ram.query(
query_vectors,
k=k,
nprobe=nprobe,
use_nuv_implementation=True,
)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# NB: local mode currently does not return distances
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY


def test_ivf_flat_ingestion_external_ids_numpy(tmp_path):
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
Expand Down

0 comments on commit 2ad1ae4

Please sign in to comment.