Skip to content

Commit

Permalink
Add distance and similarity metric as output in KNN search (#1260)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

1. Python SDK can't use _distance to output the L2 distance.
2. Use '_similarity' to represents the IP metric.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
  • Loading branch information
JinHai-CN authored May 30, 2024
1 parent 91ec6de commit 67fdd71
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 38 deletions.
36 changes: 18 additions & 18 deletions python/hello_infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_english():

res = (
table.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
)

pds_df = pds.DataFrame(res)
Expand All @@ -70,8 +70,8 @@ def test_english():
table_obj = db.get_table("my_table")
qb_result = (
table_obj.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
)
print("------tabular -------")
print("------vector-------")
Expand All @@ -85,10 +85,10 @@ def test_english():

qb_result2 = (
table_obj.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.match("body", "blooms", "topn=1")
.fusion("rrf")
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.match("body", "blooms", "topn=1")
.fusion("rrf")
.to_pl()
)
print("------vector+fulltext-------")
print(qb_result2)
Expand Down Expand Up @@ -171,8 +171,8 @@ def test_chinese():
print("------json-------")
res = (
table.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
)
pds_df = pds.DataFrame(res)
json_data = pds_df.to_json()
Expand All @@ -183,8 +183,8 @@ def test_chinese():
print("------vector-------")
qb_result = (
table_obj.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
)
print(qb_result)

Expand All @@ -199,8 +199,8 @@ def test_chinese():
for question in questions:
qb_result = (
table_obj.output(["num", "body", "_score"])
.match("body", question, "topn=10")
.to_pl()
.match("body", question, "topn=10")
.to_pl()
)
print(f"question: {question}")
print(qb_result)
Expand All @@ -209,10 +209,10 @@ def test_chinese():
for question in questions:
qb_result = (
table_obj.output(["num", "body", "_score"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 10)
.match("body", question, "topn=10")
.fusion("rrf")
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 10)
.match("body", question, "topn=10")
.fusion("rrf")
.to_pl()
)
print(f"question: {question}")
print(qb_result)
Expand Down
5 changes: 5 additions & 0 deletions python/infinity/remote_thrift/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ def output(self, columns: Optional[list]) -> InfinityThriftQueryBuilder:
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
select_list.append(parsed_expr)
case "_similarity":
func_expr = FunctionExpr(function_name="similarity", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
select_list.append(parsed_expr)
case "_distance":
func_expr = FunctionExpr(function_name="distance", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
Expand Down
2 changes: 1 addition & 1 deletion python/infinity/remote_thrift/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]):
literal_type=ttypes.LiteralType.DoubleTensorArray,
f64_tensor_array_value=value)
else:
raise InfinityException(3069, "Invalid constant expression")
raise InfinityException(3069, f"Invalid constant expression: {type(value)}")

expr_type = ttypes.ParsedExprType(
constant_expr=constant_expression)
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "infinity_sdk"
version = "0.2.0.dev2"
version = "0.2.0.dev3"
dependencies = [
"sqlglot~=11.7.1",
"pydantic~=2.7.1",
Expand Down
6 changes: 3 additions & 3 deletions python/test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_knn_on_vector_column(self, get_infinity_db, check_data, column_name):
copy_data("tmp_20240116.csv")
test_csv_dir = "/var/infinity/test_data/tmp_20240116.csv"
table_obj.import_data(test_csv_dir, None)
res = table_obj.output(["variant_id", "_row_id", "_distance"]).knn(
res = table_obj.output(["variant_id", "_row_id", "_similarity"]).knn(
column_name, [1.0] * 4, "float", "ip", 2).to_pl()
print(res)

Expand Down Expand Up @@ -302,12 +302,12 @@ def test_valid_embedding_data_type(self, get_infinity_db, check_data, embedding_
test_csv_dir = "/var/infinity/test_data/tmp_20240116.csv"
table_obj.import_data(test_csv_dir, None)
if embedding_data_type[1]:
res = table_obj.output(["variant_id"]).knn("gender_vector", embedding_data, embedding_data_type[0],
res = table_obj.output(["variant_id", "_distance"]).knn("gender_vector", embedding_data, embedding_data_type[0],
"l2",
2).to_pl()
print(res)
else:
res = table_obj.output(["variant_id"]).knn("gender_vector", embedding_data, embedding_data_type[0],
res = table_obj.output(["variant_id", "_similarity"]).knn("gender_vector", embedding_data, embedding_data_type[0],
"ip",
2).to_pl()

Expand Down
11 changes: 7 additions & 4 deletions src/function/builtin_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ void BuiltinFunctions::RegisterSpecialFunction() {
SharedPtr<SpecialFunction> row_function = MakeShared<SpecialFunction>("ROW_ID", DataType(LogicalType::kBigInt), 1, SpecialType::kRowID);
Catalog::AddSpecialFunction(catalog_ptr_.get(), row_function);

SharedPtr<SpecialFunction> create_ts_function = MakeShared<SpecialFunction>("DISTANCE", DataType(LogicalType::kFloat), 2, SpecialType::kDistance);
Catalog::AddSpecialFunction(catalog_ptr_.get(), create_ts_function);
SharedPtr<SpecialFunction> distance_function = MakeShared<SpecialFunction>("DISTANCE", DataType(LogicalType::kFloat), 2, SpecialType::kDistance);
Catalog::AddSpecialFunction(catalog_ptr_.get(), distance_function);

SharedPtr<SpecialFunction> delete_ts_function = MakeShared<SpecialFunction>("SCORE", DataType(LogicalType::kFloat), 3, SpecialType::kScore);
Catalog::AddSpecialFunction(catalog_ptr_.get(), delete_ts_function);
SharedPtr<SpecialFunction> similarity_function = MakeShared<SpecialFunction>("SIMILARITY", DataType(LogicalType::kFloat), 3, SpecialType::kSimilarity);
Catalog::AddSpecialFunction(catalog_ptr_.get(), similarity_function);

SharedPtr<SpecialFunction> score_function = MakeShared<SpecialFunction>("SCORE", DataType(LogicalType::kFloat), 4, SpecialType::kScore);
Catalog::AddSpecialFunction(catalog_ptr_.get(), score_function);
}

} // namespace infinity
1 change: 1 addition & 0 deletions src/function/special_function.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace infinity {
export enum class SpecialType {
kRowID,
kDistance,
kSimilarity,
kScore,
kCreateTs,
kDeleteTs,
Expand Down
1 change: 1 addition & 0 deletions src/parser/type/data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ std::shared_ptr<DataType> DataType::Deserialize(const nlohmann::json &data_type_
}
case LogicalType::kSparse: {
type_info = SparseInfo::Deserialize(type_info_json);
break;
}
default:
// There's no type_info for other types
Expand Down
39 changes: 39 additions & 0 deletions src/planner/bind_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import column_identifer;
import block_index;
import column_expr;
import logger;
import knn_expr;

namespace infinity {

Expand Down Expand Up @@ -368,6 +369,44 @@ const Binding *BindContext::GetBindingFromCurrentOrParentByName(const String &bi
return binding_iter->second.get();
}

void BindContext::BoundSearch(ParsedExpr *expr) {
if (expr == nullptr) {
return;
}
auto search_expr = (SearchExpr *)expr;

if(!search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty()) {
SizeT expr_count = search_expr->knn_exprs_.size();
KnnExpr* first_knn = search_expr->knn_exprs_[0];
KnnDistanceType first_distance_type = first_knn->distance_type_;
for(SizeT idx = 1; idx < expr_count; ++ idx) {
if(search_expr->knn_exprs_[idx]->distance_type_ != first_distance_type) {
// Mixed distance type
return ;
}
}
switch(first_distance_type) {
case KnnDistanceType::kL2:
case KnnDistanceType::kHamming: {
allow_distance = true;
break;
}
case KnnDistanceType::kInnerProduct:
case KnnDistanceType::kCosine: {
allow_similarity = true;
break;
}
default: {
String error_message = "Invalid KNN metric type";
LOG_ERROR(error_message);
UnrecoverableError(error_message);
}
}
}

allow_score = !search_expr->match_exprs_.empty() || !search_expr->match_tensor_exprs_.empty() || !(search_expr->fusion_exprs_.empty());
}

// void
// BindContext::AddChild(const SharedPtr<BindContext>& child) {
// child->binding_context_id_ = GenerateBindingContextIndex();
Expand Down
11 changes: 2 additions & 9 deletions src/planner/bind_context.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ public:
bool single_row = false;

bool allow_distance = false;
bool allow_similarity = false;
bool allow_score = false;

public:
Expand Down Expand Up @@ -166,15 +167,7 @@ public:

void BoundTable(const String &table_name) { bound_table_set_.insert(table_name); }

void BoundSearch(ParsedExpr *expr) {
if (expr == nullptr) {
return;
}
auto search_expr = (SearchExpr *)expr;

allow_distance = !search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty();
allow_score = !search_expr->match_exprs_.empty() || !search_expr->match_tensor_exprs_.empty() || !(search_expr->fusion_exprs_.empty());
}
void BoundSearch(ParsedExpr *expr);

void AddSubqueryBinding(const String &name,
u64 table_index,
Expand Down
10 changes: 9 additions & 1 deletion src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,15 @@ Optional<SharedPtr<BaseExpression>> ExpressionBinder::TryBuildSpecialFuncExpr(co
switch (special_function_ptr->special_type()) {
case SpecialType::kDistance: {
if (!bind_context_ptr->allow_distance) {
Status status = Status::SyntaxError("DISTANCE() needs to be allowed only when there is only MATCH VECTOR");
Status status = Status::SyntaxError("DISTANCE() needs to be allowed only when there is only MATCH VECTOR with distance metrics, like L2");
LOG_ERROR(status.message());
RecoverableError(status);
}
break;
}
case SpecialType::kSimilarity: {
if (!bind_context_ptr->allow_similarity) {
Status status = Status::SyntaxError("SIMILARITY() needs to be allowed only when there is only MATCH VECTOR with similarity metrics, like Inner product");
LOG_ERROR(status.message());
RecoverableError(status);
}
Expand Down
1 change: 1 addition & 0 deletions src/planner/optimizer/column_remapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ SharedPtr<BaseExpression> BindingRemapper::VisitReplace(const SharedPtr<ColumnEx
column_cnt_ - 1);
}
case SpecialType::kScore:
case SpecialType::kSimilarity:
case SpecialType::kDistance: {
return ReferenceExpression::Make(expression->Type(),
expression->table_name(),
Expand Down
5 changes: 4 additions & 1 deletion test/sql/dql/knn/test_knn_ip.slt
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ SELECT c2 FROM test_knn_ip SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float
0.2,0.1,0.3,0.4

query II
SELECT c1, ROW_ID(), DISTANCE() FROM test_knn_ip SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'ip', 3);
SELECT c1, ROW_ID(), SIMILARITY() FROM test_knn_ip SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'ip', 3);
----
8 3 0.270000
6 2 0.250000
4 1 0.230000

statement error
SELECT c1, ROW_ID(), DISTANCE() FROM test_knn_l2 SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'ip', 3);

# copy to create another new block
# there will has 2 knn_scan operator to scan the blocks, and one merge_knn to merge
statement ok
Expand Down
3 changes: 3 additions & 0 deletions test/sql/dql/knn/test_knn_l2.slt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ SELECT c1, ROW_ID(), DISTANCE() FROM test_knn_l2 SEARCH MATCH VECTOR (c2, [0.3,
6 2 0.060000
4 1 0.100000

statement error
SELECT c1, ROW_ID(), SIMILARITY() FROM test_knn_l2 SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'l2', 3);

# copy to create another new block
# there will has 2 knn_scan operator to scan the blocks, and one merge_knn to merge
statement ok
Expand Down

0 comments on commit 67fdd71

Please sign in to comment.