Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable setting resource_class or resources when calling query() on IVFFlatIndex #165

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions apis/python/src/tiledb/vector_search/ivf_flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def submit_local(d, func, *args, **kwargs):
# Drop kwarg
kwargs.pop("image_name", None)
kwargs.pop("resource_class", None)
kwargs.pop("resources", None)
return d.submit_local(func, *args, **kwargs)

Expand Down Expand Up @@ -133,6 +134,8 @@ def query_internal(
nthreads: int = -1,
use_nuv_implementation: bool = False,
mode: Mode = None,
resource_class: Optional[str] = None,
resources: Optional[Mapping[str, Any]] = None,
num_partitions: int = -1,
num_workers: int = -1,
):
Expand All @@ -153,7 +156,21 @@ def query_internal(
wether to use the nuv query implementation. Default: False
mode: Mode
If provided the query will be executed using TileDB cloud taskgraphs.
For distributed execution you can use REALTIME or BATCH mode
For distributed execution you can use REALTIME or BATCH mode.
For local execution you can use LOCAL mode.
resource_class:
The name of the resource class to use ("standard" or "large"). Resource classes define maximum
limits for cpu and memory usage. Can only be used in REALTIME or BATCH mode.
Cannot be used alongside resources.
In REALTIME or BATCH mode if neither resource_class nor resources are provided,
we default to the "large" resource class.
resources:
A specification for the amount of resources to use when executing using TileDB cloud
taskgraphs, of the form: {"cpu": "6", "memory": "12Gi", "gpu": 1}. Can only be used
in REALTIME or BATCH mode. Cannot be used alongside resource_class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can only be used in BATCH mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops thanks, good catch, updated.

num_partitions: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate documentation for num_partitions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good catch, removed.

Only relevant for taskgraph based execution.
If provided, we split the query execution in that many partitions.
num_partitions: int
Only relevant for taskgraph based execution.
If provided, we split the query execution in that many partitions.
Expand All @@ -167,6 +184,9 @@ def query_internal(
(queries.shape[0], k), index.MAX_UINT64
)

if not (mode == Mode.REALTIME or mode == Mode.BATCH) and (resource_class or resources):
raise TypeError("Can only pass resource_class or resources in REALTIME or BATCH mode")

assert queries.dtype == np.float32

if queries.ndim == 1:
Expand Down Expand Up @@ -217,6 +237,8 @@ def query_internal(
nthreads=nthreads,
nprobe=nprobe,
mode=mode,
resource_class=resource_class,
resources=resources,
num_partitions=num_partitions,
num_workers=num_workers,
config=self.config,
Expand All @@ -229,6 +251,8 @@ def taskgraph_query(
nprobe: int = 10,
nthreads: int = -1,
mode: Mode = None,
resource_class: Optional[str] = None,
resources: Optional[Mapping[str, Any]] = None,
num_partitions: int = -1,
num_workers: int = -1,
config: Optional[Mapping[str, Any]] = None,
Expand All @@ -248,7 +272,18 @@ def taskgraph_query(
Number of threads to use for query
mode: Mode
If provided the query will be executed using TileDB cloud taskgraphs.
For distributed execution you can use REALTIME or BATCH mode
For distributed execution you can use REALTIME or BATCH mode.
For local execution you can use LOCAL mode.
resource_class:
The name of the resource class to use ("standard" or "large"). Resource classes define maximum
limits for cpu and memory usage. Can only be used in REALTIME or BATCH mode.
Cannot be used alongside resources.
In REALTIME or BATCH mode if neither resource_class nor resources are provided,
we default to the "large" resource class.
resources:
A specification for the amount of resources to use when executing using TileDB cloud
taskgraphs, of the form: {"cpu": "6", "memory": "12Gi", "gpu": 1}. Can only be used
in REALTIME or BATCH mode. Cannot be used alongside resource_class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, only BATCH mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

num_partitions: int
Only relevant for taskgraph based execution.
If provided, we split the query execution in that many partitions.
Expand All @@ -268,6 +303,9 @@ def taskgraph_query(
from tiledb.vector_search.module import (array_to_matrix, dist_qv,
partition_ivf_index)

if resource_class and resources:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check to make sure that resources is not defined for REALTIME

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done. Originally I just let tiledb.cloud throw this error, but you're right, that's inconsistent with how I'm handling other errors, so updated.

raise TypeError("Cannot provide both resource_class and resources")

def dist_qv_udf(
dtype: np.dtype,
parts_uri: str,
Expand Down Expand Up @@ -373,7 +411,8 @@ def dist_qv_udf(
k_nn=k,
config=config,
timestamp=self.base_array_timestamp,
resource_class="large" if mode == Mode.REALTIME else None,
resource_class="large" if (not resources and not resource_class) else resource_class,
resources=resources,
image_name="3.9-vectorsearch",
)
)
Expand Down
35 changes: 33 additions & 2 deletions apis/python/test/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from common import *
from tiledb.cloud import groups
from tiledb.cloud import groups, tiledb_cloud_error
from tiledb.cloud.dag import Mode

import tiledb.vector_search as vs
Expand All @@ -17,6 +17,8 @@ class CloudTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
if not os.getenv("TILEDB_REST_TOKEN"):
raise ValueError("TILEDB_REST_TOKEN not set")
tiledb.cloud.login(token=os.getenv("TILEDB_REST_TOKEN"))
namespace, storage_path, _ = groups._default_ns_path_cred()
storage_path = storage_path.replace("//", "/").replace("/", "//", 1)
Expand Down Expand Up @@ -76,10 +78,39 @@ def test_cloud_ivf_flat(self):
# UDF library releases.
# mode=Mode.BATCH,
)

_, result_i = index.query(query_vectors, k=k, nprobe=nprobe)
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY

_, result_i = index.query(
query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, num_partitions=2
query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, num_partitions=2, resource_class="standard"
)
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY

_, result_i = index.query(
query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, num_partitions=2
)
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY

# We now will test for invalid scenarios when setting the query() resources.
resources = {"cpu": "9", "memory": "12Gi", "gpu": 0}

# Cannot pass resource_class or resources to LOCAL mode or to no mode.
with self.assertRaises(TypeError):
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, resource_class="large")
with self.assertRaises(TypeError):
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, resources=resources)
with self.assertRaises(TypeError):
index.query(query_vectors, k=k, nprobe=nprobe, resource_class="large")
with self.assertRaises(TypeError):
index.query(query_vectors, k=k, nprobe=nprobe, resources=resources)

# Cannot pass resources to REALTIME.
with self.assertRaises(tiledb_cloud_error.TileDBCloudError):
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, resources=resources)

# Cannot pass both resource_class and resources.
with self.assertRaises(TypeError):
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, resource_class="large", resources=resources)
with self.assertRaises(TypeError):
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.BATCH, resource_class="large", resources=resources)
Loading