Skip to content

Commit ac7554c

Browse files
committedJun 17, 2024
Support cosine similarity
1 parent 3bd70ca commit ac7554c

22 files changed

+963
-101
lines changed
 

‎src/executor/operator/physical_scan/physical_knn_scan.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
362362
IVFFlatScanTemplate.template operator()<AnnIVFFlatIP<DataType>>(std::forward<OptionalFilter>(filter)...);
363363
break;
364364
}
365+
case KnnDistanceType::kCosine: {
366+
IVFFlatScanTemplate.template operator()<AnnIVFFlatCOS<DataType>>(std::forward<OptionalFilter>(filter)...);
367+
break;
368+
}
365369
default: {
366370
Status status = Status::NotSupport("Not implemented KNN distance");
367371
LOG_ERROR(status.message());
@@ -453,6 +457,7 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
453457
case KnnDistanceType::kHamming: {
454458
break;
455459
}
460+
// FIXME:
456461
case KnnDistanceType::kCosine:
457462
case KnnDistanceType::kInnerProduct: {
458463
for (i64 i = 0; i < result_n; ++i) {

‎src/function/table/knn_scan_data.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ KnnDistance1<f32>::KnnDistance1(KnnDistanceType dist_type) {
5252
dist_func_ = L2Distance<f32, f32, f32, SizeT>;
5353
break;
5454
}
55+
case KnnDistanceType::kCosine: {
56+
dist_func_ = CosineDistance<f32, f32, f32, SizeT>;
57+
break;
58+
}
5559
case KnnDistanceType::kInnerProduct: {
5660
dist_func_ = IPDistance<f32, f32, f32, SizeT>;
5761
break;

‎src/planner/bind_context.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -376,17 +376,17 @@ void BindContext::BoundSearch(ParsedExpr *expr) {
376376
}
377377
auto search_expr = (SearchExpr *)expr;
378378

379-
if(!search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty()) {
379+
if (!search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty()) {
380380
SizeT expr_count = search_expr->knn_exprs_.size();
381-
KnnExpr* first_knn = search_expr->knn_exprs_[0];
381+
KnnExpr *first_knn = search_expr->knn_exprs_[0];
382382
KnnDistanceType first_distance_type = first_knn->distance_type_;
383-
for(SizeT idx = 1; idx < expr_count; ++ idx) {
384-
if(search_expr->knn_exprs_[idx]->distance_type_ != first_distance_type) {
383+
for (SizeT idx = 1; idx < expr_count; ++idx) {
384+
if (search_expr->knn_exprs_[idx]->distance_type_ != first_distance_type) {
385385
// Mixed distance type
386-
return ;
386+
return;
387387
}
388388
}
389-
switch(first_distance_type) {
389+
switch (first_distance_type) {
390390
case KnnDistanceType::kL2:
391391
case KnnDistanceType::kHamming: {
392392
allow_distance = true;

‎src/storage/definition/index_base.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ namespace infinity {
3737

3838
String MetricTypeToString(MetricType metric_type) {
3939
switch (metric_type) {
40+
case MetricType::kMetricCosine: {
41+
return "cos";
42+
}
4043
case MetricType::kMetricInnerProduct: {
4144
return "ip";
4245
}
@@ -50,7 +53,9 @@ String MetricTypeToString(MetricType metric_type) {
5053
}
5154

5255
MetricType StringToMetricType(const String &str) {
53-
if (str == "ip") {
56+
if (str == "cos") {
57+
return MetricType::kMetricCosine;
58+
} else if (str == "ip") {
5459
return MetricType::kMetricInnerProduct;
5560
} else if (str == "l2") {
5661
return MetricType::kMetricL2;

‎src/storage/definition/index_base.cppm

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace infinity {
2525

2626
// TODO shenyushi: use definition in knn_exprs.h
2727
export enum class MetricType {
28+
kMetricCosine,
2829
kMetricInnerProduct,
2930
kMetricL2,
3031
kInvalid,

‎src/storage/knn_index/ann_ivf/ann_ivf_flat.cppm

+6-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ class AnnIVFFlat final : public KnnDistance<typename Compare::DistanceType> {
3838
using DistType = typename Compare::DistanceType;
3939
using ResultHandler = ReservoirResultHandler<Compare>;
4040
static inline DistType Distance(const DistType *x, const DistType *y, u32 dimension) {
41-
if constexpr (metric == MetricType::kMetricL2) {
41+
if constexpr (metric == MetricType::kMetricCosine) {
42+
return CosineDistance<DistType>(x, y, dimension);
43+
} else if constexpr (metric == MetricType::kMetricL2) {
4244
return L2Distance<DistType>(x, y, dimension);
4345
} else if constexpr (metric == MetricType::kMetricInnerProduct) {
4446
return IPDistance<DistType>(x, y, dimension);
@@ -285,4 +287,7 @@ using AnnIVFFlatL2 = AnnIVFFlat<CompareMax<DistType, RowID>, MetricType::kMetric
285287
export template <typename DistType>
286288
using AnnIVFFlatIP = AnnIVFFlat<CompareMin<DistType, RowID>, MetricType::kMetricInnerProduct, KnnDistanceAlgoType::kKnnFlatIp>;
287289

290+
export template <typename DistType>
291+
using AnnIVFFlatCOS = AnnIVFFlat<CompareMin<DistType, RowID>, MetricType::kMetricCosine, KnnDistanceAlgoType::kKnnFlatCosine>;
292+
288293
}; // namespace infinity

‎src/storage/knn_index/ann_ivf/some_simd_functions.cppm

+133-78
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
module;
1616

17+
#include <cmath>
18+
1719
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
1820
#include <immintrin.h>
1921
#elif defined(__GNUC__) && defined(__aarch64__)
@@ -32,105 +34,158 @@ namespace infinity {
3234

3335
// x = ( x7, x6, x5, x4, x3, x2, x1, x0 )
3436
float calc_256_sum_8(__m256 x) {
35-
// high_quad = ( x7, x6, x5, x4 )
36-
const __m128 high_quad = _mm256_extractf128_ps(x, 1);
37-
// low_quad = ( x3, x2, x1, x0 )
38-
const __m128 low_quad = _mm256_castps256_ps128(x);
39-
// sum_quad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 )
40-
const __m128 sum_quad = _mm_add_ps(low_quad, high_quad);
41-
// low_dual = ( -, -, x1 + x5, x0 + x4 )
42-
const __m128 low_dual = sum_quad;
43-
// high_dual = ( -, -, x3 + x7, x2 + x6 )
44-
const __m128 high_dual = _mm_movehl_ps(sum_quad, sum_quad);
45-
// sum_dual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 )
46-
const __m128 sum_dual = _mm_add_ps(low_dual, high_dual);
47-
// low = ( -, -, -, x0 + x2 + x4 + x6 )
48-
const __m128 low = sum_dual;
49-
// high = ( -, -, -, x1 + x3 + x5 + x7 )
50-
const __m128 high = _mm_shuffle_ps(sum_dual, sum_dual, 0x1);
51-
// sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 )
52-
const __m128 sum = _mm_add_ss(low, high);
53-
return _mm_cvtss_f32(sum);
37+
// high_quad = ( x7, x6, x5, x4 )
38+
const __m128 high_quad = _mm256_extractf128_ps(x, 1);
39+
// low_quad = ( x3, x2, x1, x0 )
40+
const __m128 low_quad = _mm256_castps256_ps128(x);
41+
// sum_quad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 )
42+
const __m128 sum_quad = _mm_add_ps(low_quad, high_quad);
43+
// low_dual = ( -, -, x1 + x5, x0 + x4 )
44+
const __m128 low_dual = sum_quad;
45+
// high_dual = ( -, -, x3 + x7, x2 + x6 )
46+
const __m128 high_dual = _mm_movehl_ps(sum_quad, sum_quad);
47+
// sum_dual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 )
48+
const __m128 sum_dual = _mm_add_ps(low_dual, high_dual);
49+
// low = ( -, -, -, x0 + x2 + x4 + x6 )
50+
const __m128 low = sum_dual;
51+
// high = ( -, -, -, x1 + x3 + x5 + x7 )
52+
const __m128 high = _mm_shuffle_ps(sum_dual, sum_dual, 0x1);
53+
// sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 )
54+
const __m128 sum = _mm_add_ss(low, high);
55+
return _mm_cvtss_f32(sum);
5456
}
5557

5658
#endif
5759

5860
#if defined(__AVX2__)
5961

6062
export f32 L2Distance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) {
61-
u32 i = 0;
62-
__m256 sum_1 = _mm256_setzero_ps();
63-
__m256 sum_2 = _mm256_setzero_ps();
64-
_mm_prefetch(vector1, _MM_HINT_NTA);
65-
_mm_prefetch(vector2, _MM_HINT_NTA);
66-
for (; i + 16 <= dimension; i += 16) {
67-
_mm_prefetch(vector1 + i + 16, _MM_HINT_NTA);
68-
_mm_prefetch(vector2 + i + 16, _MM_HINT_NTA);
69-
auto diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
70-
auto diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + i + 8), _mm256_loadu_ps(vector2 + i + 8));
71-
auto mul_1 = _mm256_mul_ps(diff_1, diff_1);
72-
auto mul_2 = _mm256_mul_ps(diff_2, diff_2);
73-
// add mul to sum
74-
sum_1 = _mm256_add_ps(sum_1, mul_1);
75-
sum_2 = _mm256_add_ps(sum_2, mul_2);
76-
}
77-
if (i + 8 <= dimension) {
78-
auto diff = _mm256_sub_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
79-
auto mul = _mm256_mul_ps(diff, diff);
80-
sum_1 = _mm256_add_ps(sum_1, mul);
81-
i += 8;
82-
}
83-
f32 distance = calc_256_sum_8(sum_1) + calc_256_sum_8(sum_2);
84-
for (; i < dimension; ++i) {
85-
auto diff = vector1[i] - vector2[i];
86-
distance += diff * diff;
87-
}
88-
return distance;
63+
u32 i = 0;
64+
__m256 sum_1 = _mm256_setzero_ps();
65+
__m256 sum_2 = _mm256_setzero_ps();
66+
_mm_prefetch(vector1, _MM_HINT_NTA);
67+
_mm_prefetch(vector2, _MM_HINT_NTA);
68+
for (; i + 16 <= dimension; i += 16) {
69+
_mm_prefetch(vector1 + i + 16, _MM_HINT_NTA);
70+
_mm_prefetch(vector2 + i + 16, _MM_HINT_NTA);
71+
auto diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
72+
auto diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + i + 8), _mm256_loadu_ps(vector2 + i + 8));
73+
auto mul_1 = _mm256_mul_ps(diff_1, diff_1);
74+
auto mul_2 = _mm256_mul_ps(diff_2, diff_2);
75+
// add mul to sum
76+
sum_1 = _mm256_add_ps(sum_1, mul_1);
77+
sum_2 = _mm256_add_ps(sum_2, mul_2);
78+
}
79+
if (i + 8 <= dimension) {
80+
auto diff = _mm256_sub_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
81+
auto mul = _mm256_mul_ps(diff, diff);
82+
sum_1 = _mm256_add_ps(sum_1, mul);
83+
i += 8;
84+
}
85+
f32 distance = calc_256_sum_8(sum_1) + calc_256_sum_8(sum_2);
86+
for (; i < dimension; ++i) {
87+
auto diff = vector1[i] - vector2[i];
88+
distance += diff * diff;
89+
}
90+
return distance;
8991
}
9092

91-
#elif defined(__SSE__)
93+
#elif defined(__SSE__)
9294

93-
export f32 L2Distance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) {
94-
return F32L2SSEResidual(vector1, vector2, dimension);
95-
}
95+
export f32 L2Distance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) { return F32L2SSEResidual(vector1, vector2, dimension); }
9696

9797
#endif
9898

9999
#if defined(__AVX2__)
100100

101-
export f32 IPDistance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) {
102-
u32 i = 0;
103-
__m256 sum_1 = _mm256_setzero_ps();
104-
__m256 sum_2 = _mm256_setzero_ps();
105-
_mm_prefetch(vector1, _MM_HINT_NTA);
106-
_mm_prefetch(vector2, _MM_HINT_NTA);
107-
for (; i + 16 <= dimension; i += 16) {
108-
_mm_prefetch(vector1 + i + 16, _MM_HINT_NTA);
109-
_mm_prefetch(vector2 + i + 16, _MM_HINT_NTA);
110-
auto mul_1 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
111-
auto mul_2 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i + 8), _mm256_loadu_ps(vector2 + i + 8));
112-
// add mul to sum
113-
sum_1 = _mm256_add_ps(sum_1, mul_1);
114-
sum_2 = _mm256_add_ps(sum_2, mul_2);
115-
}
116-
if (i + 8 <= dimension) {
117-
auto mul = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
118-
sum_1 = _mm256_add_ps(sum_1, mul);
119-
i += 8;
120-
}
121-
f32 distance = calc_256_sum_8(sum_1) + calc_256_sum_8(sum_2);
122-
for (; i < dimension; ++i) {
123-
distance += vector1[i] * vector2[i];
124-
}
125-
return distance;
101+
export f32 CosineDistance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) {
102+
u32 i = 0;
103+
__m256 dot_sum_1 = _mm256_setzero_ps();
104+
__m256 dot_sum_2 = _mm256_setzero_ps();
105+
__m256 norm_v1_1 = _mm256_setzero_ps();
106+
__m256 norm_v1_2 = _mm256_setzero_ps();
107+
__m256 norm_v2_1 = _mm256_setzero_ps();
108+
__m256 norm_v2_2 = _mm256_setzero_ps();
109+
_mm_prefetch(vector1, _MM_HINT_NTA);
110+
_mm_prefetch(vector2, _MM_HINT_NTA);
111+
for (; i + 16 <= dimension; i += 16) {
112+
_mm_prefetch(vector1 + i + 16, _MM_HINT_NTA);
113+
_mm_prefetch(vector2 + i + 16, _MM_HINT_NTA);
114+
auto dot_mul_1 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
115+
auto dot_mul_2 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i + 8), _mm256_loadu_ps(vector2 + i + 8));
116+
auto norm_mul_v1_1 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector1 + i));
117+
auto norm_mul_v1_2 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i + 8), _mm256_loadu_ps(vector1 + i + 8));
118+
auto norm_mul_v2_1 = _mm256_mul_ps(_mm256_loadu_ps(vector2 + i), _mm256_loadu_ps(vector2 + i));
119+
auto norm_mul_v2_2 = _mm256_mul_ps(_mm256_loadu_ps(vector2 + i + 8), _mm256_loadu_ps(vector2 + i + 8));
120+
// add mul to sum
121+
dot_sum_1 = _mm256_add_ps(dot_sum_1, dot_mul_1);
122+
dot_sum_2 = _mm256_add_ps(dot_sum_2, dot_mul_2);
123+
norm_v1_1 = _mm256_add_ps(norm_v1_1, norm_mul_v1_1);
124+
norm_v1_2 = _mm256_add_ps(norm_v1_2, norm_mul_v1_2);
125+
norm_v2_1 = _mm256_add_ps(norm_v2_1, norm_mul_v2_1);
126+
norm_v2_2 = _mm256_add_ps(norm_v2_2, norm_mul_v2_2);
127+
}
128+
if (i + 8 <= dimension) {
129+
auto dot_mul = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
130+
auto norm_mul_v1 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector1 + i));
131+
auto norm_mul_v2 = _mm256_mul_ps(_mm256_loadu_ps(vector2 + i), _mm256_loadu_ps(vector2 + i));
132+
133+
dot_sum_1 = _mm256_add_ps(dot_sum_1, dot_mul);
134+
norm_v1_1 = _mm256_add_ps(norm_v1_1, norm_mul_v1);
135+
norm_v2_1 = _mm256_add_ps(norm_v2_1, norm_mul_v2);
136+
i += 8;
137+
}
138+
139+
f32 dot = calc_256_sum_8(dot_sum_1) + calc_256_sum_8(dot_sum_2);
140+
f32 norm_v1 = calc_256_sum_8(norm_v1_1) + calc_256_sum_8(norm_v1_2);
141+
f32 norm_v2 = calc_256_sum_8(norm_v2_1) + calc_256_sum_8(norm_v2_2);
142+
for (; i < dimension; ++i) {
143+
dot += vector1[i] * vector2[i];
144+
norm_v1 += vector1[i] * vector1[i];
145+
norm_v2 += vector2[i] * vector2[i];
146+
}
147+
return dot != 0 ? dot / sqrt(norm_v1 * norm_v2) : 0;
126148
}
127149

128-
#elif defined(__SSE__)
150+
#elif defined(__SSE__)
151+
152+
export f32 CosineDistance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) { return F32CosSSEResidual(vector1, vector2, dimension); }
153+
154+
#endif
155+
156+
#if defined(__AVX2__)
129157

130158
export f32 IPDistance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) {
131-
return F32IPSSEResidual(vector1, vector2, dimension);
159+
u32 i = 0;
160+
__m256 sum_1 = _mm256_setzero_ps();
161+
__m256 sum_2 = _mm256_setzero_ps();
162+
_mm_prefetch(vector1, _MM_HINT_NTA);
163+
_mm_prefetch(vector2, _MM_HINT_NTA);
164+
for (; i + 16 <= dimension; i += 16) {
165+
_mm_prefetch(vector1 + i + 16, _MM_HINT_NTA);
166+
_mm_prefetch(vector2 + i + 16, _MM_HINT_NTA);
167+
auto mul_1 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
168+
auto mul_2 = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i + 8), _mm256_loadu_ps(vector2 + i + 8));
169+
// add mul to sum
170+
sum_1 = _mm256_add_ps(sum_1, mul_1);
171+
sum_2 = _mm256_add_ps(sum_2, mul_2);
172+
}
173+
if (i + 8 <= dimension) {
174+
auto mul = _mm256_mul_ps(_mm256_loadu_ps(vector1 + i), _mm256_loadu_ps(vector2 + i));
175+
sum_1 = _mm256_add_ps(sum_1, mul);
176+
i += 8;
177+
}
178+
f32 distance = calc_256_sum_8(sum_1) + calc_256_sum_8(sum_2);
179+
for (; i < dimension; ++i) {
180+
distance += vector1[i] * vector2[i];
181+
}
182+
return distance;
132183
}
133184

185+
#elif defined(__SSE__)
186+
187+
export f32 IPDistance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) { return F32IPSSEResidual(vector1, vector2, dimension); }
188+
134189
#endif
135190

136191
} // namespace infinity

‎src/storage/knn_index/ann_ivf/vector_distance.cppm

+18
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module;
1616
#include <type_traits>
1717
import stl;
1818
import some_simd_functions;
19+
import hnsw_simd_func;
1920

2021
export module vector_distance;
2122

@@ -34,6 +35,23 @@ DiffType L2Distance(const ElemType1 *vector1, const ElemType2 *vector2, const Di
3435
}
3536
}
3637

38+
export template <typename DiffType, typename ElemType1, typename ElemType2, typename DimType = u32>
39+
DiffType CosineDistance(const ElemType1 *vector1, const ElemType2 *vector2, const DimType dimension) {
40+
if constexpr (std::is_same_v<ElemType1, f32> && std::is_same_v<ElemType2, f32>) {
41+
return F32CosAVX(vector1, vector2, dimension);
42+
} else {
43+
DiffType dot_product{};
44+
DiffType norm1{};
45+
DiffType norm2{};
46+
for (u32 i = 0; i < dimension; ++i) {
47+
dot_product += ((DiffType)vector1[i]) * ((DiffType)vector2[i]);
48+
norm1 += ((DiffType)vector1[i]) * ((DiffType)vector1[i]);
49+
norm2 += ((DiffType)vector2[i]) * ((DiffType)vector2[i]);
50+
}
51+
return dot_product != 0 ? dot_product / sqrt(norm1 * norm2) : 0;
52+
}
53+
}
54+
3755
export template <typename DiffType, typename ElemType1, typename ElemType2, typename DimType = u32>
3856
DiffType IPDistance(const ElemType1 *vector1, const ElemType2 *vector2, const DimType dimension) {
3957
if constexpr (std::is_same_v<ElemType1, f32> && std::is_same_v<ElemType2, f32>) {

‎src/storage/knn_index/knn_distance.cppm

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace infinity {
2626

2727
export enum class KnnDistanceAlgoType {
2828
kInvalid,
29+
kKnnFlatCosine,
2930
kKnnFlatIp,
3031
kKnnFlatIpReservoir,
3132
kKnnFlatIpBlas,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
module;
16+
17+
#include <functional>
18+
19+
export module knn_flat_cos;
20+
21+
import stl;
22+
import knn_distance;
23+
import logger;
24+
import knn_result_handler;
25+
import infinity_exception;
26+
import default_values;
27+
import vector_distance;
28+
import bitmask;
29+
import knn_expr;
30+
import internal_types;
31+
32+
namespace infinity {
33+
34+
export template <typename DistType>
35+
class KnnFlatCOS final : public KnnDistance<DistType> {
36+
using ResultHandler = HeapResultHandler<CompareMin<DistType, RowID>>;
37+
38+
public:
39+
explicit KnnFlatCOS(const DistType *queries, i64 query_count, i64 topk, i64 dimension, EmbeddingDataType elem_data_type)
40+
: KnnDistance<DistType>(KnnDistanceAlgoType::kKnnFlatCosine, elem_data_type, query_count, dimension, topk), queries_(queries) {
41+
42+
id_array_ = MakeUniqueForOverwrite<RowID[]>(topk * query_count);
43+
distance_array_ = MakeUniqueForOverwrite<DistType[]>(topk * query_count);
44+
result_handler_ = MakeUnique<ResultHandler>(query_count, topk, distance_array_.get(), id_array_.get());
45+
}
46+
47+
void Begin() final {
48+
if (begin_ || this->query_count_ == 0) {
49+
return;
50+
}
51+
52+
result_handler_->Begin();
53+
54+
begin_ = true;
55+
}
56+
57+
void Search(const DistType *base, u16 base_count, u32 segment_id, u16 block_id) final {
58+
if (!begin_) {
59+
String error_message = "KnnFlatCOS isn't begin";
60+
LOG_CRITICAL(error_message);
61+
UnrecoverableError(error_message);
62+
}
63+
64+
this->total_base_count_ += base_count;
65+
66+
if (base_count == 0) {
67+
return;
68+
}
69+
70+
u32 segment_offset_start = block_id * DEFAULT_BLOCK_CAPACITY;
71+
72+
for (u64 i = 0; i < this->query_count_; ++i) {
73+
const DistType *x_i = queries_ + i * this->dimension_;
74+
const DistType *y_j = base;
75+
76+
for (u16 j = 0; j < base_count; ++j, y_j += this->dimension_) {
77+
auto cos = CosineDistance<DistType>(x_i, y_j, this->dimension_);
78+
result_handler_->AddResult(i, cos, RowID(segment_id, segment_offset_start + j));
79+
}
80+
}
81+
}
82+
83+
void Search(const DistType *base, u16 base_count, u32 segment_id, u16 block_id, Bitmask &bitmask) final {
84+
if (bitmask.IsAllTrue()) {
85+
Search(base, base_count, segment_id, block_id);
86+
return;
87+
}
88+
if (!begin_) {
89+
String error_message = "KnnFlatCOS isn't begin";
90+
LOG_CRITICAL(error_message);
91+
UnrecoverableError(error_message);
92+
}
93+
94+
this->total_base_count_ += base_count;
95+
96+
if (base_count == 0) {
97+
return;
98+
}
99+
100+
u32 segment_offset_start = block_id * DEFAULT_BLOCK_CAPACITY;
101+
102+
for (u64 i = 0; i < this->query_count_; ++i) {
103+
const DistType *x_i = queries_ + i * this->dimension_;
104+
const DistType *y_j = base;
105+
106+
for (u16 j = 0; j < base_count; ++j, y_j += this->dimension_) {
107+
auto segment_offset = segment_offset_start + j;
108+
if (bitmask.IsTrue(segment_offset)) {
109+
auto cos = CosineDistance<DistType>(x_i, y_j, this->dimension_);
110+
result_handler_->AddResult(i, cos, RowID(segment_id, segment_offset_start + j));
111+
}
112+
}
113+
}
114+
}
115+
116+
void End() final {
117+
if (!begin_) {
118+
return;
119+
}
120+
121+
result_handler_->End();
122+
123+
begin_ = false;
124+
}
125+
126+
[[nodiscard]] inline DistType *GetDistances() const final { return distance_array_.get(); }
127+
128+
[[nodiscard]] inline RowID *GetIDs() const final { return id_array_.get(); }
129+
130+
[[nodiscard]] inline DistType *GetDistanceByIdx(u64 idx) const final {
131+
if (idx >= this->query_count_) {
132+
String error_message = "Query index exceeds the limit";
133+
LOG_CRITICAL(error_message);
134+
UnrecoverableError(error_message);
135+
}
136+
return distance_array_.get() + idx * this->top_k_;
137+
}
138+
139+
[[nodiscard]] inline RowID *GetIDByIdx(u64 idx) const final {
140+
if (idx >= this->query_count_) {
141+
String error_message = "Query index exceeds the limit";
142+
LOG_CRITICAL(error_message);
143+
UnrecoverableError(error_message);
144+
}
145+
return id_array_.get() + idx * this->top_k_;
146+
}
147+
148+
private:
149+
UniquePtr<RowID[]> id_array_{};
150+
UniquePtr<DistType[]> distance_array_{};
151+
152+
UniquePtr<ResultHandler> result_handler_{};
153+
154+
const DistType *queries_{};
155+
bool begin_{false};
156+
};
157+
158+
template class KnnFlatCOS<f32>;
159+
160+
} // namespace infinity

‎src/storage/knn_index/knn_flat/knn_flat_l2.cppm

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public:
7575
const DistType *x_i = queries_ + i * this->dimension_;
7676
const DistType *y_j = base;
7777

78-
for (u16 j = 0; j < base_count; j++, y_j += this->dimension_) {
78+
for (u16 j = 0; j < base_count; ++j, y_j += this->dimension_) {
7979
auto l2 = L2Distance<DistType>(x_i, y_j, this->dimension_);
8080
result_handler_->AddResult(i, l2, RowID(segment_id, segment_offset_start + j));
8181
}

‎src/storage/knn_index/knn_hnsw/abstract_hnsw.cppm

+19-9
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,30 @@ namespace infinity {
3636

3737
export template <typename DataType, typename LabelType>
3838
class AbstractHnsw {
39-
using Hnsw1 = KnnHnsw<PlainIPVecStoreType<DataType>, LabelType>;
40-
using Hnsw2 = KnnHnsw<PlainL2VecStoreType<DataType>, LabelType>;
41-
using Hnsw3 = KnnHnsw<LVQIPVecStoreType<DataType, i8>, LabelType>;
42-
using Hnsw4 = KnnHnsw<LVQL2VecStoreType<DataType, i8>, LabelType>;
39+
using Hnsw1 = KnnHnsw<PlainCosVecStoreType<DataType>, LabelType>;
40+
using Hnsw2 = KnnHnsw<PlainIPVecStoreType<DataType>, LabelType>;
41+
using Hnsw3 = KnnHnsw<PlainL2VecStoreType<DataType>, LabelType>;
42+
using Hnsw4 = KnnHnsw<LVQCosVecStoreType<DataType, i8>, LabelType>;
43+
using Hnsw5 = KnnHnsw<LVQIPVecStoreType<DataType, i8>, LabelType>;
44+
using Hnsw6 = KnnHnsw<LVQL2VecStoreType<DataType, i8>, LabelType>;
4345

4446
public:
4547
AbstractHnsw(void *ptr, const IndexHnsw *index_hnsw) {
4648
switch (index_hnsw->encode_type_) {
4749
case HnswEncodeType::kPlain: {
4850
switch (index_hnsw->metric_type_) {
49-
case MetricType::kMetricInnerProduct: {
51+
case MetricType::kMetricCosine: {
5052
knn_hnsw_ptr_ = reinterpret_cast<Hnsw1 *>(ptr);
5153
break;
5254
}
53-
case MetricType::kMetricL2: {
55+
case MetricType::kMetricInnerProduct: {
5456
knn_hnsw_ptr_ = reinterpret_cast<Hnsw2 *>(ptr);
5557
break;
5658
}
59+
case MetricType::kMetricL2: {
60+
knn_hnsw_ptr_ = reinterpret_cast<Hnsw3 *>(ptr);
61+
break;
62+
}
5763
default: {
5864
String error_message = "HNSW supports inner product and L2 distance.";
5965
LOG_CRITICAL(error_message);
@@ -64,12 +70,16 @@ public:
6470
}
6571
case HnswEncodeType::kLVQ: {
6672
switch (index_hnsw->metric_type_) {
73+
case MetricType::kMetricCosine: {
74+
knn_hnsw_ptr_ = reinterpret_cast<Hnsw4 *>(ptr);
75+
break;
76+
}
6777
case MetricType::kMetricInnerProduct: {
68-
knn_hnsw_ptr_ = reinterpret_cast<Hnsw3 *>(ptr);
78+
knn_hnsw_ptr_ = reinterpret_cast<Hnsw5 *>(ptr);
6979
break;
7080
}
7181
case MetricType::kMetricL2: {
72-
knn_hnsw_ptr_ = reinterpret_cast<Hnsw4 *>(ptr);
82+
knn_hnsw_ptr_ = reinterpret_cast<Hnsw6 *>(ptr);
7383
break;
7484
}
7585
default: {
@@ -172,7 +182,7 @@ public:
172182
}
173183

174184
private:
175-
std::variant<Hnsw1 *, Hnsw2 *, Hnsw3 *, Hnsw4 *> knn_hnsw_ptr_;
185+
std::variant<Hnsw1 *, Hnsw2 *, Hnsw3 *, Hnsw4 *, Hnsw5 *, Hnsw6 *> knn_hnsw_ptr_;
176186
};
177187

178188
} // namespace infinity

‎src/storage/knn_index/knn_hnsw/data_store/data_store.cppm

+13-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module;
1616

1717
#include <cassert>
1818
#include <ostream>
19+
#include <type_traits>
1920

2021
export module data_store;
2122

@@ -44,6 +45,13 @@ public:
4445
using VecStoreMeta = typename VecStoreT::Meta;
4546
using VecStoreInner = typename VecStoreT::Inner;
4647

48+
public:
49+
template <typename T, typename = void>
50+
struct has_compress_type : std::false_type {};
51+
52+
template <typename T>
53+
struct has_compress_type<T, std::void_t<typename T::CompressType>> : std::true_type {};
54+
4755
private:
4856
DataStore(SizeT chunk_size, SizeT max_chunk_n, VecStoreMeta &&vec_store_meta, GraphStoreMeta &&graph_store_meta)
4957
: chunk_size_(chunk_size), max_chunk_n_(max_chunk_n), vec_store_meta_(std::move(vec_store_meta)),
@@ -74,7 +82,11 @@ public:
7482
}
7583

7684
static This Make(SizeT chunk_size, SizeT max_chunk_n, SizeT dim, SizeT Mmax0, SizeT Mmax) {
77-
VecStoreMeta vec_store_meta = VecStoreMeta::Make(dim);
85+
bool normalize = false;
86+
if constexpr (has_compress_type<VecStoreT>::value) {
87+
normalize = true;
88+
}
89+
VecStoreMeta vec_store_meta = VecStoreMeta::Make(dim, normalize);
7890
GraphStoreMeta graph_store_meta = GraphStoreMeta::Make(Mmax0, Mmax);
7991
This ret(chunk_size, max_chunk_n, std::move(vec_store_meta), std::move(graph_store_meta));
8092
ret.cur_vec_num_ = 0;

‎src/storage/knn_index/knn_hnsw/data_store/lvq_vec_store.cppm

+27-4
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,17 @@ private:
7373
}
7474

7575
public:
76-
LVQVecStoreMeta() : dim_(0), compress_data_size_(0) {}
76+
LVQVecStoreMeta() : dim_(0), compress_data_size_(0), normalize_(false) {}
7777
LVQVecStoreMeta(This &&other)
7878
: dim_(std::exchange(other.dim_, 0)), compress_data_size_(std::exchange(other.compress_data_size_, 0)), mean_(std::move(other.mean_)),
79-
global_cache_(std::exchange(other.global_cache_, GlobalCacheType())) {}
79+
global_cache_(std::exchange(other.global_cache_, GlobalCacheType())), normalize_(other.normalize_) {}
8080

8181
static This Make(SizeT dim) { return This(dim); }
82+
static This Make(SizeT dim, bool normalize) {
83+
This ret(dim);
84+
ret.normalize_ = normalize;
85+
return ret;
86+
}
8287

8388
void Save(FileHandler &file_handler) const {
8489
file_handler.Write(&dim_, sizeof(dim_));
@@ -101,7 +106,23 @@ public:
101106
return query;
102107
}
103108

104-
void CompressTo(const DataType *src, LVQData *dest) const {
109+
virtual void CompressTo(const DataType *src, LVQData *dest) const {
110+
if (normalize_) {
111+
DataType norm = 0;
112+
DataType *src_without_const = const_cast<DataType *>(src);
113+
for (SizeT j = 0; j < this->dim_; ++j) {
114+
norm += src_without_const[j] * src_without_const[j];
115+
}
116+
norm = std::sqrt(norm);
117+
if (norm == 0) {
118+
std::fill(dest->compress_vec_, dest->compress_vec_ + this->dim_, 0);
119+
} else {
120+
for (SizeT j = 0; j < this->dim_; ++j) {
121+
src_without_const[j] /= norm;
122+
}
123+
}
124+
}
125+
105126
CompressType *compress = dest->compress_vec_;
106127

107128
DataType lower = std::numeric_limits<DataType>::max();
@@ -187,13 +208,15 @@ private:
187208

188209
void DecompressTo(const LVQData *src, DataType *dest) const { DecompressByMeanTo(src, mean_.get(), dest); };
189210

190-
private:
211+
protected:
191212
SizeT dim_;
192213
SizeT compress_data_size_;
193214

194215
UniquePtr<MeanType[]> mean_;
195216
GlobalCacheType global_cache_;
196217

218+
bool normalize_;
219+
197220
public:
198221
void Dump(std::ostream &os) const {
199222
os << "[CONST] dim: " << dim_ << ", compress_data_size: " << compress_data_size_ << std::endl;

‎src/storage/knn_index/knn_hnsw/data_store/plain_vec_store.cppm

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public:
4545
PlainVecStoreMeta(This &&other) : dim_(std::exchange(other.dim_, 0)) {}
4646

4747
static This Make(SizeT dim) { return This(dim); }
48+
static This Make(SizeT dim, bool) { return This(dim); }
4849

4950
void Save(FileHandler &file_handler) const { file_handler.Write(&dim_, sizeof(dim_)); }
5051

‎src/storage/knn_index/knn_hnsw/data_store/sparse_vec_store.cppm

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ private:
4343

4444
public:
4545
static This Make(SizeT max_dim) { return This(max_dim); }
46+
static This Make(SizeT max_dim, bool) { return This(max_dim); }
4647

4748
void Save(FileHandler &file_handler) const { file_handler.Write(&max_dim_, sizeof(max_dim_)); }
4849

‎src/storage/knn_index/knn_hnsw/data_store/vec_store_type.cppm

+34
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,34 @@ import stl;
2020
import plain_vec_store;
2121
import sparse_vec_store;
2222
import lvq_vec_store;
23+
import dist_func_cos;
2324
import dist_func_l2;
2425
import dist_func_ip;
2526
import dist_func_sparse_ip;
2627
import sparse_util;
2728

2829
namespace infinity {
2930

31+
export template <typename DataT>
32+
class PlainCosVecStoreType {
33+
public:
34+
using DataType = DataT;
35+
using CompressType = void;
36+
using Meta = PlainVecStoreMeta<DataType>;
37+
using Inner = PlainVecStoreInner<DataType>;
38+
using QueryVecType = const DataType *;
39+
using StoreType = typename Meta::StoreType;
40+
using QueryType = typename Meta::QueryType;
41+
using Distance = PlainCosDist<DataType>;
42+
43+
static constexpr bool HasOptimize = false;
44+
};
45+
3046
export template <typename DataT>
3147
class PlainL2VecStoreType {
3248
public:
3349
using DataType = DataT;
50+
using CompressType = void;
3451
using Meta = PlainVecStoreMeta<DataType>;
3552
using Inner = PlainVecStoreInner<DataType>;
3653
using QueryVecType = const DataType *;
@@ -45,6 +62,7 @@ export template <typename DataT>
4562
class PlainIPVecStoreType {
4663
public:
4764
using DataType = DataT;
65+
using CompressType = void;
4866
using Meta = PlainVecStoreMeta<DataType>;
4967
using Inner = PlainVecStoreInner<DataType>;
5068
using QueryVecType = const DataType *;
@@ -59,6 +77,7 @@ export template <typename DataT, typename IndexT>
5977
class SparseIPVecStoreType {
6078
public:
6179
using DataType = DataT;
80+
using CompressType = void;
6281
using Meta = SparseVecStoreMeta<DataT, IndexT>;
6382
using Inner = SparseVecStoreInner<DataT, IndexT>;
6483
using QueryVecType = SparseVecRef<DataT, IndexT>;
@@ -69,6 +88,21 @@ public:
6988
static constexpr bool HasOptimize = false;
7089
};
7190

91+
export template <typename DataT, typename CompressT>
92+
class LVQCosVecStoreType {
93+
public:
94+
using DataType = DataT;
95+
using CompressType = CompressT;
96+
using Meta = LVQVecStoreMeta<DataType, CompressType, LVQCosCache<DataType, CompressType>>;
97+
using Inner = LVQVecStoreInner<DataType, CompressType, LVQCosCache<DataType, CompressType>>;
98+
using QueryVecType = const DataType *;
99+
using StoreType = typename Meta::StoreType;
100+
using QueryType = typename Meta::QueryType;
101+
using Distance = LVQCosDist<DataType, CompressType>;
102+
103+
static constexpr bool HasOptimize = true;
104+
};
105+
72106
export template <typename DataT, typename CompressT>
73107
class LVQL2VecStoreType {
74108
public:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
module;
16+
17+
#include "header.h"
18+
#include <ostream>
19+
20+
import stl;
21+
import logger;
22+
import third_party;
23+
import hnsw_common;
24+
import hnsw_simd_func;
25+
import plain_vec_store;
26+
import lvq_vec_store;
27+
28+
export module dist_func_cos;
29+
30+
namespace infinity {
31+
32+
export template <typename DataType>
33+
class PlainCosDist {
34+
public:
35+
using VecStoreMeta = PlainVecStoreMeta<DataType>;
36+
using StoreType = typename VecStoreMeta::StoreType;
37+
38+
private:
39+
using SIMDFuncType = DataType (*)(const DataType *, const DataType *, SizeT);
40+
41+
SIMDFuncType SIMDFunc;
42+
43+
public:
44+
PlainCosDist() : SIMDFunc(nullptr) {}
45+
PlainCosDist(PlainCosDist &&other) : SIMDFunc(std::exchange(other.SIMDFunc, nullptr)) {}
46+
PlainCosDist &operator=(PlainCosDist &&other) {
47+
if (this != &other) {
48+
SIMDFunc = std::exchange(other.SIMDFunc, nullptr);
49+
}
50+
return *this;
51+
}
52+
~PlainCosDist() = default;
53+
54+
PlainCosDist(SizeT dim) {
55+
if constexpr (std::is_same<DataType, float>()) {
56+
#if defined(USE_AVX512)
57+
if (dim % 16 == 0) {
58+
SIMDFunc = F32CosAVX512;
59+
} else {
60+
SIMDFunc = F32CosAVX512Residual;
61+
}
62+
#elif defined(USE_AVX)
63+
if (dim % 16 == 0) {
64+
SIMDFunc = F32CosAVX;
65+
} else {
66+
SIMDFunc = F32CosAVXResidual;
67+
}
68+
#elif defined(USE_SSE)
69+
if (dim % 16 == 0) {
70+
SIMDFunc = F32CosSSE;
71+
} else {
72+
SIMDFunc = F32CosSSEResidual;
73+
}
74+
#else
75+
SIMDFunc = F32CosBF;
76+
#endif
77+
}
78+
}
79+
80+
DataType operator()(const StoreType &v1, const StoreType &v2, const VecStoreMeta &vec_store_meta) const {
81+
return -SIMDFunc(v1, v2, vec_store_meta.dim());
82+
}
83+
};
84+
85+
export template <typename DataType, typename CompressType>
86+
class LVQCosCache {
87+
public:
88+
using LocalCacheType = Pair<DataType, DataType>;
89+
using GlobalCacheType = Pair<DataType, DataType>;
90+
91+
static LocalCacheType MakeLocalCache(const CompressType *c, DataType scale, SizeT dim, const MeanType *mean) {
92+
i64 norm1 = 0;
93+
MeanType mean_c = 0;
94+
for (SizeT i = 0; i < dim; ++i) {
95+
norm1 += c[i];
96+
mean_c += mean[i] * c[i];
97+
}
98+
return {norm1 * scale, mean_c * scale};
99+
}
100+
101+
static GlobalCacheType MakeGlobalCache(const MeanType *mean, SizeT dim) {
102+
MeanType norm1 = 0;
103+
MeanType norm2 = 0;
104+
for (SizeT i = 0; i < dim; ++i) {
105+
norm1 += mean[i];
106+
norm2 += mean[i] * mean[i];
107+
}
108+
return {norm1, norm2};
109+
}
110+
111+
static void DumpLocalCache(std::ostream &os, const LocalCacheType &local_cache) {
112+
os << "norm1: " << local_cache.first << ", mean_c: " << local_cache.second << std::endl;
113+
}
114+
115+
static void DumpGlobalCache(std::ostream &os, const GlobalCacheType &global_cache) {
116+
os << "norm1_mean: " << global_cache.first << ", norm2_mean: " << global_cache.second << std::endl;
117+
}
118+
};
119+
120+
export template <typename DataType, typename CompressType>
121+
class LVQCosDist {
122+
public:
123+
using This = LVQCosDist<DataType, CompressType>;
124+
using VecStoreMeta = LVQVecStoreMeta<DataType, CompressType, LVQCosCache<DataType, CompressType>>;
125+
using StoreType = typename VecStoreMeta::StoreType;
126+
127+
private:
128+
using SIMDFuncType = i32 (*)(const CompressType *, const CompressType *, SizeT);
129+
130+
SIMDFuncType SIMDFunc;
131+
132+
public:
133+
LVQCosDist() : SIMDFunc(nullptr) {}
134+
LVQCosDist(LVQCosDist &&other) : SIMDFunc(std::exchange(other.SIMDFunc, nullptr)) {}
135+
LVQCosDist &operator=(LVQCosDist &&other) {
136+
if (this != &other) {
137+
SIMDFunc = std::exchange(other.SIMDFunc, nullptr);
138+
}
139+
return *this;
140+
}
141+
~LVQCosDist() = default;
142+
LVQCosDist(SizeT dim) {
143+
if constexpr (std::is_same<CompressType, i8>()) {
144+
#if defined(USE_AVX512)
145+
if (dim % 16 == 0) {
146+
SIMDFunc = I8IPAVX512;
147+
} else {
148+
SIMDFunc = I8IPAVX512Residual;
149+
}
150+
#elif defined(USE_AVX)
151+
if (dim % 16 == 0) {
152+
SIMDFunc = I8IPAVX;
153+
} else {
154+
SIMDFunc = I8IPAVXResidual;
155+
}
156+
#elif defined(USE_SSE)
157+
if (dim % 16 == 0) {
158+
SIMDFunc = I8IPSSE;
159+
} else {
160+
SIMDFunc = I8IPSSEResidual;
161+
}
162+
#else
163+
SIMDFunc = I8IPBF;
164+
#endif
165+
}
166+
}
167+
168+
DataType operator()(const StoreType &v1, const StoreType &v2, const VecStoreMeta &vec_store_meta) const {
169+
SizeT dim = vec_store_meta.dim();
170+
i32 c1c2_ip = SIMDFunc(v1->compress_vec_, v2->compress_vec_, dim);
171+
auto scale1 = v1->scale_;
172+
auto scale2 = v2->scale_;
173+
auto bias1 = v1->bias_;
174+
auto bias2 = v2->bias_;
175+
auto [norm1_scale_1, mean_c_scale_1] = v1->local_cache_;
176+
auto [norm1_scale_2, mean_c_scale_2] = v2->local_cache_;
177+
auto [norm1_mean, norm2_mean] = vec_store_meta.global_cache();
178+
auto dist = scale1 * scale2 * c1c2_ip + bias2 * norm1_scale_1 + bias1 * norm1_scale_2 + dim * bias1 * bias2 + norm2_mean +
179+
(bias1 + bias2) * norm1_mean + mean_c_scale_1 + mean_c_scale_2;
180+
181+
return -dist;
182+
}
183+
};
184+
185+
} // namespace infinity

‎src/storage/knn_index/knn_hnsw/simd_func.cppm

+203
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
module;
1616

17+
#include <cmath>
1718
#include "header.h"
1819

1920
import stl;
@@ -38,6 +39,202 @@ void log_m256(const __m256i &value) {
3839
std::cout << "]" << std::endl;
3940
}
4041

42+
export float F32CosBF(const float *pv1, const float *pv2, size_t dim) {
43+
float dot_product = 0;
44+
float norm1 = 0;
45+
float norm2 = 0;
46+
for (size_t i = 0; i < dim; i++) {
47+
dot_product += pv1[i] * pv2[i];
48+
norm1 += pv1[i] * pv1[i];
49+
norm2 += pv2[i] * pv2[i];
50+
}
51+
return 1 - dot_product / sqrt(norm1 * norm2);
52+
}
53+
54+
#if defined(USE_AVX512)
55+
56+
export float F32CosAVX512(const float *pv1, const float *pv2, size_t dim) {
57+
size_t dim16 = dim >> 4;
58+
59+
const float *pEnd1 = pv1 + (dim16 << 4);
60+
61+
__m512 mul = _mm512_set1_ps(0);
62+
__m512 norm_v1 = _mm512_set1_ps(0);
63+
__m512 norm_v2 = _mm512_set1_ps(0);
64+
65+
__m512 v1, v2;
66+
67+
while (pv1 < pEnd1) {
68+
v1 = _mm512_loadu_ps(pv1);
69+
pv1 += 16;
70+
v2 = _mm512_loadu_ps(pv2);
71+
pv2 += 16;
72+
73+
mul = _mm512_fmadd_ps(v1, v2, mul);
74+
norm_v1 = _mm512_fmadd_ps(v1, v1, norm_v1);
75+
norm_v2 = _mm512_fmadd_ps(v2, v2, norm_v2);
76+
}
77+
78+
float mul_res = _mm512_reduce_add_ps(mul);
79+
float v1_res = _mm512_reduce_add_ps(norm_v1);
80+
float v2_res = _mm512_reduce_add_ps(norm_v2);
81+
82+
size_t tail = dim & 15;
83+
const float *pBegin1 = pv1 + (dim & ~15);
84+
const float *pBegin2 = pv2 + (dim & ~15);
85+
for (size_t i = 0; i < tail; i++) {
86+
mul_res += pBegin1[i] * pBegin2[i];
87+
v1_res += pBegin1[i] * pBegin1[i];
88+
v2_res += pBegin2[i] * pBegin2[i];
89+
}
90+
91+
return mul_res != 0 ? mul_res / sqrt(v1_res * v2_res) : 0;
92+
}
93+
94+
export float F32CosAVX512Residual(const float *pv1, const float *pv2, size_t dim) {
95+
return F32CosAVX512(pv1, pv2, dim);
96+
}
97+
98+
#endif
99+
100+
#if defined(USE_AVX)
101+
102+
export float F32CosAVX(const float *pv1, const float *pv2, size_t dim) {
103+
float PORTABLE_ALIGN32 MulTmpRes[8];
104+
float PORTABLE_ALIGN32 V1TmpRes[8];
105+
float PORTABLE_ALIGN32 V2TmpRes[8];
106+
size_t dim16 = dim >> 4;
107+
108+
const float *pEnd1 = pv1 + (dim16 << 4);
109+
110+
__m256 mul = _mm256_set1_ps(0);
111+
__m256 norm_v1 = _mm256_set1_ps(0);
112+
__m256 norm_v2 = _mm256_set1_ps(0);
113+
114+
__m256 v1, v2;
115+
116+
while (pv1 < pEnd1) {
117+
v1 = _mm256_loadu_ps(pv1);
118+
pv1 += 8;
119+
v2 = _mm256_loadu_ps(pv2);
120+
pv2 += 8;
121+
mul = _mm256_add_ps(mul, _mm256_mul_ps(v1, v2));
122+
norm_v1 = _mm256_add_ps(norm_v1, _mm256_mul_ps(v1, v1));
123+
norm_v2 = _mm256_add_ps(norm_v2, _mm256_mul_ps(v2, v2));
124+
125+
v1 = _mm256_loadu_ps(pv1);
126+
pv1 += 8;
127+
v2 = _mm256_loadu_ps(pv2);
128+
pv2 += 8;
129+
mul = _mm256_add_ps(mul, _mm256_mul_ps(v1, v2));
130+
norm_v1 = _mm256_add_ps(norm_v1, _mm256_mul_ps(v1, v1));
131+
norm_v2 = _mm256_add_ps(norm_v2, _mm256_mul_ps(v2, v2));
132+
}
133+
134+
_mm256_store_ps(MulTmpRes, mul);
135+
_mm256_store_ps(V1TmpRes, norm_v1);
136+
_mm256_store_ps(V2TmpRes, norm_v2);
137+
138+
float mul_res = MulTmpRes[0] + MulTmpRes[1] + MulTmpRes[2] + MulTmpRes[3] + MulTmpRes[4] + MulTmpRes[5] + MulTmpRes[6] + MulTmpRes[7];
139+
float v1_res = V1TmpRes[0] + V1TmpRes[1] + V1TmpRes[2] + V1TmpRes[3] + V1TmpRes[4] + V1TmpRes[5] + V1TmpRes[6] + V1TmpRes[7];
140+
float v2_res = V2TmpRes[0] + V2TmpRes[1] + V2TmpRes[2] + V2TmpRes[3] + V2TmpRes[4] + V2TmpRes[5] + V2TmpRes[6] + V2TmpRes[7];
141+
142+
size_t tail = dim & 15;
143+
const float *pBegin1 = pv1 + (dim & ~15);
144+
const float *pBegin2 = pv2 + (dim & ~15);
145+
for (size_t i = 0; i < tail; i++) {
146+
mul_res += pBegin1[i] * pBegin2[i];
147+
v1_res += pBegin1[i] * pBegin1[i];
148+
v2_res += pBegin2[i] * pBegin2[i];
149+
}
150+
151+
return mul_res != 0 ? mul_res / sqrt(v1_res * v2_res) : 0;
152+
}
153+
154+
export float F32CosAVXResidual(const float *pv1, const float *pv2, size_t dim) {
155+
return F32CosAVX(pv1, pv2, dim);
156+
}
157+
158+
#endif
159+
160+
#if defined(USE_SSE)
161+
162+
export float F32CosSSE(const float *pv1, const float *pv2, size_t dim) {
163+
alignas(16) float MulTmpRes[4];
164+
alignas(16) float V1TmpRes[4];
165+
alignas(16) float V2TmpRes[4];
166+
size_t dim16 = dim >> 4;
167+
168+
const float *pEnd1 = pv1 + (dim16 << 4);
169+
170+
__m128 mul = _mm_set1_ps(0);
171+
__m128 norm_v1 = _mm_set1_ps(0);
172+
__m128 norm_v2 = _mm_set1_ps(0);
173+
174+
__m128 v1, v2;
175+
176+
while (pv1 < pEnd1) {
177+
v1 = _mm_loadu_ps(pv1);
178+
pv1 += 4;
179+
v2 = _mm_loadu_ps(pv2);
180+
pv2 += 4;
181+
mul = _mm_add_ps(mul, _mm_mul_ps(v1, v2));
182+
norm_v1 = _mm_add_ps(norm_v1, _mm_mul_ps(v1, v1));
183+
norm_v2 = _mm_add_ps(norm_v2, _mm_mul_ps(v2, v2));
184+
185+
v1 = _mm_loadu_ps(pv1);
186+
pv1 += 4;
187+
v2 = _mm_loadu_ps(pv2);
188+
pv2 += 4;
189+
mul = _mm_add_ps(mul, _mm_mul_ps(v1, v2));
190+
norm_v1 = _mm_add_ps(norm_v1, _mm_mul_ps(v1, v1));
191+
norm_v2 = _mm_add_ps(norm_v2, _mm_mul_ps(v2, v2));
192+
193+
v1 = _mm_loadu_ps(pv1);
194+
pv1 += 4;
195+
v2 = _mm_loadu_ps(pv2);
196+
pv2 += 4;
197+
mul = _mm_add_ps(mul, _mm_mul_ps(v1, v2));
198+
norm_v1 = _mm_add_ps(norm_v1, _mm_mul_ps(v1, v1));
199+
norm_v2 = _mm_add_ps(norm_v2, _mm_mul_ps(v2, v2));
200+
201+
v1 = _mm_loadu_ps(pv1);
202+
pv1 += 4;
203+
v2 = _mm_loadu_ps(pv2);
204+
pv2 += 4;
205+
mul = _mm_add_ps(mul, _mm_mul_ps(v1, v2));
206+
norm_v1 = _mm_add_ps(norm_v1, _mm_mul_ps(v1, v1));
207+
norm_v2 = _mm_add_ps(norm_v2, _mm_mul_ps(v2, v2));
208+
}
209+
210+
_mm_store_ps(MulTmpRes, mul);
211+
_mm_store_ps(V1TmpRes, norm_v1);
212+
_mm_store_ps(V2TmpRes, norm_v2);
213+
214+
float mul_res = MulTmpRes[0] + MulTmpRes[1] + MulTmpRes[2] + MulTmpRes[3];
215+
float v1_res = V1TmpRes[0] + V1TmpRes[1] + V1TmpRes[2] + V1TmpRes[3];
216+
float v2_res = V2TmpRes[0] + V2TmpRes[1] + V2TmpRes[2] + V2TmpRes[3];
217+
218+
size_t tail = dim & 15;
219+
const float *pBegin1 = pv1 + (dim & ~15);
220+
const float *pBegin2 = pv2 + (dim & ~15);
221+
for (size_t i = 0; i < tail; i++) {
222+
mul_res += pBegin1[i] * pBegin2[i];
223+
v1_res += pBegin1[i] * pBegin1[i];
224+
v2_res += pBegin2[i] * pBegin2[i];
225+
}
226+
227+
return mul_res != 0 ? mul_res / sqrt(v1_res * v2_res) : 0;
228+
}
229+
230+
export float F32CosSSEResidual(const float *pv1, const float *pv2, size_t dim) {
231+
return F32CosSSE(pv1, pv2, dim);
232+
}
233+
234+
#endif
235+
236+
//------------------------------//------------------------------//------------------------------
237+
41238
export int32_t I8IPBF(const int8_t *pv1, const int8_t *pv2, size_t dim) {
42239
int32_t res = 0;
43240
for (size_t i = 0; i < dim; i++) {
@@ -47,6 +244,7 @@ export int32_t I8IPBF(const int8_t *pv1, const int8_t *pv2, size_t dim) {
47244
}
48245

49246
#if defined(USE_AVX512)
247+
50248
export int32_t I8IPAVX512(const int8_t *pv1, const int8_t *pv2, size_t dim) {
51249
size_t dim64 = dim >> 6;
52250
const int8_t *pend1 = pv1 + (dim64 << 6);
@@ -79,6 +277,7 @@ export int32_t I8IPAVX512Residual(const int8_t *pv1, const int8_t *pv2, size_t d
79277
#endif
80278

81279
#if defined(USE_AVX)
280+
82281
export int32_t I8IPAVX(const int8_t *pv1, const int8_t *pv2, size_t dim) {
83282
size_t dim32 = dim >> 5;
84283
const int8_t *pend1 = pv1 + (dim32 << 5);
@@ -116,6 +315,7 @@ export int32_t I8IPAVXResidual(const int8_t *pv1, const int8_t *pv2, size_t dim)
116315
#endif
117316

118317
#if defined(USE_SSE)
318+
119319
export int32_t I8IPSSE(const int8_t *pv1, const int8_t *pv2, size_t dim) {
120320
size_t dim16 = dim >> 4;
121321
const int8_t *pend1 = pv1 + (dim16 << 4);
@@ -235,6 +435,7 @@ export float F32L2AVXResidual(const float *pv1, const float *pv2, size_t dim) {
235435
#endif
236436

237437
#if defined(USE_SSE)
438+
238439
export float F32L2SSE(const float *pv1, const float *pv2, size_t dim) {
239440
alignas(16) float TmpRes[4];
240441
size_t dim16 = dim >> 4;
@@ -251,6 +452,7 @@ export float F32L2SSE(const float *pv1, const float *pv2, size_t dim) {
251452
pv2 += 4;
252453
diff = _mm_sub_ps(v1, v2);
253454
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
455+
254456
v1 = _mm_loadu_ps(pv1);
255457
pv1 += 4;
256458
v2 = _mm_loadu_ps(pv2);
@@ -367,6 +569,7 @@ export float F32IPAVXResidual(const float *pVect1, const float *pVect2, SizeT qt
367569
#endif
368570

369571
#if defined(USE_SSE)
572+
370573
export float F32IPSSE(const float *pVect1, const float *pVect2, SizeT qty) {
371574
alignas(16) float TmpRes[4];
372575

‎src/unit_test/storage/knnindex/knn_hnsw/test_hnsw.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include "unit_test/base_test.h"
1616
#include <fstream>
17+
#include <gtest/gtest.h>
18+
#include <iostream>
1719
#include <thread>
1820

1921
import stl;
@@ -30,6 +32,7 @@ import data_store;
3032

3133
import dist_func_l2;
3234
import dist_func_ip;
35+
import dist_func_cos;
3336
import vec_store_type;
3437
import hnsw_common;
3538
import infinity_exception;
@@ -256,3 +259,8 @@ TEST_F(HnswAlgTest, test4) {
256259
using Hnsw = KnnHnsw<LVQL2VecStoreType<float, int8_t>, LabelT>;
257260
TestParallel<Hnsw>();
258261
}
262+
263+
TEST_F(HnswAlgTest, test5) {
264+
using Hnsw = KnnHnsw<PlainCosVecStoreType<float>, LabelT>;
265+
TestSimple<Hnsw>();
266+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
statement ok
2+
DROP TABLE IF EXISTS test_knn_cos;
3+
4+
statement ok
5+
CREATE TABLE test_knn_cos(c1 INT, c2 EMBEDDING(FLOAT, 4));
6+
7+
# copy to create one block
8+
# the csv has 4 rows, the cosine distance to target([0.3, 0.3, 0.2, 0.2]) is:
9+
# 1. (0.3*0.4+0.3*0.3+0.2*0.2+0.2*0.1) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.4^2+0.3^2+0.2^2+0.1^2)) = 0.96675508
10+
# 2. (0.3*0.3+0.3*0.2+0.2*0.1+0.2*0.4) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.3^2+0.2^2+0.1^2+0.4^2)) = 0.895143593
11+
# 3. (0.3*0.2+0.3*0.1+0.2*0.3+0.2*0.4) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.2^2+0.1^2+0.3^2+0.4^2)) = 0.823532105
12+
# 4. (0.3*0.1+0.3*0.2+0.2*0.3-0.2*0.2) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.1^2+0.2^2+0.3^2+(-0.2)^2)) = 0.50847518
13+
statement ok
14+
COPY test_knn_cos FROM '/var/infinity/test_data/embedding_float_dim4.csv' WITH (DELIMITER ',');
15+
16+
# metric cos will order descendingly. The query will return row 1, 2, 3
17+
query I
18+
SELECT c1 FROM test_knn_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
19+
----
20+
8
21+
6
22+
4
23+
24+
query I
25+
SELECT c2 FROM test_knn_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
26+
----
27+
[0.4,0.3,0.2,0.1]
28+
[0.3,0.2,0.1,0.4]
29+
[0.2,0.1,0.3,0.4]
30+
31+
query II
32+
SELECT c1, ROW_ID(), SIMILARITY() FROM test_knn_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
33+
----
34+
8 3 0.966755
35+
6 2 0.895144
36+
4 1 0.823532
37+
38+
statement error
39+
SELECT c1, ROW_ID() DISTANCE() FROM test_knn_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
40+
41+
# copy to create another new block
42+
# there will has 2 knn_scan operator to scan the blocks, and one merge_knn to merge
43+
statement ok
44+
COPY test_knn_cos FROM '/var/infinity/test_data/embedding_float_dim4.csv' WITH (DELIMITER ',');
45+
46+
# the query will return block 1 row 4, block 2 row 4 and a row 3
47+
query I
48+
SELECT c1 FROM test_knn_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
49+
----
50+
8
51+
8
52+
6
53+
54+
# copy to create another new block
55+
statement ok
56+
COPY test_knn_cos FROM '/var/infinity/test_data/embedding_float_dim4.csv' WITH (DELIMITER ',');
57+
58+
# the query will return row 4 from block 1, 2 and 3
59+
query I
60+
SELECT c1 FROM test_knn_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
61+
----
62+
8
63+
8
64+
8
65+
66+
statement ok
67+
DROP TABLE test_knn_cos;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
statement ok
2+
DROP TABLE IF EXISTS test_knn_hnsw_cos;
3+
4+
statement ok
5+
CREATE TABLE test_knn_hnsw_cos(c1 INT, c2 EMBEDDING(FLOAT, 4));
6+
7+
# copy to create one blocks
8+
# the csv has 4 rows, the inner product to target([0.3, 0.3, 0.2, 0.2]) is:
9+
# 1. (0.3*0.4+0.3*0.3+0.2*0.2+0.2*0.1) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.4^2+0.3^2+0.2^2+0.1^2)) = 0.96675508
10+
# 2. (0.3*0.3+0.3*0.2+0.2*0.1+0.2*0.4) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.3^2+0.2^2+0.1^2+0.4^2)) = 0.895143593
11+
# 3. (0.3*0.2+0.3*0.1+0.2*0.3+0.2*0.4) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.2^2+0.1^2+0.3^2+0.4^2)) = 0.823532105
12+
# 4. (0.3*0.1+0.3*0.2+0.2*0.3-0.2*0.2) / sqrt((0.3^2+0.3^2+0.2^2+0.2^2) * (0.1^2+0.2^2+0.3^2+(-0.2)^2)) = 0.50847518
13+
statement ok
14+
COPY test_knn_hnsw_cos FROM '/var/infinity/test_data/embedding_float_dim4.csv' WITH (DELIMITER ',');
15+
16+
# mertic cos will order ascendingly. The query will return row 4, 3, 2
17+
query I
18+
SELECT c1 FROM test_knn_hnsw_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
19+
----
20+
8
21+
6
22+
4
23+
24+
# copy to create another new block
25+
# there will has 2 knn_scan operator to scan the blocks, and one merge_knn to merge
26+
statement ok
27+
COPY test_knn_hnsw_cos FROM '/var/infinity/test_data/embedding_float_dim4.csv' WITH (DELIMITER ',');
28+
29+
# the query will return block 1 row 4, block 2 row 4, block 1 row 3
30+
query I
31+
SELECT c1 FROM test_knn_hnsw_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3);
32+
----
33+
8
34+
8
35+
6
36+
37+
# create hnsw index on existing 2 segments
38+
statement ok
39+
CREATE INDEX idx1 ON test_knn_hnsw_cos (c2) USING Hnsw WITH (M = 16, ef_construction = 200, metric = cos);
40+
41+
# the query will return block 1 row 4, block 2 row 4 and a row 3
42+
# select with 2 index segment
43+
query I
44+
SELECT c1 FROM test_knn_hnsw_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3) WITH (ef = 4);
45+
----
46+
8
47+
8
48+
6
49+
50+
# copy to create another new block with no index
51+
statement ok
52+
COPY test_knn_hnsw_cos FROM '/var/infinity/test_data/embedding_float_dim4.csv' WITH (DELIMITER ',');
53+
54+
# the query will return row 4 from block 1, 2 and 3
55+
# select with 2 index segment and 1 non-index segment
56+
query I
57+
SELECT c1 FROM test_knn_hnsw_cos SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'cosine', 3) WITH (ef = 4);
58+
----
59+
8
60+
8
61+
8
62+
63+
statement ok
64+
DROP TABLE test_knn_hnsw_cos;

0 commit comments

Comments
 (0)
Please sign in to comment.