diff --git a/src/common/analyzer/analyzer.cppm b/src/common/analyzer/analyzer.cppm index 2166fafc5e..039130d6aa 100644 --- a/src/common/analyzer/analyzer.cppm +++ b/src/common/analyzer/analyzer.cppm @@ -49,11 +49,23 @@ public: } protected: - typedef void (*HookType)(void *data, const char *text, const u32 len, const u32 offset, const u32 end_offset, const bool is_special_char); + typedef void (*HookType)(void *data, + const char *text, + const u32 len, + const u32 offset, + const u32 end_offset, + const bool is_special_char, + const u16 payload); virtual int AnalyzeImpl(const Term &input, void *data, HookType func) { return -1; } - static void AppendTermList(void *data, const char *text, const u32 len, const u32 offset, const u32 end_offset, const bool is_special_char) { + static void AppendTermList(void *data, + const char *text, + const u32 len, + const u32 offset, + const u32 end_offset, + const bool is_special_char, + const u16 payload) { void **parameters = (void **)data; TermList *output = (TermList *)parameters[0]; Analyzer *analyzer = (Analyzer *)parameters[1]; @@ -62,9 +74,9 @@ protected: return; if (is_special_char && analyzer->convert_to_placeholder_) { if (output->empty() == true || output->back().text_.compare(PLACE_HOLDER) != 0) - output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset); + output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset, payload); } else { - output->Add(text, len, offset, end_offset); + output->Add(text, len, offset, end_offset, payload); } } diff --git a/src/common/analyzer/analyzer_pool.cpp b/src/common/analyzer/analyzer_pool.cpp index 1dee2c8039..2ada50741a 100644 --- a/src/common/analyzer/analyzer_pool.cpp +++ b/src/common/analyzer/analyzer_pool.cpp @@ -34,6 +34,7 @@ import ngram_analyzer; import rag_analyzer; import whitespace_analyzer; import ik_analyzer; +import rank_features_analyzer; import logger; namespace infinity { @@ -330,6 +331,9 @@ Tuple, Status> AnalyzerPool::GetAnalyzer(const std::string_v } return {MakeUnique(name.substr(suffix_pos + 1)), Status::OK()}; } + case Str2Int(RANKFEATURES.data()): { + return {MakeUnique(), Status::OK()}; + } default: { if (std::filesystem::is_regular_file(name)) { // Suppose it is a customized Python script analyzer diff --git a/src/common/analyzer/analyzer_pool.cppm b/src/common/analyzer/analyzer_pool.cppm index 804783cfe7..34c309e77d 100644 --- a/src/common/analyzer/analyzer_pool.cppm +++ b/src/common/analyzer/analyzer_pool.cppm @@ -45,6 +45,7 @@ public: static constexpr std::string_view IK = "ik"; static constexpr std::string_view KEYWORD = "keyword"; static constexpr std::string_view WHITESPACE = "whitespace"; + static constexpr std::string_view RANKFEATURES = "rankfeatures"; private: CacheType cache_{}; diff --git a/src/common/analyzer/common_analyzer.cpp b/src/common/analyzer/common_analyzer.cpp index 8ffcbf6261..1ed5721d69 100644 --- a/src/common/analyzer/common_analyzer.cpp +++ b/src/common/analyzer/common_analyzer.cpp @@ -52,12 +52,12 @@ int CommonLanguageAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType if (is_index_) { if (IsSpecialChar()) { - func(data, token_, len_, offset_, end_offset_, true); + func(data, token_, len_, offset_, end_offset_, true, 0); temp_offset = offset_; continue; } if (is_raw_) { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; continue; } @@ -77,37 +77,37 @@ int CommonLanguageAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType bool lowercase_is_different = memcmp(token_, lowercase_term, len_) != 0; if (stemming_term_str_size && stem_only_) { - func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false); + func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false, 0); temp_offset = offset_; } else if (stemming_term_str_size || (case_sensitive_ && contain_lower_ && lowercase_is_different)) { /// have more than one output if (case_sensitive_) { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } else { - func(data, lowercase_term, len_, offset_, end_offset_, false); + func(data, lowercase_term, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } if (stemming_term_str_size) { - func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false); + func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false, 0); temp_offset = offset_; } if (case_sensitive_ && contain_lower_ && lowercase_is_different) { - func(data, lowercase_term, len_, offset_, end_offset_, false); + func(data, lowercase_term, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } } else { /// have only one output if (case_sensitive_) { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } else { - func(data, lowercase_term, len_, offset_, end_offset_, false); + func(data, lowercase_term, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } } } else { - func(data, token_, len_, offset_, end_offset_, false); + func(data, token_, len_, offset_, end_offset_, false, 0); temp_offset = offset_; } } diff --git a/src/common/analyzer/ik/ik_analyzer.cpp b/src/common/analyzer/ik/ik_analyzer.cpp index 564e1ffb93..e58bc703a2 100644 --- a/src/common/analyzer/ik/ik_analyzer.cpp +++ b/src/common/analyzer/ik/ik_analyzer.cpp @@ -90,7 +90,7 @@ int IKAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { while ((lexeme = context_->GetNextLexeme()) != nullptr) { std::wstring text = lexeme->GetLexemeText(); String token = CharacterUtil::UTF16ToUTF8(text); - func(data, token.c_str(), token.size(), offset++, 0, false); + func(data, token.c_str(), token.size(), offset++, 0, false, 0); delete lexeme; }; return 0; diff --git a/src/common/analyzer/ngram_analyzer.cpp b/src/common/analyzer/ngram_analyzer.cpp index caf1a39d94..94e3a2673f 100644 --- a/src/common/analyzer/ngram_analyzer.cpp +++ b/src/common/analyzer/ngram_analyzer.cpp @@ -59,7 +59,7 @@ int NGramAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { while (cur < len && NextInString(input.text_.c_str(), len, &cur, &token_start, &token_length)) { if (token_length == 0) continue; - func(data, input.text_.c_str() + token_start, token_length, offset, offset + token_length, false); + func(data, input.text_.c_str() + token_start, token_length, offset, offset + token_length, false, 0); offset++; } diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index 7aa384ae81..4e900e603c 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -1438,7 +1438,7 @@ int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { Split(output, blank_pattern_, tokens); unsigned offset = 0; for (auto &t : tokens) { - func(data, t.c_str(), t.size(), offset++, 0, false); + func(data, t.c_str(), t.size(), offset++, 0, false, 0); } return 0; } diff --git a/src/common/analyzer/rank_features_analyzer.cpp b/src/common/analyzer/rank_features_analyzer.cpp new file mode 100644 index 0000000000..89b1ceb8e7 --- /dev/null +++ b/src/common/analyzer/rank_features_analyzer.cpp @@ -0,0 +1,52 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +module rank_features_analyzer; +import stl; +import term; +import analyzer; +import third_party; + +namespace infinity { + +u16 FloatToU16(float value) { + if (value < 0.0f) + value = 0.0f; + if (value > 65535.0f) + value = 65535.0f; + return static_cast(value); +} + +int RankFeaturesAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { + nlohmann::json line_json = nlohmann::json::parse(input.text_); + u32 offset = 0; + for (auto iter = line_json.begin(); iter != line_json.end(); ++iter) { + String key = iter.key(); + String value = iter.value(); + float target = 0; + try { + target = std::stof(value); + } catch (const std::exception &e) { + } + u16 weight = FloatToU16(target); + func(data, key.data(), key.size(), offset++, 0, false, weight); + } + + return 0; +} + +} // namespace infinity diff --git a/src/common/analyzer/rank_features_analyzer.cppm b/src/common/analyzer/rank_features_analyzer.cppm new file mode 100644 index 0000000000..79237ea0f5 --- /dev/null +++ b/src/common/analyzer/rank_features_analyzer.cppm @@ -0,0 +1,35 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module rank_features_analyzer; +import stl; +import term; +import analyzer; + +namespace infinity { + +export class RankFeaturesAnalyzer : public Analyzer { + String delimiters_{}; + +public: + RankFeaturesAnalyzer() = default; + ~RankFeaturesAnalyzer() override = default; + +protected: + int AnalyzeImpl(const Term &input, void *data, HookType func) override; +}; + +} // namespace infinity diff --git a/src/common/analyzer/term.cppm b/src/common/analyzer/term.cppm index 3543f96ea2..ea556ce7c6 100644 --- a/src/common/analyzer/term.cppm +++ b/src/common/analyzer/term.cppm @@ -41,11 +41,12 @@ public: export class TermList : public Deque { public: - void Add(const char *text, const u32 len, const u32 offset, const u32 end_offset) { + void Add(const char *text, const u32 len, const u32 offset, const u32 end_offset, const u16 payload = 0) { push_back(global_temporary_); back().text_.assign(text, len); back().word_offset_ = offset; back().end_offset_ = end_offset; + back().payload_ = payload; } void Add(cppjieba::Word &cut_word) { @@ -54,18 +55,20 @@ public: back().word_offset_ = cut_word.offset; } - void Add(const String &token, const u32 offset, const u32 end_offset) { + void Add(const String &token, const u32 offset, const u32 end_offset, const u16 payload = 0) { push_back(global_temporary_); back().text_ = token; back().word_offset_ = offset; back().end_offset_ = end_offset; + back().payload_ = payload; } - void Add(String &token, const u32 offset, const u32 end_offset) { + void Add(String &token, const u32 offset, const u32 end_offset, const u16 payload = 0) { push_back(global_temporary_); std::swap(back().text_, token); back().word_offset_ = offset; back().end_offset_ = end_offset; + back().payload_ = payload; } private: diff --git a/src/common/analyzer/whitespace_analyzer.cpp b/src/common/analyzer/whitespace_analyzer.cpp index 0bd02d5134..83a558b6a8 100644 --- a/src/common/analyzer/whitespace_analyzer.cpp +++ b/src/common/analyzer/whitespace_analyzer.cpp @@ -37,7 +37,7 @@ int WhitespaceAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func std::string t; u32 offset = 0; while (is >> t) { - func(data, t.data(), t.size(), offset++, 0, false); + func(data, t.data(), t.size(), offset++, 0, false, 0); } return 0; } else { @@ -49,11 +49,11 @@ int WhitespaceAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func while (search_start < input_text.size()) { const auto found = input_text.find_first_of(delimiters, search_start); if (found == std::string_view::npos) { - func(data, input_text.data() + search_start, input_text.size() - search_start, offset++, 0, false); + func(data, input_text.data() + search_start, input_text.size() - search_start, offset++, 0, false, 0); break; } if (found > search_start) { - func(data, input_text.data() + search_start, found - search_start, offset++, 0, false); + func(data, input_text.data() + search_start, found - search_start, offset++, 0, false, 0); } search_start = found + 1; } diff --git a/src/executor/operator/physical_match.cpp b/src/executor/operator/physical_match.cpp index ea87c6b305..7c37d07cf6 100644 --- a/src/executor/operator/physical_match.cpp +++ b/src/executor/operator/physical_match.cpp @@ -245,7 +245,12 @@ bool PhysicalMatch::ExecuteInner(QueryContext *query_context, OperatorState *ope static_cast(finish_init_query_builder_time - execute_start_time).count())); // 2 build query iterator - FullTextQueryContext full_text_query_context(ft_similarity_, bm25_params_, minimum_should_match_option_, top_n_, match_expr_->index_names_); + FullTextQueryContext full_text_query_context(ft_similarity_, + bm25_params_, + minimum_should_match_option_, + rank_features_option_, + top_n_, + match_expr_->index_names_); full_text_query_context.query_tree_ = MakeUnique(common_query_filter_.get(), std::move(query_tree_)); const auto query_iterators = CreateQueryIterators(query_builder, full_text_query_context, early_term_algo_, begin_threshold_, score_threshold_); const auto finish_query_builder_time = std::chrono::high_resolution_clock::now(); @@ -329,6 +334,7 @@ PhysicalMatch::PhysicalMatch(const u64 id, const u32 top_n, const SharedPtr &common_query_filter, MinimumShouldMatchOption &&minimum_should_match_option, + RankFeaturesOption &&rank_features_option, const f32 score_threshold, const FulltextSimilarity ft_similarity, const BM25Params &bm25_params, @@ -339,7 +345,8 @@ PhysicalMatch::PhysicalMatch(const u64 id, base_table_ref_(std::move(base_table_ref)), match_expr_(std::move(match_expr)), index_reader_(std::move(index_reader)), query_tree_(std::move(query_tree)), begin_threshold_(begin_threshold), early_term_algo_(early_term_algo), top_n_(top_n), common_query_filter_(common_query_filter), minimum_should_match_option_(std::move(minimum_should_match_option)), - score_threshold_(score_threshold), ft_similarity_(ft_similarity), bm25_params_(bm25_params) {} + rank_features_option_(std::move(rank_features_option)), score_threshold_(score_threshold), ft_similarity_(ft_similarity), + bm25_params_(bm25_params) {} PhysicalMatch::~PhysicalMatch() = default; diff --git a/src/executor/operator/physical_match.cppm b/src/executor/operator/physical_match.cppm index 9db6a0863c..b4c038c223 100644 --- a/src/executor/operator/physical_match.cppm +++ b/src/executor/operator/physical_match.cppm @@ -54,6 +54,7 @@ public: u32 top_n, const SharedPtr &common_query_filter, MinimumShouldMatchOption &&minimum_should_match_option, + RankFeaturesOption &&rank_features_option, f32 score_threshold, FulltextSimilarity ft_similarity, const BM25Params &bm25_params, @@ -112,6 +113,8 @@ private: SharedPtr common_query_filter_; // for minimum_should_match MinimumShouldMatchOption minimum_should_match_option_{}; + // for rank features + RankFeaturesOption rank_features_option_{}; f32 score_threshold_{}; FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25}; BM25Params bm25_params_; diff --git a/src/executor/physical_planner.cpp b/src/executor/physical_planner.cpp index 95098f9025..6f9289f661 100644 --- a/src/executor/physical_planner.cpp +++ b/src/executor/physical_planner.cpp @@ -980,6 +980,7 @@ UniquePtr PhysicalPlanner::BuildMatch(const SharedPtrtop_n_, logical_match->common_query_filter_, std::move(logical_match->minimum_should_match_option_), + std::move(logical_match->rank_features_option_), logical_match->score_threshold_, logical_match->ft_similarity_, logical_match->bm25_params_, diff --git a/src/planner/bound_select_statement.cpp b/src/planner/bound_select_statement.cpp index 916cb09e7c..44d0923f3e 100644 --- a/src/planner/bound_select_statement.cpp +++ b/src/planner/bound_select_statement.cpp @@ -275,6 +275,11 @@ SharedPtr BoundSelectStatement::BuildPlan(QueryContext *query_conte match_node->minimum_should_match_option_ = ParseMinimumShouldMatchOption(iter->second); } + // option: rank_features + if (iter = search_ops.options_.find("rank_features"); iter != search_ops.options_.end()) { + match_node->rank_features_option_ = ParseRankFeaturesOption(iter->second); + } + // option: threshold if (iter = search_ops.options_.find("threshold"); iter != search_ops.options_.end()) { match_node->score_threshold_ = DataType::StringToValue(iter->second); diff --git a/src/planner/node/logical_match.cppm b/src/planner/node/logical_match.cppm index 3480b1c692..5f4e974ba9 100644 --- a/src/planner/node/logical_match.cppm +++ b/src/planner/node/logical_match.cppm @@ -66,6 +66,7 @@ public: SharedPtr common_query_filter_{}; MinimumShouldMatchOption minimum_should_match_option_{}; + RankFeaturesOption rank_features_option_{}; f32 score_threshold_{}; FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25}; BM25Params bm25_params_; diff --git a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm index 72cdcee1b1..ebdbf1e72e 100644 --- a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm +++ b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm @@ -91,6 +91,7 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator { UniquePtr query_tree_; MinimumShouldMatchOption minimum_should_match_option_; u32 minimum_should_match_ = 0; + RankFeaturesOption rank_features_option_; std::atomic_flag after_optimize_ = {}; f32 score_threshold_ = {}; FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25; diff --git a/src/storage/invertedindex/search/doc_iterator.cppm b/src/storage/invertedindex/search/doc_iterator.cppm index 017fea1a0a..bb9d3b5093 100644 --- a/src/storage/invertedindex/search/doc_iterator.cppm +++ b/src/storage/invertedindex/search/doc_iterator.cppm @@ -44,6 +44,8 @@ export enum class DocIteratorType : u8 { kScoreThresholdIterator, kKeywordIterator, kMustFirstIterator, + kRankFeatureDocIterator, + kRankFeaturesDocIterator, }; export struct DocIteratorEstimateIterateCost { diff --git a/src/storage/invertedindex/search/parse_fulltext_options.cpp b/src/storage/invertedindex/search/parse_fulltext_options.cpp index 424bca5b98..e6fb128060 100644 --- a/src/storage/invertedindex/search/parse_fulltext_options.cpp +++ b/src/storage/invertedindex/search/parse_fulltext_options.cpp @@ -14,7 +14,9 @@ module; +#include #include + module parse_fulltext_options; import stl; @@ -127,4 +129,60 @@ u32 GetMinimumShouldMatchParameter(const MinimumShouldMatchOption &option_vec, c match_option); } +void Split(const std::string_view &input, const String &split_pattern, Vector &result, bool keep_delim = false) { + re2::RE2 pattern(split_pattern); + re2::StringPiece leftover(input.data()); + re2::StringPiece last_end = leftover; + re2::StringPiece extracted_delim_token; + + while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) { + std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data()); + if (!token.empty()) { + result.push_back(String(token.data(), token.size())); + } + if (keep_delim) + result.push_back(String(extracted_delim_token.data(), extracted_delim_token.size())); + last_end = leftover; + } + + if (!leftover.empty()) { + result.push_back(String(leftover.data(), leftover.size())); + } +} + +void ParseRankFeatureOption(std::string_view input_str, RankFeatureOption &feature_option) { + Vector feature_strs; + Split(input_str, "^", feature_strs); + if (feature_strs.size() == 2) { + feature_option.field_ = feature_strs[0]; + feature_option.feature_ = feature_strs[1]; + feature_option.boost_ = 1.0f; + } else if (feature_strs.size() == 3) { + feature_option.field_ = feature_strs[0]; + feature_option.feature_ = feature_strs[1]; + const auto boost_str = feature_strs[2]; + try { + feature_option.boost_ = std::stof(boost_str); + } catch (const std::exception &e) { + RecoverableError( + Status::SyntaxError(std::format("Invalid rank_features parameter format: Failed to parse float value in option '{}'.", input_str))); + } + } else { + RecoverableError( + Status::SyntaxError(std::format("Invalid rank_features parameter format: Expect 3 parts separated by '^', but get: '{}'.", input_str))); + } +} + +RankFeaturesOption ParseRankFeaturesOption(std::string_view input_str) { + RankFeaturesOption result; + Vector feature_strs; + Split(input_str, ",", feature_strs); + for (auto &feature_str : feature_strs) { + RankFeatureOption feature_option; + ParseRankFeatureOption(feature_str, feature_option); + result.push_back(feature_option); + } + return result; +} + } // namespace infinity diff --git a/src/storage/invertedindex/search/parse_fulltext_options.cppm b/src/storage/invertedindex/search/parse_fulltext_options.cppm index 5cd4aebc88..5c39cface0 100644 --- a/src/storage/invertedindex/search/parse_fulltext_options.cppm +++ b/src/storage/invertedindex/search/parse_fulltext_options.cppm @@ -49,4 +49,14 @@ export struct BM25Params { float delta_phrase = 0.0F; }; +export struct RankFeatureOption { + String field_; + String feature_; + float boost_; +}; + +export using RankFeaturesOption = Vector; + +export RankFeaturesOption ParseRankFeaturesOption(std::string_view input_str); + } // namespace infinity diff --git a/src/storage/invertedindex/search/query_builder.cpp b/src/storage/invertedindex/search/query_builder.cpp index 30676904d7..3b50bff8a0 100644 --- a/src/storage/invertedindex/search/query_builder.cpp +++ b/src/storage/invertedindex/search/query_builder.cpp @@ -16,6 +16,9 @@ module; #include #include + +#include "query_node.h" + module query_builder; import stl; @@ -72,6 +75,17 @@ UniquePtr QueryBuilder::CreateSearch(FullTextQueryContext &context) LOG_DEBUG(std::move(oss).str()); } #endif + if (!context.rank_features_option_.empty()) { + auto rank_features_node = std::make_unique(); + for (auto rank_feature : context.rank_features_option_) { + auto rank_feature_node = std::make_unique(); + rank_feature_node->term_ = rank_feature.feature_; + rank_feature_node->column_ = rank_feature.field_; + rank_feature_node->boost_ = rank_feature.boost_; + rank_features_node->Add(std::move(rank_feature_node)); + } + // auto rank_features_iter = rank_features_node->CreateSearch(params); + } return result; } diff --git a/src/storage/invertedindex/search/query_builder.cppm b/src/storage/invertedindex/search/query_builder.cppm index 24560b8da0..b43b2bfbc1 100644 --- a/src/storage/invertedindex/search/query_builder.cppm +++ b/src/storage/invertedindex/search/query_builder.cppm @@ -36,6 +36,7 @@ export struct FullTextQueryContext { const FulltextSimilarity ft_similarity_{}; const BM25Params bm25_params_{}; const MinimumShouldMatchOption minimum_should_match_option_{}; + const RankFeaturesOption rank_features_option_{}; u32 minimum_should_match_ = 0; u32 topn_ = 0; EarlyTermAlgo early_term_algo_ = EarlyTermAlgo::kNaive; @@ -44,10 +45,11 @@ export struct FullTextQueryContext { FullTextQueryContext(const FulltextSimilarity ft_similarity, const BM25Params &bm25_params, const MinimumShouldMatchOption &minimum_should_match_option, + const RankFeaturesOption &rank_features_option, const u32 topn, const Vector &index_names) - : ft_similarity_(ft_similarity), bm25_params_(bm25_params), minimum_should_match_option_(minimum_should_match_option), topn_(topn), - index_names_(index_names) {} + : ft_similarity_(ft_similarity), bm25_params_(bm25_params), minimum_should_match_option_(minimum_should_match_option), + rank_features_option_(rank_features_option), topn_(topn), index_names_(index_names) {} }; export class QueryBuilder { diff --git a/src/storage/invertedindex/search/query_node.cpp b/src/storage/invertedindex/search/query_node.cpp index ace49c00c2..ee6fdced94 100644 --- a/src/storage/invertedindex/search/query_node.cpp +++ b/src/storage/invertedindex/search/query_node.cpp @@ -27,6 +27,8 @@ import keyword_iterator; import must_first_iterator; import batch_or_iterator; import blockmax_leaf_iterator; +import rank_feature_doc_iterator; +import rank_features_doc_iterator; namespace infinity { @@ -390,6 +392,11 @@ std::unique_ptr OrQueryNode::InnerGetNewOptimizedQueryTree() { } } +std::unique_ptr RankFeaturesQueryNode::InnerGetNewOptimizedQueryTree() { + UnrecoverableError("OptimizeInPlaceInner: Unexpected case! RankFeaturesQueryNode should not exist in parser output"); + return nullptr; +} + // 4. deal with "and_not": // "and_not" does not exist in parser output, it is generated during optimization @@ -429,6 +436,25 @@ std::unique_ptr TermQueryNode::CreateSearch(const CreateSearchParam return search; } +std::unique_ptr RankFeatureQueryNode::CreateSearch(const CreateSearchParams params, bool) const { + ColumnID column_id = params.table_entry->GetColumnIdByName(column_); + ColumnIndexReader *column_index_reader = params.index_reader->GetColumnIndexReader(column_id, params.index_names_); + if (!column_index_reader) { + RecoverableError(Status::SyntaxError(fmt::format(R"(Invalid query statement: Column "{}" has no fulltext index)", column_))); + return nullptr; + } + + bool fetch_position = false; + auto posting_iterator = column_index_reader->Lookup(term_, fetch_position); + if (!posting_iterator) { + return nullptr; + } + auto search = MakeUnique(std::move(posting_iterator), column_id, boost_); + search->term_ptr_ = &term_; + search->column_name_ptr_ = &column_; + return search; +} + std::unique_ptr PhraseQueryNode::CreateSearch(const CreateSearchParams params, bool) const { ColumnID column_id = params.table_entry->GetColumnIdByName(column_); ColumnIndexReader *column_index_reader = params.index_reader->GetColumnIndexReader(column_id, params.index_names_); @@ -685,6 +711,25 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams } } +std::unique_ptr RankFeaturesQueryNode::CreateSearch(const CreateSearchParams params, const bool is_top_level) const { + Vector> sub_doc_iters; + sub_doc_iters.reserve(children_.size()); + const auto next_params = params.RemoveMSM(); + for (auto &child : children_) { + auto iter = child->CreateSearch(next_params, false); + if (!iter) { + // no need to continue if any child is invalid + return nullptr; + } + sub_doc_iters.emplace_back(std::move(iter)); + } + if (sub_doc_iters.empty()) { + return nullptr; + } else { + return MakeUnique(std::move(sub_doc_iters)); + } +} + std::unique_ptr KeywordQueryNode::CreateSearch(const CreateSearchParams params, bool) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); @@ -753,6 +798,21 @@ void TermQueryNode::GetQueryColumnsTerms(std::vector &columns, std: terms.push_back(term_); } +void RankFeatureQueryNode::PrintTree(std::ostream &os, const std::string &prefix, const bool is_final) const { + os << prefix; + os << (is_final ? "└──" : "├──"); + os << QueryNodeTypeToString(type_); + os << " (weight: " << weight_ << ")"; + os << " (column: " << column_ << ")"; + os << " (term: " << term_ << ")"; + os << '\n'; +} + +void RankFeatureQueryNode::GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const { + columns.push_back(column_); + terms.push_back(term_); +} + void PhraseQueryNode::PrintTree(std::ostream &os, const std::string &prefix, const bool is_final) const { os << prefix; os << (is_final ? "└──" : "├──"); diff --git a/src/storage/invertedindex/search/query_node.h b/src/storage/invertedindex/search/query_node.h index 212df86128..6ab1af394c 100644 --- a/src/storage/invertedindex/search/query_node.h +++ b/src/storage/invertedindex/search/query_node.h @@ -125,6 +125,20 @@ struct TermQueryNode : public QueryNode { void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; }; +struct RankFeatureQueryNode : public QueryNode { + std::string term_; + std::string column_; + float boost_; + + RankFeatureQueryNode() : QueryNode(QueryNodeType::TERM) {} + + uint32_t LeafCount() const override { return 1; } + void PushDownWeight(float factor) override { MultiplyWeight(factor); } + std::unique_ptr CreateSearch(CreateSearchParams params, bool is_top_level) const override; + void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const override; + void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; +}; + struct PhraseQueryNode final : public QueryNode { std::vector terms_; std::string column_; @@ -191,6 +205,12 @@ struct OrQueryNode final : public MultiQueryNode { std::unique_ptr CreateSearch(CreateSearchParams params, bool is_top_level) const override; }; +struct RankFeaturesQueryNode final : public MultiQueryNode { + RankFeaturesQueryNode() : MultiQueryNode(QueryNodeType::OR) {} + std::unique_ptr InnerGetNewOptimizedQueryTree() override; + std::unique_ptr CreateSearch(CreateSearchParams params, bool is_top_level) const override; +}; + struct KeywordQueryNode final : public MultiQueryNode { KeywordQueryNode() : MultiQueryNode(QueryNodeType::KEYWORD) {} void PushDownWeight(float factor) override { MultiplyWeight(factor); } diff --git a/src/storage/invertedindex/search/rank_feature_doc_iterator.cpp b/src/storage/invertedindex/search/rank_feature_doc_iterator.cpp new file mode 100644 index 0000000000..9bdc4d2d4b --- /dev/null +++ b/src/storage/invertedindex/search/rank_feature_doc_iterator.cpp @@ -0,0 +1,68 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +#include +module rank_feature_doc_iterator; + +import stl; +import logger; + +namespace infinity { + +RankFeatureDocIterator::RankFeatureDocIterator(UniquePtr &&iter, const u64 column_id, float boost) + : column_id_(column_id), boost_(boost), iter_(std::move(iter)) {} + +RankFeatureDocIterator::~RankFeatureDocIterator() {} + +bool RankFeatureDocIterator::Next(RowID doc_id) { + assert(doc_id != INVALID_ROWID); + if (doc_id_ != INVALID_ROWID && doc_id_ >= doc_id) + return true; + doc_id_ = iter_->SeekDoc(doc_id); + if (doc_id_ == INVALID_ROWID) + return false; + return true; +} + +float RankFeatureDocIterator::Score() { + u16 payload = iter_->GetCurrentDocPayload(); + float weight = static_cast(payload); + return weight * boost_; +} + +void RankFeatureDocIterator::PrintTree(std::ostream &os, const String &prefix, bool is_final) const { + os << prefix; + os << (is_final ? "└──" : "├──"); + os << "RankFeatureDocIterator"; + os << " (column: " << *column_name_ptr_ << ")"; + os << " (term: " << *term_ptr_ << ")"; + os << '\n'; +} + +void RankFeatureDocIterator::BatchDecodeTo(const RowID buffer_start_doc_id, const RowID buffer_end_doc_id, u16 *payload_ptr) { + auto iter_doc_id = iter_->DocID(); + assert((buffer_start_doc_id <= iter_doc_id && iter_doc_id < buffer_end_doc_id)); + while (iter_doc_id < buffer_end_doc_id) { + const auto pos = iter_doc_id - buffer_start_doc_id; + const auto payload = iter_->GetCurrentDocPayload(); + payload_ptr[pos] = payload; + iter_doc_id = iter_->SeekDoc(iter_doc_id + 1); + } + doc_id_ = iter_doc_id; +} + +} // namespace infinity diff --git a/src/storage/invertedindex/search/rank_feature_doc_iterator.cppm b/src/storage/invertedindex/search/rank_feature_doc_iterator.cppm new file mode 100644 index 0000000000..4305fbecad --- /dev/null +++ b/src/storage/invertedindex/search/rank_feature_doc_iterator.cppm @@ -0,0 +1,61 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module rank_feature_doc_iterator; + +import stl; + +import posting_iterator; +import index_defines; +import doc_iterator; +import internal_types; +import doc_iterator; +import third_party; + +namespace infinity { + +export class RankFeatureDocIterator final : public DocIterator { +public: + RankFeatureDocIterator(UniquePtr &&iter, u64 column_id, float boost); + + ~RankFeatureDocIterator() override; + + DocIteratorType GetType() const override { return DocIteratorType::kRankFeatureDocIterator; } + + String Name() const override { return "RankFeatureDocIterator"; } + + void UpdateScoreThreshold(float threshold) override {} + + u32 MatchCount() const override { return 0; } + + bool Next(RowID doc_id) override; + + float Score() override; + + void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override; + + void BatchDecodeTo(RowID buffer_start_doc_id, RowID buffer_end_doc_id, u16 *payload_ptr); + + const String *term_ptr_ = nullptr; + const String *column_name_ptr_ = nullptr; + +private: + u64 column_id_; + float boost_ = 1.0f; + UniquePtr iter_; +}; + +} // namespace infinity diff --git a/src/storage/invertedindex/search/rank_features_doc_iterator.cpp b/src/storage/invertedindex/search/rank_features_doc_iterator.cpp new file mode 100644 index 0000000000..d7c365c9a5 --- /dev/null +++ b/src/storage/invertedindex/search/rank_features_doc_iterator.cpp @@ -0,0 +1,127 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +#include +#include +#include + +module rank_features_doc_iterator; + +import stl; +import third_party; +import index_defines; +import rank_feature_doc_iterator; +import multi_doc_iterator; +import internal_types; +import logger; +import infinity_exception; +import simd_functions; + +namespace infinity { + +constexpr u32 BATCH_OR_LEN = 128; + +RankFeaturesDocIterator::RankFeaturesDocIterator(Vector> &&iterators) : MultiDocIterator(std::move(iterators)) { + estimate_iterate_cost_ = {}; + const SizeT num_iterators = children_.size(); + for (SizeT i = 0; i < num_iterators; i++) { + auto it = dynamic_cast(children_[i].get()); + if (it == nullptr) { + UnrecoverableError("RankFeaturesDocIterator only supports RankFeatureDocIterator"); + } + } + static_assert(sizeof(f32) == sizeof(u32)); + memset_bytes_ = sizeof(u32) * BATCH_OR_LEN * 2u * num_iterators; + const auto alloc_bytes = sizeof(u32) * BATCH_OR_LEN * (2u + 2u * num_iterators); + aligned_buffer_ = std::aligned_alloc(64, alloc_bytes); + if (!aligned_buffer_) { + UnrecoverableError(fmt::format("{}: Out of memory!", __func__)); + } + match_cnt_ptr_ = static_cast(aligned_buffer_); + payload_ptr_ = reinterpret_cast(match_cnt_ptr_ + BATCH_OR_LEN * num_iterators); +} + +RankFeaturesDocIterator::~RankFeaturesDocIterator() { std::free(aligned_buffer_); } + +bool RankFeaturesDocIterator::Next(RowID doc_id) { + if (doc_id_ != INVALID_ROWID) [[likely]] { + if (doc_id_ >= doc_id) { + return true; + } + // now buffer_start_doc_id_ <= doc_id_ < doc_id + if (u32 pos = doc_id - buffer_start_doc_id_; pos < BATCH_OR_LEN) { + for (; pos < BATCH_OR_LEN; ++pos) { + if (match_cnt_ptr_[pos]) { + doc_id_ = buffer_start_doc_id_ + pos; + return true; + } + } + // now need to search from buffer_start_doc_id_ + BATCH_OR_LEN + doc_id = buffer_start_doc_id_ + BATCH_OR_LEN; + } + } else { + for (const auto &child : children_) { + child->Next(doc_id); + } + } + RowID next_buffer_start_doc_id = INVALID_ROWID; + for (const auto &child : children_) { + if (child->DocID() != INVALID_ROWID) { + child->Next(doc_id); + const RowID child_doc_id = child->DocID(); + next_buffer_start_doc_id = std::min(next_buffer_start_doc_id, child_doc_id); + } + } + if (next_buffer_start_doc_id != INVALID_ROWID) [[likely]] { + DecodeFrom(next_buffer_start_doc_id); + } + doc_id_ = next_buffer_start_doc_id; + return doc_id_ != INVALID_ROWID; +} + +u32 RankFeaturesDocIterator::MatchCount() const { return match_cnt_ptr_[doc_id_ - buffer_start_doc_id_]; } + +void RankFeaturesDocIterator::DecodeFrom(const RowID buffer_start_doc_id) { + buffer_start_doc_id_ = buffer_start_doc_id; + std::memset(aligned_buffer_, 0, memset_bytes_); + const auto buffer_end_doc_id = buffer_start_doc_id + BATCH_OR_LEN; + for (u32 i = 0; i < children_.size(); ++i) { + const auto child = children_[i].get(); + if (const auto child_doc_id = child->DocID(); child_doc_id != INVALID_ROWID) { + assert(child_doc_id >= buffer_start_doc_id); + if (child_doc_id >= buffer_end_doc_id) { + // no need to decode + continue; + } + const auto it = dynamic_cast(child); + it->BatchDecodeTo(buffer_start_doc_id, buffer_end_doc_id, payload_ptr_ + i * BATCH_OR_LEN); + } + } + for (u32 i = 0; i < BATCH_OR_LEN; ++i) { + u32 match_cnt = 0; + for (u32 j = 0; j < children_.size(); ++j) { + const auto payload = payload_ptr_[j * BATCH_OR_LEN + i]; + if (payload == 0) { + continue; + } + ++match_cnt; + } + match_cnt_ptr_[i] = match_cnt; + } +} + +} // namespace infinity diff --git a/src/storage/invertedindex/search/rank_features_doc_iterator.cppm b/src/storage/invertedindex/search/rank_features_doc_iterator.cppm new file mode 100644 index 0000000000..aea46502fa --- /dev/null +++ b/src/storage/invertedindex/search/rank_features_doc_iterator.cppm @@ -0,0 +1,55 @@ +// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module rank_features_doc_iterator; + +import stl; +import index_defines; +import doc_iterator; +import multi_doc_iterator; +import internal_types; + +namespace infinity { + +export class RankFeaturesDocIterator : public MultiDocIterator { +public: + explicit RankFeaturesDocIterator(Vector> &&iterators); + + ~RankFeaturesDocIterator() override; + + DocIteratorType GetType() const override { return DocIteratorType::kRankFeaturesDocIterator; } + + String Name() const override { return "RankFeaturesDocIterator"; } + + void UpdateScoreThreshold(float threshold) override {} + + bool Next(RowID doc_id) override; + + float Score() override { return 1.0f; } + + u32 MatchCount() const override; + + void DecodeFrom(RowID buffer_start_doc_id); + +private: + RowID buffer_start_doc_id_ = INVALID_ROWID; + u32 memset_bytes_ = 0; + void *aligned_buffer_ = nullptr; + u32 *match_cnt_ptr_ = nullptr; + u16 *payload_ptr_ = nullptr; +}; + +} // namespace infinity diff --git a/src/unit_test/storage/invertedindex/search/query_builder.cpp b/src/unit_test/storage/invertedindex/search/query_builder.cpp index fff878a20c..91cb1b3db7 100644 --- a/src/unit_test/storage/invertedindex/search/query_builder.cpp +++ b/src/unit_test/storage/invertedindex/search/query_builder.cpp @@ -201,7 +201,7 @@ TEST_F(QueryBuilderTest, test_and) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_root); FakeQueryBuilder fake_query_builder; @@ -273,7 +273,7 @@ TEST_F(QueryBuilderTest, test_or) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(or_root); FakeQueryBuilder fake_query_builder; @@ -351,7 +351,7 @@ TEST_F(QueryBuilderTest, test_and_not) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; @@ -435,7 +435,7 @@ TEST_F(QueryBuilderTest, test_and_not2) { LOG_INFO(oss.str()); // apply query builder Vector hints; - FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, hints); + FullTextQueryContext context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, RankFeaturesOption{}, 10, hints); context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; diff --git a/src/unit_test/storage/invertedindex/search/query_match.cpp b/src/unit_test/storage/invertedindex/search/query_match.cpp index e27fe1a1dd..e65d76ebca 100644 --- a/src/unit_test/storage/invertedindex/search/query_match.cpp +++ b/src/unit_test/storage/invertedindex/search/query_match.cpp @@ -339,7 +339,12 @@ void QueryMatchTest::QueryMatch(const String &db_name, Status status = Status::ParseMatchExprFailed(match_expr->fields_, match_expr->matching_text_); RecoverableError(status); } - FullTextQueryContext full_text_query_context(FulltextSimilarity::kBM25, BM25Params{}, MinimumShouldMatchOption{}, 10, index_hints); + FullTextQueryContext full_text_query_context(FulltextSimilarity::kBM25, + BM25Params{}, + MinimumShouldMatchOption{}, + RankFeaturesOption{}, + 10, + index_hints); full_text_query_context.early_term_algo_ = EarlyTermAlgo::kNaive; full_text_query_context.query_tree_ = std::move(query_tree); UniquePtr doc_iterator = query_builder.CreateSearch(full_text_query_context);