From c4c149a8844bd976e2e20a06d30e422f0e231d86 Mon Sep 17 00:00:00 2001 From: yangzq50 <58433399+yangzq50@users.noreply.github.com> Date: Fri, 21 Jun 2024 19:29:23 +0800 Subject: [PATCH] Update flat knn scan scheduler (#1367) ### What problem does this PR solve? Update flat knn scan scheduler Try to use at most 3 CPUs instead of all cpus Improve flat knn scan performance ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring - [x] Performance Improvement --- .../legacy_benchmark/remote_benchmark_knn.py | 2 +- .../physical_scan/physical_knn_scan.cpp | 90 +++++++---- .../physical_scan/physical_knn_scan.cppm | 9 +- src/function/table/knn_scan_data.cpp | 4 +- src/function/table/knn_scan_data.cppm | 3 +- src/scheduler/fragment_context.cpp | 8 +- src/scheduler/task_scheduler.cpp | 25 +-- .../knn_index/ann_ivf/some_simd_functions.cpp | 152 ++++++++++++++++++ .../ann_ivf/some_simd_functions.cppm | 4 +- test/sql/explain/explain_fusion.slt | 19 +-- 10 files changed, 256 insertions(+), 60 deletions(-) create mode 100644 src/storage/knn_index/ann_ivf/some_simd_functions.cpp diff --git a/python/benchmark/legacy_benchmark/remote_benchmark_knn.py b/python/benchmark/legacy_benchmark/remote_benchmark_knn.py index 9db61c7fc8..5ba0e3f00b 100644 --- a/python/benchmark/legacy_benchmark/remote_benchmark_knn.py +++ b/python/benchmark/legacy_benchmark/remote_benchmark_knn.py @@ -220,7 +220,7 @@ def one_thread(rounds, query_path, ground_truth_path, table_name): results.append(f"Recall@10: {recall_10}") results.append(f"Recall@100: {recall_100}") - conn.disconnect() + conn.disconnect() for result in results: print(result) diff --git a/src/executor/operator/physical_scan/physical_knn_scan.cpp b/src/executor/operator/physical_scan/physical_knn_scan.cpp index 389e57f3de..b04e99179d 100644 --- a/src/executor/operator/physical_scan/physical_knn_scan.cpp +++ b/src/executor/operator/physical_scan/physical_knn_scan.cpp @@ -141,6 +141,29 @@ void ReadDataBlock(DataBlock *output, void PhysicalKnnScan::Init() {} +void PhysicalKnnScan::InitBlockParallelOption() { + // TODO: Set brute force block parallel option + // 0. [0, 50), 1 thread + // 1. [50, 100), 2 thread + block_parallel_options_.emplace_back(50, 2); + // 2. [100, +Inf), 3 threads + block_parallel_options_.emplace_back(100, 3); +} + +SizeT PhysicalKnnScan::BlockScanTaskCount() const { + const u32 block_cnt = block_column_entries_size_; + SizeT brute_task_n = 1; + for (const auto &[block_n, job_n] : block_parallel_options_) { + if (block_cnt < block_n) { + break; + } + brute_task_n = job_n; + } + return brute_task_n; +} + +SizeT PhysicalKnnScan::TaskletCount() { return BlockScanTaskCount() + index_entries_size_; } + bool PhysicalKnnScan::Execute(QueryContext *query_context, OperatorState *operator_state) { auto *knn_scan_operator_state = static_cast(operator_state); auto elem_type = knn_scan_operator_state->knn_scan_function_data_->knn_scan_shared_data_->elem_type_; @@ -182,6 +205,7 @@ String PhysicalKnnScan::TableAlias() const { return base_table_ref_->alias_; } Vector &PhysicalKnnScan::ColumnIDs() const { return base_table_ref_->column_ids_; } void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: return base entry vector + InitBlockParallelOption(); // PlanWithIndex() will be called in physical planner Txn *txn = query_context->GetTxn(); TransactionID txn_id = txn->TxnID(); TxnTimeStamp begin_ts = txn->BeginTS(); @@ -237,7 +261,9 @@ void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: retu } } } - LOG_TRACE(fmt::format("KnnScan: brute force task: {}, index task: {}", block_column_entries_->size(), index_entries_->size())); + block_column_entries_size_ = block_column_entries_->size(); + index_entries_size_ = index_entries_->size(); + LOG_TRACE(fmt::format("KnnScan: brute force task: {}, index task: {}", block_column_entries_size_, index_entries_size_)); } SizeT PhysicalKnnScan::BlockEntryCount() const { return base_table_ref_->block_index_->BlockCount(); } @@ -259,39 +285,43 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat auto merge_heap = static_cast *>(knn_scan_function_data->merge_knn_base_.get()); auto query = static_cast(knn_scan_shared_data->query_embedding_); - SizeT index_task_n = knn_scan_shared_data->index_entries_->size(); - SizeT brute_task_n = knn_scan_shared_data->block_column_entries_->size(); + const SizeT index_task_n = knn_scan_shared_data->index_entries_->size(); + const SizeT brute_task_n = knn_scan_shared_data->block_column_entries_->size(); BlockIndex *block_index = knn_scan_shared_data->table_ref_->block_index_.get(); - if (u64 block_column_idx = knn_scan_shared_data->current_block_idx_++; block_column_idx < brute_task_n) { + if (u64 block_column_idx = + knn_scan_function_data->execute_block_scan_job_ ? knn_scan_shared_data->current_block_idx_++ : std::numeric_limits::max(); + block_column_idx < brute_task_n) { LOG_TRACE(fmt::format("KnnScan: {} brute force {}/{}", knn_scan_function_data->task_id_, block_column_idx + 1, brute_task_n)); // brute force - BlockColumnEntry *block_column_entry = knn_scan_shared_data->block_column_entries_->at(block_column_idx); - const BlockEntry *block_entry = block_column_entry->block_entry(); - const auto block_id = block_entry->block_id(); - const SegmentID segment_id = block_entry->GetSegmentEntry()->segment_id(); - const auto row_count = block_entry->row_count(); - - Bitmask bitmask; - if (this->CalculateFilterBitmask(segment_id, block_id, row_count, bitmask)) { - LOG_TRACE(fmt::format("KnnScan: {} brute force {}/{} not skipped after common_query_filter", - knn_scan_function_data->task_id_, - block_column_idx + 1, - brute_task_n)); - block_entry->SetDeleteBitmask(begin_ts, bitmask); - BufferManager *buffer_mgr = query_context->storage()->buffer_manager(); - ColumnVector column_vector = block_column_entry->GetColumnVector(buffer_mgr); - - auto data = reinterpret_cast(column_vector.data()); - merge_heap->Search(query, - data, - knn_scan_shared_data->dimension_, - dist_func->dist_func_, - row_count, - block_entry->segment_id(), - block_entry->block_id(), - bitmask); - } + // TODO: now will try to finish all block scan job in the task + do { + BlockColumnEntry *block_column_entry = knn_scan_shared_data->block_column_entries_->at(block_column_idx); + const BlockEntry *block_entry = block_column_entry->block_entry(); + const auto block_id = block_entry->block_id(); + const SegmentID segment_id = block_entry->GetSegmentEntry()->segment_id(); + const auto row_count = block_entry->row_count(); + Bitmask bitmask; + if (this->CalculateFilterBitmask(segment_id, block_id, row_count, bitmask)) { + // LOG_TRACE(fmt::format("KnnScan: {} brute force {}/{} not skipped after common_query_filter", + // knn_scan_function_data->task_id_, + // block_column_idx + 1, + // brute_task_n)); + block_entry->SetDeleteBitmask(begin_ts, bitmask); + BufferManager *buffer_mgr = query_context->storage()->buffer_manager(); + ColumnVector column_vector = block_column_entry->GetColumnVector(buffer_mgr); + auto data = reinterpret_cast(column_vector.data()); + merge_heap->Search(query, + data, + knn_scan_shared_data->dimension_, + dist_func->dist_func_, + row_count, + segment_id, + block_id, + bitmask); + } + block_column_idx = knn_scan_shared_data->current_block_idx_++; + } while (block_column_idx < brute_task_n); } else if (u64 index_idx = knn_scan_shared_data->current_index_idx_++; index_idx < index_task_n) { LOG_TRACE(fmt::format("KnnScan: {} index {}/{}", knn_scan_function_data->task_id_, index_idx + 1, index_task_n)); // with index diff --git a/src/executor/operator/physical_scan/physical_knn_scan.cppm b/src/executor/operator/physical_scan/physical_knn_scan.cppm index d44d6ea307..c896b53b22 100644 --- a/src/executor/operator/physical_scan/physical_knn_scan.cppm +++ b/src/executor/operator/physical_scan/physical_knn_scan.cppm @@ -73,7 +73,9 @@ public: void PlanWithIndex(QueryContext *query_context); - SizeT TaskletCount() override { return block_column_entries_->size() + index_entries_->size(); } + SizeT BlockScanTaskCount() const; + + SizeT TaskletCount() override; void FillingTableRefs(HashMap> &table_refs) override { table_refs.insert({base_table_ref_->table_index_, base_table_ref_}); @@ -88,10 +90,15 @@ public: SharedPtr>> output_types_{}; u64 knn_table_index_{}; + Vector> block_parallel_options_; + u32 block_column_entries_size_ = 0; // need this value because block_column_entries_ will be moved into KnnScanSharedData + u32 index_entries_size_ = 0; UniquePtr> block_column_entries_{}; UniquePtr> index_entries_{}; private: + void InitBlockParallelOption(); + template typename C> void ExecuteInternal(QueryContext *query_context, KnnScanOperatorState *operator_state); }; diff --git a/src/function/table/knn_scan_data.cpp b/src/function/table/knn_scan_data.cpp index 9e4674ecff..922ee819c9 100644 --- a/src/function/table/knn_scan_data.cpp +++ b/src/function/table/knn_scan_data.cpp @@ -69,8 +69,8 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { // -------------------------------------------- -KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx) - : knn_scan_shared_data_(shared_data), task_id_(current_parallel_idx) { +KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx, bool execute_block_scan_job) + : knn_scan_shared_data_(shared_data), task_id_(current_parallel_idx), execute_block_scan_job_(execute_block_scan_job) { switch (knn_scan_shared_data_->elem_type_) { case EmbeddingDataType::kElemFloat: { Init(); diff --git a/src/function/table/knn_scan_data.cppm b/src/function/table/knn_scan_data.cppm index 274afdee58..02fcf83314 100644 --- a/src/function/table/knn_scan_data.cppm +++ b/src/function/table/knn_scan_data.cppm @@ -111,7 +111,7 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type); export class KnnScanFunctionData final : public TableFunctionData { public: - KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx); + KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx, bool execute_block_scan_job); ~KnnScanFunctionData() final = default; @@ -122,6 +122,7 @@ private: public: KnnScanSharedData *knn_scan_shared_data_; const u32 task_id_; + bool execute_block_scan_job_ = false; UniquePtr merge_knn_base_{}; UniquePtr knn_distance_{}; diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index c6f2087735..b7f40c375d 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -158,18 +158,20 @@ UniquePtr MakeKnnScanState(PhysicalKnnScan *physical_knn_scan, Fr UniquePtr operator_state = MakeUnique(); KnnScanOperatorState *knn_scan_op_state_ptr = (KnnScanOperatorState *)(operator_state.get()); - + const bool execute_block_scan_job = task->TaskID() < static_cast(physical_knn_scan->BlockScanTaskCount()); switch (fragment_ctx->ContextType()) { case FragmentType::kSerialMaterialize: { SerialMaterializedFragmentCtx *serial_materialize_fragment_ctx = static_cast(fragment_ctx); knn_scan_op_state_ptr->knn_scan_function_data_ = - MakeUnique(serial_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID()); + MakeUnique(serial_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID(), execute_block_scan_job); break; } case FragmentType::kParallelMaterialize: { ParallelMaterializedFragmentCtx *parallel_materialize_fragment_ctx = static_cast(fragment_ctx); knn_scan_op_state_ptr->knn_scan_function_data_ = - MakeUnique(parallel_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID()); + MakeUnique(parallel_materialize_fragment_ctx->knn_scan_shared_data_.get(), + task->TaskID(), + execute_block_scan_job); break; } default: { diff --git a/src/scheduler/task_scheduler.cpp b/src/scheduler/task_scheduler.cpp index 24e7ca92b3..8f4359ee26 100644 --- a/src/scheduler/task_scheduler.cpp +++ b/src/scheduler/task_scheduler.cpp @@ -46,27 +46,32 @@ namespace infinity { TaskScheduler::TaskScheduler(Config *config_ptr) { Init(config_ptr); } void TaskScheduler::Init(Config *config_ptr) { - worker_count_ = config_ptr->CPULimit(); + const u64 cpu_count = Thread::hardware_concurrency(); + const u64 config_cpu_limit = config_ptr->CPULimit(); + worker_count_ = std::min(cpu_count, config_cpu_limit); worker_array_.reserve(worker_count_); worker_workloads_.resize(worker_count_); - u64 cpu_count = Thread::hardware_concurrency(); - u64 cpu_select_step = cpu_count / worker_count_; - if (cpu_select_step >= 2) { - cpu_select_step = 2; - } else { - cpu_select_step = 1; + Vector cpu_id_vec; + cpu_id_vec.reserve(cpu_count); + // even cpus first + for (u64 cpu_id = 0; cpu_id < cpu_count; cpu_id += 2) { + cpu_id_vec.push_back(cpu_id); + } + // then add odd cpus + for (u64 cpu_id = 1; cpu_id < cpu_count; cpu_id += 2) { + cpu_id_vec.push_back(cpu_id); } - for (u64 cpu_id = 0, worker_id = 0; worker_id < worker_count_; ++worker_id) { + for (u64 worker_id = 0; worker_id < worker_count_; ++worker_id) { + const u64 cpu_id = cpu_id_vec[worker_id]; UniquePtr worker_queue = MakeUnique(); UniquePtr worker_thread = MakeUnique(&TaskScheduler::WorkerLoop, this, worker_queue.get(), worker_id); // Pin the thread to specific cpu - ThreadUtil::pin(*worker_thread, cpu_id % cpu_count); + ThreadUtil::pin(*worker_thread, cpu_id); worker_array_.emplace_back(cpu_id, std::move(worker_queue), std::move(worker_thread)); worker_workloads_[worker_id] = 0; - cpu_id += cpu_select_step; } if (worker_array_.empty()) { diff --git a/src/storage/knn_index/ann_ivf/some_simd_functions.cpp b/src/storage/knn_index/ann_ivf/some_simd_functions.cpp new file mode 100644 index 0000000000..b8e7414b29 --- /dev/null +++ b/src/storage/knn_index/ann_ivf/some_simd_functions.cpp @@ -0,0 +1,152 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include + +module some_simd_functions; + +import stl; +import emvb_simd_funcs; + +#if defined(__x86_64__) && (defined(__clang_major__) && (__clang_major__ > 10)) +#define IMPRECISE_FUNCTION_BEGIN _Pragma("float_control(precise, off, push)") +#define IMPRECISE_FUNCTION_END _Pragma("float_control(pop)") +#define IMPRECISE_LOOP _Pragma("clang loop vectorize(enable) interleave(enable)") +#else +#define IMPRECISE_FUNCTION_BEGIN +#define IMPRECISE_FUNCTION_END +#define IMPRECISE_LOOP +#endif + +IMPRECISE_FUNCTION_BEGIN +float fvec_L2sqr(const float *x, const float *y, const size_t d) { + float res = 0.0f; + IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} +IMPRECISE_FUNCTION_END + +namespace infinity { + +inline f32 L2Distance_simd_128(const f32 *vector1, const f32 *vector2, u32) { + __m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + __m256 diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + __m256 diff_3 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 16), _mm256_loadu_ps(vector2 + 16)); + __m256 diff_4 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 24), _mm256_loadu_ps(vector2 + 24)); + __m256 sum_1 = _mm256_mul_ps(diff_1, diff_1); + __m256 sum_2 = _mm256_mul_ps(diff_2, diff_2); + __m256 sum_3 = _mm256_mul_ps(diff_3, diff_3); + __m256 sum_4 = _mm256_mul_ps(diff_4, diff_4); + vector1 += 32; + vector2 += 32; + diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + diff_3 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 16), _mm256_loadu_ps(vector2 + 16)); + diff_4 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 24), _mm256_loadu_ps(vector2 + 24)); + sum_1 = _mm256_fmadd_ps(diff_1, diff_1, sum_1); + sum_2 = _mm256_fmadd_ps(diff_2, diff_2, sum_2); + sum_3 = _mm256_fmadd_ps(diff_3, diff_3, sum_3); + sum_4 = _mm256_fmadd_ps(diff_4, diff_4, sum_4); + vector1 += 32; + vector2 += 32; + diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + diff_3 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 16), _mm256_loadu_ps(vector2 + 16)); + diff_4 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 24), _mm256_loadu_ps(vector2 + 24)); + sum_1 = _mm256_fmadd_ps(diff_1, diff_1, sum_1); + sum_2 = _mm256_fmadd_ps(diff_2, diff_2, sum_2); + sum_3 = _mm256_fmadd_ps(diff_3, diff_3, sum_3); + sum_4 = _mm256_fmadd_ps(diff_4, diff_4, sum_4); + vector1 += 32; + vector2 += 32; + diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + diff_3 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 16), _mm256_loadu_ps(vector2 + 16)); + diff_4 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 24), _mm256_loadu_ps(vector2 + 24)); + sum_1 = _mm256_fmadd_ps(diff_1, diff_1, sum_1); + sum_2 = _mm256_fmadd_ps(diff_2, diff_2, sum_2); + sum_3 = _mm256_fmadd_ps(diff_3, diff_3, sum_3); + sum_4 = _mm256_fmadd_ps(diff_4, diff_4, sum_4); + __m256 sum_half_1 = _mm256_add_ps(sum_1, sum_2); + __m256 sum_half_2 = _mm256_add_ps(sum_3, sum_4); + __m256 sum = _mm256_add_ps(sum_half_1, sum_half_2); + return hsum256_ps_avx(sum); +} + +inline f32 L2Distance_simd_16_multi(const f32 *vector1, const f32 *vector2, const u32 dimension) { + if (dimension < 16) { + return fvec_L2sqr(vector1, vector2, dimension); + } + __m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + __m256 diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + __m256 sum_1 = _mm256_mul_ps(diff_1, diff_1); + __m256 sum_2 = _mm256_mul_ps(diff_2, diff_2); + u32 pos = 16; + while (pos + 16 <= dimension) { + vector1 += 16; + vector2 += 16; + diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + sum_1 = _mm256_fmadd_ps(diff_1, diff_1, sum_1); + sum_2 = _mm256_fmadd_ps(diff_2, diff_2, sum_2); + pos += 16; + } + __m256 sum = _mm256_add_ps(sum_1, sum_2); + float distance = hsum256_ps_avx(sum); + if (pos < dimension) { + distance += fvec_L2sqr(vector1 + 16, vector2 + 16, dimension - pos); + } + return distance; +} + +f32 L2Distance_simd(const f32 *vector1, const f32 *vector2, const u32 dimension) { + if (dimension == 128) { + return L2Distance_simd_128(vector1, vector2, dimension); + } + if (dimension % 32 != 0) { + return L2Distance_simd_16_multi(vector1, vector2, dimension); + } + __m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + __m256 diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + __m256 diff_3 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 16), _mm256_loadu_ps(vector2 + 16)); + __m256 diff_4 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 24), _mm256_loadu_ps(vector2 + 24)); + __m256 sum_1 = _mm256_mul_ps(diff_1, diff_1); + __m256 sum_2 = _mm256_mul_ps(diff_2, diff_2); + __m256 sum_3 = _mm256_mul_ps(diff_3, diff_3); + __m256 sum_4 = _mm256_mul_ps(diff_4, diff_4); + for (u32 pos = 32; pos + 32 <= dimension; pos += 32) { + vector1 += 32; + vector2 += 32; + diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); + diff_2 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 8), _mm256_loadu_ps(vector2 + 8)); + diff_3 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 16), _mm256_loadu_ps(vector2 + 16)); + diff_4 = _mm256_sub_ps(_mm256_loadu_ps(vector1 + 24), _mm256_loadu_ps(vector2 + 24)); + sum_1 = _mm256_fmadd_ps(diff_1, diff_1, sum_1); + sum_2 = _mm256_fmadd_ps(diff_2, diff_2, sum_2); + sum_3 = _mm256_fmadd_ps(diff_3, diff_3, sum_3); + sum_4 = _mm256_fmadd_ps(diff_4, diff_4, sum_4); + } + __m256 sum_half_1 = _mm256_add_ps(sum_1, sum_2); + __m256 sum_half_2 = _mm256_add_ps(sum_3, sum_4); + __m256 sum = _mm256_add_ps(sum_half_1, sum_half_2); + return hsum256_ps_avx(sum); +} + +} // namespace infinity \ No newline at end of file diff --git a/src/storage/knn_index/ann_ivf/some_simd_functions.cppm b/src/storage/knn_index/ann_ivf/some_simd_functions.cppm index 481a7773d7..8609feedcd 100644 --- a/src/storage/knn_index/ann_ivf/some_simd_functions.cppm +++ b/src/storage/knn_index/ann_ivf/some_simd_functions.cppm @@ -59,7 +59,9 @@ float calc_256_sum_8(__m256 x) { #if defined(__AVX2__) -export f32 L2Distance_simd(const f32 *vector1, const f32 *vector2, u32 dimension) { +export f32 L2Distance_simd(const f32 *vector1, const f32 *vector2, u32 dimension); + +export f32 L2Distance_simd_old(const f32 *vector1, const f32 *vector2, u32 dimension) { u32 i = 0; __m256 sum_1 = _mm256_setzero_ps(); __m256 sum_2 = _mm256_setzero_ps(); diff --git a/test/sql/explain/explain_fusion.slt b/test/sql/explain/explain_fusion.slt index 91267d2a23..f2be3c5b50 100644 --- a/test/sql/explain/explain_fusion.slt +++ b/test/sql/explain/explain_fusion.slt @@ -21,19 +21,16 @@ PROJECT (5) - filter for secondary index: None - filter except secondary index: 10 > CAST(num (#1) AS BigInt) - output columns: [title, num, __score, __rowid] - -> MERGE KNN (7) + -> KNN SCAN (3) + - table name: explain_fusion(default_db.explain_fusion) - table index: #5 + - embedding info: vec + - element type: FLOAT32 + - dimension: 4 + - distance type: L2 + - query embedding: [0,-10,0,0.7] + - filter: 10 > CAST(num (#1) AS BigInt) - output columns: [title, num, __score, __rowid] - -> KNN SCAN (3) - - table name: explain_fusion(default_db.explain_fusion) - - table index: #5 - - embedding info: vec - - element type: FLOAT32 - - dimension: 4 - - distance type: L2 - - query embedding: [0,-10,0,0.7] - - filter: 10 > CAST(num (#1) AS BigInt) - - output columns: [title, num, __score, __rowid] query I EXPLAIN SELECT title FROM explain_fusion SEARCH MATCH TENSOR (t, [0.0, -10.0, 0.0, 0.7, 9.2, 45.6, -55.8, 3.5], 'float', 'maxsim') WHERE 10 > num;