Skip to content

Commit ed87d1a

Browse files
authored
Fix: invalid top-n value in knn which cause core dumped. (infiniflow#806)
### What problem does this PR solve? Add range check for TOP-N in KNN, now if non-positive number is passed, throw a `RecoverableError`. Issue link:infiniflow#776 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
1 parent 3510cfb commit ed87d1a

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

python/test/test_knn.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def test_various_distance_type(self, get_infinity_db, check_data, embedding_data
255255
@pytest.mark.parametrize("check_data", [{"file_name": "tmp_20240116.csv",
256256
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
257257
@pytest.mark.parametrize("topn", [
258-
2,
259-
10,
260-
# FIXME 0, ERROR: AddressSanitizer: heap-buffer-overflow on address 0x502000009bf0 at pc 0x5bdb2772d2a4 bp 0x7d8ea9efc9d0 sp 0x7d8ea9efc9c8
261-
# FIXME -1, exceeds maximum supported size of 0x10000000000
258+
(2, True),
259+
(10, True),
260+
(0, False),
261+
(-1, False),
262262
pytest.param("word", marks=pytest.mark.skip(reason="struct.error: required argument is not an integer")),
263263
pytest.param({}, marks=pytest.mark.skip(reason="struct.error: required argument is not an integer")),
264264
pytest.param((), marks=pytest.mark.skip(reason="struct.error: required argument is not an integer")),
@@ -283,5 +283,9 @@ def test_various_topn(self, get_infinity_db, check_data, topn):
283283
copy_data("tmp_20240116.csv")
284284
test_csv_dir = "/tmp/infinity/test_data/tmp_20240116.csv"
285285
table_obj.import_data(test_csv_dir, None)
286-
res = table_obj.output(["variant_id"]).knn("gender_vector", [1] * 4, "float", "pl", topn).to_pl()
287-
print(res)
286+
if topn[1]:
287+
res = table_obj.output(["variant_id"]).knn("gender_vector", [1] * 4, "float", "pl", topn[0]).to_pl()
288+
print(res)
289+
else:
290+
with pytest.raises(Exception, match="ERROR:3014*"):
291+
res = table_obj.output(["variant_id"]).knn("gender_vector", [1] * 4, "float", "pl", topn[0]).to_pl()

src/network/thrift_server.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,11 @@ class InfinityServiceHandler final : public infinity_thrift_rpc::InfinityService
406406
knn_expr = nullptr;
407407
}
408408

409+
if (search_expr != nullptr) {
410+
delete search_expr;
411+
search_expr = nullptr;
412+
}
413+
409414
ProcessStatus(response, knn_expr_status);
410415
return;
411416
}
@@ -1165,6 +1170,13 @@ class InfinityServiceHandler final : public infinity_thrift_rpc::InfinityService
11651170
}
11661171

11671172
knn_expr->topn_ = expr.topn;
1173+
if (knn_expr->topn_ <= 0) {
1174+
delete knn_expr;
1175+
knn_expr = nullptr;
1176+
String topn = std::to_string(expr.topn);
1177+
return {nullptr, Status::InvalidParameterValue("topn", topn, "topn should be greater than 0")};
1178+
}
1179+
11681180
knn_expr->opt_params_ = new Vector<InitParameter *>();
11691181
for (auto &param : expr.opt_params) {
11701182
auto init_parameter = new InitParameter();

src/planner/expression_binder.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,10 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildKnnExpr(const KnnExpr &parsed_k
454454
if (parsed_knn_expr.column_expr_->type_ != ParsedExprType::kColumn) {
455455
UnrecoverableError("Knn expression expect a column expression");
456456
}
457+
if (parsed_knn_expr.topn_ <= 0) {
458+
String topn = std::to_string(parsed_knn_expr.topn_);
459+
RecoverableError(Status::InvalidParameterValue("topn", topn, "topn should be greater than 0"));
460+
}
457461
auto expr_ptr = BuildColExpr((ColumnExpr &)*parsed_knn_expr.column_expr_, bind_context_ptr, depth, false);
458462
TypeInfo *type_info = expr_ptr->Type().type_info().get();
459463
if (type_info == nullptr or type_info->type() != TypeInfoType::kEmbedding) {

0 commit comments

Comments
 (0)