Skip to content

Commit 5d13de7

Browse files
authored
Add EMVB search: Part 3 (#1318)
### What problem does this PR solve? Add EMVB search: Part 3 Add EMVBIndex class Issue link:#1179 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
1 parent e821c02 commit 5d13de7

File tree

10 files changed

+598
-44
lines changed

10 files changed

+598
-44
lines changed
+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
module;
16+
17+
module emvb_index;
18+
import stl;
19+
import mlas_matrix_multiply;
20+
import vector_distance;
21+
import emvb_product_quantization;
22+
import emvb_search;
23+
import kmeans_partition;
24+
import index_base;
25+
import status;
26+
import logger;
27+
import third_party;
28+
import infinity_exception;
29+
30+
namespace infinity {
31+
32+
extern template class EMVBSharedVec<u32>;
33+
extern template class EMVBSearch<32>;
34+
extern template class EMVBSearch<64>;
35+
extern template class EMVBSearch<96>;
36+
extern template class EMVBSearch<128>;
37+
extern template class EMVBSearch<160>;
38+
extern template class EMVBSearch<192>;
39+
extern template class EMVBSearch<224>;
40+
extern template class EMVBSearch<256>;
41+
42+
EMVBIndex::EMVBIndex(const u32 start_segment_offset,
43+
const u32 embedding_dimension,
44+
const u32 n_centroids,
45+
const u32 residual_pq_subspace_num,
46+
const u32 residual_pq_subspace_bits)
47+
: start_segment_offset_(start_segment_offset), embedding_dimension_(embedding_dimension), n_centroids_(n_centroids),
48+
residual_pq_subspace_num_(residual_pq_subspace_num), residual_pq_subspace_bits_(residual_pq_subspace_bits) {
49+
centroids_data_.resize(n_centroids * embedding_dimension);
50+
centroid_norms_neg_half_.resize(n_centroids);
51+
centroids_to_docid_ = MakeUnique<EMVBSharedVec<u32>[]>(n_centroids);
52+
// now always use OPQ
53+
product_quantizer_ = GetEMVBOPQ(residual_pq_subspace_num, residual_pq_subspace_bits, embedding_dimension);
54+
}
55+
56+
// need embedding num:
57+
// 1. 256 * n_centroids_ for centroids
58+
// 2. 256 * (1 << residual_pq_subspace_bits) for residual product quantizer
59+
u32 EMVBIndex::ExpectLeastTrainingDataNum() const { return std::max<u32>(256 * n_centroids_, 256 * (1 << residual_pq_subspace_bits_)); }
60+
61+
void EMVBIndex::Train(const f32 *embedding_data, const u32 embedding_num, const u32 iter_cnt) {
62+
// check n_centroids_
63+
if (((n_centroids_ % 8) != 0) || (n_centroids_ == 0)) {
64+
const auto error_msg = fmt::format("EMVBIndex::Train: n_centroids_ must be a multiple of 8, got {} instead.", n_centroids_);
65+
LOG_ERROR(error_msg);
66+
UnrecoverableError(error_msg);
67+
}
68+
// check training data num
69+
if (const u32 least_num = ExpectLeastTrainingDataNum(); embedding_num < least_num) {
70+
const auto error_msg = fmt::format("EMVBIndex::Train: embedding_num must be at least {}, got {} instead.", least_num, embedding_num);
71+
LOG_ERROR(error_msg);
72+
UnrecoverableError(error_msg);
73+
}
74+
// train both centroids and residual product quantizer
75+
// step 1. train centroids
76+
{
77+
const auto result_centroid_num = GetKMeansCentroids<f32>(MetricType::kMetricL2,
78+
embedding_dimension_,
79+
embedding_num,
80+
embedding_data,
81+
centroids_data_,
82+
n_centroids_,
83+
iter_cnt);
84+
if (result_centroid_num != n_centroids_) {
85+
const auto error_msg =
86+
fmt::format("EMVBIndex::Train: KMeans failed to get {} centroids, got {} instead.", n_centroids_, result_centroid_num);
87+
LOG_ERROR(error_msg);
88+
UnrecoverableError(error_msg);
89+
}
90+
LOG_TRACE(fmt::format("EMVBIndex::Train: KMeans got {} centroids.", result_centroid_num));
91+
}
92+
{
93+
const f32 *centroid_data = centroids_data_.data();
94+
for (u32 i = 0; i < n_centroids_; ++i) {
95+
centroid_norms_neg_half_[i] = -0.5f * L2NormSquare<f32, f32, u32>(centroid_data, embedding_dimension_);
96+
centroid_data += embedding_dimension_;
97+
}
98+
}
99+
// step 2. get residuals
100+
const auto residuals = MakeUniqueForOverwrite<f32[]>(embedding_num * embedding_dimension_);
101+
{
102+
// distance: for every embedding, e * c - 0.5 * c^2, find max
103+
const auto dist_table = MakeUniqueForOverwrite<f32[]>(embedding_num * n_centroids_);
104+
matrixA_multiply_transpose_matrixB_output_to_C(embedding_data,
105+
centroids_data_.data(),
106+
embedding_num,
107+
n_centroids_,
108+
embedding_dimension_,
109+
dist_table.get());
110+
for (u32 i = 0; i < embedding_num; ++i) {
111+
const f32 *embedding_data_ptr = embedding_data + i * embedding_dimension_;
112+
f32 *output_ptr = residuals.get() + i * embedding_dimension_;
113+
f32 max_neg_distance = std::numeric_limits<f32>::lowest();
114+
u32 max_id = 0;
115+
const f32 *dist_ptr = dist_table.get() + i * n_centroids_;
116+
for (u32 k = 0; k < n_centroids_; ++k) {
117+
if (const f32 neg_distance = dist_ptr[k] + centroid_norms_neg_half_[k]; neg_distance > max_neg_distance) {
118+
max_neg_distance = neg_distance;
119+
max_id = k;
120+
}
121+
}
122+
const f32 *centroids_data_ptr = centroids_data_.data() + max_id * embedding_dimension_;
123+
for (u32 j = 0; j < embedding_dimension_; ++j) {
124+
output_ptr[j] = embedding_data_ptr[j] - centroids_data_ptr[j];
125+
}
126+
}
127+
}
128+
LOG_TRACE("EMVBIndex::Train: Finish calculate residuals.");
129+
// step 3. train residuals
130+
product_quantizer_->Train(residuals.get(), embedding_num, iter_cnt);
131+
LOG_TRACE("EMVBIndex::Train: Finish train pq for residuals.");
132+
}
133+
134+
void EMVBIndex::AddOneDocEmbeddings(const f32 *embedding_data, const u32 embedding_num) {
135+
std::lock_guard lock(append_mutex_);
136+
// only one thread can add doc to the index at the same time
137+
// step 1. doc - embedding info
138+
const u32 old_doc_num = n_docs_;
139+
const u32 old_total_embeddings = n_total_embeddings_;
140+
doc_lens_.PushBack(embedding_num);
141+
doc_offsets_.PushBack(old_total_embeddings);
142+
// step 2. assign to centroids
143+
const auto centroid_id_assignments = MakeUniqueForOverwrite<u32[]>(embedding_num);
144+
const auto residuals = MakeUniqueForOverwrite<f32[]>(embedding_num * embedding_dimension_);
145+
{
146+
// distance: for every embedding, e * c - 0.5 * c^2, find max
147+
const auto dist_table = MakeUniqueForOverwrite<f32[]>(embedding_num * n_centroids_);
148+
matrixA_multiply_transpose_matrixB_output_to_C(embedding_data,
149+
centroids_data_.data(),
150+
embedding_num,
151+
n_centroids_,
152+
embedding_dimension_,
153+
dist_table.get());
154+
for (u32 i = 0; i < embedding_num; ++i) {
155+
const f32 *embedding_data_ptr = embedding_data + i * embedding_dimension_;
156+
f32 *output_ptr = residuals.get() + i * embedding_dimension_;
157+
f32 max_neg_distance = std::numeric_limits<f32>::lowest();
158+
u32 max_id = 0;
159+
const f32 *dist_ptr = dist_table.get() + i * n_centroids_;
160+
for (u32 k = 0; k < n_centroids_; ++k) {
161+
if (const f32 neg_distance = dist_ptr[k] + centroid_norms_neg_half_[k]; neg_distance > max_neg_distance) {
162+
max_neg_distance = neg_distance;
163+
max_id = k;
164+
}
165+
}
166+
centroid_id_assignments[i] = max_id;
167+
const f32 *centroids_data_ptr = centroids_data_.data() + max_id * embedding_dimension_;
168+
for (u32 j = 0; j < embedding_dimension_; ++j) {
169+
output_ptr[j] = embedding_data_ptr[j] - centroids_data_ptr[j];
170+
}
171+
}
172+
}
173+
centroid_id_assignments_.PushBack(centroid_id_assignments.get(), centroid_id_assignments.get() + embedding_num);
174+
for (u32 i = 0; i < embedding_num; ++i) {
175+
const u32 centroid_id = centroid_id_assignments[i];
176+
centroids_to_docid_[centroid_id].PushBackIfDifferentFromLast(old_doc_num);
177+
}
178+
// step 3. add residuals to product quantizer
179+
product_quantizer_->AddEmbeddings(residuals.get(), embedding_num);
180+
// finally, update count
181+
n_total_embeddings_ += embedding_num;
182+
++n_docs_;
183+
}
184+
185+
// the two thresholds are for every (query embedding, candidate embedding) pair
186+
// candidate embeddings are centroids
187+
// unqualified pairs will not be scored
188+
// but if nothing is left, an exhaustive search will be performed
189+
190+
constexpr u32 current_max_query_token_num = 256;
191+
192+
EMVBQueryResultType EMVBIndex::GetQueryResult(const f32 *query_ptr,
193+
const u32 query_embedding_num,
194+
const u32 centroid_nprobe, // step 1, centroid candidates for every query embedding
195+
const f32 threshold_first, // step 1, threshold for query - centroid score
196+
const u32 n_doc_to_score, // topn by centroids hit count
197+
const u32 out_second_stage, // step 2, topn, use nearest centroid score as embedding score
198+
const u32 top_k, // step 3, final topk, refine score by residual pq
199+
const f32 threshold_final // step 3, threshold to reduce maxsim calculation
200+
) const {
201+
// template argument should be in ascending order
202+
// keep consistent with emvb_search.cpp
203+
return query_token_num_helper<32, 64, 96, 128, 160, 192, 224, 256>(query_ptr,
204+
query_embedding_num,
205+
centroid_nprobe,
206+
threshold_first,
207+
n_doc_to_score,
208+
out_second_stage,
209+
top_k,
210+
threshold_final);
211+
}
212+
213+
template <u32 I, u32... J>
214+
EMVBQueryResultType EMVBIndex::query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto... query_args) const {
215+
if (query_embedding_num <= I) {
216+
return GetQueryResultT<I>(query_ptr, query_embedding_num, query_args...);
217+
}
218+
return query_token_num_helper<J...>(query_ptr, query_embedding_num, query_args...);
219+
}
220+
221+
template <>
222+
EMVBQueryResultType EMVBIndex::query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto... query_args) const {
223+
auto error_msg = fmt::format("EMVBIndex::GetQueryResult: query_embedding_num max value: {}, got {} instead.",
224+
current_max_query_token_num,
225+
query_embedding_num);
226+
error_msg += fmt::format(" Embeddings after {} will not be used for search.", current_max_query_token_num);
227+
error_msg += " Please Add instantiation of EMVBSearch with a bigger FIXED_QUERY_TOKEN_NUM value.";
228+
LOG_ERROR(error_msg);
229+
return GetQueryResultT<current_max_query_token_num>(query_ptr, query_embedding_num, query_args...);
230+
}
231+
232+
template <u32 FIXED_QUERY_TOKEN_NUM>
233+
EMVBQueryResultType EMVBIndex::GetQueryResultT(const f32 *query_ptr, const u32 query_embedding_num, auto... query_args) const {
234+
UniquePtr<f32[]> extended_query_ptr;
235+
const f32 *query_ptr_to_use = query_ptr;
236+
// extend query to FIXED_QUERY_TOKEN_NUM
237+
if (query_embedding_num < FIXED_QUERY_TOKEN_NUM) {
238+
extended_query_ptr = MakeUniqueForOverwrite<f32[]>(FIXED_QUERY_TOKEN_NUM * embedding_dimension_);
239+
std::copy_n(query_ptr, query_embedding_num * embedding_dimension_, extended_query_ptr.get());
240+
std::fill_n(extended_query_ptr.get() + query_embedding_num * embedding_dimension_,
241+
(FIXED_QUERY_TOKEN_NUM - query_embedding_num) * embedding_dimension_,
242+
0.0f);
243+
query_ptr_to_use = extended_query_ptr.get();
244+
}
245+
// access snapshot of index data
246+
const u32 n_docs = n_docs_.load(std::memory_order_acquire);
247+
const auto doc_lens_snapshot = doc_lens_.GetData();
248+
const auto doc_offsets_snapshot = doc_offsets_.GetData();
249+
const auto centroid_id_assignments_snapshot = centroid_id_assignments_.GetData();
250+
// execute search
251+
EMVBSearch<FIXED_QUERY_TOKEN_NUM> search_helper(embedding_dimension_,
252+
n_docs,
253+
n_centroids_,
254+
doc_lens_snapshot.first.get(),
255+
doc_offsets_snapshot.first.get(),
256+
centroid_id_assignments_snapshot.first.get(),
257+
centroids_data_.data(),
258+
centroids_to_docid_.get(),
259+
product_quantizer_.get());
260+
return search_helper.GetQueryResult(query_ptr_to_use, query_args...);
261+
}
262+
263+
} // namespace infinity
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
module;
16+
17+
export module emvb_index;
18+
import stl;
19+
import emvb_shared_vec;
20+
21+
namespace infinity {
22+
23+
extern template class EMVBSharedVec<u32>;
24+
class EMVBProductQuantizer;
25+
26+
using EMVBQueryResultType = Tuple<u32, UniquePtr<f32[]>, UniquePtr<u32[]>>;
27+
28+
// created with fixed embedding dimension and number of centroids
29+
export class EMVBIndex {
30+
const u32 start_segment_offset_ = 0; // start offset of the index in the segment
31+
const u32 embedding_dimension_ = 0; // dimension of the embeddings
32+
const u32 n_centroids_ = 0; // number of centroids, need to be a multiple of 8
33+
const u32 residual_pq_subspace_num_ = 0; // number of subspaces in the residual product quantizer
34+
const u32 residual_pq_subspace_bits_ = 0; // number of bits for each centroid representation in the residual product quantizer
35+
Vector<f32> centroids_data_; // centroids data
36+
Vector<f32> centroid_norms_neg_half_; // (-0.5 * norm) for each centroid
37+
atomic_u32 n_docs_ = 0; // number of documents in the entire collection
38+
u32 n_total_embeddings_ = 0; // number of embeddings in the entire collection
39+
EMVBSharedVec<u32> doc_lens_; // array of document lengths
40+
EMVBSharedVec<u32> doc_offsets_; // start offsets of each document in all the embeddings
41+
EMVBSharedVec<u32> centroid_id_assignments_; // centroid id assignments for each embedding
42+
UniquePtr<EMVBSharedVec<u32>[]> centroids_to_docid_; // docids belonging to each centroid
43+
UniquePtr<EMVBProductQuantizer> product_quantizer_; // product quantizer for residuals of the embeddings
44+
std::mutex append_mutex_; // mutex for append all embeddings for one doc
45+
46+
public:
47+
EMVBIndex(u32 start_segment_offset, u32 embedding_dimension, u32 n_centroids, u32 residual_pq_subspace_num, u32 residual_pq_subspace_bits);
48+
49+
[[nodiscard]] u32 ExpectLeastTrainingDataNum() const;
50+
51+
void Train(const f32 *embedding_data, u32 embedding_num, u32 iter_cnt = 20);
52+
53+
void AddOneDocEmbeddings(const f32 *embedding_data, u32 embedding_num);
54+
55+
// the two thresholds are for every (query embedding, candidate embedding) pair
56+
// candidate embeddings are centroids
57+
// unqualified pairs will not be scored
58+
// but if nothing is left, an exhaustive search will be performed
59+
EMVBQueryResultType GetQueryResult(const f32 *query_ptr,
60+
u32 query_embedding_num,
61+
u32 centroid_nprobe, // step 1, centroid candidates for every query embedding
62+
f32 threshold_first, // step 1, threshold for query - centroid score
63+
u32 n_doc_to_score, // topn by centroids hit count
64+
u32 out_second_stage, // step 2, topn, use nearest centroid score as embedding score
65+
u32 top_k, // step 3, final topk, refine score by residual pq
66+
f32 threshold_final // step 3, threshold to reduce maxsim calculation
67+
) const;
68+
69+
private:
70+
template <u32 I, u32... J>
71+
EMVBQueryResultType query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto... query_args) const;
72+
73+
template <>
74+
EMVBQueryResultType query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto... query_args) const;
75+
76+
template <u32 FIXED_QUERY_TOKEN_NUM>
77+
EMVBQueryResultType GetQueryResultT(const f32 *query_ptr, u32 query_embedding_num, auto... query_args) const;
78+
};
79+
80+
} // namespace infinity

0 commit comments

Comments
 (0)