From b1a03e6b9d6131d1a55b9903a6e31d7bb66db84c Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Tue, 9 Apr 2024 14:20:49 -0700 Subject: [PATCH] Added separate interface for creating and writing in a faiss index Final commit for building graphs iteratively Signed-off-by: Navneet Verma --- jni/include/faiss_wrapper.h | 7 + .../org_opensearch_knn_jni_FaissService.h | 17 + ...Custom-patch-to-support-multi-vector.patch | 1050 ----------------- ...ble-precomp-table-to-be-shared-ivfpq.patch | 512 -------- ...vel-during-add-from-enterpoint-level.patch | 31 - jni/src/faiss_wrapper.cpp | 94 ++ .../org_opensearch_knn_jni_FaissService.cpp | 24 + jni/tests/faiss_wrapper_test.cpp | 51 +- .../KNN80Codec/KNN80DocValuesConsumer.java | 140 ++- .../knn/index/codec/util/KNNCodecUtil.java | 2 +- .../util/KNNVectorSerializerFactory.java | 2 +- .../org/opensearch/knn/jni/FaissService.java | 10 + .../org/opensearch/knn/jni/JNIService.java | 23 + 13 files changed, 325 insertions(+), 1638 deletions(-) delete mode 100644 jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch delete mode 100644 jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch delete mode 100644 jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 3e1adeac4d..5d6d9ee60b 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -22,6 +22,13 @@ namespace knn_jni { void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ); + // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. + // The index is serialized to indexPathJ. + long long CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong indexAddressJ, jobject parametersJ); + + void writeIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,jlong indexAddressJ, jstring indexPathJ, jobject parametersJ); + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 32b6f22f1f..208fc9bbaf 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -26,6 +26,23 @@ extern "C" { JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createIndexIteratively + * Signature: ([IJIJLjava/util/Map;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexIteratively + (JNIEnv *, jclass, jintArray, jlong, jint, jlong, jobject); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: (JLjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex + (JNIEnv *, jclass, jlong, jstring, jobject); + + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch deleted file mode 100644 index a22e281305..0000000000 --- a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch +++ /dev/null @@ -1,1050 +0,0 @@ -From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001 -From: Heemin Kim -Date: Tue, 30 Jan 2024 14:43:56 -0800 -Subject: [PATCH] Add IDGrouper for HNSW - -Signed-off-by: Heemin Kim ---- - faiss/CMakeLists.txt | 3 + - faiss/Index.h | 8 +- - faiss/IndexHNSW.cpp | 13 ++- - faiss/IndexIDMap.cpp | 29 ++++++ - faiss/IndexIDMap.h | 22 +++++ - faiss/impl/HNSW.cpp | 10 +- - faiss/impl/IDGrouper.cpp | 51 ++++++++++ - faiss/impl/IDGrouper.h | 51 ++++++++++ - faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++ - faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++ - tests/CMakeLists.txt | 2 + - tests/test_group_heap.cpp | 98 +++++++++++++++++++ - tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++ - 13 files changed, 838 insertions(+), 7 deletions(-) - create mode 100644 faiss/impl/IDGrouper.cpp - create mode 100644 faiss/impl/IDGrouper.h - create mode 100644 faiss/utils/GroupHeap.h - create mode 100644 tests/test_group_heap.cpp - create mode 100644 tests/test_id_grouper.cpp - -diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt -index a890a46f..137e68d4 100644 ---- a/faiss/CMakeLists.txt -+++ b/faiss/CMakeLists.txt -@@ -54,6 +54,7 @@ set(FAISS_SRC - impl/AuxIndexStructures.cpp - impl/CodePacker.cpp - impl/IDSelector.cpp -+ impl/IDGrouper.cpp - impl/FaissException.cpp - impl/HNSW.cpp - impl/NSG.cpp -@@ -149,6 +150,7 @@ set(FAISS_HEADERS - impl/AuxIndexStructures.h - impl/CodePacker.h - impl/IDSelector.h -+ impl/IDGrouper.h - impl/DistanceComputer.h - impl/FaissAssert.h - impl/FaissException.h -@@ -183,6 +185,7 @@ set(FAISS_HEADERS - invlists/InvertedLists.h - invlists/InvertedListsIOHook.h - utils/AlignedTable.h -+ utils/GroupHeap.h - utils/Heap.h - utils/WorkerThread.h - utils/distances.h -diff --git a/faiss/Index.h b/faiss/Index.h -index 4b4b302b..3b673d1e 100644 ---- a/faiss/Index.h -+++ b/faiss/Index.h -@@ -38,9 +38,10 @@ - - namespace faiss { - --/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and --/// impl/DistanceComputer.h -+/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h -+/// ,impl/IDGrouper.h and impl/DistanceComputer.h - struct IDSelector; -+struct IDGrouper; - struct RangeSearchResult; - struct DistanceComputer; - -@@ -52,6 +53,9 @@ struct DistanceComputer; - struct SearchParameters { - /// if non-null, only these IDs will be considered during search. - IDSelector* sel = nullptr; -+ /// if non-null, only best matched ID per group will be included in the -+ /// result. -+ IDGrouper* grp = nullptr; - /// make sure we can dynamic_cast this - virtual ~SearchParameters() {} - }; -diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp -index 9a67332d..a5e0fea0 100644 ---- a/faiss/IndexHNSW.cpp -+++ b/faiss/IndexHNSW.cpp -@@ -354,10 +354,17 @@ void IndexHNSW::search( - const SearchParameters* params_in) const { - FAISS_THROW_IF_NOT(k > 0); - -- using RH = HeapBlockResultHandler; -- RH bres(n, distances, labels, k); -+ if (params_in && params_in->grp) { -+ using RH = GroupedHeapBlockResultHandler; -+ RH bres(n, distances, labels, k, params_in->grp); - -- hnsw_search(this, n, x, bres, params_in); -+ hnsw_search(this, n, x, bres, params_in); -+ } else { -+ using RH = HeapBlockResultHandler; -+ RH bres(n, distances, labels, k); -+ -+ hnsw_search(this, n, x, bres, params_in); -+ } - - if (is_similarity_metric(this->metric_type)) { - // we need to revert the negated distances -diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp -index e093bbda..e24365d5 100644 ---- a/faiss/IndexIDMap.cpp -+++ b/faiss/IndexIDMap.cpp -@@ -102,6 +102,23 @@ struct ScopedSelChange { - } - }; - -+/// RAII object to reset the IDGrouper in the params object -+struct ScopedGrpChange { -+ SearchParameters* params = nullptr; -+ IDGrouper* old_grp = nullptr; -+ -+ void set(SearchParameters* params_2, IDGrouper* new_grp) { -+ this->params = params_2; -+ old_grp = params_2->grp; -+ params_2->grp = new_grp; -+ } -+ ~ScopedGrpChange() { -+ if (params) { -+ params->grp = old_grp; -+ } -+ } -+}; -+ - } // namespace - - template -@@ -114,6 +131,8 @@ void IndexIDMapTemplate::search( - const SearchParameters* params) const { - IDSelectorTranslated this_idtrans(this->id_map, nullptr); - ScopedSelChange sel_change; -+ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr); -+ ScopedGrpChange grp_change; - - if (params && params->sel) { - auto idtrans = dynamic_cast(params->sel); -@@ -131,6 +150,16 @@ void IndexIDMapTemplate::search( - sel_change.set(params_non_const, &this_idtrans); - } - } -+ -+ if (params && params->grp) { -+ auto idtrans = dynamic_cast(params->grp); -+ -+ if (!idtrans) { -+ auto params_non_const = const_cast(params); -+ this_idgrptrans.grp = params->grp; -+ grp_change.set(params_non_const, &this_idgrptrans); -+ } -+ } - index->search(n, x, k, distances, labels, params); - idx_t* li = labels; - #pragma omp parallel for -diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h -index 2d164123..a68887bd 100644 ---- a/faiss/IndexIDMap.h -+++ b/faiss/IndexIDMap.h -@@ -9,6 +9,7 @@ - - #include - #include -+#include - #include - - #include -@@ -124,4 +125,25 @@ struct IDSelectorTranslated : IDSelector { - } - }; - -+// IDGrouper that translates the ids using an IDMap -+struct IDGrouperTranslated : IDGrouper { -+ const std::vector& id_map; -+ const IDGrouper* grp; -+ -+ IDGrouperTranslated( -+ const std::vector& id_map, -+ const IDGrouper* grp) -+ : id_map(id_map), grp(grp) {} -+ -+ IDGrouperTranslated(IndexBinaryIDMap& index_idmap, const IDGrouper* grp) -+ : id_map(index_idmap.id_map), grp(grp) {} -+ -+ IDGrouperTranslated(IndexIDMap& index_idmap, const IDGrouper* grp) -+ : id_map(index_idmap.id_map), grp(grp) {} -+ -+ idx_t get_group(idx_t id) const override { -+ return grp->get_group(id_map[id]); -+ } -+}; -+ - } // namespace faiss -diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp -index fb4de678..b6f602a0 100644 ---- a/faiss/impl/HNSW.cpp -+++ b/faiss/impl/HNSW.cpp -@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const { - level, - nb_neighbors(level)); - size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; --#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ -- reduction(+: tot_reciprocal) reduction(+: n_node) -+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ -+ reduction(+ : tot_reciprocal) reduction(+ : n_node) - for (int i = 0; i < levels.size(); i++) { - if (levels[i] > level) { - n_node++; -@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler& res) { - if (auto hres = dynamic_cast(&res)) { - return hres->k; - } -+ -+ if (auto hres = dynamic_cast< -+ GroupedHeapBlockResultHandler::SingleResultHandler*>(&res)) { -+ return hres->k; -+ } -+ - return 1; - } - -diff --git a/faiss/impl/IDGrouper.cpp b/faiss/impl/IDGrouper.cpp -new file mode 100644 -index 00000000..ca9f5fda ---- /dev/null -+++ b/faiss/impl/IDGrouper.cpp -@@ -0,0 +1,51 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#include -+#include -+#include -+ -+namespace faiss { -+ -+/*********************************************************************** -+ * IDGrouperBitmap -+ ***********************************************************************/ -+ -+IDGrouperBitmap::IDGrouperBitmap(size_t n, uint64_t* bitmap) -+ : n(n), bitmap(bitmap) {} -+ -+idx_t IDGrouperBitmap::get_group(idx_t id) const { -+ assert(id >= 0 && "id shouldn't be less than zero"); -+ assert(id < this->n * 64 && "is should be less than total number of bits"); -+ -+ idx_t index = id >> 6; // div by 64 -+ uint64_t block = this->bitmap[index] >> -+ (id & 63); // Equivalent of words[i] >> (index % 64) -+ // block is non zero after right shift, it means, next set bit is in current -+ // block The index of set bit is "given index" + "trailing zero in the right -+ // shifted word" -+ if (block != 0) { -+ return id + __builtin_ctzll(block); -+ } -+ -+ while (++index < this->n) { -+ block = this->bitmap[index]; -+ if (block != 0) { -+ return (index << 6) + __builtin_ctzll(block); -+ } -+ } -+ -+ return NO_MORE_DOCS; -+} -+ -+void IDGrouperBitmap::set_group(idx_t group_id) { -+ idx_t index = group_id >> 6; -+ this->bitmap[index] |= 1ULL -+ << (group_id & 63); // Equivalent of 1ULL << (value % 64) -+} -+ -+} // namespace faiss -diff --git a/faiss/impl/IDGrouper.h b/faiss/impl/IDGrouper.h -new file mode 100644 -index 00000000..d56113d9 ---- /dev/null -+++ b/faiss/impl/IDGrouper.h -@@ -0,0 +1,51 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+ -+/** IDGrouper is intended to define a group of vectors to include only -+ * the nearest vector of each group during search */ -+ -+namespace faiss { -+ -+/** Encapsulates a group id of ids */ -+struct IDGrouper { -+ const idx_t NO_MORE_DOCS = std::numeric_limits::max(); -+ virtual idx_t get_group(idx_t id) const = 0; -+ virtual ~IDGrouper() {} -+}; -+ -+/** One bit per element. Constructed with a bitmap, size ceil(n / 8). -+ */ -+struct IDGrouperBitmap : IDGrouper { -+ // length of the bitmap array -+ size_t n; -+ -+ // Array of uint64_t holding the bits -+ // Using uint64_t to leverage function __builtin_ctzll which is defined in -+ // faiss/impl/platform_macros.h Group id of a given id is next set bit in -+ // the bitmap -+ uint64_t* bitmap; -+ -+ /** Construct with a binary mask -+ * -+ * @param n size of the bitmap array -+ * @param bitmap group id of a given id is next set bit in the bitmap -+ */ -+ IDGrouperBitmap(size_t n, uint64_t* bitmap); -+ idx_t get_group(idx_t id) const final; -+ void set_group(idx_t group_id); -+ ~IDGrouperBitmap() override {} -+}; -+ -+} // namespace faiss -diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h -index 270de8dc..2f7f3e7f 100644 ---- a/faiss/impl/ResultHandler.h -+++ b/faiss/impl/ResultHandler.h -@@ -12,6 +12,8 @@ - #pragma once - - #include -+#include -+#include - #include - #include - -@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler { - } - }; - -+/***************************************************************** -+ * Heap based result handler with grouping -+ *****************************************************************/ -+ -+template -+struct GroupedHeapBlockResultHandler : BlockResultHandler { -+ using T = typename C::T; -+ using TI = typename C::TI; -+ using BlockResultHandler::i0; -+ using BlockResultHandler::i1; -+ -+ T* heap_dis_tab; -+ TI* heap_ids_tab; -+ int64_t k; // number of results to keep -+ -+ IDGrouper* id_grouper; -+ TI* heap_group_ids_tab; -+ std::unordered_map* group_id_to_index_in_heap_tab; -+ -+ GroupedHeapBlockResultHandler( -+ size_t nq, -+ T* heap_dis_tab, -+ TI* heap_ids_tab, -+ size_t k, -+ IDGrouper* id_grouper) -+ : BlockResultHandler(nq), -+ heap_dis_tab(heap_dis_tab), -+ heap_ids_tab(heap_ids_tab), -+ k(k), -+ id_grouper(id_grouper) {} -+ -+ /****************************************************** -+ * API for 1 result at a time (each SingleResultHandler is -+ * called from 1 thread) -+ */ -+ -+ struct SingleResultHandler : ResultHandler { -+ GroupedHeapBlockResultHandler& hr; -+ using ResultHandler::threshold; -+ size_t k; -+ -+ T* heap_dis; -+ TI* heap_ids; -+ TI* heap_group_ids; -+ std::unordered_map group_id_to_index_in_heap; -+ -+ explicit SingleResultHandler(GroupedHeapBlockResultHandler& hr) -+ : hr(hr), k(hr.k) {} -+ -+ /// begin results for query # i -+ void begin(size_t i) { -+ heap_dis = hr.heap_dis_tab + i * k; -+ heap_ids = hr.heap_ids_tab + i * k; -+ heap_heapify(k, heap_dis, heap_ids); -+ threshold = heap_dis[0]; -+ heap_group_ids = new TI[hr.k]; -+ for (size_t i = 0; i < hr.k; i++) { -+ heap_group_ids[i] = -1; -+ } -+ } -+ -+ /// add one result for query i -+ bool add_result(T dis, TI idx) final { -+ if (!C::cmp(threshold, dis)) { -+ return false; -+ } -+ -+ idx_t group_id = hr.id_grouper->get_group(idx); -+ typename std::unordered_map::const_iterator it_pos = -+ group_id_to_index_in_heap.find(group_id); -+ if (it_pos == group_id_to_index_in_heap.end()) { -+ group_heap_replace_top( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ dis, -+ idx, -+ group_id, -+ &group_id_to_index_in_heap); -+ return true; -+ } else { -+ size_t pos = it_pos->second; -+ if (!C::cmp(heap_dis[pos], dis)) { -+ return false; -+ } -+ group_heap_replace_at( -+ pos, -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ dis, -+ idx, -+ group_id, -+ &group_id_to_index_in_heap); -+ return true; -+ } -+ } -+ -+ /// series of results for query i is done -+ void end() { -+ heap_reorder(k, heap_dis, heap_ids); -+ delete heap_group_ids; -+ } -+ }; -+ -+ /****************************************************** -+ * API for multiple results (called from 1 thread) -+ */ -+ -+ /// begin -+ void begin_multiple(size_t i0_2, size_t i1_2) final { -+ this->i0 = i0_2; -+ this->i1 = i1_2; -+ for (size_t i = i0; i < i1; i++) { -+ heap_heapify(k, heap_dis_tab + i * k, heap_ids_tab + i * k); -+ } -+ size_t size = (i1 - i0) * k; -+ heap_group_ids_tab = new TI[size]; -+ for (size_t i = 0; i < size; i++) { -+ heap_group_ids_tab[i] = -1; -+ } -+ group_id_to_index_in_heap_tab = -+ new std::unordered_map[i1 - i0]; -+ } -+ -+ /// add results for query i0..i1 and j0..j1 -+ void add_results(size_t j0, size_t j1, const T* dis_tab) final { -+#pragma omp parallel for -+ for (int64_t i = i0; i < i1; i++) { -+ T* heap_dis = heap_dis_tab + i * k; -+ TI* heap_ids = heap_ids_tab + i * k; -+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; -+ T thresh = heap_dis[0]; // NOLINT(*-use-default-none) -+ for (size_t j = j0; j < j1; j++) { -+ T dis = dis_tab_i[j]; -+ if (C::cmp(thresh, dis)) { -+ idx_t group_id = id_grouper->get_group(j); -+ typename std::unordered_map::const_iterator -+ it_pos = group_id_to_index_in_heap_tab[i - i0].find( -+ group_id); -+ if (it_pos == group_id_to_index_in_heap_tab[i - i0].end()) { -+ group_heap_replace_top( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids_tab + ((i - i0) * k), -+ dis, -+ j, -+ group_id, -+ &group_id_to_index_in_heap_tab[i - i0]); -+ thresh = heap_dis[0]; -+ } else { -+ size_t pos = it_pos->first; -+ if (C::cmp(heap_dis[pos], dis)) { -+ group_heap_replace_at( -+ pos, -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids_tab + ((i - i0) * k), -+ dis, -+ j, -+ group_id, -+ &group_id_to_index_in_heap_tab[i - i0]); -+ thresh = heap_dis[0]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// series of results for queries i0..i1 is done -+ void end_multiple() final { -+ // maybe parallel for -+ for (size_t i = i0; i < i1; i++) { -+ heap_reorder(k, heap_dis_tab + i * k, heap_ids_tab + i * k); -+ } -+ delete group_id_to_index_in_heap_tab; -+ delete heap_group_ids_tab; -+ } -+}; -+ - /***************************************************************** - * Reservoir result handler - * -diff --git a/faiss/utils/GroupHeap.h b/faiss/utils/GroupHeap.h -new file mode 100644 -index 00000000..3b7078da ---- /dev/null -+++ b/faiss/utils/GroupHeap.h -@@ -0,0 +1,182 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+namespace faiss { -+ -+/** -+ * From start_index, it compare its value with parent node's and swap if needed. -+ * Continue until either there is no swap or it reaches the top node. -+ */ -+template -+static inline void group_up_heap( -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ std::unordered_map* group_id_to_index_in_heap, -+ size_t start_index) { -+ heap_dis--; /* Use 1-based indexing for easier node->child translation */ -+ heap_ids--; -+ heap_group_ids--; -+ size_t i = start_index + 1, i_father; -+ typename C::T target_dis = heap_dis[i]; -+ typename C::TI target_id = heap_ids[i]; -+ typename C::TI target_group_id = heap_group_ids[i]; -+ -+ while (i > 1) { -+ i_father = i >> 1; -+ if (!C::cmp2( -+ target_dis, -+ heap_dis[i_father], -+ target_id, -+ heap_ids[i_father])) { -+ /* the heap structure is ok */ -+ break; -+ } -+ heap_dis[i] = heap_dis[i_father]; -+ heap_ids[i] = heap_ids[i_father]; -+ heap_group_ids[i] = heap_group_ids[i_father]; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+ i = i_father; -+ } -+ heap_dis[i] = target_dis; -+ heap_ids[i] = target_id; -+ heap_group_ids[i] = target_group_id; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+} -+ -+/** -+ * From start_index, it compare its value with child node's and swap if needed. -+ * Continue until either there is no swap or it reaches the leaf node. -+ */ -+template -+static inline void group_down_heap( -+ size_t k, -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ std::unordered_map* group_id_to_index_in_heap, -+ size_t start_index) { -+ heap_dis--; /* Use 1-based indexing for easier node->child translation */ -+ heap_ids--; -+ heap_group_ids--; -+ size_t i = start_index + 1, i1, i2; -+ typename C::T target_dis = heap_dis[i]; -+ typename C::TI target_id = heap_ids[i]; -+ typename C::TI target_group_id = heap_group_ids[i]; -+ -+ while (1) { -+ i1 = i << 1; -+ i2 = i1 + 1; -+ if (i1 > k) { -+ break; -+ } -+ -+ // Note that C::cmp2() is a bool function answering -+ // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max -+ // heap and same with the `<` sign for min heap. -+ if ((i2 == k + 1) || -+ C::cmp2(heap_dis[i1], heap_dis[i2], heap_ids[i1], heap_ids[i2])) { -+ if (C::cmp2(target_dis, heap_dis[i1], target_id, heap_ids[i1])) { -+ break; -+ } -+ heap_dis[i] = heap_dis[i1]; -+ heap_ids[i] = heap_ids[i1]; -+ heap_group_ids[i] = heap_group_ids[i1]; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+ i = i1; -+ } else { -+ if (C::cmp2(target_dis, heap_dis[i2], target_id, heap_ids[i2])) { -+ break; -+ } -+ heap_dis[i] = heap_dis[i2]; -+ heap_ids[i] = heap_ids[i2]; -+ heap_group_ids[i] = heap_group_ids[i2]; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+ i = i2; -+ } -+ } -+ heap_dis[i] = target_dis; -+ heap_ids[i] = target_id; -+ heap_group_ids[i] = target_group_id; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+} -+ -+template -+static inline void group_heap_replace_top( -+ size_t k, -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ typename C::T dis, -+ typename C::TI id, -+ typename C::TI group_id, -+ std::unordered_map* group_id_to_index_in_heap) { -+ assert(group_id_to_index_in_heap->find(group_id) == -+ group_id_to_index_in_heap->end() && -+ "group id should not exist in the binary heap"); -+ -+ group_id_to_index_in_heap->erase(heap_group_ids[0]); -+ heap_group_ids[0] = group_id; -+ heap_dis[0] = dis; -+ heap_ids[0] = id; -+ (*group_id_to_index_in_heap)[group_id] = 0; -+ group_down_heap( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ group_id_to_index_in_heap, -+ 0); -+} -+ -+template -+static inline void group_heap_replace_at( -+ size_t pos, -+ size_t k, -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ typename C::T dis, -+ typename C::TI id, -+ typename C::TI group_id, -+ std::unordered_map* group_id_to_index_in_heap) { -+ assert(group_id_to_index_in_heap->find(group_id) != -+ group_id_to_index_in_heap->end() && -+ "group id should exist in the binary heap"); -+ assert(group_id_to_index_in_heap->find(group_id)->second == pos && -+ "index of group id in the heap should be same as pos"); -+ -+ heap_dis[pos] = dis; -+ heap_ids[pos] = id; -+ group_up_heap( -+ heap_dis, heap_ids, heap_group_ids, group_id_to_index_in_heap, pos); -+ group_down_heap( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ group_id_to_index_in_heap, -+ pos); -+} -+ -+} // namespace faiss -\ No newline at end of file -diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt -index cc0a4f4c..96e19328 100644 ---- a/tests/CMakeLists.txt -+++ b/tests/CMakeLists.txt -@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC - test_approx_topk.cpp - test_RCQ_cropping.cpp - test_distances_simd.cpp -+ test_id_grouper.cpp -+ test_group_heap.cpp - test_heap.cpp - test_code_distance.cpp - test_hnsw.cpp -diff --git a/tests/test_group_heap.cpp b/tests/test_group_heap.cpp -new file mode 100644 -index 00000000..0e8fe7a7 ---- /dev/null -+++ b/tests/test_group_heap.cpp -@@ -0,0 +1,98 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+#include -+#include -+#include -+#include -+ -+using namespace faiss; -+ -+TEST(GroupHeap, group_heap_replace_top) { -+ using C = CMax; -+ const int k = 100; -+ float binary_heap_values[k]; -+ int64_t binary_heap_ids[k]; -+ heap_heapify(k, binary_heap_values, binary_heap_ids); -+ int64_t binary_heap_group_ids[k]; -+ for (size_t i = 0; i < k; i++) { -+ binary_heap_group_ids[i] = -1; -+ } -+ std::unordered_map group_id_to_index_in_heap; -+ for (int i = 1000; i > 0; i--) { -+ group_heap_replace_top( -+ k, -+ binary_heap_values, -+ binary_heap_ids, -+ binary_heap_group_ids, -+ i * 10.0, -+ i, -+ i, -+ &group_id_to_index_in_heap); -+ } -+ -+ heap_reorder(k, binary_heap_values, binary_heap_ids); -+ -+ for (int i = 0; i < k; i++) { -+ ASSERT_EQ((i + 1) * 10.0, binary_heap_values[i]); -+ ASSERT_EQ(i + 1, binary_heap_ids[i]); -+ } -+} -+ -+TEST(GroupHeap, group_heap_replace_at) { -+ using C = CMax; -+ const int k = 10; -+ float binary_heap_values[k]; -+ int64_t binary_heap_ids[k]; -+ heap_heapify(k, binary_heap_values, binary_heap_ids); -+ int64_t binary_heap_group_ids[k]; -+ for (size_t i = 0; i < k; i++) { -+ binary_heap_group_ids[i] = -1; -+ } -+ std::unordered_map group_id_to_index_in_heap; -+ -+ std::unordered_map group_id_to_id; -+ for (int i = 1000; i > 0; i--) { -+ int64_t group_id = rand() % 100; -+ group_id_to_id[group_id] = i; -+ if (group_id_to_index_in_heap.find(group_id) == -+ group_id_to_index_in_heap.end()) { -+ group_heap_replace_top( -+ k, -+ binary_heap_values, -+ binary_heap_ids, -+ binary_heap_group_ids, -+ i * 10.0, -+ i, -+ group_id, -+ &group_id_to_index_in_heap); -+ } else { -+ group_heap_replace_at( -+ group_id_to_index_in_heap.at(group_id), -+ k, -+ binary_heap_values, -+ binary_heap_ids, -+ binary_heap_group_ids, -+ i * 10.0, -+ i, -+ group_id, -+ &group_id_to_index_in_heap); -+ } -+ } -+ -+ heap_reorder(k, binary_heap_values, binary_heap_ids); -+ -+ std::vector sorted_ids; -+ for (const auto& pair : group_id_to_id) { -+ sorted_ids.push_back(pair.second); -+ } -+ std::sort(sorted_ids.begin(), sorted_ids.end()); -+ -+ for (int i = 0; i < k && binary_heap_ids[i] != -1; i++) { -+ ASSERT_EQ(sorted_ids[i] * 10.0, binary_heap_values[i]); -+ ASSERT_EQ(sorted_ids[i], binary_heap_ids[i]); -+ } -+} -diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp -new file mode 100644 -index 00000000..2aed5500 ---- /dev/null -+++ b/tests/test_id_grouper.cpp -@@ -0,0 +1,189 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+#include -+#include -+#include -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+// 64-bit int -+using idx_t = faiss::idx_t; -+ -+using namespace faiss; -+ -+TEST(IdGrouper, get_group) { -+ uint64_t ids1[1] = {0b1000100010001000}; -+ IDGrouperBitmap bitmap(1, ids1); -+ -+ ASSERT_EQ(3, bitmap.get_group(0)); -+ ASSERT_EQ(3, bitmap.get_group(1)); -+ ASSERT_EQ(3, bitmap.get_group(2)); -+ ASSERT_EQ(3, bitmap.get_group(3)); -+ ASSERT_EQ(7, bitmap.get_group(4)); -+ ASSERT_EQ(7, bitmap.get_group(5)); -+ ASSERT_EQ(7, bitmap.get_group(6)); -+ ASSERT_EQ(7, bitmap.get_group(7)); -+ ASSERT_EQ(11, bitmap.get_group(8)); -+ ASSERT_EQ(11, bitmap.get_group(9)); -+ ASSERT_EQ(11, bitmap.get_group(10)); -+ ASSERT_EQ(11, bitmap.get_group(11)); -+ ASSERT_EQ(15, bitmap.get_group(12)); -+ ASSERT_EQ(15, bitmap.get_group(13)); -+ ASSERT_EQ(15, bitmap.get_group(14)); -+ ASSERT_EQ(15, bitmap.get_group(15)); -+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(16)); -+} -+ -+TEST(IdGrouper, set_group) { -+ idx_t group_ids[] = {64, 127, 128, 1022}; -+ uint64_t ids[16] = {}; // 1023 / 64 + 1 -+ IDGrouperBitmap bitmap(16, ids); -+ -+ for (int i = 0; i < 4; i++) { -+ bitmap.set_group(group_ids[i]); -+ } -+ -+ int group_id_index = 0; -+ for (int i = 0; i <= group_ids[3]; i++) { -+ ASSERT_EQ(group_ids[group_id_index], bitmap.get_group(i)); -+ if (group_ids[group_id_index] == i) { -+ group_id_index++; -+ } -+ } -+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); -+} -+ -+TEST(IdGrouper, bitmap_with_hnsw) { -+ int d = 1; // dimension -+ int nb = 10; // database size -+ -+ std::mt19937 rng; -+ std::uniform_real_distribution<> distrib; -+ -+ float* xb = new float[d * nb]; -+ -+ for (int i = 0; i < nb; i++) { -+ for (int j = 0; j < d; j++) -+ xb[d * i + j] = distrib(rng); -+ xb[d * i] += i / 1000.; -+ } -+ -+ uint64_t bitmap[1] = {}; -+ faiss::IDGrouperBitmap id_grouper(1, bitmap); -+ for (int i = 0; i < nb; i++) { -+ if (i % 2 == 1) { -+ id_grouper.set_group(i); -+ } -+ } -+ -+ int k = 10; -+ int m = 8; -+ faiss::Index* index = -+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); -+ index->add(nb, xb); // add vectors to the index -+ -+ // search -+ idx_t* I = new idx_t[k]; -+ float* D = new float[k]; -+ -+ auto pSearchParameters = new faiss::SearchParametersHNSW(); -+ pSearchParameters->grp = &id_grouper; -+ -+ index->search(1, xb, k, D, I, pSearchParameters); -+ -+ std::unordered_set group_ids; -+ ASSERT_EQ(0, I[0]); -+ ASSERT_EQ(0, D[0]); -+ group_ids.insert(id_grouper.get_group(I[0])); -+ for (int j = 1; j < 5; j++) { -+ ASSERT_NE(-1, I[j]); -+ ASSERT_NE(std::numeric_limits::max(), D[j]); -+ group_ids.insert(id_grouper.get_group(I[j])); -+ } -+ for (int j = 5; j < k; j++) { -+ ASSERT_EQ(-1, I[j]); -+ ASSERT_EQ(std::numeric_limits::max(), D[j]); -+ } -+ ASSERT_EQ(5, group_ids.size()); -+ -+ delete[] I; -+ delete[] D; -+ delete[] xb; -+} -+ -+TEST(IdGrouper, bitmap_with_hnswn_idmap) { -+ int d = 1; // dimension -+ int nb = 10; // database size -+ -+ std::mt19937 rng; -+ std::uniform_real_distribution<> distrib; -+ -+ float* xb = new float[d * nb]; -+ idx_t* xids = new idx_t[d * nb]; -+ -+ for (int i = 0; i < nb; i++) { -+ for (int j = 0; j < d; j++) -+ xb[d * i + j] = distrib(rng); -+ xb[d * i] += i / 1000.; -+ } -+ -+ uint64_t bitmap[1] = {}; -+ faiss::IDGrouperBitmap id_grouper(1, bitmap); -+ int num_grp = 0; -+ int grp_size = 2; -+ int id_in_grp = 0; -+ for (int i = 0; i < nb; i++) { -+ xids[i] = i + num_grp; -+ id_in_grp++; -+ if (id_in_grp == grp_size) { -+ id_grouper.set_group(i + num_grp + 1); -+ num_grp++; -+ id_in_grp = 0; -+ } -+ } -+ -+ int k = 10; -+ int m = 8; -+ faiss::Index* index = -+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); -+ faiss::IndexIDMap id_map = -+ faiss::IndexIDMap(index); // add vectors to the index -+ id_map.add_with_ids(nb, xb, xids); -+ -+ // search -+ idx_t* I = new idx_t[k]; -+ float* D = new float[k]; -+ -+ auto pSearchParameters = new faiss::SearchParametersHNSW(); -+ pSearchParameters->grp = &id_grouper; -+ -+ id_map.search(1, xb, k, D, I, pSearchParameters); -+ -+ std::unordered_set group_ids; -+ ASSERT_EQ(0, I[0]); -+ ASSERT_EQ(0, D[0]); -+ group_ids.insert(id_grouper.get_group(I[0])); -+ for (int j = 1; j < 5; j++) { -+ ASSERT_NE(-1, I[j]); -+ ASSERT_NE(std::numeric_limits::max(), D[j]); -+ group_ids.insert(id_grouper.get_group(I[j])); -+ } -+ for (int j = 5; j < k; j++) { -+ ASSERT_EQ(-1, I[j]); -+ ASSERT_EQ(std::numeric_limits::max(), D[j]); -+ } -+ ASSERT_EQ(5, group_ids.size()); -+ -+ delete[] I; -+ delete[] D; -+ delete[] xb; -+} --- -2.39.3 (Apple Git-145) - diff --git a/jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch b/jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch deleted file mode 100644 index dfc5099aaa..0000000000 --- a/jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch +++ /dev/null @@ -1,512 +0,0 @@ -From c5ca07299b427dedafc738b98bd20f8f286f6783 Mon Sep 17 00:00:00 2001 -From: John Mazanec -Date: Wed, 21 Feb 2024 15:34:15 -0800 -Subject: [PATCH] Enable precomp table to be shared ivfpq - -Changes IVFPQ and IVFPQFastScan indices to be able to share the -precomputed table amongst other instances. Switches var to a pointer and -add necessary functions to set them correctly. - -Adds a tests to validate the behavior. - -Signed-off-by: John Mazanec ---- - faiss/IndexIVFPQ.cpp | 47 +++++++- - faiss/IndexIVFPQ.h | 16 ++- - faiss/IndexIVFPQFastScan.cpp | 47 ++++++-- - faiss/IndexIVFPQFastScan.h | 11 +- - tests/CMakeLists.txt | 1 + - tests/test_disable_pq_sdc_tables.cpp | 4 +- - tests/test_ivfpq_share_table.cpp | 173 +++++++++++++++++++++++++++ - 7 files changed, 284 insertions(+), 15 deletions(-) - create mode 100644 tests/test_ivfpq_share_table.cpp - -diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp -index 0b7f4d05..07bc7e83 100644 ---- a/faiss/IndexIVFPQ.cpp -+++ b/faiss/IndexIVFPQ.cpp -@@ -59,6 +59,29 @@ IndexIVFPQ::IndexIVFPQ( - polysemous_training = nullptr; - do_polysemous_training = false; - polysemous_ht = 0; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; -+} -+ -+IndexIVFPQ::IndexIVFPQ(const IndexIVFPQ& orig) : IndexIVF(orig), pq(orig.pq) { -+ code_size = orig.pq.code_size; -+ invlists->code_size = code_size; -+ is_trained = orig.is_trained; -+ by_residual = orig.by_residual; -+ use_precomputed_table = orig.use_precomputed_table; -+ scan_table_threshold = orig.scan_table_threshold; -+ -+ polysemous_training = orig.polysemous_training; -+ do_polysemous_training = orig.do_polysemous_training; -+ polysemous_ht = orig.polysemous_ht; -+ precomputed_table = new AlignedTable(*orig.precomputed_table); -+ owns_precomputed_table = true; -+} -+ -+IndexIVFPQ::~IndexIVFPQ() { -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } - } - - /**************************************************************** -@@ -466,11 +489,23 @@ void IndexIVFPQ::precompute_table() { - use_precomputed_table, - quantizer, - pq, -- precomputed_table, -+ *precomputed_table, - by_residual, - verbose); - } - -+void IndexIVFPQ::set_precomputed_table( -+ AlignedTable* _precompute_table, -+ int _use_precomputed_table) { -+ // Clean up old pre-computed table -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } -+ owns_precomputed_table = false; -+ precomputed_table = _precompute_table; -+ use_precomputed_table = _use_precomputed_table; -+} -+ - namespace { - - #define TIC t0 = get_cycles() -@@ -650,7 +685,7 @@ struct QueryTables { - - fvec_madd( - pq.M * pq.ksub, -- ivfpq.precomputed_table.data() + key * pq.ksub * pq.M, -+ ivfpq.precomputed_table->data() + key * pq.ksub * pq.M, - -2.0, - sim_table_2, - sim_table); -@@ -679,7 +714,7 @@ struct QueryTables { - k >>= cpq.nbits; - - // get corresponding table -- const float* pc = ivfpq.precomputed_table.data() + -+ const float* pc = ivfpq.precomputed_table->data() + - (ki * pq.M + cm * Mf) * pq.ksub; - - if (polysemous_ht == 0) { -@@ -709,7 +744,7 @@ struct QueryTables { - dis0 = coarse_dis; - - const float* s = -- ivfpq.precomputed_table.data() + key * pq.ksub * pq.M; -+ ivfpq.precomputed_table->data() + key * pq.ksub * pq.M; - for (int m = 0; m < pq.M; m++) { - sim_table_ptrs[m] = s; - s += pq.ksub; -@@ -729,7 +764,7 @@ struct QueryTables { - int ki = k & ((uint64_t(1) << cpq.nbits) - 1); - k >>= cpq.nbits; - -- const float* pc = ivfpq.precomputed_table.data() + -+ const float* pc = ivfpq.precomputed_table->data() + - (ki * pq.M + cm * Mf) * pq.ksub; - - for (int m = m0; m < m0 + Mf; m++) { -@@ -1346,6 +1381,8 @@ IndexIVFPQ::IndexIVFPQ() { - do_polysemous_training = false; - polysemous_ht = 0; - polysemous_training = nullptr; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; - } - - struct CodeCmp { -diff --git a/faiss/IndexIVFPQ.h b/faiss/IndexIVFPQ.h -index d5d21da4..850bbe44 100644 ---- a/faiss/IndexIVFPQ.h -+++ b/faiss/IndexIVFPQ.h -@@ -48,7 +48,8 @@ struct IndexIVFPQ : IndexIVF { - - /// if use_precompute_table - /// size nlist * pq.M * pq.ksub -- AlignedTable precomputed_table; -+ bool owns_precomputed_table; -+ AlignedTable* precomputed_table; - - IndexIVFPQ( - Index* quantizer, -@@ -58,6 +59,10 @@ struct IndexIVFPQ : IndexIVF { - size_t nbits_per_idx, - MetricType metric = METRIC_L2); - -+ IndexIVFPQ(const IndexIVFPQ& orig); -+ -+ ~IndexIVFPQ(); -+ - void encode_vectors( - idx_t n, - const float* x, -@@ -139,6 +144,15 @@ struct IndexIVFPQ : IndexIVF { - /// build precomputed table - void precompute_table(); - -+ /** -+ * Initialize the precomputed table -+ * @param precompute_table -+ * @param _use_precomputed_table -+ */ -+ void set_precomputed_table( -+ AlignedTable* precompute_table, -+ int _use_precomputed_table); -+ - IndexIVFPQ(); - }; - -diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp -index d069db13..09a335ff 100644 ---- a/faiss/IndexIVFPQFastScan.cpp -+++ b/faiss/IndexIVFPQFastScan.cpp -@@ -46,6 +46,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan( - : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) { - by_residual = false; // set to false by default because it's faster - -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; - init_fastscan(M, nbits, nlist, metric, bbs); - } - -@@ -53,6 +55,17 @@ IndexIVFPQFastScan::IndexIVFPQFastScan() { - by_residual = false; - bbs = 0; - M2 = 0; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; -+} -+ -+IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQFastScan& orig) -+ : IndexIVFFastScan(orig), pq(orig.pq) { -+ by_residual = orig.by_residual; -+ bbs = orig.bbs; -+ M2 = orig.M2; -+ precomputed_table = new AlignedTable(*orig.precomputed_table); -+ owns_precomputed_table = true; - } - - IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) -@@ -71,13 +84,15 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) - ntotal = orig.ntotal; - is_trained = orig.is_trained; - nprobe = orig.nprobe; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; - -- precomputed_table.resize(orig.precomputed_table.size()); -+ precomputed_table->resize(orig.precomputed_table->size()); - -- if (precomputed_table.nbytes() > 0) { -- memcpy(precomputed_table.get(), -- orig.precomputed_table.data(), -- precomputed_table.nbytes()); -+ if (precomputed_table->nbytes() > 0) { -+ memcpy(precomputed_table->get(), -+ orig.precomputed_table->data(), -+ precomputed_table->nbytes()); - } - - for (size_t i = 0; i < nlist; i++) { -@@ -102,6 +117,12 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) - orig_invlists = orig.invlists; - } - -+IndexIVFPQFastScan::~IndexIVFPQFastScan() { -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } -+} -+ - /********************************************************* - * Training - *********************************************************/ -@@ -127,11 +148,23 @@ void IndexIVFPQFastScan::precompute_table() { - use_precomputed_table, - quantizer, - pq, -- precomputed_table, -+ *precomputed_table, - by_residual, - verbose); - } - -+void IndexIVFPQFastScan::set_precomputed_table( -+ AlignedTable* _precompute_table, -+ int _use_precomputed_table) { -+ // Clean up old pre-computed table -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } -+ owns_precomputed_table = false; -+ precomputed_table = _precompute_table; -+ use_precomputed_table = _use_precomputed_table; -+} -+ - /********************************************************* - * Code management functions - *********************************************************/ -@@ -229,7 +262,7 @@ void IndexIVFPQFastScan::compute_LUT( - if (cij >= 0) { - fvec_madd_simd( - dim12, -- precomputed_table.get() + cij * dim12, -+ precomputed_table->get() + cij * dim12, - -2, - ip_table.get() + i * dim12, - tab); -diff --git a/faiss/IndexIVFPQFastScan.h b/faiss/IndexIVFPQFastScan.h -index 00dd2f11..91f35a6e 100644 ---- a/faiss/IndexIVFPQFastScan.h -+++ b/faiss/IndexIVFPQFastScan.h -@@ -38,7 +38,8 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { - /// precomputed tables management - int use_precomputed_table = 0; - /// if use_precompute_table size (nlist, pq.M, pq.ksub) -- AlignedTable precomputed_table; -+ bool owns_precomputed_table; -+ AlignedTable* precomputed_table; - - IndexIVFPQFastScan( - Index* quantizer, -@@ -51,6 +52,10 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { - - IndexIVFPQFastScan(); - -+ IndexIVFPQFastScan(const IndexIVFPQFastScan& orig); -+ -+ ~IndexIVFPQFastScan(); -+ - // built from an IndexIVFPQ - explicit IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs = 32); - -@@ -60,6 +65,10 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { - - /// build precomputed table, possibly updating use_precomputed_table - void precompute_table(); -+ /// Pass in externally a precomputed -+ void set_precomputed_table( -+ AlignedTable* precompute_table, -+ int _use_precomputed_table); - - /// same as the regular IVFPQ encoder. The codes are not reorganized by - /// blocks a that point -diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt -index 9017edc5..0889bf72 100644 ---- a/tests/CMakeLists.txt -+++ b/tests/CMakeLists.txt -@@ -33,6 +33,7 @@ set(FAISS_TEST_SRC - test_partitioning.cpp - test_fastscan_perf.cpp - test_disable_pq_sdc_tables.cpp -+ test_ivfpq_share_table.cpp - ) - - add_executable(faiss_test ${FAISS_TEST_SRC}) -diff --git a/tests/test_disable_pq_sdc_tables.cpp b/tests/test_disable_pq_sdc_tables.cpp -index b211a5c4..a27973d5 100644 ---- a/tests/test_disable_pq_sdc_tables.cpp -+++ b/tests/test_disable_pq_sdc_tables.cpp -@@ -15,7 +15,9 @@ - #include "faiss/index_io.h" - #include "test_util.h" - --pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; -+namespace { -+ pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; -+} - - TEST(IO, TestReadHNSWPQ_whenSDCDisabledFlagPassed_thenDisableSDCTable) { - Tempfilename index_filename(&temp_file_mutex, "/tmp/faiss_TestReadHNSWPQ"); -diff --git a/tests/test_ivfpq_share_table.cpp b/tests/test_ivfpq_share_table.cpp -new file mode 100644 -index 00000000..f827315d ---- /dev/null -+++ b/tests/test_ivfpq_share_table.cpp -@@ -0,0 +1,173 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#include -+ -+#include -+ -+#include "faiss/Index.h" -+#include "faiss/IndexHNSW.h" -+#include "faiss/IndexIVFPQFastScan.h" -+#include "faiss/index_factory.h" -+#include "faiss/index_io.h" -+#include "test_util.h" -+ -+namespace { -+ pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; -+} -+ -+std::vector generate_data( -+ int d, -+ int n, -+ std::default_random_engine rng, -+ std::uniform_real_distribution u) { -+ std::vector vectors(n * d); -+ for (size_t i = 0; i < n * d; i++) { -+ vectors[i] = u(rng); -+ } -+ return vectors; -+} -+ -+void assert_float_vectors_almost_equal( -+ std::vector a, -+ std::vector b) { -+ float margin = 0.000001; -+ ASSERT_EQ(a.size(), b.size()); -+ for (int i = 0; i < a.size(); i++) { -+ ASSERT_NEAR(a[i], b[i], margin); -+ } -+} -+ -+/// Test case test precomputed table sharing for IVFPQ indices. -+template /// T represents class cast to use for index -+void test_ivfpq_table_sharing( -+ const std::string& index_description, -+ const std::string& filename, -+ faiss::MetricType metric) { -+ // Setup the index: -+ // 1. Build an index -+ // 2. ingest random data -+ // 3. serialize to disk -+ int d = 32, n = 1000; -+ std::default_random_engine rng( -+ std::chrono::system_clock::now().time_since_epoch().count()); -+ std::uniform_real_distribution u(0, 100); -+ -+ std::vector index_vectors = generate_data(d, n, rng, u); -+ std::vector query_vectors = generate_data(d, n, rng, u); -+ -+ Tempfilename index_filename(&temp_file_mutex, filename); -+ { -+ std::unique_ptr index_writer( -+ faiss::index_factory(d, index_description.c_str(), metric)); -+ -+ index_writer->train(n, index_vectors.data()); -+ index_writer->add(n, index_vectors.data()); -+ faiss::write_index(index_writer.get(), index_filename.c_str()); -+ } -+ -+ // Load index from disk. Confirm that the sdc table is equal to 0 when -+ // disable sdc is set -+ std::unique_ptr> sharedAlignedTable( -+ new faiss::AlignedTable()); -+ int shared_use_precomputed_table = 0; -+ int k = 10; -+ std::vector distances_test_a(k * n); -+ std::vector labels_test_a(k * n); -+ { -+ std::vector distances_baseline(k * n); -+ std::vector labels_baseline(k * n); -+ -+ std::unique_ptr index_read_pq_table_enabled( -+ dynamic_cast(faiss::read_index( -+ index_filename.c_str(), faiss::IO_FLAG_READ_ONLY))); -+ std::unique_ptr index_read_pq_table_disabled( -+ dynamic_cast(faiss::read_index( -+ index_filename.c_str(), -+ faiss::IO_FLAG_READ_ONLY | -+ faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE))); -+ faiss::initialize_IVFPQ_precomputed_table( -+ shared_use_precomputed_table, -+ index_read_pq_table_disabled->quantizer, -+ index_read_pq_table_disabled->pq, -+ *sharedAlignedTable, -+ index_read_pq_table_disabled->by_residual, -+ index_read_pq_table_disabled->verbose); -+ index_read_pq_table_disabled->set_precomputed_table( -+ sharedAlignedTable.get(), shared_use_precomputed_table); -+ -+ ASSERT_TRUE(index_read_pq_table_enabled->owns_precomputed_table); -+ ASSERT_FALSE(index_read_pq_table_disabled->owns_precomputed_table); -+ index_read_pq_table_enabled->search( -+ n, -+ query_vectors.data(), -+ k, -+ distances_baseline.data(), -+ labels_baseline.data()); -+ index_read_pq_table_disabled->search( -+ n, -+ query_vectors.data(), -+ k, -+ distances_test_a.data(), -+ labels_test_a.data()); -+ -+ assert_float_vectors_almost_equal(distances_baseline, distances_test_a); -+ ASSERT_EQ(labels_baseline, labels_test_a); -+ } -+ -+ // The precomputed table should only be set for L2 metric type -+ if (metric == faiss::METRIC_L2) { -+ ASSERT_EQ(shared_use_precomputed_table, 1); -+ } else { -+ ASSERT_EQ(shared_use_precomputed_table, 0); -+ } -+ -+ // At this point, the original has gone out of scope, the destructor has -+ // been called. Confirm that initializing a new index from the table -+ // preserves the functionality. -+ { -+ std::vector distances_test_b(k * n); -+ std::vector labels_test_b(k * n); -+ -+ std::unique_ptr index_read_pq_table_disabled( -+ dynamic_cast(faiss::read_index( -+ index_filename.c_str(), -+ faiss::IO_FLAG_READ_ONLY | -+ faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE))); -+ index_read_pq_table_disabled->set_precomputed_table( -+ sharedAlignedTable.get(), shared_use_precomputed_table); -+ ASSERT_FALSE(index_read_pq_table_disabled->owns_precomputed_table); -+ index_read_pq_table_disabled->search( -+ n, -+ query_vectors.data(), -+ k, -+ distances_test_b.data(), -+ labels_test_b.data()); -+ assert_float_vectors_almost_equal(distances_test_a, distances_test_b); -+ ASSERT_EQ(labels_test_a, labels_test_b); -+ } -+} -+ -+TEST(TestIVFPQTableSharing, L2) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4", "/tmp/ivfpql2", faiss::METRIC_L2); -+} -+ -+TEST(TestIVFPQTableSharing, IP) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4", "/tmp/ivfpqip", faiss::METRIC_INNER_PRODUCT); -+} -+ -+TEST(TestIVFPQTableSharing, FastScanL2) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4fsr", "/tmp/ivfpqfsl2", faiss::METRIC_L2); -+} -+ -+TEST(TestIVFPQTableSharing, FastScanIP) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4fsr", "/tmp/ivfpqfsip", faiss::METRIC_INNER_PRODUCT); -+} --- -2.39.3 (Apple Git-145) - diff --git a/jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch b/jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch deleted file mode 100644 index a9d9381f9b..0000000000 --- a/jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch +++ /dev/null @@ -1,31 +0,0 @@ -From aa1ca485c0ab8b79dae1fb5c1567149c5f61b533 Mon Sep 17 00:00:00 2001 -From: John Mazanec -Date: Thu, 14 Mar 2024 12:22:06 -0700 -Subject: [PATCH] Initialize maxlevel during add from enterpoint->level - -Signed-off-by: John Mazanec ---- - similarity_search/src/method/hnsw.cc | 6 +++++- - 1 file changed, 5 insertions(+), 1 deletion(-) - -diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc -index 35b372c..e9a725e 100644 ---- a/similarity_search/src/method/hnsw.cc -+++ b/similarity_search/src/method/hnsw.cc -@@ -542,8 +542,12 @@ namespace similarity { - - NewElement->init(curlevel, maxM_, maxM0_); - -- int maxlevelcopy = maxlevel_; -+ // Get the enterpoint at this moment and then use it to set the -+ // max level that is used. Copying maxlevel from this->maxlevel_ -+ // can lead to race conditions during concurrent insertion. See: -+ // https://github.com/nmslib/nmslib/issues/544 - HnswNode *ep = enterpoint_; -+ int maxlevelcopy = ep->level; - if (curlevel < maxlevelcopy) { - const Object *currObj = ep->getData(); - --- -2.39.3 (Apple Git-146) - diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 817bdb8163..09bbb7eaf8 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -81,6 +81,100 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); +void knn_jni::faiss_wrapper::writeIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,jlong indexAddressJ, jstring indexPathJ, jobject parametersJ) { + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, + parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + auto *idMap = reinterpret_cast((long long)indexAddressJ); + faiss::write_index(idMap, indexPathCpp.c_str()); +} + +long long knn_jni::faiss_wrapper::CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong indexAddressJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int) dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if (numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexIDMap *idMap = nullptr; + long indexAddress = (long) indexAddressJ; + if(indexAddress == 0) { + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + faiss::Index *indexWriter = faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric); + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, + parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter); + jniUtil->DeleteLocalRef(env, subParametersJ); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Check that the index does not need to be trained + if (!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + idMap = new faiss::IndexIDMap(indexWriter); + idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data()); + } else { + idMap = reinterpret_cast(indexAddress); + idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data()); + } + return (long long)idMap; +} + + void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 3249ed8728..fb8af44f4c 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -17,6 +17,7 @@ #include "faiss_wrapper.h" #include "jni_util.h" +#include static knn_jni::JNIUtil jniUtil; static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1; @@ -50,6 +51,29 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE } } + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexIteratively(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddressJ, jobject parametersJ) +{ + try { + return (jlong)knn_jni::faiss_wrapper::CreateIndexIteratively(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddressJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return 0; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddressJ, jstring indexPathJ, jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::writeIndex(&jniUtil, env, indexAddressJ, indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 05854f7edb..cfdc3fa7d1 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -26,7 +26,7 @@ using ::testing::Return; float randomDataMin = -500.0; float randomDataMax = 500.0; -TEST(FaissCreateIndexTest, BasicAssertions) { +TEST(FaissCreateIndexIterativelyTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; @@ -52,6 +52,55 @@ TEST(FaissCreateIndexTest, BasicAssertions) { JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( + jniEnv, reinterpret_cast(&vectors))) + .WillRepeatedly(Return(vectors->size())); + + // Create the index + long long indexAddress = knn_jni::faiss_wrapper::CreateIndexIteratively( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) vectors, dim, (jlong)0, + (jobject)¶metersMap); + std::cout<<"Index address is "< index(test_util::FaissLoadIndex(indexPath)); + + // Clean up + std::remove(indexPath.c_str()); +} + + +TEST(FaissCreateIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + auto *vectors = new std::vector(); + int dim = 2; + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + std::string indexPath = test_util::RandomString(10, "/tmp/", ".faiss"); + std::string spaceType = knn_jni::L2; + std::string index_description = "HNSW32,Flat"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&index_description; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 096df817a7..1f19b854fa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -8,19 +8,23 @@ import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.util.BytesRef; import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.StopWatch; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; @@ -36,6 +40,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import java.io.ByteArrayInputStream; import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; @@ -46,14 +51,15 @@ import java.nio.file.StandardOpenOption; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; /** * This class writes the KNN docvalues to the segments @@ -108,19 +114,11 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); + if (values == null) { + log.info("BinaryDocValues is null. Returning.."); return; } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - } - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + // Get the KNN engine final KNNEngine knnEngine = getKNNEngine(field); final String engineFileName = buildEngineFileName( state.segmentInfo.name, @@ -132,35 +130,98 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName ).toString(); - NativeIndexCreator indexCreator; - // Create library index either from model or from scratch - if (field.attributes().containsKey(MODEL_ID)) { - String modelId = field.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); - } else { - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } + Map parametersMap = getKNNIndexFromScratchParameters(field, knnEngine); - if (isMerge) { - recordMergeStats(pair.docs.length, arraySize); - } + long indexAddress = createIndex(values, knnEngine, parametersMap); - if (isRefresh) { - recordRefreshStats(); + if (indexAddress == 0) { + log.info("Index is not created. Returning.."); } - // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that - // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will - // not be marked as added to the directory. state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); + // Now we can write the index + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexAddress, parametersMap, indexPath, knnEngine); + return null; + }); writeFooter(indexPath, engineFileName); } + private long createIndex(BinaryDocValues values, final KNNEngine knnEngine, final Map parametersMap) + throws IOException { + List vectorList = new ArrayList<>(); + List docIdList = new ArrayList<>(); + int dimension = 0; + SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; + + long totalLiveDocs = KNNCodecUtil.getTotalLiveDocsCount(values); + long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes(); + long vectorsPerTransfer = Integer.MIN_VALUE; + Long indexAddress = 0L; + + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + BytesRef bytesref = values.binaryValue(); + try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + dimension = vector.length; + + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + if (vectorsPerTransfer == 0) { + vectorsPerTransfer = totalLiveDocs; + } + } + if (vectorList.size() == vectorsPerTransfer) { + final long vectorAddress = JNICommons.storeVectorData( + 0, + vectorList.toArray(new float[][] {}), + (long) vectorList.size() * dimension + ); + List docIdList2 = docIdList; + int finalDimension = dimension; + long indexAddress2 = indexAddress; + indexAddress = AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.buildIndex( + docIdList2.stream().mapToInt(Integer::intValue).toArray(), + vectorAddress, + indexAddress2, + finalDimension, + parametersMap, + knnEngine + ) + ); + + // We should probably come up with a better way to reuse the vectorList memory which we have + // created. Problem here is doing like this can lead to a lot of list memory which is of no use and + // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. + vectorList = new ArrayList<>(); + docIdList = new ArrayList<>(); + JNICommons.freeVectorData(vectorAddress); + } + vectorList.add(vector); + } + docIdList.add(doc); + } + if (vectorList.isEmpty() == false) { + long vectorAddress = JNICommons.storeVectorData(0, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + List docIdList2 = docIdList; + int finalDimension = dimension; + long indexAddress2 = indexAddress; + indexAddress = AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.buildIndex( + docIdList2.stream().mapToInt(Integer::intValue).toArray(), + vectorAddress, + indexAddress2, + finalDimension, + parametersMap, + knnEngine + ) + ); + } + return indexAddress; + } + private void recordMergeStats(int length, long arraySize) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); @@ -193,8 +254,7 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN }); } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { + private Map getKNNIndexFromScratchParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); @@ -225,11 +285,7 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa // Used to determine how many threads to use when indexing parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - // Pass the path for the nms library to save the file - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); - return null; - }); + return parameters; } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index e059626081..8eb82dc286 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -127,7 +127,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } - private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { long totalLiveDocs; if (binaryDocValues instanceof KNN80BinaryDocValues) { totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java index 5c1e4ca9bb..bf2f5c6be9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java @@ -56,7 +56,7 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn return getSerializerBySerializationMode(serializationMode); } - static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { + public static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { int numberOfAvailableBytesInStream = byteStream.available(); if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) { return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS); diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 32516ef9dd..29feb7932f 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -49,6 +49,16 @@ class FaissService { }); } + public static native long createIndexIteratively( + int[] ids, + long vectorsAddress, + int dim, + long indexAddress, + Map parameters + ); + + public static native void writeIndex(long indexAddress, String indexPath, Map parameters); + /** * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 5a5b6794a2..40678e65cf 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -57,6 +57,29 @@ public static void createIndex( throw new IllegalArgumentException(String.format("CreateIndex not supported for provided engine : %s", knnEngine.getName())); } + public static long buildIndex( + int[] ids, + long vectorsAddress, + long indexAddress, + int dim, + Map parameters, + KNNEngine knnEngine + ) { + if (KNNEngine.FAISS == knnEngine) { + return FaissService.createIndexIteratively(ids, vectorsAddress, dim, indexAddress, parameters); + } + throw new IllegalArgumentException(String.format("buildIndex not supported for provided engine : %s", knnEngine.getName())); + + } + + public static void writeIndex(long indexAddress, Map parameters, String indexPath, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + FaissService.writeIndex(indexAddress, indexPath, parameters); + return; + } + throw new IllegalArgumentException(String.format("writeIndex not supported for provided engine : %s", knnEngine.getName())); + } + /** * Create an index for the native library with a provided template index *