Skip to content

Commit 3be612f

Browse files
authored
Add EMVB search: Part 1 (#1305)
### What problem does this PR solve? Add EMVB search: Part 1, with product quantizer missing Issue link:#1179 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases
1 parent f9503a2 commit 3be612f

File tree

10 files changed

+1053
-5
lines changed

10 files changed

+1053
-5
lines changed

.github/workflows/tests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
run: sudo docker exec infinity_build bash -c "mkdir -p /var/infinity && cd /infinity/ && cmake-build-debug/src/test_main > unittest_debug.log 2>&1"
6363

6464
- name: Collect infinity unit test debug output
65+
if: ${{ !cancelled() }} # always run this step even if previous steps failed
6566
run: cat unittest_debug.log 2>/dev/null || true
6667

6768
- name: Install pysdk
@@ -159,6 +160,7 @@ jobs:
159160
run: sudo docker exec infinity_build bash -c "mkdir -p /var/infinity && cd /infinity/ && cmake-build-release/src/test_main > unittest_release.log 2>&1"
160161

161162
- name: Collect infinity unit test release output
163+
if: ${{ !cancelled() }} # always run this step even if previous steps failed
162164
run: cat unittest_release.log 2>/dev/null || true
163165

164166
- name: Install pysdk

src/common/stl.cppm

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module;
1818
#include <algorithm>
1919
#include <atomic>
2020
#include <bit>
21+
#include <bitset>
2122
#include <cassert>
2223
#include <charconv>
2324
#include <chrono>
@@ -112,6 +113,7 @@ export namespace std {
112113
using std::try_to_lock;
113114

114115
using std::accumulate;
116+
using std::bitset;
115117
using std::binary_search;
116118
using std::ceil;
117119
using std::copy_n;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
module;
16+
17+
export module emvb_result_handler;
18+
import stl;
19+
import infinity_exception;
20+
21+
namespace infinity {
22+
23+
// EMVB needs move-only ID support
24+
// EMVB only use min heap
25+
template <typename DistType, typename ID>
26+
struct EMVBCompareMin {
27+
using DistanceType = DistType;
28+
using IDType = ID;
29+
inline static bool Compare(DistanceType a, DistanceType b) { return a < b; }
30+
static constexpr DistanceType InitialValue() { return std::numeric_limits<DistanceType>::lowest(); }
31+
struct CompareReverse {
32+
static constexpr DistanceType InitialValue() { return std::numeric_limits<DistanceType>::max(); }
33+
};
34+
};
35+
36+
template <typename Compare>
37+
inline void HeapifyDown(typename Compare::DistanceType *distance, typename Compare::IDType *id, const u32 size, u32 index) {
38+
if (index == 0 || (index << 1) > size) {
39+
return;
40+
}
41+
auto tmp_d = distance[index];
42+
auto tmp_i = std::move(id[index]);
43+
for (u32 sub; (sub = (index << 1)) <= size; index = sub) {
44+
if (sub + 1 <= size && Compare::Compare(distance[sub + 1], distance[sub])) {
45+
++sub;
46+
}
47+
if (!Compare::Compare(distance[sub], tmp_d)) {
48+
break;
49+
}
50+
distance[index] = distance[sub];
51+
id[index] = std::move(id[sub]);
52+
}
53+
distance[index] = tmp_d;
54+
id[index] = std::move(tmp_i);
55+
}
56+
57+
template <typename DistType>
58+
inline DistType median3(DistType a, DistType b, DistType c) {
59+
if (a > b) {
60+
std::swap(a, b);
61+
}
62+
return b <= c ? b : std::max(a, c);
63+
}
64+
65+
template <typename Compare, typename DistType>
66+
inline void count_lt_and_eq(const DistType *vals, const u32 n, DistType thresh, u32 &n_lt, u32 &n_eq) {
67+
n_lt = n_eq = 0;
68+
for (SizeT i = 0; i < n; ++i) {
69+
auto v = *(vals++);
70+
if (Compare::Compare(thresh, v)) {
71+
n_lt++;
72+
} else if (v == thresh) {
73+
n_eq++;
74+
}
75+
}
76+
}
77+
78+
template <typename Compare, typename DistType>
79+
inline DistType sample_threshold_median3(const DistType *vals, const u32 n, DistType thresh_inf, DistType thresh_sup) {
80+
DistType val3[3];
81+
u32 vi = 0;
82+
for (u32 i = 0; i < n; ++i) {
83+
DistType v = vals[(i * 6700417ull) % n];
84+
if (Compare::Compare(v, thresh_inf) && Compare::Compare(thresh_sup, v)) {
85+
val3[vi++] = v;
86+
if (vi == 3) {
87+
break;
88+
}
89+
}
90+
}
91+
if (vi == 3) {
92+
return median3(val3[0], val3[1], val3[2]);
93+
}
94+
if (vi != 0) {
95+
return val3[0];
96+
}
97+
for (SizeT i = 0; i < n; i++) {
98+
if (const DistType v = vals[i]; Compare::Compare(v, thresh_inf) && Compare::Compare(thresh_sup, v)) {
99+
return v;
100+
}
101+
}
102+
return thresh_inf;
103+
}
104+
105+
template <typename Compare, typename DistType, typename ID>
106+
inline u32 compress_array(DistType *vals, ID *ids, const u32 n, DistType thresh, u32 n_eq) {
107+
u32 wp = 0;
108+
for (u32 i = 0; i < n; ++i) {
109+
if (Compare::Compare(thresh, vals[i])) {
110+
if (wp != i) {
111+
vals[wp] = vals[i];
112+
ids[wp] = std::move(ids[i]);
113+
}
114+
++wp;
115+
} else if (n_eq > 0 && vals[i] == thresh) {
116+
if (wp != i) {
117+
vals[wp] = vals[i];
118+
ids[wp] = std::move(ids[i]);
119+
}
120+
++wp;
121+
--n_eq;
122+
}
123+
}
124+
if (n_eq != 0) {
125+
UnrecoverableError("compress_array error: n_eq != 0");
126+
}
127+
return wp;
128+
}
129+
130+
template <typename Compare, typename DistType, typename ID>
131+
inline DistType partition_median3(DistType *vals, ID *ids, const u32 n, const u32 q_min, const u32 q_max, u32 &q_out) {
132+
if (n < 3) {
133+
UnrecoverableError("partition_median3 error: n < 3");
134+
}
135+
DistType thresh_inf = Compare::CompareReverse::InitialValue();
136+
DistType thresh_sup = Compare::InitialValue();
137+
DistType thresh = median3(vals[0], vals[n / 2], vals[n - 1]);
138+
u32 n_eq = 0;
139+
u32 n_lt = 0;
140+
u32 q = 0;
141+
for (int it = 0; it < 200; ++it) {
142+
count_lt_and_eq<Compare>(vals, n, thresh, n_lt, n_eq);
143+
if (n_lt <= q_min) {
144+
if (n_lt + n_eq >= q_min) {
145+
q = q_min;
146+
break;
147+
}
148+
thresh_inf = thresh;
149+
} else if (n_lt <= q_max) {
150+
q = n_lt;
151+
break;
152+
} else {
153+
thresh_sup = thresh;
154+
}
155+
DistType new_thresh = sample_threshold_median3<Compare>(vals, n, thresh_inf, thresh_sup);
156+
if (new_thresh == thresh_inf) {
157+
UnrecoverableError("partition_median3 error: new_thresh == thresh_inf");
158+
}
159+
thresh = new_thresh;
160+
}
161+
if (q < n_lt) {
162+
UnrecoverableError("partition_median3 error: q < n_lt");
163+
}
164+
const u32 n_eq_extra = q - n_lt;
165+
auto wp = compress_array<Compare>(vals, ids, n, thresh, n_eq_extra);
166+
if (wp != q) {
167+
UnrecoverableError("partition_median3 error: wp != q");
168+
}
169+
q_out = q;
170+
return thresh;
171+
}
172+
173+
template <class Compare>
174+
class EMVBReservoirResultHandlerT {
175+
using DistType = typename Compare::DistanceType;
176+
using ID = typename Compare::IDType;
177+
u32 top_k_;
178+
u32 capacity_;
179+
u32 size_;
180+
DistType threshold_;
181+
UniquePtr<DistType[]> reservoir_distance_ptr_;
182+
UniquePtr<ID[]> reservoir_id_ptr_;
183+
184+
public:
185+
explicit EMVBReservoirResultHandlerT(const u32 top_k) : top_k_{top_k}, capacity_{2 * top_k}, size_{0}, threshold_{Compare::InitialValue()} {
186+
if (capacity_ < 8) {
187+
capacity_ = 8;
188+
}
189+
reservoir_distance_ptr_ = MakeUniqueForOverwrite<DistType[]>(capacity_);
190+
reservoir_id_ptr_ = MakeUniqueForOverwrite<ID[]>(capacity_);
191+
}
192+
193+
[[nodiscard]] auto GetThreshold() const { return threshold_; }
194+
195+
[[nodiscard]] auto GetSize() const { return size_; }
196+
197+
[[nodiscard]] auto GetDistancePtr() { return std::move(reservoir_distance_ptr_); }
198+
199+
[[nodiscard]] auto GetIdPtr() { return std::move(reservoir_id_ptr_); }
200+
201+
void Add(DistType distance, auto &&id) {
202+
auto q_id_distance = reservoir_distance_ptr_.get();
203+
auto q_id_id = reservoir_id_ptr_.get();
204+
if (Compare::Compare(threshold_, distance)) {
205+
if (size_ == capacity_) {
206+
threshold_ = partition_median3<Compare>(q_id_distance, q_id_id, capacity_, top_k_, (capacity_ + top_k_) / 2, size_);
207+
}
208+
q_id_distance[size_] = distance;
209+
q_id_id[size_] = std::move(id);
210+
++size_;
211+
}
212+
}
213+
214+
void EndWithoutSort() {
215+
if (size_ > top_k_) {
216+
const auto size = size_;
217+
const auto result_size = top_k_;
218+
const auto distance_ptr = reservoir_distance_ptr_.get();
219+
const auto id_ptr = reservoir_id_ptr_.get();
220+
const auto dis_result = distance_ptr - 1;
221+
const auto id_result = id_ptr - 1;
222+
for (u32 index = result_size / 2; index > 0; --index) {
223+
HeapifyDown<Compare>(dis_result, id_result, result_size, index);
224+
}
225+
for (u32 j = result_size; j < size; ++j) {
226+
if (Compare::Compare(distance_ptr[0], distance_ptr[j])) {
227+
distance_ptr[0] = distance_ptr[j];
228+
id_ptr[0] = std::move(id_ptr[j]);
229+
HeapifyDown<Compare>(dis_result, id_result, result_size, 1);
230+
}
231+
}
232+
size_ = result_size;
233+
}
234+
}
235+
236+
void EndSort() {
237+
EndWithoutSort();
238+
auto result_size = size_;
239+
const auto dis_result = reservoir_distance_ptr_.get() - 1;
240+
const auto id_result = reservoir_id_ptr_.get() - 1;
241+
while (result_size > 1) {
242+
std::swap(dis_result[result_size], dis_result[1]);
243+
std::swap(id_result[result_size], id_result[1]);
244+
--result_size;
245+
HeapifyDown<Compare>(dis_result, id_result, result_size, 1);
246+
}
247+
}
248+
};
249+
250+
export template <typename DistType, typename ID>
251+
using EMVBReservoirResultHandler = EMVBReservoirResultHandlerT<EMVBCompareMin<DistType, ID>>;
252+
253+
} // namespace infinity

0 commit comments

Comments
 (0)