Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize asset creation and registration #540

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions apis/python/src/tiledb/vector_search/flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Stores all vectors in a 2D TileDB array performing exhaustive similarity
search between the query vectors and all the dataset vectors.
"""
from typing import Any, Mapping
from threading import Thread
from typing import Any, Mapping, Sequence

import numpy as np

Expand All @@ -16,7 +17,7 @@
from tiledb.vector_search.utils import MAX_FLOAT32
from tiledb.vector_search.utils import MAX_INT32
from tiledb.vector_search.utils import MAX_UINT64
from tiledb.vector_search.utils import add_to_group
from tiledb.vector_search.utils import create_array_and_add_to_group

TILE_SIZE_BYTES = 128000000 # 128MB
INDEX_TYPE = "FLAT"
Expand Down Expand Up @@ -151,9 +152,11 @@ def create(
dimensions: int,
vector_type: np.dtype,
group_exists: bool = False,
group: tiledb.Group = None,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
asset_creation_threads: Sequence[Thread] = None,
**kwargs,
) -> FlatIndex:
"""
Expand All @@ -179,21 +182,42 @@ def create(
distance_metric: vspy.DistanceMetric
Distance metric to use for the index.
If not provided, use L2 distance.
group: tiledb.Group
TileDB group open in write mode.
Internal, this is used to avoid opening the group multiple times during
ingestion.
asset_creation_threads: Sequence[Thread]
List of asset creation threads to append new threads.
Internal, this is used to parallelize all asset creation during
ingestion.
"""
validate_storage_version(storage_version)

index.create_metadata(
uri=uri,
dimensions=dimensions,
vector_type=vector_type,
index_type=INDEX_TYPE,
storage_version=storage_version,
distance_metric=distance_metric,
group_exists=group_exists,
config=config,
)
with tiledb.scope_ctx(ctx_or_config=config):
group = tiledb.Group(uri, "w")
if not group_exists:
try:
tiledb.group_create(uri)
except tiledb.TileDBError as err:
raise err
if group is None:
grp = tiledb.Group(uri, "w")
else:
grp = group

if asset_creation_threads is not None:
threads = asset_creation_threads
else:
threads = []

index.create_metadata(
group=grp,
dimensions=dimensions,
vector_type=vector_type,
index_type=INDEX_TYPE,
storage_version=storage_version,
distance_metric=distance_metric,
)

tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
Expand Down Expand Up @@ -221,8 +245,17 @@ def create(
cell_order="col-major",
tile_order="col-major",
)
tiledb.Array.create(ids_uri, ids_schema)
add_to_group(group, ids_uri, ids_array_name)
thread = Thread(
target=create_array_and_add_to_group,
kwargs={
"array_uri": ids_uri,
"array_name": ids_array_name,
"group": grp,
"schema": ids_schema,
},
)
thread.start()
threads.append(thread)

parts_array_rows_dim = tiledb.Dim(
name="rows",
Expand All @@ -249,8 +282,17 @@ def create(
cell_order="col-major",
tile_order="col-major",
)
tiledb.Array.create(parts_uri, parts_schema)
add_to_group(group, parts_uri, parts_array_name)
thread = Thread(
target=create_array_and_add_to_group,
kwargs={
"array_uri": parts_uri,
"array_name": parts_array_name,
"group": grp,
"schema": parts_schema,
},
)
thread.start()
threads.append(thread)

external_id_dim = tiledb.Dim(
name="external_id",
Expand All @@ -265,8 +307,23 @@ def create(
attrs=[vector_attr],
allows_duplicates=False,
)
tiledb.Array.create(updates_array_uri, updates_schema)
add_to_group(group, updates_array_uri, updates_array_name)
thread = Thread(
target=create_array_and_add_to_group,
kwargs={
"array_uri": updates_array_uri,
"array_name": updates_array_name,
"group": grp,
"schema": updates_schema,
},
)
thread.start()
threads.append(thread)

group.close()
return FlatIndex(uri=uri, config=config)
if asset_creation_threads is None:
for thread in threads:
thread.join()
if group is None:
grp.close()
return FlatIndex(uri=uri, config=config)
else:
return None
29 changes: 10 additions & 19 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,34 +853,25 @@ def _open_updates_array(self, timestamp: int = None):


def create_metadata(
uri: str,
group: tiledb.Group,
dimensions: int,
vector_type: np.dtype,
index_type: str,
storage_version: str,
distance_metric: vspy.DistanceMetric,
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
):
"""
Creates the index group adding index metadata.
"""
with tiledb.scope_ctx(ctx_or_config=config):
if not group_exists:
try:
tiledb.group_create(uri)
except tiledb.TileDBError as err:
raise err
group = tiledb.Group(uri, "w")
group.meta["dataset_type"] = DATASET_TYPE
group.meta["dtype"] = np.dtype(vector_type).name
group.meta["storage_version"] = storage_version
group.meta["index_type"] = index_type
group.meta["base_sizes"] = json.dumps([0])
group.meta["ingestion_timestamps"] = json.dumps([0])
group.meta["has_updates"] = False
group.meta["distance_metric"] = int(distance_metric)
group.close()
group.meta["dataset_type"] = DATASET_TYPE
group.meta["dtype"] = np.dtype(vector_type).name
group.meta["dimensions"] = dimensions
group.meta["storage_version"] = storage_version
group.meta["index_type"] = index_type
group.meta["base_sizes"] = json.dumps([0])
group.meta["ingestion_timestamps"] = json.dumps([0])
group.meta["has_updates"] = False
group.meta["distance_metric"] = int(distance_metric)


"""
Expand Down
65 changes: 49 additions & 16 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import enum
from functools import partial
from typing import Any, Mapping, Optional, Tuple
from threading import Thread
from typing import Any, Mapping, Optional, Sequence, Tuple

import numpy as np

Expand All @@ -24,6 +25,7 @@
from tiledb.vector_search.storage_formats import STORAGE_VERSION
from tiledb.vector_search.storage_formats import validate_storage_version
from tiledb.vector_search.utils import add_to_group
from tiledb.vector_search.utils import create_array_and_add_to_group
from tiledb.vector_search.utils import is_type_erased_index
from tiledb.vector_search.utils import normalize_vectors
from tiledb.vector_search.utils import to_temporal_policy
Expand Down Expand Up @@ -611,6 +613,7 @@ def create_partial_write_array_group(
dimensions: int,
filters: Any,
create_index_array: bool,
asset_creation_threads: Sequence[Thread],
) -> str:
tile_size = int(
ivf_flat_index.TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
Expand Down Expand Up @@ -640,12 +643,17 @@ def create_partial_write_array_group(
cell_order="col-major",
tile_order="col-major",
)
tiledb.Array.create(partial_write_array_index_uri, index_schema)
add_to_group(
temp_data_group,
partial_write_array_index_uri,
INDEX_ARRAY_NAME,
thread = Thread(
target=create_array_and_add_to_group,
kwargs={
"array_uri": partial_write_array_index_uri,
"array_name": INDEX_ARRAY_NAME,
"group": temp_data_group,
"schema": index_schema,
},
)
thread.start()
asset_creation_threads.append(thread)

if not tiledb.array_exists(partial_write_array_ids_uri):
logger.debug("Creating temp ids array")
Expand All @@ -670,12 +678,17 @@ def create_partial_write_array_group(
tile_order="col-major",
)
logger.debug(ids_schema)
tiledb.Array.create(partial_write_array_ids_uri, ids_schema)
add_to_group(
temp_data_group,
partial_write_array_ids_uri,
IDS_ARRAY_NAME,
thread = Thread(
target=create_array_and_add_to_group,
kwargs={
"array_uri": partial_write_array_ids_uri,
"array_name": IDS_ARRAY_NAME,
"group": temp_data_group,
"schema": ids_schema,
},
)
thread.start()
asset_creation_threads.append(thread)

if not tiledb.array_exists(partial_write_array_parts_uri):
logger.debug("Creating temp parts array")
Expand All @@ -702,12 +715,17 @@ def create_partial_write_array_group(
)
logger.debug(parts_schema)
logger.debug(partial_write_array_parts_uri)
tiledb.Array.create(partial_write_array_parts_uri, parts_schema)
add_to_group(
temp_data_group,
partial_write_array_parts_uri,
PARTS_ARRAY_NAME,
thread = Thread(
target=create_array_and_add_to_group,
kwargs={
"array_uri": partial_write_array_parts_uri,
"array_name": PARTS_ARRAY_NAME,
"group": temp_data_group,
"schema": parts_schema,
},
)
thread.start()
asset_creation_threads.append(thread)
return partial_write_array_index_uri

def create_arrays(
Expand All @@ -720,6 +738,7 @@ def create_arrays(
vector_type: np.dtype,
logger: logging.Logger,
storage_version: str,
asset_creation_threads: Sequence[Thread],
) -> None:
if index_type == "FLAT":
if not arrays_created:
Expand All @@ -728,9 +747,11 @@ def create_arrays(
dimensions=dimensions,
vector_type=vector_type,
group_exists=True,
group=group,
config=config,
storage_version=storage_version,
distance_metric=distance_metric,
asset_creation_threads=asset_creation_threads,
)
elif index_type == "IVF_FLAT":
if not arrays_created:
Expand All @@ -739,16 +760,19 @@ def create_arrays(
dimensions=dimensions,
vector_type=vector_type,
group_exists=True,
group=group,
config=config,
storage_version=storage_version,
distance_metric=distance_metric,
asset_creation_threads=asset_creation_threads,
)
create_partial_write_array_group(
temp_data_group=temp_data_group,
vector_type=vector_type,
dimensions=dimensions,
filters=DEFAULT_ATTR_FILTERS,
create_index_array=True,
asset_creation_threads=asset_creation_threads,
)

# Note that we don't create type-erased indexes (i.e. Vamana) here. Instead we create them
Expand Down Expand Up @@ -1539,13 +1563,17 @@ def ingest_type_erased(

temp_data_group_uri = f"{index_group_uri}/{PARTIAL_WRITE_ARRAY_DIR}"
temp_data_group = tiledb.Group(temp_data_group_uri, "w")
asset_creation_threads = []
create_partial_write_array_group(
temp_data_group=temp_data_group,
vector_type=vector_type,
dimensions=dimensions,
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
create_index_array=False,
asset_creation_threads=asset_creation_threads,
)
for thread in asset_creation_threads:
thread.join()
temp_data_group.close()
temp_data_group = tiledb.Group(temp_data_group_uri)
ids_array_uri = temp_data_group[IDS_ARRAY_NAME].uri
Expand Down Expand Up @@ -2942,6 +2970,7 @@ def consolidate_and_vacuum(

logger.debug("Creating arrays")
group = tiledb.Group(index_group_uri, "w")
asset_creation_threads = []
temp_data_group = create_temp_data_group(group=group)
create_arrays(
group=group,
Expand All @@ -2953,6 +2982,7 @@ def consolidate_and_vacuum(
vector_type=vector_type,
logger=logger,
storage_version=storage_version,
asset_creation_threads=asset_creation_threads,
)

if (
Expand Down Expand Up @@ -2996,6 +3026,9 @@ def consolidate_and_vacuum(
else:
if external_ids_type is None:
external_ids_type = "U64BIN"

for thread in asset_creation_threads:
thread.join()
temp_data_group.close()
group.meta["temp_size"] = size
group.close()
Expand Down
Loading
Loading