Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distance metrics integration #422

Merged
merged 26 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3166e79
Distance metrics added for flat
cainamisir Jun 19, 2024
e946cbb
Added cosine distance function [skip ci]
cainamisir Jun 20, 2024
3796274
Changed distance to Enum
cainamisir Jun 20, 2024
9e914dc
exception for invalid distance metric
cainamisir Jun 20, 2024
81565e7
enum use and exception for non-flat distance metrics
cainamisir Jun 20, 2024
504d046
distance metric test
cainamisir Jun 20, 2024
c571281
formatting
cainamisir Jun 20, 2024
5af20ee
Cleaning up integration by moving to create
cainamisir Jul 1, 2024
2f67be5
Formatting and cleaning up test
cainamisir Jul 1, 2024
3525ff9
Merge branch 'main' of https://github.com/TileDB-Inc/TileDB-Vector-Se…
cainamisir Jul 4, 2024
27b93f3
Cosine optimization
cainamisir Jul 4, 2024
9396bdc
Formatting
cainamisir Jul 4, 2024
4417a28
distance metric integration in C++ layer for vamana+ivfpq
cainamisir Jul 5, 2024
face742
Final FLAT changes
cainamisir Jul 15, 2024
701b2eb
Metadata handling changes
cainamisir Jul 17, 2024
183641a
Cleanup remaining stuff from ivf_flat
cainamisir Jul 18, 2024
daea5f7
Cleanup setting metadata
cainamisir Jul 18, 2024
ca403c3
Cleanup
cainamisir Jul 19, 2024
53410b6
Testing for metadata (Vamana and IVF_PQ)
cainamisir Jul 22, 2024
3c543f8
Merge remote-tracking branch 'origin/main' into vlad/distancemetrics
cainamisir Jul 23, 2024
dd19772
fixing test
cainamisir Jul 23, 2024
6df9286
formatting
cainamisir Jul 23, 2024
6c0de47
Remove useless throw
cainamisir Jul 24, 2024
8517c53
swapped order of params and made style changes
cainamisir Jul 25, 2024
7ca69af
Merge branch 'main' of https://github.com/TileDB-Inc/TileDB-Vector-Se…
cainamisir Jul 25, 2024
7e3ba46
formatting
cainamisir Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion apis/python/src/tiledb/vector_search/flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,19 @@ def query_internal(
(queries.shape[0], k), MAX_UINT64
)

try:
distance_metric = vspy.DistanceMetric(
self.group.meta.get("distance_metric", vspy.DistanceMetric.L2)
)
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
raise ValueError(
jparismorgan marked this conversation as resolved.
Show resolved Hide resolved
f"Invalid distance metric in metadata: {self.group.meta.get('distance_metric')}."
)

queries_m = array_to_matrix(np.transpose(queries))
d, i = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)
d, i = query_vq_heap(
self._db, queries_m, self._ids, k, nthreads, distance_metric
)

return np.transpose(np.array(d)), np.transpose(np.array(i))

Expand All @@ -149,6 +160,7 @@ def create(
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
) -> FlatIndex:
"""
Expand Down Expand Up @@ -185,6 +197,7 @@ def create(
)
with tiledb.scope_ctx(ctx_or_config=config):
group = tiledb.Group(uri, "w")
group.meta["distance_metric"] = int(distance_metric)
tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
Expand Down
1 change: 1 addition & 0 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self.ctx = vspy.Ctx(config)
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(config))
self.storage_version = self.group.meta.get("storage_version", "0.1")
self.distance_metric = self.group.meta.get("distance_metric", "L2")
if (
not storage_formats[self.storage_version]["SUPPORT_TIMETRAVEL"]
and timestamp is not None
Expand Down
10 changes: 9 additions & 1 deletion apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from tiledb.cloud.dag import Mode
from tiledb.vector_search import _tiledbvspy as vspy
from tiledb.vector_search._tiledbvspy import *
from tiledb.vector_search.storage_formats import STORAGE_VERSION
from tiledb.vector_search.storage_formats import validate_storage_version
Expand Down Expand Up @@ -80,6 +81,7 @@ def ingest(
] = None,
write_centroids_resources: Optional[Mapping[str, Any]] = None,
partial_index_resources: Optional[Mapping[str, Any]] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
**kwargs,
):
"""
Expand Down Expand Up @@ -183,6 +185,8 @@ def ingest(
Resources to request when performing the write of centroids, only applies to BATCH mode
partial_index_resources: Optional[Mapping[str, Any]]
Resources to request when performing the computation of partial indexing, only applies to BATCH mode
distance_metric: Optional[vspy.DistanceMetric]
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
Distance metric to use for the index, defaults to 'vspy.DistanceMetric.L2'. Options are 'vspy.DistanceMetric.L2', 'vspy.DistanceMetric.INNER_PRODUCT', 'vspy.DistanceMetric.COSINE'.
"""
import enum
import json
Expand Down Expand Up @@ -682,6 +686,7 @@ def create_arrays(
group_exists=True,
config=config,
storage_version=storage_version,
distance_metric=distance_metric,
)
elif index_type == "IVF_FLAT":
if not arrays_created:
Expand All @@ -692,6 +697,7 @@ def create_arrays(
group_exists=True,
config=config,
storage_version=storage_version,
distance_metric=distance_metric,
)
partial_write_array_index_uri = create_partial_write_array_group(
temp_data_group=temp_data_group,
Expand Down Expand Up @@ -1641,7 +1647,6 @@ def ingest_type_erased(
ids_array.close()

# Now that we've ingested the vectors and their IDs, train the index with the data.
from tiledb.vector_search import _tiledbvspy as vspy

ctx = vspy.Ctx(config)
if index_type == "VAMANA":
Expand Down Expand Up @@ -2672,6 +2677,7 @@ def consolidate_and_vacuum(
vector_type=vector_type,
config=config,
storage_version=storage_version,
distance_metric=distance_metric,
)
elif index_type == "IVF_PQ":
ivf_pq_index.create(
Expand All @@ -2682,6 +2688,7 @@ def consolidate_and_vacuum(
partitions=partitions,
config=config,
storage_version=storage_version,
distance_metric=distance_metric,
)
else:
raise ValueError(f"Unsupported index type {index_type}")
Expand Down Expand Up @@ -2948,6 +2955,7 @@ def consolidate_and_vacuum(
group.meta["partition_history"] = json.dumps(partition_history)
group.meta["base_sizes"] = json.dumps(base_sizes)
group.meta["ingestion_timestamps"] = json.dumps(ingestion_timestamps)

group.close()

consolidate_and_vacuum(index_group_uri=index_group_uri, config=config)
Expand Down
7 changes: 7 additions & 0 deletions apis/python/src/tiledb/vector_search/ivf_flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def create(
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
**kwargs,
) -> IVFFlatIndex:
"""
Expand All @@ -528,6 +529,11 @@ def create(
"""
validate_storage_version(storage_version)

if distance_metric != vspy.DistanceMetric.L2:
raise ValueError(
f"Distance metric {distance_metric} is not supported in IVF_FLAT"
)

index.create_metadata(
uri=uri,
dimensions=dimensions,
Expand All @@ -539,6 +545,7 @@ def create(
)
with tiledb.scope_ctx(ctx_or_config=config):
group = tiledb.Group(uri, "w")
group.meta["distance_metric"] = int(distance_metric)
tile_size = int(TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions)
group.meta["partition_history"] = json.dumps([0])
centroids_array_name = storage_formats[storage_version]["CENTROIDS_ARRAY_NAME"]
Expand Down
9 changes: 9 additions & 0 deletions apis/python/src/tiledb/vector_search/ivf_pq_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def create(
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
partitions: Optional[int] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
**kwargs,
) -> IVFPQIndex:
"""
Expand Down Expand Up @@ -177,6 +178,14 @@ def create(
raise ValueError(
f"Number of dimensions ({dimensions}) must be divisible by num_subspaces ({num_subspaces})."
)
if distance_metric != vspy.DistanceMetric.L2:
raise ValueError(
f"Distance metric {distance_metric} is not supported in IVF_PQ"
)

group = tiledb.Group(uri, "w")
group.meta["distance_metric"] = int(distance_metric)
group.close()
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
index = vspy.IndexIVFPQ(
feature_type=np.dtype(vector_type).name,
id_type=np.dtype(np.uint64).name,
Expand Down
27 changes: 23 additions & 4 deletions apis/python/src/tiledb/vector_search/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,24 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
ColMajorMatrix<float>& query_vectors,
const std::vector<uint64_t>& ids,
int k,
size_t nthreads)
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
auto r =
detail::flat::vq_query_heap(data, query_vectors, ids, k, nthreads);
return r;
if (distance_metric == DistanceMetric::L2) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, sum_of_squares_distance{});
return r;
} else if (distance_metric == DistanceMetric::INNER_PRODUCT) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, inner_product_distance{});
return r;
} else if (distance_metric == DistanceMetric::COSINE) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, cosine_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
});
}

Expand Down Expand Up @@ -812,6 +825,12 @@ PYBIND11_MODULE(_tiledbvspy, m) {
declare_debug_matrix<float>(m, "_f32");
declare_debug_matrix<uint64_t>(m, "_u64");

py::enum_<DistanceMetric>(m, "DistanceMetric")
.value("L2", DistanceMetric::L2)
.value("INNER_PRODUCT", DistanceMetric::INNER_PRODUCT)
.value("COSINE", DistanceMetric::COSINE)
.export_values();

/* === Module inits === */

init_kmeans(m);
Expand Down
8 changes: 8 additions & 0 deletions apis/python/src/tiledb/vector_search/vamana_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def create(
vector_type: np.dtype,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
**kwargs,
) -> VamanaIndex:
"""
Expand Down Expand Up @@ -170,6 +171,13 @@ def create(
id_type=np.dtype(np.uint64).name,
dimensions=dimensions,
)
if distance_metric != vspy.DistanceMetric.L2:
raise ValueError(
f"Distance metric {distance_metric} is not supported in VAMANA"
)
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
group = tiledb.Group(uri, "w")
group.meta["distance_metric"] = int(distance_metric)
group.close()
# TODO(paris): Run all of this with a single C++ call.
empty_vector = vspy.FeatureVectorArray(
dimensions, 0, np.dtype(vector_type).name, np.dtype(np.uint64).name
Expand Down
119 changes: 119 additions & 0 deletions apis/python/test/test_distance_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import numpy as np
import pytest
from array_paths import *
from sklearn.neighbors import NearestNeighbors

from tiledb.vector_search import _tiledbvspy as vspy
from tiledb.vector_search.ingestion import ingest
from tiledb.vector_search.utils import load_fvecs

siftsmall_uri = siftsmall_inputs_file
queries_uri = siftsmall_query_file
cainamisir marked this conversation as resolved.
Show resolved Hide resolved


def normalize_vectors(vectors):
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
return vectors / norms


def test_cosine_similarity(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_FLAT")
index = ingest(
index_type="FLAT",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
distance_metric=vspy.DistanceMetric.COSINE,
)

dataset_vectors = load_fvecs(siftsmall_uri)
query_vectors = load_fvecs(queries_uri)

nn_cosine_sklearn = NearestNeighbors(n_neighbors=5, metric="cosine")
nn_cosine_sklearn.fit(dataset_vectors)
distances_sklearn, indices_sklearn = nn_cosine_sklearn.kneighbors(query_vectors)

distances, indices = index.query(query_vectors, k=5)
cainamisir marked this conversation as resolved.
Show resolved Hide resolved

assert np.allclose(
distances_sklearn, distances, 1e-4
), "Cosine similarity distances do not match"
assert np.array_equal(
indices_sklearn, indices
), "Cosine similarity indices do not match"


def test_inner_product(tmp_path):
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
index_uri = os.path.join(tmp_path, "sift10k_flat_IP")
index = ingest(
index_type="FLAT",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
distance_metric=vspy.DistanceMetric.INNER_PRODUCT,
)

dataset_vectors = load_fvecs(siftsmall_uri)
query_vectors = load_fvecs(queries_uri)

inner_products_sklearn = np.dot(query_vectors, dataset_vectors.T)

sorted_inner_products_sklearn = np.sort(inner_products_sklearn, axis=1)

distances, _ = index.query(query_vectors, k=5)

sorted_distances = np.sort(distances, axis=1)

for i in range(len(sorted_inner_products_sklearn)):
compare = sorted_inner_products_sklearn[i][:5]
assert np.allclose(
compare, sorted_distances[i], 1e-4
), "Inner products do not match"


def test_l2_distance(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_L2")
index = ingest(
index_type="FLAT",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
)

dataset_vectors = load_fvecs(siftsmall_uri)
query_vectors = load_fvecs(queries_uri)

nn_l2 = NearestNeighbors(n_neighbors=5, metric="euclidean")
nn_l2.fit(dataset_vectors)
distances_l2, indices_l2 = nn_l2.kneighbors(query_vectors)

distances, indices = index.query(query_vectors, k=5)
distances = np.sqrt(distances)
cainamisir marked this conversation as resolved.
Show resolved Hide resolved
assert np.allclose(distances_l2, distances, 1e-4), "L2 distances do not match"
assert np.array_equal(indices_l2, indices), "L2 indices do not match"


def test_wrong_distance_metric(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_IDK")
with pytest.raises(AttributeError):
ingest(
index_type="FLAT",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
distance_metric=vspy.DistanceMetric.IDK,
)


def test_wrong_type_with_distance_metric(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_IVF_FLAT_COSINE")
with pytest.raises(ValueError):
ingest(
index_type="IVF_FLAT",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
distance_metric=vspy.DistanceMetric.COSINE,
)
Loading
Loading