1
+ // Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // https://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ module;
16
+
17
+ export module emvb_result_handler;
18
+ import stl;
19
+ import infinity_exception;
20
+
21
+ namespace infinity {
22
+
23
+ // EMVB needs move-only ID support
24
+ // EMVB only use min heap
25
+ template <typename DistType, typename ID>
26
+ struct EMVBCompareMin {
27
+ using DistanceType = DistType;
28
+ using IDType = ID;
29
+ inline static bool Compare (DistanceType a, DistanceType b) { return a < b; }
30
+ static constexpr DistanceType InitialValue () { return std::numeric_limits<DistanceType>::lowest (); }
31
+ struct CompareReverse {
32
+ static constexpr DistanceType InitialValue () { return std::numeric_limits<DistanceType>::max (); }
33
+ };
34
+ };
35
+
36
+ template <typename Compare>
37
+ inline void HeapifyDown (typename Compare::DistanceType *distance, typename Compare::IDType *id, const u32 size, u32 index) {
38
+ if (index == 0 || (index << 1 ) > size) {
39
+ return ;
40
+ }
41
+ auto tmp_d = distance[index ];
42
+ auto tmp_i = std::move (id[index ]);
43
+ for (u32 sub; (sub = (index << 1 )) <= size; index = sub) {
44
+ if (sub + 1 <= size && Compare::Compare (distance[sub + 1 ], distance[sub])) {
45
+ ++sub;
46
+ }
47
+ if (!Compare::Compare (distance[sub], tmp_d)) {
48
+ break ;
49
+ }
50
+ distance[index ] = distance[sub];
51
+ id[index ] = std::move (id[sub]);
52
+ }
53
+ distance[index ] = tmp_d;
54
+ id[index ] = std::move (tmp_i);
55
+ }
56
+
57
+ template <typename DistType>
58
+ inline DistType median3 (DistType a, DistType b, DistType c) {
59
+ if (a > b) {
60
+ std::swap (a, b);
61
+ }
62
+ return b <= c ? b : std::max (a, c);
63
+ }
64
+
65
+ template <typename Compare, typename DistType>
66
+ inline void count_lt_and_eq (const DistType *vals, const u32 n, DistType thresh, u32 &n_lt, u32 &n_eq) {
67
+ n_lt = n_eq = 0 ;
68
+ for (SizeT i = 0 ; i < n; ++i) {
69
+ auto v = *(vals++);
70
+ if (Compare::Compare (thresh, v)) {
71
+ n_lt++;
72
+ } else if (v == thresh) {
73
+ n_eq++;
74
+ }
75
+ }
76
+ }
77
+
78
+ template <typename Compare, typename DistType>
79
+ inline DistType sample_threshold_median3 (const DistType *vals, const u32 n, DistType thresh_inf, DistType thresh_sup) {
80
+ DistType val3[3 ];
81
+ u32 vi = 0 ;
82
+ for (u32 i = 0 ; i < n; ++i) {
83
+ DistType v = vals[(i * 6700417ull ) % n];
84
+ if (Compare::Compare (v, thresh_inf) && Compare::Compare (thresh_sup, v)) {
85
+ val3[vi++] = v;
86
+ if (vi == 3 ) {
87
+ break ;
88
+ }
89
+ }
90
+ }
91
+ if (vi == 3 ) {
92
+ return median3 (val3[0 ], val3[1 ], val3[2 ]);
93
+ }
94
+ if (vi != 0 ) {
95
+ return val3[0 ];
96
+ }
97
+ for (SizeT i = 0 ; i < n; i++) {
98
+ if (const DistType v = vals[i]; Compare::Compare (v, thresh_inf) && Compare::Compare (thresh_sup, v)) {
99
+ return v;
100
+ }
101
+ }
102
+ return thresh_inf;
103
+ }
104
+
105
+ template <typename Compare, typename DistType, typename ID>
106
+ inline u32 compress_array (DistType *vals, ID *ids, const u32 n, DistType thresh, u32 n_eq) {
107
+ u32 wp = 0 ;
108
+ for (u32 i = 0 ; i < n; ++i) {
109
+ if (Compare::Compare (thresh, vals[i])) {
110
+ if (wp != i) {
111
+ vals[wp] = vals[i];
112
+ ids[wp] = std::move (ids[i]);
113
+ }
114
+ ++wp;
115
+ } else if (n_eq > 0 && vals[i] == thresh) {
116
+ if (wp != i) {
117
+ vals[wp] = vals[i];
118
+ ids[wp] = std::move (ids[i]);
119
+ }
120
+ ++wp;
121
+ --n_eq;
122
+ }
123
+ }
124
+ if (n_eq != 0 ) {
125
+ UnrecoverableError (" compress_array error: n_eq != 0" );
126
+ }
127
+ return wp;
128
+ }
129
+
130
+ template <typename Compare, typename DistType, typename ID>
131
+ inline DistType partition_median3 (DistType *vals, ID *ids, const u32 n, const u32 q_min, const u32 q_max, u32 &q_out) {
132
+ if (n < 3 ) {
133
+ UnrecoverableError (" partition_median3 error: n < 3" );
134
+ }
135
+ DistType thresh_inf = Compare::CompareReverse::InitialValue ();
136
+ DistType thresh_sup = Compare::InitialValue ();
137
+ DistType thresh = median3 (vals[0 ], vals[n / 2 ], vals[n - 1 ]);
138
+ u32 n_eq = 0 ;
139
+ u32 n_lt = 0 ;
140
+ u32 q = 0 ;
141
+ for (int it = 0 ; it < 200 ; ++it) {
142
+ count_lt_and_eq<Compare>(vals, n, thresh, n_lt, n_eq);
143
+ if (n_lt <= q_min) {
144
+ if (n_lt + n_eq >= q_min) {
145
+ q = q_min;
146
+ break ;
147
+ }
148
+ thresh_inf = thresh;
149
+ } else if (n_lt <= q_max) {
150
+ q = n_lt;
151
+ break ;
152
+ } else {
153
+ thresh_sup = thresh;
154
+ }
155
+ DistType new_thresh = sample_threshold_median3<Compare>(vals, n, thresh_inf, thresh_sup);
156
+ if (new_thresh == thresh_inf) {
157
+ UnrecoverableError (" partition_median3 error: new_thresh == thresh_inf" );
158
+ }
159
+ thresh = new_thresh;
160
+ }
161
+ if (q < n_lt) {
162
+ UnrecoverableError (" partition_median3 error: q < n_lt" );
163
+ }
164
+ const u32 n_eq_extra = q - n_lt;
165
+ auto wp = compress_array<Compare>(vals, ids, n, thresh, n_eq_extra);
166
+ if (wp != q) {
167
+ UnrecoverableError (" partition_median3 error: wp != q" );
168
+ }
169
+ q_out = q;
170
+ return thresh;
171
+ }
172
+
173
+ template <class Compare >
174
+ class EMVBReservoirResultHandlerT {
175
+ using DistType = typename Compare::DistanceType;
176
+ using ID = typename Compare::IDType;
177
+ u32 top_k_;
178
+ u32 capacity_;
179
+ u32 size_;
180
+ DistType threshold_;
181
+ UniquePtr<DistType[]> reservoir_distance_ptr_;
182
+ UniquePtr<ID[]> reservoir_id_ptr_;
183
+
184
+ public:
185
+ explicit EMVBReservoirResultHandlerT (const u32 top_k) : top_k_{top_k}, capacity_{2 * top_k}, size_{0 }, threshold_{Compare::InitialValue ()} {
186
+ if (capacity_ < 8 ) {
187
+ capacity_ = 8 ;
188
+ }
189
+ reservoir_distance_ptr_ = MakeUniqueForOverwrite<DistType[]>(capacity_);
190
+ reservoir_id_ptr_ = MakeUniqueForOverwrite<ID[]>(capacity_);
191
+ }
192
+
193
+ [[nodiscard]] auto GetThreshold () const { return threshold_; }
194
+
195
+ [[nodiscard]] auto GetSize () const { return size_; }
196
+
197
+ [[nodiscard]] auto GetDistancePtr () { return std::move (reservoir_distance_ptr_); }
198
+
199
+ [[nodiscard]] auto GetIdPtr () { return std::move (reservoir_id_ptr_); }
200
+
201
+ void Add (DistType distance, auto &&id) {
202
+ auto q_id_distance = reservoir_distance_ptr_.get ();
203
+ auto q_id_id = reservoir_id_ptr_.get ();
204
+ if (Compare::Compare (threshold_, distance)) {
205
+ if (size_ == capacity_) {
206
+ threshold_ = partition_median3<Compare>(q_id_distance, q_id_id, capacity_, top_k_, (capacity_ + top_k_) / 2 , size_);
207
+ }
208
+ q_id_distance[size_] = distance;
209
+ q_id_id[size_] = std::move (id);
210
+ ++size_;
211
+ }
212
+ }
213
+
214
+ void EndWithoutSort () {
215
+ if (size_ > top_k_) {
216
+ const auto size = size_;
217
+ const auto result_size = top_k_;
218
+ const auto distance_ptr = reservoir_distance_ptr_.get ();
219
+ const auto id_ptr = reservoir_id_ptr_.get ();
220
+ const auto dis_result = distance_ptr - 1 ;
221
+ const auto id_result = id_ptr - 1 ;
222
+ for (u32 index = result_size / 2 ; index > 0 ; --index ) {
223
+ HeapifyDown<Compare>(dis_result, id_result, result_size, index );
224
+ }
225
+ for (u32 j = result_size; j < size; ++j) {
226
+ if (Compare::Compare (distance_ptr[0 ], distance_ptr[j])) {
227
+ distance_ptr[0 ] = distance_ptr[j];
228
+ id_ptr[0 ] = std::move (id_ptr[j]);
229
+ HeapifyDown<Compare>(dis_result, id_result, result_size, 1 );
230
+ }
231
+ }
232
+ size_ = result_size;
233
+ }
234
+ }
235
+
236
+ void EndSort () {
237
+ EndWithoutSort ();
238
+ auto result_size = size_;
239
+ const auto dis_result = reservoir_distance_ptr_.get () - 1 ;
240
+ const auto id_result = reservoir_id_ptr_.get () - 1 ;
241
+ while (result_size > 1 ) {
242
+ std::swap (dis_result[result_size], dis_result[1 ]);
243
+ std::swap (id_result[result_size], id_result[1 ]);
244
+ --result_size;
245
+ HeapifyDown<Compare>(dis_result, id_result, result_size, 1 );
246
+ }
247
+ }
248
+ };
249
+
250
+ export template <typename DistType, typename ID>
251
+ using EMVBReservoirResultHandler = EMVBReservoirResultHandlerT<EMVBCompareMin<DistType, ID>>;
252
+
253
+ } // namespace infinity
0 commit comments