Skip to content

Commit

Permalink
Refactor thrift server
Browse files Browse the repository at this point in the history
Signed-off-by: jinhai <haijin.chn@gmail.com>
  • Loading branch information
JinHai-CN committed Mar 8, 2024
1 parent 0f8c119 commit 1a80305
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions src/network/thrift_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ using namespace apache::thrift::server;

namespace infinity {

constexpr String kErrorMsgHeader = "THRIFT ERROR";
constexpr String kErrorMsgHeader = "[THRIFT ERROR]";

class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServiceIf {
public:
Expand Down Expand Up @@ -241,7 +241,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
values = nullptr;
}
// Free current value list memory
if(value_list != nullptr) {
if (value_list != nullptr) {
for (auto &value_ptr : *value_list) {
delete value_ptr;
value_ptr = nullptr;
Expand All @@ -250,7 +250,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
value_list = nullptr;
}

if(parsed_expr != nullptr) {
if (parsed_expr != nullptr) {
delete parsed_expr;
parsed_expr = nullptr;
}
Expand Down Expand Up @@ -408,7 +408,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
output_columns = nullptr;
}

if(parsed_expr != nullptr) {
if (parsed_expr != nullptr) {
delete parsed_expr;
parsed_expr = nullptr;
}
Expand All @@ -422,7 +422,6 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
// search expr
SearchExpr *search_expr = nullptr;
if (request.__isset.search_expr) {
Status knn_expr_status;
search_expr = new SearchExpr();
auto search_expr_list = new Vector<ParsedExpr *>();
SizeT knn_expr_count = request.search_expr.knn_exprs.size();
Expand All @@ -431,7 +430,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
SizeT total_expr_count = knn_expr_count + match_expr_count + fusion_expr_exists;
search_expr_list->reserve(total_expr_count);
for (SizeT idx = 0; idx < knn_expr_count; ++idx) {
ParsedExpr *knn_expr = GetKnnExprFromProto(knn_expr_status, request.search_expr.knn_exprs[idx]);
auto [knn_expr, knn_expr_status] = GetKnnExprFromProto(request.search_expr.knn_exprs[idx]);
if (!knn_expr_status.ok()) {

if (output_columns != nullptr) {
Expand All @@ -450,7 +449,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
search_expr_list = nullptr;
}

if(knn_expr != nullptr) {
if (knn_expr != nullptr) {
delete knn_expr;
knn_expr = nullptr;
}
Expand Down Expand Up @@ -612,7 +611,6 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
// search expr
SearchExpr *search_expr = nullptr;
if (request.__isset.search_expr) {
Status knn_expr_status;
search_expr = new SearchExpr();
auto search_expr_list = new Vector<ParsedExpr *>();
SizeT knn_expr_count = request.search_expr.knn_exprs.size();
Expand All @@ -621,7 +619,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
SizeT total_expr_count = knn_expr_count + match_expr_count + fusion_expr_exists;
search_expr_list->reserve(total_expr_count);
for (SizeT idx = 0; idx < knn_expr_count; ++idx) {
ParsedExpr *knn_expr = GetKnnExprFromProto(knn_expr_status, request.search_expr.knn_exprs[idx]);
auto [knn_expr, knn_expr_status] = GetKnnExprFromProto(request.search_expr.knn_exprs[idx]);
if (!knn_expr_status.ok()) {

if (output_columns != nullptr) {
Expand Down Expand Up @@ -1246,11 +1244,11 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
for (auto &args : function_expr.arguments) {
arguments->emplace_back(GetParsedExprFromProto(status, args));
if (!status.ok()) {
if(parsed_expr != nullptr) {
if (parsed_expr != nullptr) {
delete parsed_expr;
parsed_expr = nullptr;
}
if(arguments != nullptr) {
if (arguments != nullptr) {
for (auto &argument : *arguments) {
delete argument;
argument = nullptr;
Expand All @@ -1266,32 +1264,32 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
return parsed_expr;
}

static KnnExpr *GetKnnExprFromProto(Status &status, const infinity_thrift_rpc::KnnExpr &expr) {
static std::tuple<KnnExpr *, Status> GetKnnExprFromProto(const infinity_thrift_rpc::KnnExpr &expr) {
auto knn_expr = new KnnExpr(false);
knn_expr->column_expr_ = GetColumnExprFromProto(expr.column_expr);

knn_expr->distance_type_ = GetDistanceTypeFormProto(expr.distance_type);
if (knn_expr->distance_type_ == KnnDistanceType::kInvalid) {
delete knn_expr;
knn_expr = nullptr;
status = Status::InvalidKnnDistanceType();
return nullptr;
return {nullptr, Status::InvalidKnnDistanceType()};
}
knn_expr->embedding_data_type_ = GetEmbeddingDataTypeFromProto(expr.embedding_data_type);
if (knn_expr->embedding_data_type_ == EmbeddingDataType::kElemInvalid) {
delete knn_expr;
knn_expr = nullptr;
status = Status::InvalidEmbeddingDataType();
return nullptr;
return {nullptr, Status::InvalidEmbeddingDataType()};
}

std::tie(knn_expr->embedding_data_ptr_, knn_expr->dimension_) = GetEmbeddingDataTypeDataPtrFromProto(status, expr.embedding_data);
auto [embedding_data_ptr, dimension, status] = GetEmbeddingDataTypeDataPtrFromProto(expr.embedding_data);
knn_expr->embedding_data_ptr_ = embedding_data_ptr;
knn_expr->dimension_ = dimension;
if (!status.ok()) {
if(knn_expr != nullptr) {
if (knn_expr != nullptr) {
delete knn_expr;
knn_expr = nullptr;
}
return nullptr;
return {nullptr, status};
}

knn_expr->topn_ = expr.topn;
Expand All @@ -1302,7 +1300,7 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
init_parameter->param_value_ = param.param_value;
knn_expr->opt_params_->emplace_back(init_parameter);
}
return knn_expr;
return {knn_expr, status};
}

static MatchExpr *GetMatchExprFromProto(const infinity_thrift_rpc::MatchExpr &expr) {
Expand Down Expand Up @@ -1331,7 +1329,8 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
auto parsed_expr = GetFunctionExprFromProto(status, *expr.type.function_expr);
return parsed_expr;
} else if (expr.type.__isset.knn_expr == true) {
auto parsed_expr = GetKnnExprFromProto(status, *expr.type.knn_expr);
auto [parsed_expr, knn_expr_status] = GetKnnExprFromProto(*expr.type.knn_expr);
status = knn_expr_status;
return parsed_expr;
} else if (expr.type.__isset.match_expr == true) {
auto parsed_expr = GetMatchExprFromProto(*expr.type.match_expr);
Expand Down Expand Up @@ -1381,28 +1380,27 @@ class InfinityServiceHandler : virtual public infinity_thrift_rpc::InfinityServi
}
}

static std::pair<void *, int64_t> GetEmbeddingDataTypeDataPtrFromProto(Status &status, const infinity_thrift_rpc::EmbeddingData &embedding_data) {
static std::tuple<void *, int64_t, Status> GetEmbeddingDataTypeDataPtrFromProto(const infinity_thrift_rpc::EmbeddingData &embedding_data) {
if (embedding_data.__isset.i8_array_value) {
return std::make_pair((void *)embedding_data.i8_array_value.data(), embedding_data.i8_array_value.size());
return {(void *)embedding_data.i8_array_value.data(), embedding_data.i8_array_value.size(), Status::OK()};
} else if (embedding_data.__isset.i16_array_value) {
return std::make_pair((void *)embedding_data.i16_array_value.data(), embedding_data.i16_array_value.size());
return {(void *)embedding_data.i16_array_value.data(), embedding_data.i16_array_value.size(), Status::OK()};
} else if (embedding_data.__isset.i32_array_value) {
return std::make_pair((void *)embedding_data.i32_array_value.data(), embedding_data.i32_array_value.size());
return {(void *)embedding_data.i32_array_value.data(), embedding_data.i32_array_value.size(), Status::OK()};
} else if (embedding_data.__isset.i64_array_value) {
return std::make_pair((void *)embedding_data.i64_array_value.data(), embedding_data.i64_array_value.size());
return {(void *)embedding_data.i64_array_value.data(), embedding_data.i64_array_value.size(), Status::OK()};
} else if (embedding_data.__isset.f32_array_value) {
auto ptr_double = (double *)(embedding_data.f32_array_value.data());
auto ptr_float = (float *)(embedding_data.f32_array_value.data());
for (size_t i = 0; i < embedding_data.f32_array_value.size(); ++i) {
ptr_float[i] = float(ptr_double[i]);
}
return std::make_pair((void *)embedding_data.f32_array_value.data(), embedding_data.f32_array_value.size());
return {(void *)embedding_data.f32_array_value.data(), embedding_data.f32_array_value.size(), Status::OK()};
} else if (embedding_data.__isset.f64_array_value) {
return std::make_pair((void *)embedding_data.f64_array_value.data(), embedding_data.f64_array_value.size());
return {(void *)embedding_data.f64_array_value.data(), embedding_data.f64_array_value.size(), Status::OK()};
} else {
status = Status::InvalidEmbeddingDataType();
return {nullptr, 0, Status::InvalidEmbeddingDataType()};
}
return std::make_pair(nullptr, 0);
}

static Tuple<UpdateExpr *, Status> GetUpdateExprFromProto(const infinity_thrift_rpc::UpdateExpr &update_expr) {
Expand Down

0 comments on commit 1a80305

Please sign in to comment.