Skip to content

Commit

Permalink
distance metric integration in C++ layer for vamana+ivfpq
Browse files Browse the repository at this point in the history
generalize metric handling logic to index.py
  • Loading branch information
cainamisir committed Jul 5, 2024
1 parent 9396bdc commit 4417a28
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 33 deletions.
11 changes: 1 addition & 10 deletions apis/python/src/tiledb/vector_search/flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,9 @@ def query_internal(
(queries.shape[0], k), MAX_UINT64
)

try:
distance_metric = vspy.DistanceMetric(
self.group.meta.get("distance_metric", vspy.DistanceMetric.L2)
)
except ValueError:
raise ValueError(
f"Invalid distance metric in metadata: {self.group.meta.get('distance_metric')}."
)

queries_m = array_to_matrix(np.transpose(queries))
d, i = query_vq_heap(
self._db, queries_m, self._ids, k, nthreads, distance_metric
self._db, queries_m, self._ids, k, nthreads, self.distance_metric
)

return np.transpose(np.array(d)), np.transpose(np.array(i))
Expand Down
10 changes: 9 additions & 1 deletion apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,15 @@ def __init__(
self.ctx = vspy.Ctx(config)
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(config))
self.storage_version = self.group.meta.get("storage_version", "0.1")
self.distance_metric = self.group.meta.get("distance_metric", "L2")
try:
self.distance_metric = vspy.DistanceMetric(
self.group.meta.get("distance_metric", vspy.DistanceMetric.L2)
)
except ValueError:
raise ValueError(
f"Invalid distance metric in metadata: {self.group.meta.get('distance_metric')}."
)

if (
not storage_formats[self.storage_version]["SUPPORT_TIMETRAVEL"]
and timestamp is not None
Expand Down
4 changes: 1 addition & 3 deletions apis/python/src/tiledb/vector_search/ivf_pq_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,14 @@ def create(
f"Distance metric {distance_metric} is not supported in IVF_PQ"
)

group = tiledb.Group(uri, "w")
group.meta["distance_metric"] = int(distance_metric)
group.close()
index = vspy.IndexIVFPQ(
feature_type=np.dtype(vector_type).name,
id_type=np.dtype(np.uint64).name,
partitioning_index_type=np.dtype(np.uint64).name,
dimensions=dimensions,
n_list=partitions if (partitions is not None and partitions is not -1) else 0,
num_subspaces=num_subspaces,
distance_metric=int(distance_metric),
)
# TODO(paris): Run all of this with a single C++ call.
empty_vector = vspy.FeatureVectorArray(
Expand Down
4 changes: 1 addition & 3 deletions apis/python/src/tiledb/vector_search/vamana_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,12 @@ def create(
dimensions=dimensions,
l_build=l_build if l_build > 0 else L_BUILD_DEFAULT,
r_max_degree=r_max_degree if l_build > 0 else R_MAX_DEGREE_DEFAULT,
distance_metric=int(distance_metric),
)
if distance_metric != vspy.DistanceMetric.L2:
raise ValueError(
f"Distance metric {distance_metric} is not supported in VAMANA"
)
group = tiledb.Group(uri, "w")
group.meta["distance_metric"] = int(distance_metric)
group.close()
# TODO(paris): Run all of this with a single C++ call.
empty_vector = vspy.FeatureVectorArray(
dimensions, 0, np.dtype(vector_type).name, np.dtype(np.uint64).name
Expand Down
69 changes: 57 additions & 12 deletions apis/python/test/test_distance_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def normalize_vectors(vectors):
return vectors / norms


def test_cosine_similarity(tmp_path):
def test_cosine_DISTANCE(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_FLAT")
index = ingest(
index_type="FLAT",
Expand All @@ -33,19 +33,17 @@ def test_cosine_similarity(tmp_path):

nn_cosine_sklearn = NearestNeighbors(n_neighbors=5, metric="cosine")
nn_cosine_sklearn.fit(dataset_vectors)
distances_sklearn, indices_sklearn = nn_cosine_sklearn.kneighbors(query_vectors)
distances_sklearn, ids_sklearn = nn_cosine_sklearn.kneighbors(query_vectors)

distances, indices = index.query(query_vectors, k=5)
distances, ids = index.query(query_vectors, k=5)

assert np.allclose(
distances_sklearn, distances, 1e-4
), "Cosine similarity distances do not match"
assert np.array_equal(
indices_sklearn, indices
), "Cosine similarity indices do not match"
), "Cosine distances do not match"
assert np.array_equal(ids_sklearn, ids), "Cosine distance ids do not match"


def test_inner_product(tmp_path):
def test_inner_product_distances(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_IP")
index = ingest(
index_type="FLAT",
Expand All @@ -69,7 +67,9 @@ def test_inner_product(tmp_path):

for i in range(len(sorted_inner_products_sklearn)):
compare = sorted_inner_products_sklearn[i][:5]
assert np.allclose(compare, distances[i], 1e-4), "Inner products do not match"
assert np.allclose(
compare, distances[i], 1e-4
), "Inner product distances do not match"


def test_l2_distance(tmp_path):
Expand All @@ -86,12 +86,12 @@ def test_l2_distance(tmp_path):

nn_l2 = NearestNeighbors(n_neighbors=5, metric="euclidean")
nn_l2.fit(dataset_vectors)
distances_l2, indices_l2 = nn_l2.kneighbors(query_vectors)
distances_l2, ids_l2 = nn_l2.kneighbors(query_vectors)

distances, indices = index.query(query_vectors, k=5)
distances, ids = index.query(query_vectors, k=5)
distances = np.sqrt(distances)
assert np.allclose(distances_l2, distances, 1e-4), "L2 distances do not match"
assert np.array_equal(indices_l2, indices), "L2 indices do not match"
assert np.array_equal(ids_l2, ids), "L2 ids do not match"


def test_wrong_distance_metric(tmp_path):
Expand All @@ -116,3 +116,48 @@ def test_wrong_type_with_distance_metric(tmp_path):
source_type="FVEC",
distance_metric=vspy.DistanceMetric.COSINE,
)


def test_vamana_create_l2(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_L2")
ingest(
index_type="VAMANA",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
distance_metric=vspy.DistanceMetric.L2,
)


def test_vamana_create_cosine(tmp_path):
index_uri = os.path.join(tmp_path, "sift10k_flat_COSINE")
with pytest.raises(RuntimeError):
ingest(
index_type="VAMANA",
index_uri=index_uri,
source_uri=siftsmall_uri,
source_type="FVEC",
distance_metric=vspy.DistanceMetric.COSINE,
)


# def test_ivfpq_create_l2(tmp_path):
# index_uri = os.path.join(tmp_path, "sift10k_flat_L2")
# index = ingest(
# index_type="IVFPQ",
# index_uri=index_uri,
# source_uri=siftsmall_uri,
# source_type="FVEC",
# distance_metric=vspy.DistanceMetric.L2,
# )

# def test_ivfpq_create_cosine(tmp_path):
# index_uri = os.path.join(tmp_path, "sift10k_flat_COSINE")
# with pytest.raises(RuntimeError):
# ingest(
# index_type="IVFPQ",
# index_uri=index_uri,
# source_uri=siftsmall_uri,
# source_type="FVEC",
# distance_metric=vspy.DistanceMetric.COSINE,
# )
19 changes: 19 additions & 0 deletions src/include/api/ivf_pq_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "api/feature_vector_array.h"
#include "api_defs.h"
#include "index/index_defs.h"
#include "index/index_metadata.h"
#include "index/ivf_pq_index.h"

/*******************************************************************************
Expand Down Expand Up @@ -112,6 +113,19 @@ class IndexIVFPQ {
id_datatype_ = string_to_datatype(value);
} else if (key == "partitioning_index_type") {
partitioning_index_datatype_ = string_to_datatype(value);
} else if (key == "distance_metric") {
try {
int metric_value = std::stoi(value);
if (metric_value < 0 ||
metric_value > static_cast<int>(DistanceMetric::COSINE)) {
throw std::runtime_error(
"Invalid distance metric value: " + value);
}
distance_metric = static_cast<DistanceMetric>(metric_value);
} catch (const std::exception& e) {
throw std::runtime_error(
"Error setting distance metric: " + std::string(e.what()));
}
} else {
throw std::runtime_error("Invalid index config key: " + key);
}
Expand Down Expand Up @@ -157,6 +171,8 @@ class IndexIVFPQ {
" != " + std::to_string(index_->dimensions()));
}
dimensions_ = index_->dimensions();
base_index_metadata<IndexIVFPQ>::get_distance_metric_metadata(
ctx, group_uri, distance_metric);
}

/**
Expand Down Expand Up @@ -242,6 +258,8 @@ class IndexIVFPQ {
"Cannot write_index() because there is no index.");
}
index_->write_index(ctx, group_uri, temporal_policy, storage_version);
base_index_metadata<IndexIVFPQ>::set_distance_metric_metadata(
ctx, group_uri, distance_metric);
}

static void clear_history(
Expand Down Expand Up @@ -562,6 +580,7 @@ class IndexIVFPQ {
tiledb_datatype_t id_datatype_{TILEDB_ANY};
tiledb_datatype_t partitioning_index_datatype_{TILEDB_ANY};
std::unique_ptr<index_base> index_;
DistanceMetric distance_metric;
};

// clang-format off
Expand Down
24 changes: 23 additions & 1 deletion src/include/api/vamana_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "api/feature_vector.h"
#include "api/feature_vector_array.h"
#include "api_defs.h"
#include "index/index_metadata.h"
#include "index/vamana_index.h"

/*******************************************************************************
Expand Down Expand Up @@ -69,7 +70,6 @@ class IndexVamana {
IndexVamana(IndexVamana&&) = default;
IndexVamana& operator=(const IndexVamana&) = delete;
IndexVamana& operator=(IndexVamana&&) = default;

/**
* @brief Create an index with the given configuration. The index in this
* state must next be trained. The sequence for creating an index in this
Expand Down Expand Up @@ -104,6 +104,23 @@ class IndexVamana {
feature_datatype_ = string_to_datatype(value);
} else if (key == "id_type") {
id_datatype_ = string_to_datatype(value);
} else if (key == "distance_metric") {
try {
int metric_value = std::stoi(value);
if (metric_value < 0 ||
metric_value > static_cast<int>(DistanceMetric::COSINE)) {
throw std::runtime_error(
"Invalid distance metric value: " + value);
}
distance_metric = static_cast<DistanceMetric>(metric_value);
if (distance_metric != DistanceMetric::L2) {
throw std::runtime_error(
"Invalid distance metric for Vamana: " + value);
}
} catch (const std::exception& e) {
throw std::runtime_error(
"Error setting distance metric: " + std::string(e.what()));
}
} else {
throw std::runtime_error("Invalid index config key: " + key);
}
Expand Down Expand Up @@ -141,6 +158,8 @@ class IndexVamana {
"Dimensions mismatch: " + std::to_string(dimensions_) +
" != " + std::to_string(index_->dimensions()));
}
base_index_metadata<IndexVamana>::get_distance_metric_metadata(
ctx, group_uri, distance_metric);
dimensions_ = index_->dimensions();
}

Expand Down Expand Up @@ -225,6 +244,8 @@ class IndexVamana {
"Cannot write_index() because there is no index.");
}
index_->write_index(ctx, group_uri, temporal_policy, storage_version);
base_index_metadata<IndexVamana>::set_distance_metric_metadata(
ctx, group_uri, distance_metric);
}

static void clear_history(
Expand Down Expand Up @@ -488,6 +509,7 @@ class IndexVamana {
static constexpr tiledb_datatype_t adjacency_row_index_datatype_{
TILEDB_UINT64};
std::unique_ptr<index_base> index_;
DistanceMetric distance_metric;
};

// clang-format off
Expand Down
35 changes: 33 additions & 2 deletions src/include/index/index_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@

#include "index/index_defs.h"
#include "index/index_group.h"
#include "tdb_defs.h"

#include "nlohmann/json.hpp"
#include "scoring.h"
#include "tdb_defs.h"

/**
* @brief Metadata for an IVF_FLAT index.
Expand Down Expand Up @@ -143,6 +143,37 @@ class base_index_metadata {
return vec;
}

static void set_distance_metric_metadata(
const tiledb::Context& ctx,
const std::string& group_uri,
const DistanceMetric& distance_metric) {
tiledb::Group write_group(ctx, group_uri, TILEDB_WRITE);
uint32_t metric_value = static_cast<uint32_t>(distance_metric);
write_group.put_metadata(
"distance_metric", TILEDB_UINT32, 1, &metric_value);
write_group.close();
}

static void get_distance_metric_metadata(
const tiledb::Context& ctx,
const std::string& group_uri,
DistanceMetric& distance_metric) {
tiledb::Group read_group(ctx, group_uri, TILEDB_READ);
tiledb_datatype_t type;
if (read_group.has_metadata("distance_metric", &type)) {
uint32_t value_num;
const void* value;
read_group.get_metadata("distance_metric", &type, &value_num, &value);
if (type == TILEDB_UINT32 && value_num == 1) {
uint32_t metric_value = *static_cast<const uint32_t*>(value);
distance_metric = static_cast<DistanceMetric>(metric_value);
}
} else {
distance_metric = DistanceMetric::L2; // Default value
}
read_group.close();
}

/**
* @brief Given a name, value, and required flag, read in the metadata
* associated with the name and store into value. An exception is thrown if
Expand Down
2 changes: 1 addition & 1 deletion src/include/scoring.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ struct cosine_distance {
} // namespace _cosine_distance
using cosine_distance = _cosine_distance::cosine_distance;

enum class DistanceMetric { L2, INNER_PRODUCT, COSINE };
enum class DistanceMetric : uint32_t { L2 = 0, INNER_PRODUCT = 1, COSINE = 2 };
// ----------------------------------------------------------------------------
// Functions for dealing with the case of when size of scores < k_nn
// ----------------------------------------------------------------------------
Expand Down

0 comments on commit 4417a28

Please sign in to comment.