Skip to content

Commit

Permalink
Update flat knn scan scheduler (#1367)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Update flat knn scan scheduler
Try to use at most 3 CPUs instead of all cpus
Improve flat knn scan performance

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- [x] Performance Improvement
  • Loading branch information
yangzq50 authored Jun 21, 2024
1 parent 331e4d7 commit c4c149a
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 60 deletions.
2 changes: 1 addition & 1 deletion python/benchmark/legacy_benchmark/remote_benchmark_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def one_thread(rounds, query_path, ground_truth_path, table_name):
results.append(f"Recall@10: {recall_10}")
results.append(f"Recall@100: {recall_100}")

conn.disconnect()
conn.disconnect()

for result in results:
print(result)
Expand Down
90 changes: 60 additions & 30 deletions src/executor/operator/physical_scan/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,29 @@ void ReadDataBlock(DataBlock *output,

void PhysicalKnnScan::Init() {}

void PhysicalKnnScan::InitBlockParallelOption() {
// TODO: Set brute force block parallel option
// 0. [0, 50), 1 thread
// 1. [50, 100), 2 thread
block_parallel_options_.emplace_back(50, 2);
// 2. [100, +Inf), 3 threads
block_parallel_options_.emplace_back(100, 3);
}

SizeT PhysicalKnnScan::BlockScanTaskCount() const {
const u32 block_cnt = block_column_entries_size_;
SizeT brute_task_n = 1;
for (const auto &[block_n, job_n] : block_parallel_options_) {
if (block_cnt < block_n) {
break;
}
brute_task_n = job_n;
}
return brute_task_n;
}

SizeT PhysicalKnnScan::TaskletCount() { return BlockScanTaskCount() + index_entries_size_; }

bool PhysicalKnnScan::Execute(QueryContext *query_context, OperatorState *operator_state) {
auto *knn_scan_operator_state = static_cast<KnnScanOperatorState *>(operator_state);
auto elem_type = knn_scan_operator_state->knn_scan_function_data_->knn_scan_shared_data_->elem_type_;
Expand Down Expand Up @@ -182,6 +205,7 @@ String PhysicalKnnScan::TableAlias() const { return base_table_ref_->alias_; }
Vector<SizeT> &PhysicalKnnScan::ColumnIDs() const { return base_table_ref_->column_ids_; }

void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: return base entry vector
InitBlockParallelOption(); // PlanWithIndex() will be called in physical planner
Txn *txn = query_context->GetTxn();
TransactionID txn_id = txn->TxnID();
TxnTimeStamp begin_ts = txn->BeginTS();
Expand Down Expand Up @@ -237,7 +261,9 @@ void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: retu
}
}
}
LOG_TRACE(fmt::format("KnnScan: brute force task: {}, index task: {}", block_column_entries_->size(), index_entries_->size()));
block_column_entries_size_ = block_column_entries_->size();
index_entries_size_ = index_entries_->size();
LOG_TRACE(fmt::format("KnnScan: brute force task: {}, index task: {}", block_column_entries_size_, index_entries_size_));
}

SizeT PhysicalKnnScan::BlockEntryCount() const { return base_table_ref_->block_index_->BlockCount(); }
Expand All @@ -259,39 +285,43 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
auto merge_heap = static_cast<MergeKnn<DataType, C> *>(knn_scan_function_data->merge_knn_base_.get());
auto query = static_cast<const DataType *>(knn_scan_shared_data->query_embedding_);

SizeT index_task_n = knn_scan_shared_data->index_entries_->size();
SizeT brute_task_n = knn_scan_shared_data->block_column_entries_->size();
const SizeT index_task_n = knn_scan_shared_data->index_entries_->size();
const SizeT brute_task_n = knn_scan_shared_data->block_column_entries_->size();
BlockIndex *block_index = knn_scan_shared_data->table_ref_->block_index_.get();

if (u64 block_column_idx = knn_scan_shared_data->current_block_idx_++; block_column_idx < brute_task_n) {
if (u64 block_column_idx =
knn_scan_function_data->execute_block_scan_job_ ? knn_scan_shared_data->current_block_idx_++ : std::numeric_limits<u64>::max();
block_column_idx < brute_task_n) {
LOG_TRACE(fmt::format("KnnScan: {} brute force {}/{}", knn_scan_function_data->task_id_, block_column_idx + 1, brute_task_n));
// brute force
BlockColumnEntry *block_column_entry = knn_scan_shared_data->block_column_entries_->at(block_column_idx);
const BlockEntry *block_entry = block_column_entry->block_entry();
const auto block_id = block_entry->block_id();
const SegmentID segment_id = block_entry->GetSegmentEntry()->segment_id();
const auto row_count = block_entry->row_count();

Bitmask bitmask;
if (this->CalculateFilterBitmask(segment_id, block_id, row_count, bitmask)) {
LOG_TRACE(fmt::format("KnnScan: {} brute force {}/{} not skipped after common_query_filter",
knn_scan_function_data->task_id_,
block_column_idx + 1,
brute_task_n));
block_entry->SetDeleteBitmask(begin_ts, bitmask);
BufferManager *buffer_mgr = query_context->storage()->buffer_manager();
ColumnVector column_vector = block_column_entry->GetColumnVector(buffer_mgr);

auto data = reinterpret_cast<const DataType *>(column_vector.data());
merge_heap->Search(query,
data,
knn_scan_shared_data->dimension_,
dist_func->dist_func_,
row_count,
block_entry->segment_id(),
block_entry->block_id(),
bitmask);
}
// TODO: now will try to finish all block scan job in the task
do {
BlockColumnEntry *block_column_entry = knn_scan_shared_data->block_column_entries_->at(block_column_idx);
const BlockEntry *block_entry = block_column_entry->block_entry();
const auto block_id = block_entry->block_id();
const SegmentID segment_id = block_entry->GetSegmentEntry()->segment_id();
const auto row_count = block_entry->row_count();
Bitmask bitmask;
if (this->CalculateFilterBitmask(segment_id, block_id, row_count, bitmask)) {
// LOG_TRACE(fmt::format("KnnScan: {} brute force {}/{} not skipped after common_query_filter",
// knn_scan_function_data->task_id_,
// block_column_idx + 1,
// brute_task_n));
block_entry->SetDeleteBitmask(begin_ts, bitmask);
BufferManager *buffer_mgr = query_context->storage()->buffer_manager();
ColumnVector column_vector = block_column_entry->GetColumnVector(buffer_mgr);
auto data = reinterpret_cast<const DataType *>(column_vector.data());
merge_heap->Search(query,
data,
knn_scan_shared_data->dimension_,
dist_func->dist_func_,
row_count,
segment_id,
block_id,
bitmask);
}
block_column_idx = knn_scan_shared_data->current_block_idx_++;
} while (block_column_idx < brute_task_n);
} else if (u64 index_idx = knn_scan_shared_data->current_index_idx_++; index_idx < index_task_n) {
LOG_TRACE(fmt::format("KnnScan: {} index {}/{}", knn_scan_function_data->task_id_, index_idx + 1, index_task_n));
// with index
Expand Down
9 changes: 8 additions & 1 deletion src/executor/operator/physical_scan/physical_knn_scan.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ public:

void PlanWithIndex(QueryContext *query_context);

SizeT TaskletCount() override { return block_column_entries_->size() + index_entries_->size(); }
SizeT BlockScanTaskCount() const;

SizeT TaskletCount() override;

void FillingTableRefs(HashMap<SizeT, SharedPtr<BaseTableRef>> &table_refs) override {
table_refs.insert({base_table_ref_->table_index_, base_table_ref_});
Expand All @@ -88,10 +90,15 @@ public:
SharedPtr<Vector<SharedPtr<DataType>>> output_types_{};
u64 knn_table_index_{};

Vector<Pair<u32, u32>> block_parallel_options_;
u32 block_column_entries_size_ = 0; // need this value because block_column_entries_ will be moved into KnnScanSharedData
u32 index_entries_size_ = 0;
UniquePtr<Vector<BlockColumnEntry *>> block_column_entries_{};
UniquePtr<Vector<SegmentIndexEntry *>> index_entries_{};

private:
void InitBlockParallelOption();

template <typename DataType, template <typename, typename> typename C>
void ExecuteInternal(QueryContext *query_context, KnnScanOperatorState *operator_state);
};
Expand Down
4 changes: 2 additions & 2 deletions src/function/table/knn_scan_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ KnnDistance1<f32>::KnnDistance1(KnnDistanceType dist_type) {

// --------------------------------------------

KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx)
: knn_scan_shared_data_(shared_data), task_id_(current_parallel_idx) {
KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx, bool execute_block_scan_job)
: knn_scan_shared_data_(shared_data), task_id_(current_parallel_idx), execute_block_scan_job_(execute_block_scan_job) {
switch (knn_scan_shared_data_->elem_type_) {
case EmbeddingDataType::kElemFloat: {
Init<f32>();
Expand Down
3 changes: 2 additions & 1 deletion src/function/table/knn_scan_data.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ KnnDistance1<f32>::KnnDistance1(KnnDistanceType dist_type);

export class KnnScanFunctionData final : public TableFunctionData {
public:
KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx);
KnnScanFunctionData(KnnScanSharedData *shared_data, u32 current_parallel_idx, bool execute_block_scan_job);

~KnnScanFunctionData() final = default;

Expand All @@ -122,6 +122,7 @@ private:
public:
KnnScanSharedData *knn_scan_shared_data_;
const u32 task_id_;
bool execute_block_scan_job_ = false;

UniquePtr<MergeKnnBase> merge_knn_base_{};
UniquePtr<KnnDistanceBase1> knn_distance_{};
Expand Down
8 changes: 5 additions & 3 deletions src/scheduler/fragment_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,20 @@ UniquePtr<OperatorState> MakeKnnScanState(PhysicalKnnScan *physical_knn_scan, Fr

UniquePtr<OperatorState> operator_state = MakeUnique<KnnScanOperatorState>();
KnnScanOperatorState *knn_scan_op_state_ptr = (KnnScanOperatorState *)(operator_state.get());

const bool execute_block_scan_job = task->TaskID() < static_cast<i64>(physical_knn_scan->BlockScanTaskCount());
switch (fragment_ctx->ContextType()) {
case FragmentType::kSerialMaterialize: {
SerialMaterializedFragmentCtx *serial_materialize_fragment_ctx = static_cast<SerialMaterializedFragmentCtx *>(fragment_ctx);
knn_scan_op_state_ptr->knn_scan_function_data_ =
MakeUnique<KnnScanFunctionData>(serial_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID());
MakeUnique<KnnScanFunctionData>(serial_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID(), execute_block_scan_job);
break;
}
case FragmentType::kParallelMaterialize: {
ParallelMaterializedFragmentCtx *parallel_materialize_fragment_ctx = static_cast<ParallelMaterializedFragmentCtx *>(fragment_ctx);
knn_scan_op_state_ptr->knn_scan_function_data_ =
MakeUnique<KnnScanFunctionData>(parallel_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID());
MakeUnique<KnnScanFunctionData>(parallel_materialize_fragment_ctx->knn_scan_shared_data_.get(),
task->TaskID(),
execute_block_scan_job);
break;
}
default: {
Expand Down
25 changes: 15 additions & 10 deletions src/scheduler/task_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,32 @@ namespace infinity {
TaskScheduler::TaskScheduler(Config *config_ptr) { Init(config_ptr); }

void TaskScheduler::Init(Config *config_ptr) {
worker_count_ = config_ptr->CPULimit();
const u64 cpu_count = Thread::hardware_concurrency();
const u64 config_cpu_limit = config_ptr->CPULimit();
worker_count_ = std::min(cpu_count, config_cpu_limit);
worker_array_.reserve(worker_count_);
worker_workloads_.resize(worker_count_);
u64 cpu_count = Thread::hardware_concurrency();

u64 cpu_select_step = cpu_count / worker_count_;
if (cpu_select_step >= 2) {
cpu_select_step = 2;
} else {
cpu_select_step = 1;
Vector<u64> cpu_id_vec;
cpu_id_vec.reserve(cpu_count);
// even cpus first
for (u64 cpu_id = 0; cpu_id < cpu_count; cpu_id += 2) {
cpu_id_vec.push_back(cpu_id);
}
// then add odd cpus
for (u64 cpu_id = 1; cpu_id < cpu_count; cpu_id += 2) {
cpu_id_vec.push_back(cpu_id);
}

for (u64 cpu_id = 0, worker_id = 0; worker_id < worker_count_; ++worker_id) {
for (u64 worker_id = 0; worker_id < worker_count_; ++worker_id) {
const u64 cpu_id = cpu_id_vec[worker_id];
UniquePtr<FragmentTaskBlockQueue> worker_queue = MakeUnique<FragmentTaskBlockQueue>();
UniquePtr<Thread> worker_thread = MakeUnique<Thread>(&TaskScheduler::WorkerLoop, this, worker_queue.get(), worker_id);
// Pin the thread to specific cpu
ThreadUtil::pin(*worker_thread, cpu_id % cpu_count);
ThreadUtil::pin(*worker_thread, cpu_id);

worker_array_.emplace_back(cpu_id, std::move(worker_queue), std::move(worker_thread));
worker_workloads_[worker_id] = 0;
cpu_id += cpu_select_step;
}

if (worker_array_.empty()) {
Expand Down
Loading

0 comments on commit c4c149a

Please sign in to comment.