diff --git a/src/common/analyzer/analyzer.cppm b/src/common/analyzer/analyzer.cppm index 2166fafc5e..039130d6aa 100644 --- a/src/common/analyzer/analyzer.cppm +++ b/src/common/analyzer/analyzer.cppm @@ -49,11 +49,23 @@ public: } protected: - typedef void (*HookType)(void *data, const char *text, const u32 len, const u32 offset, const u32 end_offset, const bool is_special_char); + typedef void (*HookType)(void *data, + const char *text, + const u32 len, + const u32 offset, + const u32 end_offset, + const bool is_special_char, + const u16 payload); virtual int AnalyzeImpl(const Term &input, void *data, HookType func) { return -1; } - static void AppendTermList(void *data, const char *text, const u32 len, const u32 offset, const u32 end_offset, const bool is_special_char) { + static void AppendTermList(void *data, + const char *text, + const u32 len, + const u32 offset, + const u32 end_offset, + const bool is_special_char, + const u16 payload) { void **parameters = (void **)data; TermList *output = (TermList *)parameters[0]; Analyzer *analyzer = (Analyzer *)parameters[1]; @@ -62,9 +74,9 @@ protected: return; if (is_special_char && analyzer->convert_to_placeholder_) { if (output->empty() == true || output->back().text_.compare(PLACE_HOLDER) != 0) - output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset); + output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset, payload); } else { - output->Add(text, len, offset, end_offset); + output->Add(text, len, offset, end_offset, payload); } } diff --git a/src/common/analyzer/analyzer_pool.cpp b/src/common/analyzer/analyzer_pool.cpp index 1dee2c8039..2ada50741a 100644 --- a/src/common/analyzer/analyzer_pool.cpp +++ b/src/common/analyzer/analyzer_pool.cpp @@ -34,6 +34,7 @@ import ngram_analyzer; import rag_analyzer; import whitespace_analyzer; import ik_analyzer; +import rank_features_analyzer; import logger; namespace infinity { @@ -330,6 +331,9 @@ Tuple, Status> AnalyzerPool::GetAnalyzer(const std::string_v } return {MakeUnique(name.substr(suffix_pos + 1)), Status::OK()}; } + case Str2Int(RANKFEATURES.data()): { + return {MakeUnique(), Status::OK()}; + } default: { if (std::filesystem::is_regular_file(name)) { // Suppose it is a customized Python script analyzer diff --git a/src/common/analyzer/analyzer_pool.cppm b/src/common/analyzer/analyzer_pool.cppm index 804783cfe7..34c309e77d 100644 --- a/src/common/analyzer/analyzer_pool.cppm +++ b/src/common/analyzer/analyzer_pool.cppm @@ -45,6 +45,7 @@ public: static constexpr std::string_view IK = "ik"; static constexpr std::string_view KEYWORD = "keyword"; static constexpr std::string_view WHITESPACE = "whitespace"; + static constexpr std::string_view RANKFEATURES = "rankfeatures"; private: CacheType cache_{}; diff --git a/src/common/analyzer/common_analyzer.cpp b/src/common/analyzer/common_analyzer.cpp index 8ffcbf6261..1ed5721d69 100644 --- a/src/common/analyzer/common_analyzer.cpp +++ b/src/common/analyzer/common_analyzer.cpp @@ -52,12 +52,12 @@ int CommonLanguageAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType if (is_index_) { if (IsSpecialChar()) { - func(data, token_, len_, offset_, end_offset_, true); + func(data, token_, len_, offset_, end_offset_, true, 0); temp_offset = offset_; continue; } if (is_raw_) { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; continue; } @@ -77,37 +77,37 @@ int CommonLanguageAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType bool lowercase_is_different = memcmp(token_, lowercase_term, len_) != 0; if (stemming_term_str_size && stem_only_) { - func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false); + func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false, 0); temp_offset = offset_; } else if (stemming_term_str_size || (case_sensitive_ && contain_lower_ && lowercase_is_different)) { /// have more than one output if (case_sensitive_) { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } else { - func(data, lowercase_term, len_, offset_, end_offset_, false); + func(data, lowercase_term, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } if (stemming_term_str_size) { - func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false); + func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false, 0); temp_offset = offset_; } if (case_sensitive_ && contain_lower_ && lowercase_is_different) { - func(data, lowercase_term, len_, offset_, end_offset_, false); + func(data, lowercase_term, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } } else { /// have only one output if (case_sensitive_) { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } else { - func(data, lowercase_term, len_, offset_, end_offset_, false); + func(data, lowercase_term, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } } } else { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } } diff --git a/src/common/analyzer/ik/ik_analyzer.cpp b/src/common/analyzer/ik/ik_analyzer.cpp index 564e1ffb93..e58bc703a2 100644 --- a/src/common/analyzer/ik/ik_analyzer.cpp +++ b/src/common/analyzer/ik/ik_analyzer.cpp @@ -90,7 +90,7 @@ int IKAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { while ((lexeme = context_->GetNextLexeme()) != nullptr) { std::wstring text = lexeme->GetLexemeText(); String token = CharacterUtil::UTF16ToUTF8(text); - func(data, token.c_str(), token.size(), offset++, 0, false); + func(data, token.c_str(), token.size(), offset++, 0, false, 0); delete lexeme; }; return 0; diff --git a/src/common/analyzer/ngram_analyzer.cpp b/src/common/analyzer/ngram_analyzer.cpp index caf1a39d94..94e3a2673f 100644 --- a/src/common/analyzer/ngram_analyzer.cpp +++ b/src/common/analyzer/ngram_analyzer.cpp @@ -59,7 +59,7 @@ int NGramAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { while (cur < len && NextInString(input.text_.c_str(), len, &cur, &token_start, &token_length)) { if (token_length == 0) continue; - func(data, input.text_.c_str() + token_start, token_length, offset, offset + token_length, false); + func(data, input.text_.c_str() + token_start, token_length, offset, offset + token_length, false, 0); offset++; } diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index 7aa384ae81..4e900e603c 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -1438,7 +1438,7 @@ int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { Split(output, blank_pattern_, tokens); unsigned offset = 0; for (auto &t : tokens) { - func(data, t.c_str(), t.size(), offset++, 0, false); + func(data, t.c_str(), t.size(), offset++, 0, false, 0); } return 0; } diff --git a/src/common/analyzer/rank_features_analyzer.cpp b/src/common/analyzer/rank_features_analyzer.cpp new file mode 100644 index 0000000000..89b1ceb8e7 --- /dev/null +++ b/src/common/analyzer/rank_features_analyzer.cpp @@ -0,0 +1,52 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +module rank_features_analyzer; +import stl; +import term; +import analyzer; +import third_party; + +namespace infinity { + +u16 FloatToU16(float value) { + if (value < 0.0f) + value = 0.0f; + if (value > 65535.0f) + value = 65535.0f; + return static_cast(value); +} + +int RankFeaturesAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { + nlohmann::json line_json = nlohmann::json::parse(input.text_); + u32 offset = 0; + for (auto iter = line_json.begin(); iter != line_json.end(); ++iter) { + String key = iter.key(); + String value = iter.value(); + float target = 0; + try { + target = std::stof(value); + } catch (const std::exception &e) { + } + u16 weight = FloatToU16(target); + func(data, key.data(), key.size(), offset++, 0, false, weight); + } + + return 0; +} + +} // namespace infinity diff --git a/src/common/analyzer/rank_features_analyzer.cppm b/src/common/analyzer/rank_features_analyzer.cppm new file mode 100644 index 0000000000..79237ea0f5 --- /dev/null +++ b/src/common/analyzer/rank_features_analyzer.cppm @@ -0,0 +1,35 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module rank_features_analyzer; +import stl; +import term; +import analyzer; + +namespace infinity { + +export class RankFeaturesAnalyzer : public Analyzer { + String delimiters_{}; + +public: + RankFeaturesAnalyzer() = default; + ~RankFeaturesAnalyzer() override = default; + +protected: + int AnalyzeImpl(const Term &input, void *data, HookType func) override; +}; + +} // namespace infinity diff --git a/src/common/analyzer/term.cppm b/src/common/analyzer/term.cppm index 3543f96ea2..ea556ce7c6 100644 --- a/src/common/analyzer/term.cppm +++ b/src/common/analyzer/term.cppm @@ -41,11 +41,12 @@ public: export class TermList : public Deque { public: - void Add(const char *text, const u32 len, const u32 offset, const u32 end_offset) { + void Add(const char *text, const u32 len, const u32 offset, const u32 end_offset, const u16 payload = 0) { push_back(global_temporary_); back().text_.assign(text, len); back().word_offset_ = offset; back().end_offset_ = end_offset; + back().payload_ = payload; } void Add(cppjieba::Word &cut_word) { @@ -54,18 +55,20 @@ public: back().word_offset_ = cut_word.offset; } - void Add(const String &token, const u32 offset, const u32 end_offset) { + void Add(const String &token, const u32 offset, const u32 end_offset, const u16 payload = 0) { push_back(global_temporary_); back().text_ = token; back().word_offset_ = offset; back().end_offset_ = end_offset; + back().payload_ = payload; } - void Add(String &token, const u32 offset, const u32 end_offset) { + void Add(String &token, const u32 offset, const u32 end_offset, const u16 payload = 0) { push_back(global_temporary_); std::swap(back().text_, token); back().word_offset_ = offset; back().end_offset_ = end_offset; + back().payload_ = payload; } private: diff --git a/src/common/analyzer/whitespace_analyzer.cpp b/src/common/analyzer/whitespace_analyzer.cpp index 0bd02d5134..83a558b6a8 100644 --- a/src/common/analyzer/whitespace_analyzer.cpp +++ b/src/common/analyzer/whitespace_analyzer.cpp @@ -37,7 +37,7 @@ int WhitespaceAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func std::string t; u32 offset = 0; while (is >> t) { - func(data, t.data(), t.size(), offset++, 0, false); + func(data, t.data(), t.size(), offset++, 0, false, 0); } return 0; } else { @@ -49,11 +49,11 @@ int WhitespaceAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func while (search_start < input_text.size()) { const auto found = input_text.find_first_of(delimiters, search_start); if (found == std::string_view::npos) { - func(data, input_text.data() + search_start, input_text.size() - search_start, offset++, 0, false); + func(data, input_text.data() + search_start, input_text.size() - search_start, offset++, 0, false, 0); break; } if (found > search_start) { - func(data, input_text.data() + search_start, found - search_start, offset++, 0, false); + func(data, input_text.data() + search_start, found - search_start, offset++, 0, false, 0); } search_start = found + 1; } diff --git a/src/executor/hash_table.cpp b/src/executor/hash_table.cpp index 7871ffda7a..d57d892109 100644 --- a/src/executor/hash_table.cpp +++ b/src/executor/hash_table.cpp @@ -16,109 +16,127 @@ module; #include import stl; - +import logical_type; import column_vector; - +import status; import infinity_exception; +import third_party; +import internal_types; module hash_table; -#if 0 namespace infinity { -void HashTable::Init(const Vector &types) { - types_ = types; - SizeT type_count = types.size(); + +void HashTableBase::Init(Vector> types) { + types_ = std::move(types); + SizeT key_size = 0; + SizeT type_count = types_.size(); for (SizeT idx = 0; idx < type_count; ++idx) { - const DataType &data_type = types[idx]; + const DataType &data_type = *types_[idx]; switch (data_type.type()) { - case kBoolean: - case kTinyInt: - case kSmallInt: - case kInteger: - case kBigInt: - case kHugeInt: - case kFloat: - case kDouble: - case kDecimal: - case kVarchar: - case kDate: - case kTime: - case kDateTime: - case kMixed: { + case LogicalType::kBoolean: + case LogicalType::kTinyInt: + case LogicalType::kSmallInt: + case LogicalType::kInteger: + case LogicalType::kBigInt: + case LogicalType::kFloat: + case LogicalType::kDouble: + case LogicalType::kDate: + case LogicalType::kTime: + case LogicalType::kDateTime: + case LogicalType::kTimestamp: { + SizeT type_size = data_type.Size(); + key_size += type_size; break; // All these type can be hashed. } - case kTimestamp: - case kInterval: - case kArray: - case kTuple: - case kPoint: - case kLine: - case kLineSeg: - case kBox: -// case kPath: -// case kPolygon: - case kCircle: -// case kBitmap: - case kUuid: -// case kBlob: - case kEmbedding: - case kRowID: - case kNull: - case kMissing: - case kInvalid: { + case LogicalType::kVarchar: { + key_size = 0; + break; // Varchar can be hashed. + } + default: { RecoverableError(Status::NotSupport(fmt::format("Attempt to construct hash key for type: {}", data_type.ToString()))); } } - - SizeT type_size = data_type.Size(); - key_size_ += type_size; + if (key_size == 0) { + break; + } } - // Key layout: col1\0col2\0col3\0. - key_size_ += type_count; + if (key_size) { + // Key layout: col1\0col2\0col3\0. + key_size += type_count; + } + key_size_ = key_size; } -void HashTable::Append(const Vector> &columns, SizeT block_id, SizeT row_count) { - UniquePtr hash_key = MakeUnique(key_size_); +void HashTableBase::GetHashKey(const Vector> &columns, SizeT row_id, String &hash_key) const { SizeT column_count = columns.size(); - for (SizeT row_id = 0; row_id < row_count; ++row_id) { - std::memset(hash_key.get(), 0, key_size_); - SizeT offset = 0; + hash_key.clear(); + bool has_null = false; + for (SizeT column_id = 0; column_id < column_count; ++column_id) { + if (!columns[column_id]->nulls_ptr_->IsTrue(row_id)) { + hash_key += "\0\0"; + has_null = true; + continue; + } - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - char *target_ptr = hash_key.get() + offset; - if (!columns[column_id]->nulls_ptr_->IsTrue(row_id)) { - *(target_ptr) = '\0'; - offset += 2; - continue; - } + const DataType &data_type = *types_[column_id]; - DataType &data_type = types_[column_id]; - if (data_type.type() == kMixed) { - // Only float/boolean/integer/string can be built as hash key. Array/Tuple will be treated as null - RecoverableError(Status::NotSupport("Attempt to construct hash key for heterogeneous type")); - } + if (data_type.type() == LogicalType::kVarchar) { + Span text = columns[column_id]->GetVarchar(row_id); + hash_key.append(text.begin(), text.end()); + } else { + SizeT type_size = types_[column_id]->Size(); + Span binary(reinterpret_cast(columns[column_id]->data() + type_size * row_id), type_size); + hash_key.append(binary.begin(), binary.end()); + } + hash_key += '\0'; + } + if (!has_null && key_size_ && hash_key.size() != key_size_) { + UnrecoverableError(fmt::format("Hash key size mismatch: {} vs {}", hash_key.size(), key_size_)); + } +} - if (data_type.type() == kVarchar) { - VarcharT *vchar_ptr = &((VarcharT *)(columns[column_id]->data_ptr_))[row_id]; - if (vchar_ptr->IsInlined()) { - std::memcpy(target_ptr, vchar_ptr->prefix, vchar_ptr->length); - } else { - std::memcpy(target_ptr, vchar_ptr->ptr, vchar_ptr->length); - } - offset += (vchar_ptr->length + 1); - } else { - SizeT type_size = types_[column_id].Size(); - std::memcpy(target_ptr, columns[column_id]->data_ptr_ + type_size * row_id, type_size); - offset += (type_size + 1); - } +void HashTable::Append(const Vector> &columns, SizeT block_id, SizeT row_count) { + String hash_key; + if (key_size_) { + hash_key.reserve(key_size_); + } + for (SizeT row_id = 0; row_id < row_count; ++row_id) { + GetHashKey(columns, row_id, hash_key); + hash_table_[std::move(hash_key)][block_id].emplace_back(row_id); + } +} + +void MergeHashTable::Append(const Vector> &columns, SizeT block_id, SizeT row_count) { + String hash_key; + if (key_size_) { + hash_key.reserve(key_size_); + } + for (SizeT row_id = 0; row_id < row_count; ++row_id) { + GetHashKey(columns, row_id, hash_key); + if (auto iter = hash_table_.find(hash_key); iter != hash_table_.end()) { + UnrecoverableError("Duplicate key in merge hash table"); + } else { + hash_table_.emplace_hint(iter, std::move(hash_key), Pair(block_id, row_id)); } + } +} - String key(hash_key.get(), key_size_); - hash_table_[key][block_id].emplace_back(row_id); +bool MergeHashTable::GetOrInsert(const Vector> &columns, SizeT row_id, Pair &block_row_id) { + String hash_key; + if (key_size_) { + hash_key.reserve(key_size_); + } + GetHashKey(columns, row_id, hash_key); + auto iter = hash_table_.find(hash_key); + if (iter == hash_table_.end()) { + hash_table_.emplace_hint(iter, std::move(hash_key), block_row_id); + return false; } + block_row_id = iter->second; + return true; } -} // namespace infinity -#endif \ No newline at end of file +} // namespace infinity \ No newline at end of file diff --git a/src/executor/hash_table.cppm b/src/executor/hash_table.cppm index ed1b41060c..13c58dcfcc 100644 --- a/src/executor/hash_table.cppm +++ b/src/executor/hash_table.cppm @@ -23,18 +23,37 @@ import data_type; namespace infinity { -export class HashTable { +class HashTableBase { public: - void Init(const Vector &types); + bool Initialized() const { return !types_.empty(); } - void Append(const Vector> &columns, SizeT block_id, SizeT row_count); + void Init(Vector> types); + + void GetHashKey(const Vector> &columns, SizeT row_id, String &hash_key) const; public: - Vector types_{}; + Vector> types_{}; SizeT key_size_{}; +}; +export class HashTable : public HashTableBase { +public: + void Append(const Vector> &columns, SizeT block_id, SizeT row_count); + +public: // Key -> (block id -> row array) HashMap>> hash_table_{}; }; +export class MergeHashTable : public HashTableBase { +public: + void Append(const Vector> &columns, SizeT block_id, SizeT row_count); + + bool GetOrInsert(const Vector> &columns, SizeT row_id, Pair &block_row_id); + +public: + // Key -> (block id, row id) + HashMap> hash_table_{}; +}; + } // namespace infinity diff --git a/src/executor/operator/physical_aggregate.cpp b/src/executor/operator/physical_aggregate.cpp index 9cf0b9f804..243ecef7d9 100644 --- a/src/executor/operator/physical_aggregate.cpp +++ b/src/executor/operator/physical_aggregate.cpp @@ -51,12 +51,8 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper OperatorState *prev_op_state = operator_state->prev_op_state_; auto *aggregate_operator_state = static_cast(operator_state); - // 1. Execute group-by expressions to generate unique key. - // ExpressionEvaluator groupby_executor; - // groupby_executor.Init(groups_); - - Vector> groupby_columns; SizeT group_count = groups_.size(); + bool task_completed = prev_op_state->Complete(); if (group_count == 0) { // Aggregate without group by expression @@ -64,71 +60,80 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper auto result = SimpleAggregateExecute(prev_op_state->data_block_array_, aggregate_operator_state->data_block_array_, aggregate_operator_state->states_, - prev_op_state->Complete()); + task_completed); prev_op_state->data_block_array_.clear(); if (prev_op_state->Complete()) { aggregate_operator_state->SetComplete(); } return result; } -#if 0 + + // 1. Execute group-by expressions to generate unique key. + Vector> groupby_columns; groupby_columns.reserve(group_count); - Vector types; - types.reserve(group_count); + Vector> groupby_types; + groupby_types.reserve(group_count); - for(i64 idx = 0; auto& expr: groups_) { - SharedPtr col_def = MakeShared(idx, - MakeShared(expr->Type()), - expr->Name(), - std::set()); + for (i64 idx = 0; auto &expr : groups_) { + SharedPtr col_def = MakeShared(idx, MakeShared(expr->Type()), expr->Name(), std::set()); groupby_columns.emplace_back(col_def); - types.emplace_back(expr->Type()); - ++ idx; + groupby_types.emplace_back(MakeShared(expr->Type())); + ++idx; } - SharedPtr groupby_tabledef = TableDef::Make(MakeShared("default_db"), MakeShared("groupby"), groupby_columns); + SharedPtr groupby_tabledef = + TableDef::Make(MakeShared("default_db"), MakeShared("groupby"), MakeShared(""), groupby_columns); SharedPtr groupby_table = DataTable::Make(groupby_tabledef, TableType::kIntermediate); - groupby_executor.Execute(input_table_, groupby_table); + // Prepare the expression states + Vector> expr_states; + expr_states.reserve(group_count); + for (const auto &expr : groups_) { + // expression state + expr_states.emplace_back(ExpressionState::CreateState(expr)); + } + + SizeT input_block_count = prev_op_state->data_block_array_.size(); + for (SizeT block_idx = 0; block_idx < input_block_count; ++block_idx) { + DataBlock *input_data_block = prev_op_state->data_block_array_[block_idx].get(); + + groupby_table->data_blocks_.emplace_back(DataBlock::MakeUniquePtr()); + DataBlock *output_data_block = groupby_table->data_blocks_.back().get(); + output_data_block->Init(groupby_types, 1); + + ExpressionEvaluator groupby_executor; + groupby_executor.Init(input_data_block); + + for (SizeT expr_idx = 0; expr_idx < group_count; ++expr_idx) { + groupby_executor.Execute(groups_[expr_idx], expr_states[expr_idx], output_data_block->column_vectors[expr_idx]); + } + output_data_block->Finalize(); + } // 2. Use the unique key to get the row list of the same key. - hash_table_.Init(types); + HashTable &hash_table = aggregate_operator_state->hash_table_; + if (!hash_table.Initialized()) { + hash_table.Init(groupby_types); + } SizeT block_count = groupby_table->DataBlockCount(); - for(SizeT block_id = 0; block_id < block_count; ++ block_id) { - const SharedPtr& block_ptr = groupby_table->GetDataBlockById(block_id); - hash_table_.Append(block_ptr->column_vectors, block_id, block_ptr->row_count()); + for (SizeT block_id = 0; block_id < block_count; ++block_id) { + const SharedPtr &block_ptr = groupby_table->GetDataBlockById(block_id); + hash_table.Append(block_ptr->column_vectors, block_id, block_ptr->row_count()); } // 3. forlop each aggregates function on each group by bucket, to calculate the result according to the row list SharedPtr output_groupby_table = DataTable::Make(groupby_tabledef, TableType::kIntermediate); - GenerateGroupByResult(groupby_table, output_groupby_table); - + GenerateGroupByResult(groupby_table, output_groupby_table, hash_table); // input table after group by, each block belong to one group. This is the prerequisites to execute aggregate function. - SharedPtr grouped_input_table; - { - SizeT column_count = input_table_->ColumnCount(); - Vector> columns; - columns.reserve(column_count); - for(SizeT idx = 0; idx < column_count; ++ idx) { - SharedPtr col_type = input_table_->GetColumnTypeById(idx); - String col_name = input_table_->GetColumnNameById(idx); - - SharedPtr col_def = MakeShared(idx, col_type, col_name, std::set()); - columns.emplace_back(col_def); - } - - SharedPtr table_def = TableDef::Make(MakeShared("default_db"), MakeShared("grouped_input"), columns); - - grouped_input_table = DataTable::Make(table_def, TableType::kGroupBy); - } - GroupByInputTable(input_table_, grouped_input_table); + Vector> grouped_input_datablocks; + GroupByInputTable(prev_op_state->data_block_array_, grouped_input_datablocks, hash_table); // generate output aggregate table SizeT aggregates_count = aggregates_.size(); - if(aggregates_count > 0) { + if (aggregates_count > 0) { SharedPtr output_aggregate_table{}; // Prepare the output table columns @@ -142,53 +147,47 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper // Prepare the output block Vector> output_types; output_types.reserve(aggregates_count); + auto &agg_states = aggregate_operator_state->states_; - for(i64 idx = 0; auto& expr: aggregates_) { + AggregateFlag flag = aggregate_operator_state->data_block_array_.empty() + ? (!task_completed ? AggregateFlag::kUninitialized : AggregateFlag::kRunAndFinish) + : (!task_completed ? AggregateFlag::kRunning : AggregateFlag::kFinish); + for (i64 idx = 0; auto &expr : aggregates_) { // expression state - expr_states.emplace_back(ExpressionState::CreateState(expr)); + expr_states.emplace_back(ExpressionState::CreateState(std::static_pointer_cast(expr), agg_states[idx].get(), flag)); SharedPtr data_type = MakeShared(expr->Type()); // column definition - SharedPtr col_def = MakeShared(idx, - data_type, - expr->Name(), - std::set()); + SharedPtr col_def = MakeShared(idx, data_type, expr->Name(), std::set()); aggregate_columns.emplace_back(col_def); // for output block output_types.emplace_back(data_type); - ++ idx; + ++idx; } // output aggregate table definition - SharedPtr aggregate_tabledef = TableDef::Make(MakeShared("default_db"), - MakeShared("aggregate"), - aggregate_columns); + SharedPtr aggregate_tabledef = + TableDef::Make(MakeShared("default_db"), MakeShared("aggregate"), MakeShared(""), aggregate_columns); output_aggregate_table = DataTable::Make(aggregate_tabledef, TableType::kAggregate); // Loop blocks - HashMap> block_map; - SizeT input_data_block_count = grouped_input_table->DataBlockCount(); - for(SizeT block_idx = 0; block_idx < input_data_block_count; ++ block_idx) { + SizeT input_data_block_count = grouped_input_datablocks.size(); + for (SizeT block_idx = 0; block_idx < input_data_block_count; ++block_idx) { SharedPtr output_data_block = DataBlock::Make(); - output_data_block->Init(output_types); - Vector> input_blocks{grouped_input_table->GetDataBlockById(block_idx)}; -// block_map[groupby_index_] = block_ptr; -// block_map[input_table_index_] = block_ptr; + output_data_block->Init(output_types, 1); + DataBlock *input_block = grouped_input_datablocks[block_idx].get(); // Loop aggregate expression ExpressionEvaluator evaluator; - evaluator.Init(input_blocks); - for(SizeT expr_idx = 0; expr_idx < aggregates_count; ++ expr_idx) { - Vector> blocks_column; - blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]); - evaluator.Execute(aggregates_[expr_idx], - expr_states[expr_idx], - blocks_column); - if(blocks_column[0].get() != output_data_block->column_vectors[expr_idx].get()) { + evaluator.Init(input_block); + for (SizeT expr_idx = 0; expr_idx < aggregates_count; ++expr_idx) { + SharedPtr blocks_column = output_data_block->column_vectors[expr_idx]; + evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], blocks_column); + if (blocks_column.get() != output_data_block->column_vectors[expr_idx].get()) { // column vector in blocks column might be changed to the column vector from column reference. // This check and assignment is to make sure the right column vector are assign to output_data_block - output_data_block->column_vectors[expr_idx] = blocks_column[0]; + output_data_block->column_vectors[expr_idx] = blocks_column; } } @@ -201,33 +200,32 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper } // 4. generate the result to output - this->output_ = output_groupby_table; -#endif + output_groupby_table->ShrinkBlocks(); + for (SizeT block_idx = 0; block_idx < output_groupby_table->DataBlockCount(); ++block_idx) { + SharedPtr output_data_block = output_groupby_table->GetDataBlockById(block_idx); + aggregate_operator_state->data_block_array_.push_back(MakeUnique(std::move(*output_data_block))); + } + + prev_op_state->data_block_array_.clear(); + if (prev_op_state->Complete()) { + aggregate_operator_state->SetComplete(); + } return true; } -void PhysicalAggregate::GroupByInputTable(const SharedPtr &input_table, SharedPtr &grouped_input_table) { - SizeT column_count = input_table->ColumnCount(); - +void PhysicalAggregate::GroupByInputTable(const Vector> &input_datablocks, + Vector> &output_datablocks, + HashTable &hash_table) { // 1. Get output table column types. - Vector> types; - types.reserve(column_count); - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - SharedPtr input_type = input_table->GetColumnTypeById(column_id); - SharedPtr output_type = grouped_input_table->GetColumnTypeById(column_id); - if (*input_type != *output_type) { - Status status = Status::DataTypeMismatch(input_type->ToString(), output_type->ToString()); - RecoverableError(status); - } - types.emplace_back(input_type); - } + Vector> types = input_datablocks.front()->types(); + SizeT column_count = input_datablocks.front()->column_count(); // 2. Generate data blocks and append it into output table according to the group by hash table. - const Vector> &input_datablocks = input_table->data_blocks_; - for (const auto &item : hash_table_.hash_table_) { + // const Vector> &input_datablocks = input_table->data_blocks_; + for (const auto &item : hash_table.hash_table_) { // 2.1 Each hash bucket will be insert in to one data block - SharedPtr output_datablock = DataBlock::Make(); + UniquePtr output_datablock = DataBlock::MakeUniquePtr(); SizeT datablock_size = 0; for (const auto &vec_pair : item.second) { datablock_size += vec_pair.second.size(); @@ -236,124 +234,36 @@ void PhysicalAggregate::GroupByInputTable(const SharedPtr &input_tabl output_datablock->Init(types, datablock_capacity); // Loop each block - SizeT output_row_idx = 0; + SizeT output_data_num = 0; for (const auto &vec_pair : item.second) { SizeT input_block_id = vec_pair.first; - // Loop each row of same block - for (const auto input_offset : vec_pair.second) { - - // Forloop each column - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - switch (types[column_id]->type()) { - case LogicalType::kBoolean: { - ((BooleanT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((BooleanT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kTinyInt: { - ((TinyIntT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((TinyIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kSmallInt: { - ((SmallIntT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((SmallIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kInteger: { - ((IntegerT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((IntegerT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kBigInt: { - ((BigIntT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((BigIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kHugeInt: { - String error_message = "Not implement: HugeInt data shuffle"; - UnrecoverableError(error_message); - break; - } - case LogicalType::kFloat: { - ((FloatT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((FloatT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kDouble: { - ((DoubleT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((DoubleT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kDecimal: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kVarchar: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kDate: { - ((DateT *)(output_datablock->column_vectors[column_id]->data()))[output_row_idx] = - ((DateT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kTime: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kDateTime: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kTimestamp: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kInterval: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kMixed: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - default: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - } - } + // Forloop each column + for (SizeT column_id = 0; column_id < column_count; ++column_id) { + // Loop each row of same block + for (const auto input_offset : vec_pair.second) { - ++output_row_idx; + output_datablock->column_vectors[column_id]->AppendWith(*input_datablocks[input_block_id]->column_vectors[column_id], + input_offset, + 1); + ++output_data_num; + } } } - if (output_row_idx != datablock_size) { - String error_message = fmt::format("Expected block size: {}, but only copied data size: {}", datablock_size, output_row_idx); + if (output_data_num != datablock_size * column_count) { + String error_message = + fmt::format("Expected block size: {}, but only copied data size: {}", datablock_size * column_count, output_data_num); UnrecoverableError(error_message); break; } - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - output_datablock->column_vectors[column_id]->Finalize(datablock_size); - } - output_datablock->Finalize(); - grouped_input_table->Append(output_datablock); + output_datablocks.push_back(std::move(output_datablock)); } } -void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_table, SharedPtr &output_table) { +void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_table, SharedPtr &output_table, HashTable &hash_table) { SizeT column_count = input_table->ColumnCount(); Vector> types; @@ -370,12 +280,10 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_ SharedPtr output_datablock = nullptr; const Vector> &input_datablocks = input_table->data_blocks_; -// SizeT row_count = hash_table_.hash_table_.size(); -#if 1 - for (SizeT block_row_idx = 0; const auto &item : hash_table_.hash_table_) { + for (const auto &item : hash_table.hash_table_) { // Each hash bucket will generate one data block. output_datablock = DataBlock::Make(); - output_datablock->Init(types); + output_datablock->Init(types, 1); // Only get the first row(block id and row offset of the block) of the bucket SizeT input_block_id = item.second.begin()->first; @@ -383,209 +291,12 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_ // Only the first position of the column vector has value. for (SizeT column_id = 0; column_id < column_count; ++column_id) { - switch (types[column_id]->type()) { - case LogicalType::kBoolean: { - ((BooleanT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((BooleanT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kTinyInt: { - ((TinyIntT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((TinyIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kSmallInt: { - ((SmallIntT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((SmallIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kInteger: { - ((IntegerT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((IntegerT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kBigInt: { - ((BigIntT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((BigIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kHugeInt: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kFloat: { - ((FloatT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((FloatT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kDouble: { - ((DoubleT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((DoubleT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kDecimal: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kVarchar: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kDate: { - ((DateT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((DateT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kTime: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kDateTime: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kTimestamp: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kInterval: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - case LogicalType::kMixed: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - default: { - String error_message = "Not implement: data shuffle."; - UnrecoverableError(error_message); - break; - } - } - - output_datablock->column_vectors[column_id]->Finalize(block_row_idx + 1); + output_datablock->column_vectors[column_id]->AppendWith(*input_datablocks[input_block_id]->column_vectors[column_id], input_offset, 1); } output_datablock->Finalize(); output_table->Append(output_datablock); } -#else - for (SizeT row_id = 0, block_row_idx = 0; const auto &item : hash_table_.hash_table_) { - // DEFAULT VECTOR SIZE buckets will generate one data block. - if (row_id % DEFAULT_VECTOR_SIZE == 0) { - if (output_datablock.get() != nullptr) { - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - output_datablock->column_vectors[column_id]->tail_index_ = block_row_idx; - } - - output_datablock->Finalize(); - output_table->Append(output_datablock); - block_row_idx = 0; - } - - output_datablock = DataBlock::Make(); - output_datablock->Init(types); - } - - SizeT input_block_id = item.second.begin()->first; - SizeT input_offset = item.second.begin()->second.front(); - - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - switch (types[column_id].type()) { - case LogicalType::kBoolean: { - ((BooleanT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((BooleanT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kTinyInt: { - ((TinyIntT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((TinyIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kSmallInt: { - ((SmallIntT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((SmallIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kInteger: { - ((IntegerT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((IntegerT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kBigInt: { - ((BigIntT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((BigIntT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kHugeInt: { - NotImplementError("HugeInt data shuffle isn't implemented.") - } - case LogicalType::kFloat: { - ((FloatT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((FloatT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kDouble: { - ((DoubleT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx] = - ((DoubleT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; - break; - } - case LogicalType::kDecimal: { - NotImplementError("Decimal data shuffle isn't implemented.") - } - case LogicalType::kVarchar: { - NotImplementError("Varchar data shuffle isn't implemented.") - } - case LogicalType::kDate: { - NotImplementError("Date data shuffle isn't implemented.") - } - case LogicalType::kTime: { - NotImplementError("Time data shuffle isn't implemented.") - } - case LogicalType::kDateTime: { - NotImplementError("Datetime data shuffle isn't implemented.") - } - case LogicalType::kTimestamp: { - NotImplementError("Timestamp data shuffle isn't implemented.") - } - case LogicalType::kTimestampTZ: { - NotImplementError("TimestampTZ data shuffle isn't implemented.") - } - case LogicalType::kInterval: { - NotImplementError("Interval data shuffle isn't implemented.") - } - case LogicalType::kMixed: { - NotImplementError("Heterogeneous data shuffle isn't implemented.") - } - default: { - ExecutorError("Unexpected data type") - } - } - } - - ++block_row_idx; - ++row_id; - - if (row_id == row_count) { - // All hash table data are checked - for (SizeT column_id = 0; column_id < column_count; ++column_id) { - output_datablock->column_vectors[column_id]->tail_index_ = block_row_idx; - } - - output_datablock->Finalize(); - output_table->Append(output_datablock); - break; - } - } -#endif } bool PhysicalAggregate::SimpleAggregateExecute(const Vector> &input_blocks, diff --git a/src/executor/operator/physical_aggregate.cppm b/src/executor/operator/physical_aggregate.cppm index 75643104ee..d3c2e7532f 100644 --- a/src/executor/operator/physical_aggregate.cppm +++ b/src/executor/operator/physical_aggregate.cppm @@ -63,13 +63,12 @@ public: return 0; } - void GroupByInputTable(const SharedPtr &input_table, SharedPtr &output_table); + void GroupByInputTable(const Vector> &input_blocks, Vector> &output_blocks, HashTable &hash_table); - void GenerateGroupByResult(const SharedPtr &input_table, SharedPtr &output_table); + void GenerateGroupByResult(const SharedPtr &input_table, SharedPtr &output_table, HashTable &hash_table); Vector> groups_{}; Vector> aggregates_{}; - HashTable hash_table_; bool SimpleAggregateExecute(const Vector> &input_blocks, Vector> &output_blocks, @@ -89,7 +88,6 @@ public: Vector GetHashRanges(i64 parallel_count) const; private: - SharedPtr input_table_{}; u64 groupby_index_{}; u64 aggregate_index_{}; }; diff --git a/src/executor/operator/physical_match.cpp b/src/executor/operator/physical_match.cpp index ea87c6b305..7c37d07cf6 100644 --- a/src/executor/operator/physical_match.cpp +++ b/src/executor/operator/physical_match.cpp @@ -245,7 +245,12 @@ bool PhysicalMatch::ExecuteInner(QueryContext *query_context, OperatorState *ope static_cast(finish_init_query_builder_time - execute_start_time).count())); // 2 build query iterator - FullTextQueryContext full_text_query_context(ft_similarity_, bm25_params_, minimum_should_match_option_, top_n_, match_expr_->index_names_); + FullTextQueryContext full_text_query_context(ft_similarity_, + bm25_params_, + minimum_should_match_option_, + rank_features_option_, + top_n_, + match_expr_->index_names_); full_text_query_context.query_tree_ = MakeUnique(common_query_filter_.get(), std::move(query_tree_)); const auto query_iterators = CreateQueryIterators(query_builder, full_text_query_context, early_term_algo_, begin_threshold_, score_threshold_); const auto finish_query_builder_time = std::chrono::high_resolution_clock::now(); @@ -329,6 +334,7 @@ PhysicalMatch::PhysicalMatch(const u64 id, const u32 top_n, const SharedPtr &common_query_filter, MinimumShouldMatchOption &&minimum_should_match_option, + RankFeaturesOption &&rank_features_option, const f32 score_threshold, const FulltextSimilarity ft_similarity, const BM25Params &bm25_params, @@ -339,7 +345,8 @@ PhysicalMatch::PhysicalMatch(const u64 id, base_table_ref_(std::move(base_table_ref)), match_expr_(std::move(match_expr)), index_reader_(std::move(index_reader)), query_tree_(std::move(query_tree)), begin_threshold_(begin_threshold), early_term_algo_(early_term_algo), top_n_(top_n), common_query_filter_(common_query_filter), minimum_should_match_option_(std::move(minimum_should_match_option)), - score_threshold_(score_threshold), ft_similarity_(ft_similarity), bm25_params_(bm25_params) {} + rank_features_option_(std::move(rank_features_option)), score_threshold_(score_threshold), ft_similarity_(ft_similarity), + bm25_params_(bm25_params) {} PhysicalMatch::~PhysicalMatch() = default; diff --git a/src/executor/operator/physical_match.cppm b/src/executor/operator/physical_match.cppm index 9db6a0863c..b4c038c223 100644 --- a/src/executor/operator/physical_match.cppm +++ b/src/executor/operator/physical_match.cppm @@ -54,6 +54,7 @@ public: u32 top_n, const SharedPtr &common_query_filter, MinimumShouldMatchOption &&minimum_should_match_option, + RankFeaturesOption &&rank_features_option, f32 score_threshold, FulltextSimilarity ft_similarity, const BM25Params &bm25_params, @@ -112,6 +113,8 @@ private: SharedPtr common_query_filter_; // for minimum_should_match MinimumShouldMatchOption minimum_should_match_option_{}; + // for rank features + RankFeaturesOption rank_features_option_{}; f32 score_threshold_{}; FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25}; BM25Params bm25_params_; diff --git a/src/executor/operator/physical_merge_aggregate.cpp b/src/executor/operator/physical_merge_aggregate.cpp index 715b7d835f..6e7c29ea5c 100644 --- a/src/executor/operator/physical_merge_aggregate.cpp +++ b/src/executor/operator/physical_merge_aggregate.cpp @@ -29,6 +29,9 @@ import logical_type; import physical_aggregate; import aggregate_expression; import infinity_exception; +import hash_table; +import column_def; +import column_vector; namespace infinity { @@ -40,8 +43,12 @@ void PhysicalMergeAggregate::Init() {} bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState *operator_state) { auto merge_aggregate_op_state = static_cast(operator_state); - - SimpleMergeAggregateExecute(merge_aggregate_op_state); + auto agg_op = dynamic_cast(this->left()); + if (agg_op->groups_.size() == 0) { + SimpleMergeAggregateExecute(merge_aggregate_op_state); + } else { + GroupByMergeAggregateExecute(merge_aggregate_op_state); + } if (merge_aggregate_op_state->input_complete_) { @@ -57,6 +64,89 @@ bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState return false; } +void PhysicalMergeAggregate::GroupByMergeAggregateExecute(MergeAggregateOperatorState *op_state) { + auto *agg_op = static_cast(this->left()); + SizeT group_count = agg_op->groups_.size(); + MergeHashTable &hash_table = op_state->hash_table_; + + auto &input_block = op_state->input_data_block_; + if (!hash_table.Initialized()) { + Vector> groupby_types; + groupby_types.reserve(group_count); + for (i64 idx = 0; auto &expr : agg_op->groups_) { + SharedPtr col_def = MakeShared(idx, MakeShared(expr->Type()), expr->Name(), std::set()); + groupby_types.emplace_back(MakeShared(expr->Type())); + ++idx; + } + + hash_table.Init(groupby_types); + } + Vector> input_groupby_columns(input_block->column_vectors.begin(), input_block->column_vectors.begin() + group_count); + if (op_state->data_block_array_.empty()) { + hash_table.Append(input_groupby_columns, 0, input_block->row_count()); + op_state->data_block_array_.emplace_back(std::move(input_block)); + LOG_TRACE("Physical MergeAggregate execute first block"); + return; + } + + DataBlock *last_data_block = op_state->data_block_array_.back().get(); + for (SizeT row_id = 0; row_id < input_block->row_count(); ++row_id) { + if (last_data_block->available_capacity() == 0) { + Vector> types = last_data_block->types(); + op_state->data_block_array_.emplace_back(DataBlock::MakeUniquePtr()); + last_data_block = op_state->data_block_array_.back().get(); + last_data_block->Init(std::move(types), input_block->capacity()); + } + Pair block_row_id = {op_state->data_block_array_.size() - 1, last_data_block->row_count()}; + bool found = hash_table.GetOrInsert(input_groupby_columns, row_id, block_row_id); + if (!found) { + last_data_block->AppendWith(input_block.get(), row_id, 1); + continue; + } + SizeT agg_count = agg_op->aggregates_.size(); + Pair input_block_row_id = {0, row_id}; + for (SizeT col_idx = group_count; col_idx < group_count + agg_count; ++col_idx) { + auto *agg_expression = static_cast(agg_op->aggregates_[col_idx - group_count].get()); + + auto function_name = agg_expression->aggregate_function_.GetFuncName(); + + auto func_return_type = agg_expression->aggregate_function_.return_type_; + + switch (func_return_type.type()) { + LOG_TRACE("Physical MergeAggregate execute remain block"); + case LogicalType::kTinyInt: { + HandleAggregateFunction(function_name, op_state, col_idx, input_block_row_id, block_row_id); + break; + } + case LogicalType::kSmallInt: { + HandleAggregateFunction(function_name, op_state, col_idx, input_block_row_id, block_row_id); + break; + } + case LogicalType::kInteger: { + HandleAggregateFunction(function_name, op_state, col_idx, input_block_row_id, block_row_id); + break; + } + case LogicalType::kBigInt: { + HandleAggregateFunction(function_name, op_state, col_idx, input_block_row_id, block_row_id); + break; + } + case LogicalType::kFloat: { + HandleAggregateFunction(function_name, op_state, col_idx, input_block_row_id, block_row_id); + break; + } + case LogicalType::kDouble: { + HandleAggregateFunction(function_name, op_state, col_idx, input_block_row_id, block_row_id); + break; + } + default: { + String error_message = "Input value type not Implement"; + UnrecoverableError(error_message); + } + } + } + } +} + void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorState *op_state) { if (op_state->data_block_array_.empty()) { op_state->data_block_array_.emplace_back(std::move(op_state->input_data_block_)); @@ -107,17 +197,21 @@ void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorS } template -void PhysicalMergeAggregate::HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx) { +void PhysicalMergeAggregate::HandleAggregateFunction(const String &function_name, + MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id) { LOG_TRACE(function_name); if (function_name == "COUNT") { LOG_TRACE("COUNT"); - HandleCount(op_state, col_idx); + HandleCount(op_state, col_idx, input_block_row_id, output_block_row_id); } else if (function_name == "MIN") { - HandleMin(op_state, col_idx); + HandleMin(op_state, col_idx, input_block_row_id, output_block_row_id); } else if (function_name == "MAX") { - HandleMax(op_state, col_idx); + HandleMax(op_state, col_idx, input_block_row_id, output_block_row_id); } else if (function_name == "SUM") { - HandleSum(op_state, col_idx); + HandleSum(op_state, col_idx, input_block_row_id, output_block_row_id); } else if (function_name == "COUNT_STAR") { // no action for "COUNT_STAR" } else { @@ -127,27 +221,39 @@ void PhysicalMergeAggregate::HandleAggregateFunction(const String &function_name } template -void PhysicalMergeAggregate::HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx) { +void PhysicalMergeAggregate::HandleMin(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id) { MathOperation minOperation = [](T a, T b) -> T { return (a < b) ? a : b; }; - UpdateData(op_state, minOperation, col_idx); + UpdateData(op_state, minOperation, col_idx, input_block_row_id, output_block_row_id); } template -void PhysicalMergeAggregate::HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx) { +void PhysicalMergeAggregate::HandleMax(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id) { MathOperation maxOperation = [](T a, T b) -> T { return (a > b) ? a : b; }; - UpdateData(op_state, maxOperation, col_idx); + UpdateData(op_state, maxOperation, col_idx, input_block_row_id, output_block_row_id); } template -void PhysicalMergeAggregate::HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx) { +void PhysicalMergeAggregate::HandleCount(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id) { MathOperation countOperation = [](T a, T b) -> T { return a + b; }; - UpdateData(op_state, countOperation, col_idx); + UpdateData(op_state, countOperation, col_idx, input_block_row_id, output_block_row_id); } template -void PhysicalMergeAggregate::HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx) { +void PhysicalMergeAggregate::HandleSum(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id) { MathOperation sumOperation = [](T a, T b) -> T { return a + b; }; - UpdateData(op_state, sumOperation, col_idx); + UpdateData(op_state, sumOperation, col_idx, input_block_row_id, output_block_row_id); } template @@ -168,11 +274,17 @@ void PhysicalMergeAggregate::WriteValueAtPosition(MergeAggregateOperatorState *o } template -void PhysicalMergeAggregate::UpdateData(MergeAggregateOperatorState *op_state, MathOperation operation, SizeT col_idx) { - T input = GetInputData(op_state, 0, col_idx, 0); - T output = GetOutputData(op_state, 0, col_idx, 0); +void PhysicalMergeAggregate::UpdateData(MergeAggregateOperatorState *op_state, + MathOperation operation, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id) { + const auto &[input_block_id, input_row_id] = input_block_row_id; + const auto &[output_block_id, output_row_id] = output_block_row_id; + T input = GetInputData(op_state, input_block_id, col_idx, input_row_id); + T output = GetOutputData(op_state, output_block_id, col_idx, output_row_id); T new_value = operation(input, output); - WriteValueAtPosition(op_state, 0, col_idx, 0, new_value); + WriteValueAtPosition(op_state, output_block_id, col_idx, output_row_id, new_value); } } // namespace infinity diff --git a/src/executor/operator/physical_merge_aggregate.cppm b/src/executor/operator/physical_merge_aggregate.cppm index d7378bf9a1..2821764108 100644 --- a/src/executor/operator/physical_merge_aggregate.cppm +++ b/src/executor/operator/physical_merge_aggregate.cppm @@ -72,26 +72,48 @@ public: void SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state); + void GroupByMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state); + template - void UpdateData(MergeAggregateOperatorState *op_state, MathOperation operation, SizeT col_idx); + void UpdateData(MergeAggregateOperatorState *op_state, + MathOperation operation, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id); template void WriteValueAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, T value); template - void HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx); + void HandleSum(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id); template - void HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx); + void HandleCount(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id); template - void HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx); + void HandleMin(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id); template - void HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx); + void HandleMax(MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id, + const Pair &output_block_row_id); template - void HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx); + void HandleAggregateFunction(const String &function_name, + MergeAggregateOperatorState *op_state, + SizeT col_idx, + const Pair &input_block_row_id = {0, 0}, + const Pair &output_block_row_id = {0, 0}); template Value CreateValue(T value) { diff --git a/src/executor/operator_state.cppm b/src/executor/operator_state.cppm index 417052219b..d01311db75 100644 --- a/src/executor/operator_state.cppm +++ b/src/executor/operator_state.cppm @@ -37,6 +37,7 @@ import internal_types; import column_def; import data_type; import segment_entry; +import hash_table; namespace infinity { @@ -74,6 +75,7 @@ export struct AggregateOperatorState : public OperatorState { : OperatorState(PhysicalOperatorType::kAggregate), states_(std::move(states)) {} Vector> states_; + HashTable hash_table_; }; // Merge Aggregate @@ -83,6 +85,7 @@ export struct MergeAggregateOperatorState : public OperatorState { /// Since merge agg is the first op, no previous operator state. This ptr is to get input data. // Vector> input_data_blocks_{nullptr}; UniquePtr input_data_block_{nullptr}; + MergeHashTable hash_table_; bool input_complete_{false}; }; diff --git a/src/executor/physical_planner.cpp b/src/executor/physical_planner.cpp index 95098f9025..6f9289f661 100644 --- a/src/executor/physical_planner.cpp +++ b/src/executor/physical_planner.cpp @@ -980,6 +980,7 @@ UniquePtr PhysicalPlanner::BuildMatch(const SharedPtrtop_n_, logical_match->common_query_filter_, std::move(logical_match->minimum_should_match_option_), + std::move(logical_match->rank_features_option_), logical_match->score_threshold_, logical_match->ft_similarity_, logical_match->bm25_params_, diff --git a/src/planner/bound_select_statement.cpp b/src/planner/bound_select_statement.cpp index 916cb09e7c..44d0923f3e 100644 --- a/src/planner/bound_select_statement.cpp +++ b/src/planner/bound_select_statement.cpp @@ -275,6 +275,11 @@ SharedPtr BoundSelectStatement::BuildPlan(QueryContext *query_conte match_node->minimum_should_match_option_ = ParseMinimumShouldMatchOption(iter->second); } + // option: rank_features + if (iter = search_ops.options_.find("rank_features"); iter != search_ops.options_.end()) { + match_node->rank_features_option_ = ParseRankFeaturesOption(iter->second); + } + // option: threshold if (iter = search_ops.options_.find("threshold"); iter != search_ops.options_.end()) { match_node->score_threshold_ = DataType::StringToValue(iter->second); diff --git a/src/planner/node/logical_match.cppm b/src/planner/node/logical_match.cppm index 3480b1c692..5f4e974ba9 100644 --- a/src/planner/node/logical_match.cppm +++ b/src/planner/node/logical_match.cppm @@ -66,6 +66,7 @@ public: SharedPtr common_query_filter_{}; MinimumShouldMatchOption minimum_should_match_option_{}; + RankFeaturesOption rank_features_option_{}; f32 score_threshold_{}; FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25}; BM25Params bm25_params_; diff --git a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm index 72cdcee1b1..ebdbf1e72e 100644 --- a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm +++ b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm @@ -91,6 +91,7 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator { UniquePtr query_tree_; MinimumShouldMatchOption minimum_should_match_option_; u32 minimum_should_match_ = 0; + RankFeaturesOption rank_features_option_; std::atomic_flag after_optimize_ = {}; f32 score_threshold_ = {}; FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25; diff --git a/src/storage/data_block.cpp b/src/storage/data_block.cpp index ac2d76080e..863f3fe2e1 100644 --- a/src/storage/data_block.cpp +++ b/src/storage/data_block.cpp @@ -339,6 +339,7 @@ void DataBlock::AppendWith(const DataBlock *other) { for (SizeT idx = 0; idx < column_count; ++idx) { this->column_vectors[idx]->AppendWith(*other->column_vectors[idx]); } + row_count_ += other->row_count_; } void DataBlock::AppendWith(const DataBlock *other, SizeT from, SizeT count) { @@ -358,6 +359,7 @@ void DataBlock::AppendWith(const DataBlock *other, SizeT from, SizeT count) { for (SizeT idx = 0; idx < column_count; ++idx) { this->column_vectors[idx]->AppendWith(*other->column_vectors[idx], from, count); } + row_count_ += count; } void DataBlock::InsertVector(const SharedPtr &vector, SizeT index) { diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 4067953be0..13f9e4c895 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -156,4 +156,35 @@ const String &DataTable::GetColumnNameById(SizeT idx) const { return definition_ SharedPtr DataTable::GetColumnTypeById(SizeT idx) const { return definition_ptr_->columns()[idx]->type(); } +void DataTable::ShrinkBlocks(SizeT block_capacity) { + if (data_blocks_.empty()) { + return; + } + auto types = data_blocks_[0]->types(); + Vector> data_blocks = std::move(data_blocks_); + + data_blocks_.emplace_back(DataBlock::MakeUniquePtr()); + auto *data_block = data_blocks_.back().get(); + data_block->Init(types, block_capacity); + for (SizeT block_i = 0; block_i < data_blocks.size(); ++block_i) { + SizeT block_offset = 0; + auto *input_block = data_blocks[block_i].get(); + while (block_offset < input_block->row_count()) { + SizeT append_count = std::min(data_block->available_capacity(), input_block->row_count() - block_offset); + if (append_count) { + data_block->AppendWith(input_block, block_offset, append_count); + } + block_offset += append_count; + if (data_block->available_capacity() == 0) { + data_blocks_.emplace_back(DataBlock::MakeUniquePtr()); + data_block = data_blocks_.back().get(); + data_block->Init(types, block_capacity); + } + } + } + for (auto &data_block : data_blocks_) { + data_block->Finalize(); + } +} + } // namespace infinity diff --git a/src/storage/data_table.cppm b/src/storage/data_table.cppm index 352660a8c5..98c03368b3 100644 --- a/src/storage/data_table.cppm +++ b/src/storage/data_table.cppm @@ -26,6 +26,7 @@ import internal_types; import third_party; import column_def; import logger; +import default_values; namespace infinity { @@ -97,6 +98,8 @@ public: // Currently this method is used in aggregate operator. void UnionWith(const SharedPtr &other); + void ShrinkBlocks(SizeT block_capacity = DEFAULT_VECTOR_SIZE); + public: SharedPtr definition_ptr_{}; SizeT row_count_{0}; diff --git a/src/storage/invertedindex/search/doc_iterator.cppm b/src/storage/invertedindex/search/doc_iterator.cppm index 017fea1a0a..bb9d3b5093 100644 --- a/src/storage/invertedindex/search/doc_iterator.cppm +++ b/src/storage/invertedindex/search/doc_iterator.cppm @@ -44,6 +44,8 @@ export enum class DocIteratorType : u8 { kScoreThresholdIterator, kKeywordIterator, kMustFirstIterator, + kRankFeatureDocIterator, + kRankFeaturesDocIterator, }; export struct DocIteratorEstimateIterateCost { diff --git a/src/storage/invertedindex/search/parse_fulltext_options.cpp b/src/storage/invertedindex/search/parse_fulltext_options.cpp index 424bca5b98..e6fb128060 100644 --- a/src/storage/invertedindex/search/parse_fulltext_options.cpp +++ b/src/storage/invertedindex/search/parse_fulltext_options.cpp @@ -14,7 +14,9 @@ module; +#include #include + module parse_fulltext_options; import stl; @@ -127,4 +129,60 @@ u32 GetMinimumShouldMatchParameter(const MinimumShouldMatchOption &option_vec, c match_option); } +void Split(const std::string_view &input, const String &split_pattern, Vector &result, bool keep_delim = false) { + re2::RE2 pattern(split_pattern); + re2::StringPiece leftover(input.data()); + re2::StringPiece last_end = leftover; + re2::StringPiece extracted_delim_token; + + while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) { + std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data()); + if (!token.empty()) { + result.push_back(String(token.data(), token.size())); + } + if (keep_delim) + result.push_back(String(extracted_delim_token.data(), extracted_delim_token.size())); + last_end = leftover; + } + + if (!leftover.empty()) { + result.push_back(String(leftover.data(), leftover.size())); + } +} + +void ParseRankFeatureOption(std::string_view input_str, RankFeatureOption &feature_option) { + Vector feature_strs; + Split(input_str, "^", feature_strs); + if (feature_strs.size() == 2) { + feature_option.field_ = feature_strs[0]; + feature_option.feature_ = feature_strs[1]; + feature_option.boost_ = 1.0f; + } else if (feature_strs.size() == 3) { + feature_option.field_ = feature_strs[0]; + feature_option.feature_ = feature_strs[1]; + const auto boost_str = feature_strs[2]; + try { + feature_option.boost_ = std::stof(boost_str); + } catch (const std::exception &e) { + RecoverableError( + Status::SyntaxError(std::format("Invalid rank_features parameter format: Failed to parse float value in option '{}'.", input_str))); + } + } else { + RecoverableError( + Status::SyntaxError(std::format("Invalid rank_features parameter format: Expect 3 parts separated by '^', but get: '{}'.", input_str))); + } +} + +RankFeaturesOption ParseRankFeaturesOption(std::string_view input_str) { + RankFeaturesOption result; + Vector feature_strs; + Split(input_str, ",", feature_strs); + for (auto &feature_str : feature_strs) { + RankFeatureOption feature_option; + ParseRankFeatureOption(feature_str, feature_option); + result.push_back(feature_option); + } + return result; +} + } // namespace infinity diff --git a/src/storage/invertedindex/search/parse_fulltext_options.cppm b/src/storage/invertedindex/search/parse_fulltext_options.cppm index 5cd4aebc88..5c39cface0 100644 --- a/src/storage/invertedindex/search/parse_fulltext_options.cppm +++ b/src/storage/invertedindex/search/parse_fulltext_options.cppm @@ -49,4 +49,14 @@ export struct BM25Params { float delta_phrase = 0.0F; }; +export struct RankFeatureOption { + String field_; + String feature_; + float boost_; +}; + +export using RankFeaturesOption = Vector; + +export RankFeaturesOption ParseRankFeaturesOption(std::string_view input_str); + } // namespace infinity diff --git a/src/storage/invertedindex/search/query_builder.cpp b/src/storage/invertedindex/search/query_builder.cpp index 30676904d7..3b50bff8a0 100644 --- a/src/storage/invertedindex/search/query_builder.cpp +++ b/src/storage/invertedindex/search/query_builder.cpp @@ -16,6 +16,9 @@ module; #include #include + +#include "query_node.h" + module query_builder; import stl; @@ -72,6 +75,17 @@ UniquePtr QueryBuilder::CreateSearch(FullTextQueryContext &context) LOG_DEBUG(std::move(oss).str()); } #endif + if (!context.rank_features_option_.empty()) { + auto rank_features_node = std::make_unique(); + for (auto rank_feature : context.rank_features_option_) { + auto rank_feature_node = std::make_unique(); + rank_feature_node->term_ = rank_feature.feature_; + rank_feature_node->column_ = rank_feature.field_; + rank_feature_node->boost_ = rank_feature.boost_; + rank_features_node->Add(std::move(rank_feature_node)); + } + // auto rank_features_iter = rank_features_node->CreateSearch(params); + } return result; } diff --git a/src/storage/invertedindex/search/query_builder.cppm b/src/storage/invertedindex/search/query_builder.cppm index 24560b8da0..b43b2bfbc1 100644 --- a/src/storage/invertedindex/search/query_builder.cppm +++ b/src/storage/invertedindex/search/query_builder.cppm @@ -36,6 +36,7 @@ export struct FullTextQueryContext { const FulltextSimilarity ft_similarity_{}; const BM25Params bm25_params_{}; const MinimumShouldMatchOption minimum_should_match_option_{}; + const RankFeaturesOption rank_features_option_{}; u32 minimum_should_match_ = 0; u32 topn_ = 0; EarlyTermAlgo early_term_algo_ = EarlyTermAlgo::kNaive; @@ -44,10 +45,11 @@ export struct FullTextQueryContext { FullTextQueryContext(const FulltextSimilarity ft_similarity, const BM25Params &bm25_params, const MinimumShouldMatchOption &minimum_should_match_option, + const RankFeaturesOption &rank_features_option, const u32 topn, const Vector &index_names) - : ft_similarity_(ft_similarity), bm25_params_(bm25_params), minimum_should_match_option_(minimum_should_match_option), topn_(topn), - index_names_(index_names) {} + : ft_similarity_(ft_similarity), bm25_params_(bm25_params), minimum_should_match_option_(minimum_should_match_option), + rank_features_option_(rank_features_option), topn_(topn), index_names_(index_names) {} }; export class QueryBuilder { diff --git a/src/storage/invertedindex/search/query_node.cpp b/src/storage/invertedindex/search/query_node.cpp index ace49c00c2..ee6fdced94 100644 --- a/src/storage/invertedindex/search/query_node.cpp +++ b/src/storage/invertedindex/search/query_node.cpp @@ -27,6 +27,8 @@ import keyword_iterator; import must_first_iterator; import batch_or_iterator; import blockmax_leaf_iterator; +import rank_feature_doc_iterator; +import rank_features_doc_iterator; namespace infinity { @@ -390,6 +392,11 @@ std::unique_ptr OrQueryNode::InnerGetNewOptimizedQueryTree() { } } +std::unique_ptr RankFeaturesQueryNode::InnerGetNewOptimizedQueryTree() { + UnrecoverableError("OptimizeInPlaceInner: Unexpected case! RankFeaturesQueryNode should not exist in parser output"); + return nullptr; +} + // 4. deal with "and_not": // "and_not" does not exist in parser output, it is generated during optimization @@ -429,6 +436,25 @@ std::unique_ptr TermQueryNode::CreateSearch(const CreateSearchParam return search; } +std::unique_ptr RankFeatureQueryNode::CreateSearch(const CreateSearchParams params, bool) const { + ColumnID column_id = params.table_entry->GetColumnIdByName(column_); + ColumnIndexReader *column_index_reader = params.index_reader->GetColumnIndexReader(column_id, params.index_names_); + if (!column_index_reader) { + RecoverableError(Status::SyntaxError(fmt::format(R"(Invalid query statement: Column "{}" has no fulltext index)", column_))); + return nullptr; + } + + bool fetch_position = false; + auto posting_iterator = column_index_reader->Lookup(term_, fetch_position); + if (!posting_iterator) { + return nullptr; + } + auto search = MakeUnique(std::move(posting_iterator), column_id, boost_); + search->term_ptr_ = &term_; + search->column_name_ptr_ = &column_; + return search; +} + std::unique_ptr PhraseQueryNode::CreateSearch(const CreateSearchParams params, bool) const { ColumnID column_id = params.table_entry->GetColumnIdByName(column_); ColumnIndexReader *column_index_reader = params.index_reader->GetColumnIndexReader(column_id, params.index_names_); @@ -685,6 +711,25 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams } } +std::unique_ptr RankFeaturesQueryNode::CreateSearch(const CreateSearchParams params, const bool is_top_level) const { + Vector> sub_doc_iters; + sub_doc_iters.reserve(children_.size()); + const auto next_params = params.RemoveMSM(); + for (auto &child : children_) { + auto iter = child->CreateSearch(next_params, false); + if (!iter) { + // no need to continue if any child is invalid + return nullptr; + } + sub_doc_iters.emplace_back(std::move(iter)); + } + if (sub_doc_iters.empty()) { + return nullptr; + } else { + return MakeUnique(std::move(sub_doc_iters)); + } +} + std::unique_ptr KeywordQueryNode::CreateSearch(const CreateSearchParams params, bool) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); @@ -753,6 +798,21 @@ void TermQueryNode::GetQueryColumnsTerms(std::vector &columns, std: terms.push_back(term_); } +void RankFeatureQueryNode::PrintTree(std::ostream &os, const std::string &prefix, const bool is_final) const { + os << prefix; + os << (is_final ? "└──" : "├──"); + os << QueryNodeTypeToString(type_); + os << " (weight: " << weight_ << ")"; + os << " (column: " << column_ << ")"; + os << " (term: " << term_ << ")"; + os << '\n'; +} + +void RankFeatureQueryNode::GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const { + columns.push_back(column_); + terms.push_back(term_); +} + void PhraseQueryNode::PrintTree(std::ostream &os, const std::string &prefix, const bool is_final) const { os << prefix; os << (is_final ? "└──" : "├──"); diff --git a/src/storage/invertedindex/search/query_node.h b/src/storage/invertedindex/search/query_node.h index 212df86128..6ab1af394c 100644 --- a/src/storage/invertedindex/search/query_node.h +++ b/src/storage/invertedindex/search/query_node.h @@ -125,6 +125,20 @@ struct TermQueryNode : public QueryNode { void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; }; +struct RankFeatureQueryNode : public QueryNode { + std::string term_; + std::string column_; + float boost_; + + RankFeatureQueryNode() : QueryNode(QueryNodeType::TERM) {} + + uint32_t LeafCount() const override { return 1; } + void PushDownWeight(float factor) override { MultiplyWeight(factor); } + std::unique_ptr CreateSearch(CreateSearchParams params, bool is_top_level) const override; + void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const override; + void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; +}; + struct PhraseQueryNode final : public QueryNode { std::vector terms_; std::string column_; @@ -191,6 +205,12 @@ struct OrQueryNode final : public MultiQueryNode { std::unique_ptr CreateSearch(CreateSearchParams params, bool is_top_level) const override; }; +struct RankFeaturesQueryNode final : public MultiQueryNode { + RankFeaturesQueryNode() : MultiQueryNode(QueryNodeType::OR) {} + std::unique_ptr InnerGetNewOptimizedQueryTree() override; + std::unique_ptr CreateSearch(CreateSearchParams params, bool is_top_level) const override; +}; + struct KeywordQueryNode final : public MultiQueryNode { KeywordQueryNode() : MultiQueryNode(QueryNodeType::KEYWORD) {} void PushDownWeight(float factor) override { MultiplyWeight(factor); } diff --git a/src/storage/invertedindex/search/rank_feature_doc_iterator.cpp b/src/storage/invertedindex/search/rank_feature_doc_iterator.cpp new file mode 100644 index 0000000000..9bdc4d2d4b --- /dev/null +++ b/src/storage/invertedindex/search/rank_feature_doc_iterator.cpp @@ -0,0 +1,68 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +#include +module rank_feature_doc_iterator; + +import stl; +import logger; + +namespace infinity { + +RankFeatureDocIterator::RankFeatureDocIterator(UniquePtr &&iter, const u64 column_id, float boost) + : column_id_(column_id), boost_(boost), iter_(std::move(iter)) {} + +RankFeatureDocIterator::~RankFeatureDocIterator() {} + +bool RankFeatureDocIterator::Next(RowID doc_id) { + assert(doc_id != INVALID_ROWID); + if (doc_id_ != INVALID_ROWID && doc_id_ >= doc_id) + return true; + doc_id_ = iter_->SeekDoc(doc_id); + if (doc_id_ == INVALID_ROWID) + return false; + return true; +} + +float RankFeatureDocIterator::Score() { + u16 payload = iter_->GetCurrentDocPayload(); + float weight = static_cast(payload); + return weight * boost_; +} + +void RankFeatureDocIterator::PrintTree(std::ostream &os, const String &prefix, bool is_final) const { + os << prefix; + os << (is_final ? "└──" : "├──"); + os << "RankFeatureDocIterator"; + os << " (column: " << *column_name_ptr_ << ")"; + os << " (term: " << *term_ptr_ << ")"; + os << '\n'; +} + +void RankFeatureDocIterator::BatchDecodeTo(const RowID buffer_start_doc_id, const RowID buffer_end_doc_id, u16 *payload_ptr) { + auto iter_doc_id = iter_->DocID(); + assert((buffer_start_doc_id <= iter_doc_id && iter_doc_id < buffer_end_doc_id)); + while (iter_doc_id < buffer_end_doc_id) { + const auto pos = iter_doc_id - buffer_start_doc_id; + const auto payload = iter_->GetCurrentDocPayload(); + payload_ptr[pos] = payload; + iter_doc_id = iter_->SeekDoc(iter_doc_id + 1); + } + doc_id_ = iter_doc_id; +} + +} // namespace infinity diff --git a/src/storage/invertedindex/search/rank_feature_doc_iterator.cppm b/src/storage/invertedindex/search/rank_feature_doc_iterator.cppm new file mode 100644 index 0000000000..4305fbecad --- /dev/null +++ b/src/storage/invertedindex/search/rank_feature_doc_iterator.cppm @@ -0,0 +1,61 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module rank_feature_doc_iterator; + +import stl; + +import posting_iterator; +import index_defines; +import doc_iterator; +import internal_types; +import doc_iterator; +import third_party; + +namespace infinity { + +export class RankFeatureDocIterator final : public DocIterator { +public: + RankFeatureDocIterator(UniquePtr &&iter, u64 column_id, float boost); + + ~RankFeatureDocIterator() override; + + DocIteratorType GetType() const override { return DocIteratorType::kRankFeatureDocIterator; } + + String Name() const override { return "RankFeatureDocIterator"; } + + void UpdateScoreThreshold(float threshold) override {} + + u32 MatchCount() const override { return 0; } + + bool Next(RowID doc_id) override; + + float Score() override; + + void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override; + + void BatchDecodeTo(RowID buffer_start_doc_id, RowID buffer_end_doc_id, u16 *payload_ptr); + + const String *term_ptr_ = nullptr; + const String *column_name_ptr_ = nullptr; + +private: + u64 column_id_; + float boost_ = 1.0f; + UniquePtr iter_; +}; + +} // namespace infinity diff --git a/src/storage/invertedindex/search/rank_features_doc_iterator.cpp b/src/storage/invertedindex/search/rank_features_doc_iterator.cpp new file mode 100644 index 0000000000..d7c365c9a5 --- /dev/null +++ b/src/storage/invertedindex/search/rank_features_doc_iterator.cpp @@ -0,0 +1,127 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +#include +#include +#include + +module rank_features_doc_iterator; + +import stl; +import third_party; +import index_defines; +import rank_feature_doc_iterator; +import multi_doc_iterator; +import internal_types; +import logger; +import infinity_exception; +import simd_functions; + +namespace infinity { + +constexpr u32 BATCH_OR_LEN = 128; + +RankFeaturesDocIterator::RankFeaturesDocIterator(Vector> &&iterators) : MultiDocIterator(std::move(iterators)) { + estimate_iterate_cost_ = {}; + const SizeT num_iterators = children_.size(); + for (SizeT i = 0; i < num_iterators; i++) { + auto it = dynamic_cast(children_[i].get()); + if (it == nullptr) { + UnrecoverableError("RankFeaturesDocIterator only supports RankFeatureDocIterator"); + } + } + static_assert(sizeof(f32) == sizeof(u32)); + memset_bytes_ = sizeof(u32) * BATCH_OR_LEN * 2u * num_iterators; + const auto alloc_bytes = sizeof(u32) * BATCH_OR_LEN * (2u + 2u * num_iterators); + aligned_buffer_ = std::aligned_alloc(64, alloc_bytes); + if (!aligned_buffer_) { + UnrecoverableError(fmt::format("{}: Out of memory!", __func__)); + } + match_cnt_ptr_ = static_cast(aligned_buffer_); + payload_ptr_ = reinterpret_cast(match_cnt_ptr_ + BATCH_OR_LEN * num_iterators); +} + +RankFeaturesDocIterator::~RankFeaturesDocIterator() { std::free(aligned_buffer_); } + +bool RankFeaturesDocIterator::Next(RowID doc_id) { + if (doc_id_ != INVALID_ROWID) [[likely]] { + if (doc_id_ >= doc_id) { + return true; + } + // now buffer_start_doc_id_ <= doc_id_ < doc_id + if (u32 pos = doc_id - buffer_start_doc_id_; pos < BATCH_OR_LEN) { + for (; pos < BATCH_OR_LEN; ++pos) { + if (match_cnt_ptr_[pos]) { + doc_id_ = buffer_start_doc_id_ + pos; + return true; + } + } + // now need to search from buffer_start_doc_id_ + BATCH_OR_LEN + doc_id = buffer_start_doc_id_ + BATCH_OR_LEN; + } + } else { + for (const auto &child : children_) { + child->Next(doc_id); + } + } + RowID next_buffer_start_doc_id = INVALID_ROWID; + for (const auto &child : children_) { + if (child->DocID() != INVALID_ROWID) { + child->Next(doc_id); + const RowID child_doc_id = child->DocID(); + next_buffer_start_doc_id = std::min(next_buffer_start_doc_id, child_doc_id); + } + } + if (next_buffer_start_doc_id != INVALID_ROWID) [[likely]] { + DecodeFrom(next_buffer_start_doc_id); + } + doc_id_ = next_buffer_start_doc_id; + return doc_id_ != INVALID_ROWID; +} + +u32 RankFeaturesDocIterator::MatchCount() const { return match_cnt_ptr_[doc_id_ - buffer_start_doc_id_]; } + +void RankFeaturesDocIterator::DecodeFrom(const RowID buffer_start_doc_id) { + buffer_start_doc_id_ = buffer_start_doc_id; + std::memset(aligned_buffer_, 0, memset_bytes_); + const auto buffer_end_doc_id = buffer_start_doc_id + BATCH_OR_LEN; + for (u32 i = 0; i < children_.size(); ++i) { + const auto child = children_[i].get(); + if (const auto child_doc_id = child->DocID(); child_doc_id != INVALID_ROWID) { + assert(child_doc_id >= buffer_start_doc_id); + if (child_doc_id >= buffer_end_doc_id) { + // no need to decode + continue; + } + const auto it = dynamic_cast(child); + it->BatchDecodeTo(buffer_start_doc_id, buffer_end_doc_id, payload_ptr_ + i * BATCH_OR_LEN); + } + } + for (u32 i = 0; i < BATCH_OR_LEN; ++i) { + u32 match_cnt = 0; + for (u32 j = 0; j < children_.size(); ++j) { + const auto payload = payload_ptr_[j * BATCH_OR_LEN + i]; + if (payload == 0) { + continue; + } + ++match_cnt; + } + match_cnt_ptr_[i] = match_cnt; + } +} + +} // namespace infinity diff --git a/src/storage/invertedindex/search/rank_features_doc_iterator.cppm b/src/storage/invertedindex/search/rank_features_doc_iterator.cppm new file mode 100644 index 0000000000..aea46502fa --- /dev/null +++ b/src/storage/invertedindex/search/rank_features_doc_iterator.cppm @@ -0,0 +1,55 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module rank_features_doc_iterator; + +import stl; +import index_defines; +import doc_iterator; +import multi_doc_iterator; +import internal_types; + +namespace infinity { + +export class RankFeaturesDocIterator : public MultiDocIterator { +public: + explicit RankFeaturesDocIterator(Vector> &&iterators); + + ~RankFeaturesDocIterator() override; + + DocIteratorType GetType() const override { return DocIteratorType::kRankFeaturesDocIterator; } + + String Name() const override { return "RankFeaturesDocIterator"; } + + void UpdateScoreThreshold(float threshold) override {} + + bool Next(RowID doc_id) override; + + float Score() override { return 1.0f; } + + u32 MatchCount() const override; + + void DecodeFrom(RowID buffer_start_doc_id); + +private: + RowID buffer_start_doc_id_ = INVALID_ROWID; + u32 memset_bytes_ = 0; + void *aligned_buffer_ = nullptr; + u32 *match_cnt_ptr_ = nullptr; + u16 *payload_ptr_ = nullptr; +}; + +} // namespace infinity diff --git a/src/unit_test/storage/invertedindex/search/query_builder.cpp b/src/unit_test/storage/invertedindex/search/query_builder.cpp index fff878a20c..91cb1b3db7 100644 --- a/src/unit_test/storage/invertedindex/search/query_builder.cpp +++ b/src/unit_test/storage/invertedindex/search/query_builder.cpp @@ -201,7 +201,7 @@ TEST_F(QueryBuilderTest, test_and) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_root); FakeQueryBuilder fake_query_builder; @@ -273,7 +273,7 @@ TEST_F(QueryBuilderTest, test_or) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(or_root); FakeQueryBuilder fake_query_builder; @@ -351,7 +351,7 @@ TEST_F(QueryBuilderTest, test_and_not) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; @@ -435,7 +435,7 @@ TEST_F(QueryBuilderTest, test_and_not2) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; diff --git a/src/unit_test/storage/invertedindex/search/query_match.cpp b/src/unit_test/storage/invertedindex/search/query_match.cpp index e27fe1a1dd..e65d76ebca 100644 --- a/src/unit_test/storage/invertedindex/search/query_match.cpp +++ b/src/unit_test/storage/invertedindex/search/query_match.cpp @@ -339,7 +339,12 @@ void QueryMatchTest::QueryMatch(const String &db_name, Status status = Status::ParseMatchExprFailed(match_expr->fields_, match_expr->matching_text_); RecoverableError(status); } - FullTextQueryContext full_text_query_context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, index_hints); + FullTextQueryContext full_text_query_context(FulltextSimilarity::kBM25, + BM25Params{}, + MinimumShouldMatchOption{}, + RankFeaturesOption{}, + 10, + index_hints); full_text_query_context.early_term_algo_ = EarlyTermAlgo::kNaive; full_text_query_context.query_tree_ = std::move(query_tree); UniquePtr doc_iterator = query_builder.CreateSearch(full_text_query_context); diff --git a/test/sql/dql/aggregate/test_groupby_aggtype.slt b/test/sql/dql/aggregate/test_groupby_aggtype.slt new file mode 100644 index 0000000000..1d8a540848 --- /dev/null +++ b/test/sql/dql/aggregate/test_groupby_aggtype.slt @@ -0,0 +1,52 @@ +statement ok +DROP TABLE IF EXISTS simple_groupby; + +statement ok +CREATE TABLE simple_groupby (c1 INTEGER, c2 FLOAT); + +statement ok +INSERT INTO simple_groupby VALUES +(1,1.0), +(2,2.0), +(1,3.0), +(2,4.0), +(1,5.0); + +query IR rowsort +SELECT c1, SUM(c2) FROM simple_groupby GROUP BY c1; +---- +1 9.000000 +2 6.000000 + +query IR rowsort +SELECT c1, AVG(c2) FROM simple_groupby GROUP BY c1; +---- +1 3.000000 +2 3.000000 + +query IR rowsort +SELECT c1, MIN(c2) FROM simple_groupby GROUP BY c1; +---- +1 1.000000 +2 2.000000 + +query IR rowsort +SELECT c1, MAX(c2) FROM simple_groupby GROUP BY c1; +---- +1 5.000000 +2 4.000000 + +query II rowsort +SELECT c1, COUNT(c2) FROM simple_groupby GROUP BY c1; +---- +1 3 +2 2 + +query IF rowsort +SELECT c1, AVG(c2) FROM simple_groupby GROUP BY c1; +---- +1 3.000000 +2 3.000000 + +statement ok +DROP TABLE simple_groupby; diff --git a/test/sql/dql/aggregate/test_groupby_complex.slt b/test/sql/dql/aggregate/test_groupby_complex.slt new file mode 100644 index 0000000000..1c18613090 --- /dev/null +++ b/test/sql/dql/aggregate/test_groupby_complex.slt @@ -0,0 +1,86 @@ +statement ok +DROP TABLE IF EXISTS simple_groupby; + +statement ok +CREATE TABLE simple_groupby (c1 INTEGER, c2 INTEGER, c3 FLOAT); + +statement ok +INSERT INTO simple_groupby VALUES +(1,1,1.0), +(2,2,2.0), +(1,3,3.0), +(2,1,4.0), +(1,2,5.0), +(2,3,6.0), +(1,1,7.0), +(2,2,8.0), +(1,3,1.0), +(2,1,2.0), +(1,2,3.0), +(2,3,4.0); + +query IIR rowsort +SELECT c1, c2, SUM(c3) FROM simple_groupby GROUP BY c1, c2; +---- +1 1 8.000000 +1 2 8.000000 +1 3 4.000000 +2 1 6.000000 +2 2 10.000000 +2 3 10.000000 + +query IRI rowsort +SELECT c1, c3, SUM(c2) FROM simple_groupby GROUP BY c1, c3; +---- +1 1.000000 4 +1 3.000000 5 +1 5.000000 2 +1 7.000000 1 +2 2.000000 3 +2 4.000000 4 +2 6.000000 3 +2 8.000000 2 + +query RII rowsort +SELECT c3, SUM(c1), SUM(c2) FROM simple_groupby GROUP BY c3; +---- +1.000000 2 4 +2.000000 4 3 +3.000000 2 5 +4.000000 4 4 +5.000000 1 2 +6.000000 2 3 +7.000000 1 1 +8.000000 2 2 + +query RIII rowsort +SELECT c3, COUNT(c3), SUM(c1), SUM(c2) FROM simple_groupby GROUP BY c3; +---- +1.000000 2 2 4 +2.000000 2 4 3 +3.000000 2 2 5 +4.000000 2 4 4 +5.000000 1 1 2 +6.000000 1 2 3 +7.000000 1 1 1 +8.000000 1 2 2 + +query IIR rowsort +SELECT c1, c2, SUM(c3) FROM simple_groupby WHERE c1 > 1 GROUP BY c1, c2; +---- +2 1 6.000000 +2 2 10.000000 +2 3 10.000000 + +statement ok +DELETE FROM simple_groupby WHERE c1 <= 1; + +query IIR rowsort +SELECT c1, c2, SUM(c3) FROM simple_groupby GROUP BY c1, c2; +---- +2 1 6.000000 +2 2 10.000000 +2 3 10.000000 + +statement ok +DROP TABLE simple_groupby; diff --git a/test/sql/dql/aggregate/test_groupby_datatype.slt b/test/sql/dql/aggregate/test_groupby_datatype.slt new file mode 100644 index 0000000000..3db42bc389 --- /dev/null +++ b/test/sql/dql/aggregate/test_groupby_datatype.slt @@ -0,0 +1,106 @@ +statement ok +DROP TABLE IF EXISTS simple_groupby; + +statement ok +CREATE TABLE simple_groupby (c1 INTEGER, c2 FLOAT, c3 VARCHAR); + +statement ok +INSERT INTO simple_groupby VALUES +(1,1.0,'abc'), +(2,2.0,'abcdef'), +(3,3.0,'abcdefghi'), +(1,4.0,'abcdefghijkl'), +(2,5.0,'abcdefghijklmno'), +(3,6.0,'abcdefghijklmnopqr'), +(1,1.0,'abcdefghijklmnopqrstu'), +(2,2.0,'abcdefghijklmnopqrstuvwx'), +(3,3.0,'abcdefghijklmnopqrstuvwxyz'), +(1,4.0,'abc'), +(2,5.0,'abcdef'), +(3,6.0,'abcdefghi'), +(1,1.0,'abcdefghijkl'), +(2,2.0,'abcdefghijklmno'), +(3,3.0,'abcdefghijklmnopqr'), +(1,4.0,'abcdefghijklmnopqrstu'), +(2,5.0,'abcdefghijklmnopqrstuvwx'), +(3,6.0,'abcdefghijklmnopqrstuvwxyz'); + +query TIR rowsort +SELECT c3, SUM(c1), SUM(c2) FROM simple_groupby GROUP BY c3; +---- +abc 2 5.000000 +abcdef 4 7.000000 +abcdefghi 6 9.000000 +abcdefghijkl 2 5.000000 +abcdefghijklmno 4 7.000000 +abcdefghijklmnopqr 6 9.000000 +abcdefghijklmnopqrstu 2 5.000000 +abcdefghijklmnopqrstuvwx 4 7.000000 +abcdefghijklmnopqrstuvwxyz 6 9.000000 + +query TII rowsort +SELECT c3, CHAR_LENGTH(c3), SUM(c1) FROM simple_groupby GROUP BY c3; +---- +abc 3 2 +abcdef 6 4 +abcdefghi 9 6 +abcdefghijkl 12 2 +abcdefghijklmno 15 4 +abcdefghijklmnopqr 18 6 +abcdefghijklmnopqrstu 21 2 +abcdefghijklmnopqrstuvwx 24 4 +abcdefghijklmnopqrstuvwxyz 26 6 + +statement ok +DROP TABLE simple_groupby; + +statement ok +CREATE TABLE simple_groupby (c1 INTEGER, d DATE, dt DATETIME, t TIME, ts TIMESTAMP); + +statement ok +INSERT INTO simple_groupby VALUES +(1, DATE '1970-01-01', DATETIME '1970-01-01 00:00:00', TIME '00:00:00', TIMESTAMP '1970-01-01 00:00:00'), +(2, DATE '1970-01-01', DATETIME '1970-01-01 00:00:00', TIME '11:59:59', TIMESTAMP '1970-01-01 11:59:59'), +(3, DATE '1970-01-01', DATETIME '1970-01-01 00:00:00', TIME '12:00:00', TIMESTAMP '1970-01-01 12:00:00'), +(4, DATE '1970-01-01', DATETIME '1970-01-01 00:00:00', TIME '23:59:59', TIMESTAMP '1970-01-01 23:59:59'), +(5, DATE '1970-01-02', DATETIME '1970-01-02 00:00:00', TIME '00:00:00', TIMESTAMP '1970-01-01 00:00:00'), +(6, DATE '1970-01-02', DATETIME '1970-01-02 00:00:00', TIME '11:59:59', TIMESTAMP '1970-01-01 11:59:59'), +(7, DATE '1970-01-02', DATETIME '1970-01-02 00:00:00', TIME '12:00:00', TIMESTAMP '1970-01-01 12:00:00'), +(8, DATE '1970-01-02', DATETIME '1970-01-02 00:00:00', TIME '23:59:59', TIMESTAMP '1970-01-01 23:59:59'), +(9, DATE '1970-01-03', DATETIME '1970-01-03 00:00:00', TIME '00:00:00', TIMESTAMP '1970-01-01 00:00:00'), +(10, DATE '1970-01-03', DATETIME '1970-01-03 00:00:00', TIME '11:59:59', TIMESTAMP '1970-01-01 11:59:59'), +(11, DATE '1970-01-03', DATETIME '1970-01-03 00:00:00', TIME '12:00:00', TIMESTAMP '1970-01-01 12:00:00'), +(12, DATE '1970-01-03', DATETIME '1970-01-03 00:00:00', TIME '23:59:59', TIMESTAMP '1970-01-01 23:59:59'); + +query TI rowsort +SELECT d, SUM(c1) FROM simple_groupby GROUP BY d; +---- +1970-01-01 10 +1970-01-02 26 +1970-01-03 42 + +query TI rowsort +SELECT t, SUM(c1) FROM simple_groupby GROUP BY t; +---- +00:00:00 15 +11:59:59 18 +12:00:00 21 +23:59:59 24 + +query TI rowsort +SELECT dt, SUM(c1) FROM simple_groupby GROUP BY dt; +---- +1970-01-01 00:00:00 10 +1970-01-02 00:00:00 26 +1970-01-03 00:00:00 42 + +query TI rowsort +SELECT ts, SUM(c1) FROM simple_groupby GROUP BY ts; +---- +1970-01-01 00:00:00 15 +1970-01-01 11:59:59 18 +1970-01-01 12:00:00 21 +1970-01-01 23:59:59 24 + +statement ok +DROP TABLE simple_groupby; \ No newline at end of file diff --git a/tools/generate_groupby1.py b/tools/generate_groupby1.py new file mode 100644 index 0000000000..736f19d22e --- /dev/null +++ b/tools/generate_groupby1.py @@ -0,0 +1,114 @@ +import argparse +import os +import csv +import random +from collections import defaultdict + + +def generate(generate_if_exists: bool, copy_dir: str): + data_dir = "./test/data/csv" + slt_dir = "./test/sql/dql/aggregate" + + table_name = "test_big_groupby" + data_path = data_dir + "/test_big_groupby.csv" + data_path2 = data_dir + "/test_big_groupby2.csv" + slt_path = slt_dir + "/test_big_groupby.slt" + copy_path = copy_dir + "/test_big_groupby.csv" + copy_path2 = copy_dir + "/test_big_groupby2.csv" + + os.makedirs(data_dir, exist_ok=True) + os.makedirs(slt_dir, exist_ok=True) + if ( + os.path.exists(data_path) + and os.path.exists(slt_path) + and not generate_if_exists + ): + print( + "File {} and {} already existed exists. Skip Generating.".format( + slt_path, data_path + ) + ) + return + + row_n = 9000 + group_n1 = 150 + group_n2 = 100 + groupby_c1 = defaultdict(list) + groupby_c2 = defaultdict(list) + for d_path in [data_path, data_path2]: + with open(d_path, "w") as data_file: + writer = csv.writer(data_file) + for i in range(row_n): + c1 = random.randint(0, group_n1 - 1) + c2 = random.randint(0, group_n2 - 1) + writer.writerow([c1, c2]) + groupby_c1[c1].append(c2) + groupby_c2[c2].append(c1) + + with open(slt_path, "w") as slt_file: + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name)) + slt_file.write("\n") + + slt_file.write("statement ok\n") + slt_file.write("CREATE TABLE {} (c1 int, c2 int);\n".format(table_name)) + slt_file.write("\n") + + for c_path in [copy_path, copy_path2]: + slt_file.write("statement ok\n") + slt_file.write( + "COPY {} FROM '{}' WITH ( DELIMITER ',', FORMAT CSV );\n".format( + table_name, c_path + ) + ) + slt_file.write("\n") + + slt_file.write("query III rowsort\n") + slt_file.write( + "SELECT c2, COUNT(*), SUM(c1) FROM {} GROUP BY c2;\n".format(table_name) + ) + slt_file.write("----\n") + select_res = [] + for c2, c1_list in groupby_c2.items(): + select_res.append(f"{c2} {len(c1_list)} {sum(c1_list)}\n") + select_res.sort() + for res in select_res: + slt_file.write(res) + slt_file.write("\n") + + slt_file.write("query III rowsort\n") + slt_file.write( + "SELECT c1, COUNT(*), SUM(c2) FROM {} GROUP BY c1;\n".format(table_name) + ) + slt_file.write("----\n") + select_res = [] + for c1, c2_list in groupby_c1.items(): + select_res.append(f"{c1} {len(c2_list)} {sum(c2_list)}\n") + select_res.sort() + for res in select_res: + slt_file.write(res) + slt_file.write("\n") + + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name)) + slt_file.write("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate groupby data for test") + parser.add_argument( + "-g", + "--generate", + type=bool, + default=False, + dest="generate_if_exists", + ) + parser.add_argument( + "-c", + "--copy", + type=str, + default="/var/infinity/test_data", + dest="copy_dir", + ) + args = parser.parse_args() + generate(args.generate_if_exists, args.copy_dir) diff --git a/tools/sqllogictest.py b/tools/sqllogictest.py index 61e919bc52..2d0dd36fcd 100644 --- a/tools/sqllogictest.py +++ b/tools/sqllogictest.py @@ -37,7 +37,7 @@ from generate_tensor_array_parquet import generate as generate26 from generate_multivector_parquet import generate as generate27 from generate_multivector_knn_scan import generate as generate28 - +from generate_groupby1 import generate as generate29 class SpinnerThread(threading.Thread): def __init__(self): @@ -192,6 +192,7 @@ def copy_all(data_dir, copy_dir): generate26(args.generate_if_exists, args.copy) generate27(args.generate_if_exists, args.copy) generate28(args.generate_if_exists, args.copy) + generate29(args.generate_if_exists, args.copy) print("Generate file finshed.")