Skip to content

Commit

Permalink
Support rank features query:part1 (#2487)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Insert:
```
`json` string is required for rank features column: 
[{"Tag1":0.1},{"Tag2":0.2}]
```

Query: 
```
SELECT * FROM table SEARCH MATCH TEXT ('doc', 'second text multiple', 'rank_features= field^tag1^1.0,field^tag2^1.0;topn=10');
```

Issue link:#2309

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
yingfeng authored Jan 22, 2025
1 parent ef36c4f commit bb2bc8d
Show file tree
Hide file tree
Showing 30 changed files with 639 additions and 32 deletions.
20 changes: 16 additions & 4 deletions src/common/analyzer/analyzer.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,23 @@ public:
}

protected:
typedef void (*HookType)(void *data, const char *text, const u32 len, const u32 offset, const u32 end_offset, const bool is_special_char);
typedef void (*HookType)(void *data,
const char *text,
const u32 len,
const u32 offset,
const u32 end_offset,
const bool is_special_char,
const u16 payload);

virtual int AnalyzeImpl(const Term &input, void *data, HookType func) { return -1; }

static void AppendTermList(void *data, const char *text, const u32 len, const u32 offset, const u32 end_offset, const bool is_special_char) {
static void AppendTermList(void *data,
const char *text,
const u32 len,
const u32 offset,
const u32 end_offset,
const bool is_special_char,
const u16 payload) {
void **parameters = (void **)data;
TermList *output = (TermList *)parameters[0];
Analyzer *analyzer = (Analyzer *)parameters[1];
Expand All @@ -62,9 +74,9 @@ protected:
return;
if (is_special_char && analyzer->convert_to_placeholder_) {
if (output->empty() == true || output->back().text_.compare(PLACE_HOLDER) != 0)
output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset);
output->Add(PLACE_HOLDER.c_str(), PLACE_HOLDER.length(), offset, end_offset, payload);
} else {
output->Add(text, len, offset, end_offset);
output->Add(text, len, offset, end_offset, payload);
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/common/analyzer/analyzer_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import ngram_analyzer;
import rag_analyzer;
import whitespace_analyzer;
import ik_analyzer;
import rank_features_analyzer;
import logger;

namespace infinity {
Expand Down Expand Up @@ -330,6 +331,9 @@ Tuple<UniquePtr<Analyzer>, Status> AnalyzerPool::GetAnalyzer(const std::string_v
}
return {MakeUnique<WhitespaceAnalyzer>(name.substr(suffix_pos + 1)), Status::OK()};
}
case Str2Int(RANKFEATURES.data()): {
return {MakeUnique<RankFeaturesAnalyzer>(), Status::OK()};
}
default: {
if (std::filesystem::is_regular_file(name)) {
// Suppose it is a customized Python script analyzer
Expand Down
1 change: 1 addition & 0 deletions src/common/analyzer/analyzer_pool.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public:
static constexpr std::string_view IK = "ik";
static constexpr std::string_view KEYWORD = "keyword";
static constexpr std::string_view WHITESPACE = "whitespace";
static constexpr std::string_view RANKFEATURES = "rankfeatures";

private:
CacheType cache_{};
Expand Down
20 changes: 10 additions & 10 deletions src/common/analyzer/common_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ int CommonLanguageAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType

if (is_index_) {
if (IsSpecialChar()) {
func(data, token_, len_, offset_, end_offset_, true);
func(data, token_, len_, offset_, end_offset_, true, 0);
temp_offset = offset_;
continue;
}
if (is_raw_) {
func(data, token_, len_, offset_, end_offset_, false);
func(data, token_, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
continue;
}
Expand All @@ -77,37 +77,37 @@ int CommonLanguageAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType
bool lowercase_is_different = memcmp(token_, lowercase_term, len_) != 0;

if (stemming_term_str_size && stem_only_) {
func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false);
func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false, 0);
temp_offset = offset_;
} else if (stemming_term_str_size || (case_sensitive_ && contain_lower_ && lowercase_is_different)) {
/// have more than one output
if (case_sensitive_) {
func(data, token_, len_, offset_, end_offset_, false);
func(data, token_, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
} else {
func(data, lowercase_term, len_, offset_, end_offset_, false);
func(data, lowercase_term, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
}
if (stemming_term_str_size) {
func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false);
func(data, stem_term.c_str(), stemming_term_str_size, offset_, end_offset_, false, 0);
temp_offset = offset_;
}
if (case_sensitive_ && contain_lower_ && lowercase_is_different) {
func(data, lowercase_term, len_, offset_, end_offset_, false);
func(data, lowercase_term, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
}
} else {
/// have only one output
if (case_sensitive_) {
func(data, token_, len_, offset_, end_offset_, false);
func(data, token_, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
} else {
func(data, lowercase_term, len_, offset_, end_offset_, false);
func(data, lowercase_term, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
}
}
} else {
func(data, token_, len_, offset_, end_offset_, false);
func(data, token_, len_, offset_, end_offset_, false, 0);
temp_offset = offset_;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/analyzer/ik/ik_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ int IKAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
while ((lexeme = context_->GetNextLexeme()) != nullptr) {
std::wstring text = lexeme->GetLexemeText();
String token = CharacterUtil::UTF16ToUTF8(text);
func(data, token.c_str(), token.size(), offset++, 0, false);
func(data, token.c_str(), token.size(), offset++, 0, false, 0);
delete lexeme;
};
return 0;
Expand Down
2 changes: 1 addition & 1 deletion src/common/analyzer/ngram_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ int NGramAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
while (cur < len && NextInString(input.text_.c_str(), len, &cur, &token_start, &token_length)) {
if (token_length == 0)
continue;
func(data, input.text_.c_str() + token_start, token_length, offset, offset + token_length, false);
func(data, input.text_.c_str() + token_start, token_length, offset, offset + token_length, false, 0);
offset++;
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/analyzer/rag_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
Split(output, blank_pattern_, tokens);
unsigned offset = 0;
for (auto &t : tokens) {
func(data, t.c_str(), t.size(), offset++, 0, false);
func(data, t.c_str(), t.size(), offset++, 0, false, 0);
}
return 0;
}
Expand Down
52 changes: 52 additions & 0 deletions src/common/analyzer/rank_features_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

#include <string>
module rank_features_analyzer;
import stl;
import term;
import analyzer;
import third_party;

namespace infinity {

u16 FloatToU16(float value) {
if (value < 0.0f)
value = 0.0f;
if (value > 65535.0f)
value = 65535.0f;
return static_cast<u16>(value);
}

int RankFeaturesAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
nlohmann::json line_json = nlohmann::json::parse(input.text_);
u32 offset = 0;
for (auto iter = line_json.begin(); iter != line_json.end(); ++iter) {
String key = iter.key();
String value = iter.value();
float target = 0;
try {
target = std::stof(value);
} catch (const std::exception &e) {
}
u16 weight = FloatToU16(target);
func(data, key.data(), key.size(), offset++, 0, false, weight);
}

return 0;
}

} // namespace infinity
35 changes: 35 additions & 0 deletions src/common/analyzer/rank_features_analyzer.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module rank_features_analyzer;
import stl;
import term;
import analyzer;

namespace infinity {

export class RankFeaturesAnalyzer : public Analyzer {
String delimiters_{};

public:
RankFeaturesAnalyzer() = default;
~RankFeaturesAnalyzer() override = default;

protected:
int AnalyzeImpl(const Term &input, void *data, HookType func) override;
};

} // namespace infinity
9 changes: 6 additions & 3 deletions src/common/analyzer/term.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ public:

export class TermList : public Deque<Term> {
public:
void Add(const char *text, const u32 len, const u32 offset, const u32 end_offset) {
void Add(const char *text, const u32 len, const u32 offset, const u32 end_offset, const u16 payload = 0) {
push_back(global_temporary_);
back().text_.assign(text, len);
back().word_offset_ = offset;
back().end_offset_ = end_offset;
back().payload_ = payload;
}

void Add(cppjieba::Word &cut_word) {
Expand All @@ -54,18 +55,20 @@ public:
back().word_offset_ = cut_word.offset;
}

void Add(const String &token, const u32 offset, const u32 end_offset) {
void Add(const String &token, const u32 offset, const u32 end_offset, const u16 payload = 0) {
push_back(global_temporary_);
back().text_ = token;
back().word_offset_ = offset;
back().end_offset_ = end_offset;
back().payload_ = payload;
}

void Add(String &token, const u32 offset, const u32 end_offset) {
void Add(String &token, const u32 offset, const u32 end_offset, const u16 payload = 0) {
push_back(global_temporary_);
std::swap(back().text_, token);
back().word_offset_ = offset;
back().end_offset_ = end_offset;
back().payload_ = payload;
}

private:
Expand Down
6 changes: 3 additions & 3 deletions src/common/analyzer/whitespace_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ int WhitespaceAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func
std::string t;
u32 offset = 0;
while (is >> t) {
func(data, t.data(), t.size(), offset++, 0, false);
func(data, t.data(), t.size(), offset++, 0, false, 0);
}
return 0;
} else {
Expand All @@ -49,11 +49,11 @@ int WhitespaceAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func
while (search_start < input_text.size()) {
const auto found = input_text.find_first_of(delimiters, search_start);
if (found == std::string_view::npos) {
func(data, input_text.data() + search_start, input_text.size() - search_start, offset++, 0, false);
func(data, input_text.data() + search_start, input_text.size() - search_start, offset++, 0, false, 0);
break;
}
if (found > search_start) {
func(data, input_text.data() + search_start, found - search_start, offset++, 0, false);
func(data, input_text.data() + search_start, found - search_start, offset++, 0, false, 0);
}
search_start = found + 1;
}
Expand Down
11 changes: 9 additions & 2 deletions src/executor/operator/physical_match.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ bool PhysicalMatch::ExecuteInner(QueryContext *query_context, OperatorState *ope
static_cast<TimeDurationType>(finish_init_query_builder_time - execute_start_time).count()));

// 2 build query iterator
FullTextQueryContext full_text_query_context(ft_similarity_, bm25_params_, minimum_should_match_option_, top_n_, match_expr_->index_names_);
FullTextQueryContext full_text_query_context(ft_similarity_,
bm25_params_,
minimum_should_match_option_,
rank_features_option_,
top_n_,
match_expr_->index_names_);
full_text_query_context.query_tree_ = MakeUnique<FilterQueryNode>(common_query_filter_.get(), std::move(query_tree_));
const auto query_iterators = CreateQueryIterators(query_builder, full_text_query_context, early_term_algo_, begin_threshold_, score_threshold_);
const auto finish_query_builder_time = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -329,6 +334,7 @@ PhysicalMatch::PhysicalMatch(const u64 id,
const u32 top_n,
const SharedPtr<CommonQueryFilter> &common_query_filter,
MinimumShouldMatchOption &&minimum_should_match_option,
RankFeaturesOption &&rank_features_option,
const f32 score_threshold,
const FulltextSimilarity ft_similarity,
const BM25Params &bm25_params,
Expand All @@ -339,7 +345,8 @@ PhysicalMatch::PhysicalMatch(const u64 id,
base_table_ref_(std::move(base_table_ref)), match_expr_(std::move(match_expr)), index_reader_(std::move(index_reader)),
query_tree_(std::move(query_tree)), begin_threshold_(begin_threshold), early_term_algo_(early_term_algo), top_n_(top_n),
common_query_filter_(common_query_filter), minimum_should_match_option_(std::move(minimum_should_match_option)),
score_threshold_(score_threshold), ft_similarity_(ft_similarity), bm25_params_(bm25_params) {}
rank_features_option_(std::move(rank_features_option)), score_threshold_(score_threshold), ft_similarity_(ft_similarity),
bm25_params_(bm25_params) {}

PhysicalMatch::~PhysicalMatch() = default;

Expand Down
3 changes: 3 additions & 0 deletions src/executor/operator/physical_match.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public:
u32 top_n,
const SharedPtr<CommonQueryFilter> &common_query_filter,
MinimumShouldMatchOption &&minimum_should_match_option,
RankFeaturesOption &&rank_features_option,
f32 score_threshold,
FulltextSimilarity ft_similarity,
const BM25Params &bm25_params,
Expand Down Expand Up @@ -112,6 +113,8 @@ private:
SharedPtr<CommonQueryFilter> common_query_filter_;
// for minimum_should_match
MinimumShouldMatchOption minimum_should_match_option_{};
// for rank features
RankFeaturesOption rank_features_option_{};
f32 score_threshold_{};
FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25};
BM25Params bm25_params_;
Expand Down
1 change: 1 addition & 0 deletions src/executor/physical_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ UniquePtr<PhysicalOperator> PhysicalPlanner::BuildMatch(const SharedPtr<LogicalN
logical_match->top_n_,
logical_match->common_query_filter_,
std::move(logical_match->minimum_should_match_option_),
std::move(logical_match->rank_features_option_),
logical_match->score_threshold_,
logical_match->ft_similarity_,
logical_match->bm25_params_,
Expand Down
5 changes: 5 additions & 0 deletions src/planner/bound_select_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ SharedPtr<LogicalNode> BoundSelectStatement::BuildPlan(QueryContext *query_conte
match_node->minimum_should_match_option_ = ParseMinimumShouldMatchOption(iter->second);
}

// option: rank_features
if (iter = search_ops.options_.find("rank_features"); iter != search_ops.options_.end()) {
match_node->rank_features_option_ = ParseRankFeaturesOption(iter->second);
}

// option: threshold
if (iter = search_ops.options_.find("threshold"); iter != search_ops.options_.end()) {
match_node->score_threshold_ = DataType::StringToValue<FloatT>(iter->second);
Expand Down
1 change: 1 addition & 0 deletions src/planner/node/logical_match.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public:

SharedPtr<CommonQueryFilter> common_query_filter_{};
MinimumShouldMatchOption minimum_should_match_option_{};
RankFeaturesOption rank_features_option_{};
f32 score_threshold_{};
FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25};
BM25Params bm25_params_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator {
UniquePtr<QueryNode> query_tree_;
MinimumShouldMatchOption minimum_should_match_option_;
u32 minimum_should_match_ = 0;
RankFeaturesOption rank_features_option_;
std::atomic_flag after_optimize_ = {};
f32 score_threshold_ = {};
FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25;
Expand Down
2 changes: 2 additions & 0 deletions src/storage/invertedindex/search/doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ export enum class DocIteratorType : u8 {
kScoreThresholdIterator,
kKeywordIterator,
kMustFirstIterator,
kRankFeatureDocIterator,
kRankFeaturesDocIterator,
};

export struct DocIteratorEstimateIterateCost {
Expand Down
Loading

0 comments on commit bb2bc8d

Please sign in to comment.