Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix corner cases for ingestion #156

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def ingest(
training_sample_size: int = -1,
workers: int = -1,
input_vectors_per_work_item: int = -1,
max_tasks_per_stage: int= -1,
storage_version: str = STORAGE_VERSION,
verbose: bool = False,
trace_id: Optional[str] = None,
Expand Down Expand Up @@ -83,6 +84,9 @@ def ingest(
input_vectors_per_work_item: int = -1
number of vectors per ingestion work item,
if not provided, is auto-configured
max_tasks_per_stage: int = -1
Max number of tasks per execution stage of ingestion,
if not provided, is auto-configured
storage_version: str
Vector index storage format version.
verbose: bool
Expand Down Expand Up @@ -348,7 +352,7 @@ def create_arrays(
index_type: str,
size: int,
dimensions: int,
input_vectors_work_tasks: int,
input_vectors_work_items: int,
vector_type: np.dtype,
logger: logging.Logger,
) -> None:
Expand Down Expand Up @@ -475,10 +479,7 @@ def create_arrays(
partial_write_array_group.add(
partial_write_array_parts_uri, name=PARTS_ARRAY_NAME
)
partial_write_arrays = input_vectors_work_tasks
if updates_uri is not None:
partial_write_arrays += 1
for part in range(partial_write_arrays):
for part in range(input_vectors_work_items):
part_index_uri = partial_write_array_index_uri + "/" + str(part)
if not tiledb.array_exists(part_index_uri):
logger.debug(f"Creating part array {part_index_uri}")
Expand All @@ -505,6 +506,33 @@ def create_arrays(
logger.debug(index_schema)
tiledb.Array.create(part_index_uri, index_schema)
partial_write_array_index_group.add(part_index_uri, name=str(part))
if updates_uri is not None:
part_index_uri = partial_write_array_index_uri + "/additions"
if not tiledb.array_exists(part_index_uri):
logger.debug(f"Creating part array {part_index_uri}")
index_array_rows_dim = tiledb.Dim(
name="rows",
domain=(0, partitions),
tile=partitions,
dtype=np.dtype(np.int32),
)
index_array_dom = tiledb.Domain(index_array_rows_dim)
index_attr = tiledb.Attr(
name="values",
dtype=np.dtype(np.uint64),
filters=DEFAULT_ATTR_FILTERS,
)
index_schema = tiledb.ArraySchema(
domain=index_array_dom,
sparse=False,
attrs=[index_attr],
capacity=partitions,
cell_order="col-major",
tile_order="col-major",
)
logger.debug(index_schema)
tiledb.Array.create(part_index_uri, index_schema)
partial_write_array_index_group.add(part_index_uri, name="additions")
partial_write_array_group.close()
partial_write_array_index_group.close()

Expand Down Expand Up @@ -1098,7 +1126,7 @@ def ingest_vectors_udf(
part_name = str(part) + "-" + str(part_end)

partial_write_array_index_uri = partial_write_array_index_group[
str(int(start / batch))
str(int(part / batch))
].uri
logger.debug("Input vectors start_pos: %d, end_pos: %d", part, part_end)
updated_ids = read_updated_ids(
Expand Down Expand Up @@ -1203,7 +1231,6 @@ def ingest_additions_udf(
updates_uri: str,
vector_type: np.dtype,
write_offset: int,
task_id: int,
threads: int,
config: Optional[Mapping[str, Any]] = None,
verbose: bool = False,
Expand All @@ -1228,7 +1255,7 @@ def ingest_additions_udf(
partial_write_array_index_dir_uri
)
partial_write_array_index_uri = partial_write_array_index_group[
str(task_id)
"additions"
].uri
additions_vectors, additions_external_ids = read_additions(
updates_uri=updates_uri,
Expand Down Expand Up @@ -1678,7 +1705,6 @@ def create_ingestion_dag(
updates_uri=updates_uri,
vector_type=vector_type,
write_offset=size,
task_id=task_id,
threads=threads,
config=config,
verbose=verbose,
Expand Down Expand Up @@ -1744,15 +1770,22 @@ def consolidate_and_vacuum(
tiledb.vacuum(parts_uri, config=conf)
tiledb.consolidate(ids_uri, config=conf)
tiledb.vacuum(ids_uri, config=conf)
group.close()

# TODO remove temp data for tiledb URIs
if not index_group_uri.startswith("tiledb://"):
vfs = tiledb.VFS(config)
partial_write_array_dir_uri = (
index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
)
if vfs.is_dir(partial_write_array_dir_uri):
vfs.remove_dir(partial_write_array_dir_uri)
group = tiledb.Group(index_group_uri, "r")
if PARTIAL_WRITE_ARRAY_DIR in group:
group.close()
group = tiledb.Group(index_group_uri, "w")
group.remove(PARTIAL_WRITE_ARRAY_DIR)
vfs = tiledb.VFS(config)
partial_write_array_dir_uri = (
index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
)
if vfs.is_dir(partial_write_array_dir_uri):
vfs.remove_dir(partial_write_array_dir_uri)
group.close()

# --------------------------------------------------------------------
# End internal function definitions
Expand Down Expand Up @@ -1852,11 +1885,13 @@ def consolidate_and_vacuum(
input_vectors_work_items = int(math.ceil(size / input_vectors_per_work_item))
input_vectors_work_tasks = input_vectors_work_items
input_vectors_work_items_per_worker = 1
if input_vectors_work_tasks > MAX_TASKS_PER_STAGE:
if max_tasks_per_stage == -1:
max_tasks_per_stage = MAX_TASKS_PER_STAGE
if input_vectors_work_tasks > max_tasks_per_stage:
input_vectors_work_items_per_worker = int(
math.ceil(input_vectors_work_items / MAX_TASKS_PER_STAGE)
math.ceil(input_vectors_work_items / max_tasks_per_stage)
)
input_vectors_work_tasks = MAX_TASKS_PER_STAGE
input_vectors_work_tasks = max_tasks_per_stage
logger.debug("input_vectors_per_work_item %d", input_vectors_per_work_item)
logger.debug("input_vectors_work_items %d", input_vectors_work_items)
logger.debug("input_vectors_work_tasks %d", input_vectors_work_tasks)
Expand All @@ -1875,11 +1910,11 @@ def consolidate_and_vacuum(
)
table_partitions_work_tasks = table_partitions_work_items
table_partitions_work_items_per_worker = 1
if table_partitions_work_tasks > MAX_TASKS_PER_STAGE:
if table_partitions_work_tasks > max_tasks_per_stage:
table_partitions_work_items_per_worker = int(
math.ceil(table_partitions_work_items / MAX_TASKS_PER_STAGE)
math.ceil(table_partitions_work_items / max_tasks_per_stage)
)
table_partitions_work_tasks = MAX_TASKS_PER_STAGE
table_partitions_work_tasks = max_tasks_per_stage
logger.debug(
"table_partitions_per_work_item %d", table_partitions_per_work_item
)
Expand All @@ -1897,7 +1932,7 @@ def consolidate_and_vacuum(
index_type=index_type,
size=size,
dimensions=dimensions,
input_vectors_work_tasks=input_vectors_work_tasks,
input_vectors_work_items=input_vectors_work_items,
vector_type=vector_type,
logger=logger,
)
Expand Down
45 changes: 45 additions & 0 deletions apis/python/test/test_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,51 @@ def test_ivf_flat_ingestion_numpy(tmp_path):
assert accuracy(result, gt_i) > MINIMUM_ACCURACY


def test_ivf_flat_ingestion_multiple_workers(tmp_path):
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
gt_uri = "test/data/siftsmall/siftsmall_groundtruth.ivecs"
index_uri = os.path.join(tmp_path, "array")
k = 100
partitions = 100
nqueries = 100
nprobe = 20

query_vectors = load_fvecs(queries_uri)
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)

index = ingest(
index_type="IVF_FLAT",
index_uri=index_uri,
source_uri=source_uri,
partitions=partitions,
input_vectors_per_work_item=421,
max_tasks_per_stage=4,
)
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# Test single query vector handling
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY

index_ram = IVFFlatIndex(uri=index_uri)
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

_, result = index_ram.query(
query_vectors,
k=k,
nprobe=nprobe,
use_nuv_implementation=True,
)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

# NB: local mode currently does not return distances
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
assert accuracy(result, gt_i) > MINIMUM_ACCURACY


def test_ivf_flat_ingestion_external_ids_numpy(tmp_path):
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
Expand Down
Loading