Skip to content

Commit bb1b3a3

Browse files
committed
embedding int8
1 parent 04d0083 commit bb1b3a3

File tree

4 files changed

+257
-11
lines changed

4 files changed

+257
-11
lines changed

src/storage/buffer/file_worker/hnsw_file_worker.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ void HnswFileWorker::AllocateInMemory() {
6262
SizeT ef_c = index_hnsw->ef_construction_;
6363
EmbeddingDataType embedding_type = GetType();
6464
switch (embedding_type) {
65+
case kElemInt8: {
66+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(nullptr, index_hnsw);
67+
abstract_hnsw.Make(chunk_size_, max_chunk_num_, dimension, M, ef_c);
68+
data_ = abstract_hnsw.RawPtr();
69+
break;
70+
}
6571
case kElemFloat: {
6672
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(nullptr, index_hnsw);
6773
abstract_hnsw.Make(chunk_size_, max_chunk_num_, dimension, M, ef_c);
@@ -83,6 +89,11 @@ void HnswFileWorker::FreeInMemory() {
8389
const IndexHnsw *index_hnsw = static_cast<const IndexHnsw *>(index_base_.get());
8490
EmbeddingDataType embedding_type = GetType();
8591
switch (embedding_type) {
92+
case kElemInt8: {
93+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(data_, index_hnsw);
94+
abstract_hnsw.Free();
95+
break;
96+
}
8697
case kElemFloat: {
8798
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(data_, index_hnsw);
8899
abstract_hnsw.Free();
@@ -104,6 +115,12 @@ void HnswFileWorker::CompressToLVQ(IndexHnsw *index_hnsw) {
104115
}
105116
EmbeddingDataType embedding_type = GetType();
106117
switch (embedding_type) {
118+
case kElemInt8: {
119+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(data_, index_hnsw);
120+
abstract_hnsw.CompressToLVQ();
121+
data_ = abstract_hnsw.RawPtr();
122+
break;
123+
}
107124
case kElemFloat: {
108125
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(data_, index_hnsw);
109126
abstract_hnsw.CompressToLVQ();
@@ -125,6 +142,11 @@ void HnswFileWorker::WriteToFileImpl(bool to_spill, bool &prepare_success) {
125142
const IndexHnsw *index_hnsw = static_cast<const IndexHnsw *>(index_base_.get());
126143
EmbeddingDataType embedding_type = GetType();
127144
switch (embedding_type) {
145+
case kElemInt8: {
146+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(data_, index_hnsw);
147+
abstract_hnsw.Save(*file_handler_);
148+
break;
149+
}
128150
case kElemFloat: {
129151
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(data_, index_hnsw);
130152
abstract_hnsw.Save(*file_handler_);
@@ -143,6 +165,12 @@ void HnswFileWorker::ReadFromFileImpl() {
143165
const IndexHnsw *index_hnsw = static_cast<const IndexHnsw *>(index_base_.get());
144166
EmbeddingDataType embedding_type = GetType();
145167
switch (embedding_type) {
168+
case kElemInt8: {
169+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(nullptr, index_hnsw);
170+
abstract_hnsw.Load(*file_handler_);
171+
data_ = abstract_hnsw.RawPtr();
172+
break;
173+
}
146174
case kElemFloat: {
147175
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(nullptr, index_hnsw);
148176
abstract_hnsw.Load(*file_handler_);

src/storage/knn_index/knn_hnsw/dist_func_l2.cppm

+22
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,28 @@ public:
7474
}
7575
#else
7676
SIMDFunc = F32IPBF;
77+
#endif
78+
} else if constexpr (std::is_same<DataType, i8>()) {
79+
#if defined(USE_AVX512)
80+
if (dim % 64 == 0) {
81+
SIMDFunc = I8IPAVX512;
82+
} else {
83+
SIMDFunc = I8IPAVX512Residual;
84+
}
85+
#elif defined(USE_AVX)
86+
if (dim % 16 == 0) {
87+
SIMDFunc = I8IPAVX;
88+
} else {
89+
SIMDFunc = I8IPAVXResidual;
90+
}
91+
#elif defined(USE_SSE)
92+
if (dim % 16 == 0) {
93+
SIMDFunc = I8IPSSE;
94+
} else {
95+
SIMDFunc = I8IPSSEResidual;
96+
}
97+
#else
98+
SIMDFunc = I8L2BF;
7799
#endif
78100
}
79101
}

src/storage/knn_index/knn_hnsw/simd_func.cppm

+149-11
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <immintrin.h>
1516
module;
1617

17-
#include <cmath>
1818
#include "../header.h"
19+
#include <cmath>
20+
#include <cstdint>
1921

2022
import stl;
2123

@@ -91,9 +93,7 @@ export float F32CosAVX512(const float *pv1, const float *pv2, size_t dim) {
9193
return mul_res != 0 ? mul_res / sqrt(v1_res * v2_res) : 0;
9294
}
9395

94-
export float F32CosAVX512Residual(const float *pv1, const float *pv2, size_t dim) {
95-
return F32CosAVX512(pv1, pv2, dim);
96-
}
96+
export float F32CosAVX512Residual(const float *pv1, const float *pv2, size_t dim) { return F32CosAVX512(pv1, pv2, dim); }
9797

9898
#endif
9999

@@ -151,9 +151,7 @@ export float F32CosAVX(const float *pv1, const float *pv2, size_t dim) {
151151
return mul_res != 0 ? mul_res / sqrt(v1_res * v2_res) : 0;
152152
}
153153

154-
export float F32CosAVXResidual(const float *pv1, const float *pv2, size_t dim) {
155-
return F32CosAVX(pv1, pv2, dim);
156-
}
154+
export float F32CosAVXResidual(const float *pv1, const float *pv2, size_t dim) { return F32CosAVX(pv1, pv2, dim); }
157155

158156
#endif
159157

@@ -211,7 +209,7 @@ export float F32CosSSE(const float *pv1, const float *pv2, size_t dim) {
211209
_mm_store_ps(V1TmpRes, norm_v1);
212210
_mm_store_ps(V2TmpRes, norm_v2);
213211

214-
float mul_res = MulTmpRes[0] + MulTmpRes[1] + MulTmpRes[2] + MulTmpRes[3];
212+
float mul_res = MulTmpRes[0] + MulTmpRes[1] + MulTmpRes[2] + MulTmpRes[3];
215213
float v1_res = V1TmpRes[0] + V1TmpRes[1] + V1TmpRes[2] + V1TmpRes[3];
216214
float v2_res = V2TmpRes[0] + V2TmpRes[1] + V2TmpRes[2] + V2TmpRes[3];
217215

@@ -227,9 +225,7 @@ export float F32CosSSE(const float *pv1, const float *pv2, size_t dim) {
227225
return mul_res != 0 ? mul_res / sqrt(v1_res * v2_res) : 0;
228226
}
229227

230-
export float F32CosSSEResidual(const float *pv1, const float *pv2, size_t dim) {
231-
return F32CosSSE(pv1, pv2, dim);
232-
}
228+
export float F32CosSSEResidual(const float *pv1, const float *pv2, size_t dim) { return F32CosSSE(pv1, pv2, dim); }
233229

234230
#endif
235231

@@ -354,6 +350,148 @@ export int32_t I8IPSSEResidual(const int8_t *pv1, const int8_t *pv2, size_t dim)
354350

355351
//------------------------------//------------------------------//------------------------------
356352

353+
export signed char I8L2BF(const int8_t *pv1, const int8_t *pv2, size_t dim) {
354+
int32_t res = 0;
355+
for (size_t i = 0; i < dim; i++) {
356+
int32_t t = pv1[i] - pv2[i];
357+
res += t * t;
358+
}
359+
return (signed char)res;
360+
}
361+
362+
#if defined(USE_AVX512)
363+
364+
export signed char I8L2AVX512(const int8_t *pv1, const int8_t *pv2, size_t dim) {
365+
int8_t PORTABLE_ALIGN64 TmpRes[64];
366+
size_t dim16 = dim >> 4;
367+
368+
const int8_t *pEnd1 = pv1 + (dim16 << 4);
369+
370+
__m512i diff, v1, v2;
371+
__m512i sum = __mm512_set1_ps(0);
372+
373+
while (pv1 < pEnd1) {
374+
v1 = _mm512_loadu_si512(pv1);
375+
pv1 += 16;
376+
v2 = _mm512_loadu_si512(pv2);
377+
pv2 += 16;
378+
diff = _mm512_sub_epi8(v1, v2);
379+
sum = _mm512_add_epi8(sum, _mm512_mul_epi8(diff, diff));
380+
}
381+
382+
_mm512_store_epi8(TmpRes, sum);
383+
int32_t res = 0;
384+
for (size_t i = 0; i < 64; i++) {
385+
res += TmpRes[i];
386+
}
387+
388+
return (signed char)res;
389+
}
390+
391+
export signed char I8L2AVX512Residual(const int8_t *pv1, const int8_t *pv2, size_t dim) {
392+
return I8L2AVX512(pv1, pv2, dim) + I8L2BF(pv1 + (dim & ~63), pv2 + (dim & ~63), dim & 63);
393+
}
394+
#endif
395+
396+
#if defined(USE_AVX)
397+
398+
export signed char I8L2AVX(const int8_t *pv1, const int8_t *pv2, size_t dim) {
399+
int8_t PORTABLE_ALIGN32 TmpRes[32];
400+
size_t dim16 = dim >> 4;
401+
402+
const int8_t *pEnd1 = pv1 + (dim16 << 4);
403+
404+
__m256i diff, v1, v2;
405+
__m256i sum = _mm256_set1_epi8(0);
406+
__m512i diff16, lo, hi;
407+
408+
while (pv1 < pEnd1) {
409+
v1 = _mm256_loadu_epi8(pv1);
410+
pv1 += 8;
411+
v2 = _mm256_loadu_epi8(pv2);
412+
pv2 += 8;
413+
diff = _mm256_sub_epi8(v1, v2);
414+
diff16 = _mm512_cvtepi8_epi16(diff);
415+
lo = _mm512_extracti64x4_epi64(diff16, 0);
416+
hi = _mm512_extracti64x4_epi64(diff16, 1);
417+
sum = _mm256_add_epi8(sum, _mm512_mullo_epi16(diff16, diff16));
418+
419+
v1 = _mm256_loadu_epi8(pv1);
420+
pv1 += 8;
421+
v2 = _mm256_loadu_epi8(pv2);
422+
pv2 += 8;
423+
diff = _mm256_sub_epi8(v1, v2);
424+
sum = _mm256_add_epi8(sum, _mm256_mul_epi8(diff, diff));
425+
}
426+
427+
_mm256_storeu_epi8(TmpRes, sum);
428+
int32_t res = 0;
429+
for (size_t i = 0; i < 32; i++) {
430+
res += TmpRes[i];
431+
}
432+
return (signed char)res;
433+
}
434+
435+
export signed char I8L2AVXResidual(const int8_t *pv1, const int8_t *pv2, size_t dim) {
436+
return I8L2AVX(pv1, pv2, dim) + I8L2BF(pv1 + (dim & ~31), pv2 + (dim & ~31), dim & 31);
437+
}
438+
439+
#endif
440+
441+
#if defined(USE_SSE)
442+
443+
export signed char I8L2SSE(const int8_t *pv1, const int8_t *pv2, size_t dim) {
444+
alignas(16) int32_t TmpRes[4];
445+
size_t dim16 = dim >> 4;
446+
447+
const int8_t *pEnd1 = pv1 + (dim16 << 4);
448+
449+
__m128i diff, v1, v2;
450+
__m128i sum = _mm_set1_ps(0);
451+
452+
while (pv1 < pEnd1) {
453+
v1 = _mm_loadu_epi8(pv1);
454+
pv1 += 4;
455+
v2 = _mm_loadu_epi8(pv2);
456+
pv2 += 4;
457+
diff = _mm_sub_epi8(v1, v2);
458+
sum = _mm_add_epi8(sum, _mm_mul_epi8(diff, diff));
459+
460+
v1 = _mm_loadu_epi8(pv1);
461+
pv1 += 4;
462+
v2 = _mm_loadu_epi8(pv2);
463+
pv2 += 4;
464+
diff = _mm_sub_epi8(v1, v2);
465+
sum = _mm_add_epi8(sum, _mm_mul_epi8(diff, diff));
466+
467+
v1 = _mm_loadu_epi8(pv1);
468+
pv1 += 4;
469+
v2 = _mm_loadu_epi8(pv2);
470+
pv2 += 4;
471+
diff = _mm_sub_epi8(v1, v2);
472+
sum = _mm_add_epi8(sum, _mm_mul_epi8(diff, diff));
473+
474+
v1 = _mm_loadu_epi8(pv1);
475+
pv1 += 4;
476+
v2 = _mm_loadu_epi8(pv2);
477+
pv2 += 4;
478+
diff = _mm_sub_epi8(v1, v2);
479+
sum = _mm_add_epi8(sum, _mm_mul_epi8(diff, diff));
480+
}
481+
482+
_mm_storeu_epi8(TmpRes, sum);
483+
int32_t res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
484+
return (signed char)res;
485+
}
486+
487+
export signed char I8L2SSEResidual(const int8_t *pv1, const int8_t *pv2, size_t dim) {
488+
return I8L2SSE(pv1, pv2, dim) + I8L2BF(pv1 + (dim & ~15), pv2 + (dim & ~15), dim & 15);
489+
}
490+
491+
#endif
492+
493+
//------------------------------//------------------------------//------------------------------
494+
357495
export float F32L2BF(const float *pv1, const float *pv2, size_t dim) {
358496
float res = 0;
359497
for (size_t i = 0; i < dim; i++) {

src/storage/meta/entry/segment_index_entry.cpp

+58
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,13 @@ void SegmentIndexEntry::MemIndexInsert(SharedPtr<BlockEntry> block_entry,
300300
BlockColumnEntry *block_column_entry = block_entry->GetColumnBlockEntry(column_id);
301301
SizeT row_cnt = 0;
302302
switch (embedding_info->Type()) {
303+
case kElemInt8: {
304+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
305+
MemIndexInserterIter<i8> iter(block_offset, block_column_entry, buffer_manager, row_offset, row_count);
306+
auto [start_i, end_i] = abstract_hnsw.InsertVecs(std::move(iter));
307+
row_cnt = end_i;
308+
break;
309+
}
303310
case kElemFloat: {
304311
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
305312
MemIndexInserterIter<f32> iter(block_offset, block_column_entry, buffer_manager, row_offset, row_count);
@@ -532,6 +539,34 @@ void SegmentIndexEntry::PopulateEntirely(const SegmentEntry *segment_entry, Txn
532539
BufferHandle buffer_handle = chunk_index_entry->GetIndex();
533540

534541
switch (embedding_info->Type()) {
542+
case kElemInt8: {
543+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
544+
auto InsertHnswInner = [&](auto &iter) {
545+
HnswInsertConfig insert_config;
546+
insert_config.optimize_ = true;
547+
SegmentOffset start_i, end_i;
548+
if (!config.prepare_) {
549+
// Single thread insert
550+
std::tie(start_i, end_i) = abstract_hnsw.InsertVecs(std::move(iter), insert_config);
551+
} else {
552+
// Multi thread insert data, write file in the physical create index finish stage.
553+
std::tie(start_i, end_i) = abstract_hnsw.StoreData(std::move(iter), insert_config);
554+
}
555+
LOG_TRACE(fmt::format("Insert index: {} - {}", start_i, end_i));
556+
return end_i - start_i;
557+
};
558+
SegmentOffset row_count = 0;
559+
if (config.check_ts_) {
560+
OneColumnIterator<i8> iter(segment_entry, buffer_mgr, column_def->id(), begin_ts);
561+
row_count = InsertHnswInner(iter);
562+
} else {
563+
// Not check ts in uncommitted segment when compact segment
564+
OneColumnIterator<i8, false> iter(segment_entry, buffer_mgr, column_def->id(), begin_ts);
565+
row_count = InsertHnswInner(iter);
566+
}
567+
chunk_index_entry->SetRowCount(row_count);
568+
break;
569+
}
535570
case kElemFloat: {
536571
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
537572
auto InsertHnswInner = [&](auto &iter) {
@@ -730,6 +765,17 @@ Status SegmentIndexEntry::CreateIndexDo(atomic_u64 &create_index_idx) {
730765
BufferHandle buffer_handle = chunk_index_entry->GetIndex();
731766

732767
switch (embedding_info->Type()) {
768+
case kElemInt8: {
769+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
770+
while (true) {
771+
SizeT idx = create_index_idx.fetch_add(1);
772+
if (idx >= row_count) {
773+
break;
774+
}
775+
abstract_hnsw.Build(offset + idx);
776+
}
777+
break;
778+
}
733779
case kElemFloat: {
734780
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
735781
while (true) {
@@ -992,6 +1038,18 @@ ChunkIndexEntry *SegmentIndexEntry::RebuildChunkIndexEntries(TxnTableStore *txn_
9921038
BufferHandle buffer_handle = merged_chunk_index_entry->GetIndex();
9931039

9941040
switch (embedding_info->Type()) {
1041+
case kElemInt8: {
1042+
AbstractHnsw<i8, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
1043+
OneColumnIterator<i8> iter(segment_entry, buffer_mgr, column_def->id(), begin_ts);
1044+
HnswInsertConfig insert_config;
1045+
insert_config.optimize_ = true;
1046+
auto [start_i, end_i] = abstract_hnsw.InsertVecs(std::move(iter), insert_config);
1047+
if (end_i - start_i != row_count) {
1048+
String error_message = "Rebuild HNSW index failed.";
1049+
UnrecoverableError(error_message);
1050+
}
1051+
break;
1052+
}
9951053
case kElemFloat: {
9961054
AbstractHnsw<f32, SegmentOffset> abstract_hnsw(buffer_handle.GetDataMut(), index_hnsw);
9971055
OneColumnIterator<float, true /*check ts*/> iter(segment_entry, buffer_mgr, column_def->id(), begin_ts);

0 commit comments

Comments
 (0)