diff --git a/engine/core/type.cc b/engine/core/type.cc index a769ac6..c302d88 100644 --- a/engine/core/type.cc +++ b/engine/core/type.cc @@ -101,83 +101,83 @@ std::shared_ptr ToArrowDataType(pb::PrimitiveDataType dtype) { return dt; } -spu::PtType ArrowDataTypeToSpuPtType( +spu::pb::PtType Arrowspu::PtType( const std::shared_ptr& dtype) { - spu::PtType pt; + spu::pb::PtType pt; switch (dtype->id()) { case arrow::Type::BOOL: - pt = spu::PT_I1; + pt = spu::pb::PT_I1; break; case arrow::Type::UINT8: - pt = spu::PT_U8; + pt = spu::pb::PT_U8; break; case arrow::Type::INT8: - pt = spu::PT_I8; + pt = spu::pb::PT_I8; break; case arrow::Type::UINT16: - pt = spu::PT_U16; + pt = spu::pb::PT_U16; break; case arrow::Type::INT16: - pt = spu::PT_I16; + pt = spu::pb::PT_I16; break; case arrow::Type::UINT32: - pt = spu::PT_U32; + pt = spu::pb::PT_U32; break; case arrow::Type::INT32: - pt = spu::PT_I32; + pt = spu::pb::PT_I32; break; case arrow::Type::UINT64: - pt = spu::PT_U64; + pt = spu::pb::PT_U64; break; case arrow::Type::INT64: - pt = spu::PT_I64; + pt = spu::pb::PT_I64; break; case arrow::Type::FLOAT: - pt = spu::PT_F32; + pt = spu::pb::PT_F32; break; case arrow::Type::DOUBLE: - pt = spu::PT_F64; + pt = spu::pb::PT_F64; break; default: - pt = spu::PT_INVALID; + pt = spu::pb::PT_INVALID; } return pt; } -std::shared_ptr SpuPtTypeToArrowDataType(spu::PtType pt_type) { +std::shared_ptr SpuPtTypeToArrowDataType(spu::pb::PtType pt_type) { std::shared_ptr dt; switch (pt_type) { - case spu::PT_I8: + case spu::pb::PT_I8: dt = arrow::int8(); break; - case spu::PT_U8: + case spu::pb::PT_U8: dt = arrow::uint8(); break; - case spu::PT_I16: + case spu::pb::PT_I16: dt = arrow::int16(); break; - case spu::PT_U16: + case spu::pb::PT_U16: dt = arrow::uint16(); break; - case spu::PT_I32: + case spu::pb::PT_I32: dt = arrow::int32(); break; - case spu::PT_U32: + case spu::pb::PT_U32: dt = arrow::uint32(); break; - case spu::PT_I64: + case spu::pb::PT_I64: dt = arrow::int64(); break; - case spu::PT_U64: + case spu::pb::PT_U64: dt = arrow::uint64(); break; - case spu::PT_F32: + case spu::pb::PT_F32: dt = arrow::float32(); break; - case spu::PT_F64: + case spu::pb::PT_F64: dt = arrow::float64(); break; - case spu::PT_I1: + case spu::pb::PT_I1: dt = arrow::boolean(); break; default: @@ -186,36 +186,36 @@ std::shared_ptr SpuPtTypeToArrowDataType(spu::PtType pt_type) { return dt; } -spu::PtType DataTypeToSpuPtType(pb::PrimitiveDataType dtype) { - spu::PtType pt; +spu::pb::PtType spu::PtType(pb::PrimitiveDataType dtype) { + spu::pb::PtType pt; switch (dtype) { case pb::PrimitiveDataType::INT8: - pt = spu::PT_I8; + pt = spu::pb::PT_I8; break; case pb::PrimitiveDataType::INT16: - pt = spu::PT_I16; + pt = spu::pb::PT_I16; break; case pb::PrimitiveDataType::INT32: - pt = spu::PT_I32; + pt = spu::pb::PT_I32; break; case pb::PrimitiveDataType::INT64: case pb::PrimitiveDataType::DATETIME: case pb::PrimitiveDataType::TIMESTAMP: - pt = spu::PT_I64; + pt = spu::pb::PT_I64; break; case pb::PrimitiveDataType::BOOL: - pt = spu::PT_I1; + pt = spu::pb::PT_I1; break; case pb::PrimitiveDataType::FLOAT32: - pt = spu::PT_F32; + pt = spu::pb::PT_F32; break; case pb::PrimitiveDataType::FLOAT64: - pt = spu::PT_F64; + pt = spu::pb::PT_F64; break; default: - pt = spu::PT_INVALID; + pt = spu::pb::PT_INVALID; } return pt; } -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/core/type.h b/engine/core/type.h index fac0522..c029f7e 100644 --- a/engine/core/type.h +++ b/engine/core/type.h @@ -17,7 +17,7 @@ #include "arrow/type.h" #include "api/core.pb.h" -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace scql::engine { enum class Visibility { @@ -37,16 +37,16 @@ pb::PrimitiveDataType FromArrowDataType( std::shared_ptr ToArrowDataType(pb::PrimitiveDataType dtype); /// @brief convert arrow data type to spu plaintext type enum -/// @returns spu::PT_INVALID if @param[in] dtype is not supported -spu::PtType ArrowDataTypeToSpuPtType( +/// @returns spu::pb::PT_INVALID if @param[in] dtype is not supported +spu::pb::PtType Arrowspu::PtType( const std::shared_ptr& dtype); /// @brief convert spu plaintext type enum to arrow data type /// @returns nullptr if @param[in] pt_type is not supported -std::shared_ptr SpuPtTypeToArrowDataType(spu::PtType pt_type); +std::shared_ptr SpuPtTypeToArrowDataType(spu::pb::PtType pt_type); /// @brief convert scql primitive data type to spu plaintext type enum -/// @returns spu::PT_INVALID if @param[in] dtype is not supported -spu::PtType DataTypeToSpuPtType(pb::PrimitiveDataType dtype); +/// @returns spu::pb::PT_INVALID if @param[in] dtype is not supported +spu::pb::PtType spu::PtType(pb::PrimitiveDataType dtype); -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/exe/main.cc b/engine/exe/main.cc index 1fd95a8..b5210bb 100644 --- a/engine/exe/main.cc +++ b/engine/exe/main.cc @@ -226,15 +226,15 @@ std::unique_ptr BuildEngineService( } session_opt.log_options = opts; - std::vector allowed_protocols; + std::vector allowed_protocols; std::vector protocols_str = absl::StrSplit(FLAGS_spu_allowed_protocols, ','); for (auto& protocol_str : protocols_str) { std::string stripped_str(absl::StripAsciiWhitespace(protocol_str)); - spu::ProtocolKind protocol_kind; + spu::pb::ProtocolKind protocol_kind; - YACL_ENFORCE(spu::ProtocolKind_Parse(stripped_str, &protocol_kind), + YACL_ENFORCE(spu::pb::ProtocolKind_Parse(stripped_str, &protocol_kind), fmt::format("invalid protocol provided: {}", stripped_str)); allowed_protocols.push_back(protocol_kind); } @@ -462,4 +462,4 @@ grpc::SslCredentialsOptions LoadSslCredentialsOptions( YACL_ENFORCE(butil::ReadFileToString(butil::FilePath(cert_file), &content)); opts.pem_cert_chain = content; return opts; -} \ No newline at end of file +} diff --git a/engine/framework/session.cc b/engine/framework/session.cc index 3beaa88..68e8a49 100644 --- a/engine/framework/session.cc +++ b/engine/framework/session.cc @@ -63,7 +63,7 @@ bool Session::ValidateSPUContext() { YACL_ENFORCE(spu_ctx_ != nullptr, "SPU context is not initialized successfully."); return std::find(allowed_spu_protocols_.begin(), allowed_spu_protocols_.end(), - spu_ctx_->config().protocol()) != + spu_ctx_->config().protocol) != allowed_spu_protocols_.end(); } @@ -71,7 +71,7 @@ Session::Session(const SessionOptions& session_opt, const pb::JobStartParams& params, pb::DebugOptions debug_opts, yacl::link::ILinkFactory* link_factory, Router* router, DatasourceAdaptorMgr* ds_mgr, - const std::vector& allowed_spu_protocols) + const std::vector& allowed_spu_protocols) : id_(params.job_id()), session_opt_(session_opt), time_zone_(params.time_zone()), @@ -112,10 +112,10 @@ Session::Session(const SessionOptions& session_opt, std::accumulate( allowed_spu_protocols_.begin(), allowed_spu_protocols_.end(), std::string{}, - [](const std::string& acc, const spu::ProtocolKind& protocol) { + [](const std::string& acc, const spu::pb::ProtocolKind& protocol) { return acc.empty() - ? spu::ProtocolKind_Name(protocol) - : acc + ", " + spu::ProtocolKind_Name(protocol); + ? spu::pb::ProtocolKind_Name(protocol) + : acc + ", " + spu::pb::ProtocolKind_Name(protocol); }))); spu::mpc::Factory::RegisterProtocol(spu_ctx_.get(), lctx_); } @@ -397,4 +397,4 @@ std::shared_ptr ActiveLogger(const Session* session) { } return session_logger; } -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/framework/session.h b/engine/framework/session.h index 29b3cd2..5026a18 100644 --- a/engine/framework/session.h +++ b/engine/framework/session.h @@ -100,7 +100,7 @@ class Session { pb::DebugOptions debug_opts, yacl::link::ILinkFactory* link_factory, Router* router, DatasourceAdaptorMgr* ds_mgr, - const std::vector& allowed_spu_protocols); + const std::vector& allowed_spu_protocols); ~Session(); /// @return session id std::string Id() const { return id_; } @@ -269,7 +269,7 @@ class Session { std::shared_ptr psi_logger_ = nullptr; pb::DebugOptions debug_opts_; - const std::vector allowed_spu_protocols_; + const std::vector allowed_spu_protocols_; // for progress exposure std::atomic_int32_t nodes_count_ = -1; @@ -285,4 +285,4 @@ class Session { std::shared_ptr ActiveLogger(const Session* session); size_t CryptoHash(const std::string& str); -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/framework/session_manager.cc b/engine/framework/session_manager.cc index 67f104c..f77a88e 100644 --- a/engine/framework/session_manager.cc +++ b/engine/framework/session_manager.cc @@ -37,7 +37,7 @@ SessionManager::SessionManager( std::unique_ptr link_factory, std::unique_ptr ds_router, std::unique_ptr ds_mgr, int32_t session_timeout_s, - const std::vector& allowed_spu_protocols) + const std::vector& allowed_spu_protocols) : session_opt_(std::move(session_opt)), listener_manager_(listener_manager), link_factory_(std::move(link_factory)), @@ -341,4 +341,4 @@ std::optional SessionManager::GetTimeoutSession() { return std::nullopt; } -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/framework/session_manager.h b/engine/framework/session_manager.h index 06b29da..6d71497 100644 --- a/engine/framework/session_manager.h +++ b/engine/framework/session_manager.h @@ -34,7 +34,7 @@ class SessionManager { std::unique_ptr ds_router, std::unique_ptr ds_mgr, int32_t session_timeout_s, - const std::vector& allowed_spu_protocols); + const std::vector& allowed_spu_protocols); ~SessionManager(); @@ -80,7 +80,7 @@ class SessionManager { std::atomic to_stop_{false}; std::unique_ptr watch_thread_; std::queue session_timeout_queue_; - const std::vector allowed_spu_protocols_; + const std::vector allowed_spu_protocols_; }; -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/framework/session_manager_test.cc b/engine/framework/session_manager_test.cc index a2b89cf..9cabf55 100644 --- a/engine/framework/session_manager_test.cc +++ b/engine/framework/session_manager_test.cc @@ -49,8 +49,8 @@ class SessionManagerTest : public ::testing::Test { factory = std::make_unique(&listener_manager); EXPECT_NE(nullptr, factory.get()); SessionOptions options; - std::vector allowed_spu_protocols = { - spu::ProtocolKind::SEMI2K, spu::ProtocolKind::CHEETAH}; + std::vector allowed_spu_protocols = { + spu::pb::ProtocolKind::SEMI2K, spu::pb::ProtocolKind::CHEETAH}; mgr = std::make_unique(options, &listener_manager, std::move(factory), nullptr, nullptr, 1, allowed_spu_protocols); @@ -75,7 +75,7 @@ TEST_F(SessionManagerTest, Works) { alice->CopyFrom(op::test::BuildParty(op::test::kPartyAlice, 0)); params.mutable_spu_runtime_cfg()->CopyFrom( - op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::SEMI2K)); + op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::SEMI2K)); } pb::DebugOptions debug_opts; // When @@ -121,14 +121,14 @@ TEST_F(SessionManagerTest, TestSessionCreation) { SessionOptions options; common_params.mutable_spu_runtime_cfg()->CopyFrom( - op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::REF2K)); + op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::REF2K)); auto create_session = [&](const pb::JobStartParams& params) { pb::DebugOptions debug_opts; // not allowed to create session with REF2K. - std::vector allowed_protocols{spu::ProtocolKind::CHEETAH, - spu::ProtocolKind::SEMI2K, - spu::ProtocolKind::ABY3}; + std::vector allowed_protocols{spu::pb::ProtocolKind::CHEETAH, + spu::pb::ProtocolKind::SEMI2K, + spu::pb::ProtocolKind::ABY3}; EXPECT_THROW(std::make_shared(options, params, debug_opts, &g_mem_link_factory, nullptr, nullptr, allowed_protocols), @@ -151,4 +151,4 @@ TEST_F(SessionManagerTest, TestSessionCreation) { futures[1].get(); } -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/operator/arrow_func_test.cc b/engine/operator/arrow_func_test.cc index 403910c..e2779db 100644 --- a/engine/operator/arrow_func_test.cc +++ b/engine/operator/arrow_func_test.cc @@ -42,7 +42,7 @@ class ArrowFuncTest INSTANTIATE_TEST_SUITE_P( ArrowFuncBatchTest, ArrowFuncTest, testing::Combine( - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}), testing::Values( ArrowFuncTestCase{ .ins = {test::NamedTensor( @@ -190,4 +190,4 @@ void ArrowFuncTest::FeedInputs(ExecContext* ctx, const ArrowFuncTestCase& tc) { test::FeedInputsAsPrivate(ctx, tc.ins); } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/cast.cc b/engine/operator/cast.cc index e791b72..32fd10d 100644 --- a/engine/operator/cast.cc +++ b/engine/operator/cast.cc @@ -73,7 +73,7 @@ void Cast::Execute(ExecContext* ctx) { "string in spu is hash, not support cast"); auto* symbols = ctx->GetSession()->GetDeviceSymbols(); auto* sctx = ctx->GetSession()->GetSpuContext(); - auto to_type = spu::getEncodeType(DataTypeToSpuPtType(output_pb.elem_type())); + auto to_type = spu::getEncodeType(spu::PtType(output_pb.elem_type())); auto value = symbols->getVar(util::SpuVarNameEncoder::GetValueName(input_pb.name())); @@ -93,4 +93,4 @@ void Cast::Execute(ExecContext* ctx) { #endif // SCQL_WITH_NULL } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/coalesce_test.cc b/engine/operator/coalesce_test.cc index fdbaf85..ee69fec 100644 --- a/engine/operator/coalesce_test.cc +++ b/engine/operator/coalesce_test.cc @@ -37,7 +37,7 @@ class CoalesceTest INSTANTIATE_TEST_SUITE_P( CoalesceBatchTest, CoalesceTest, testing::Combine( - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}), testing::Values( // test private status CoalesceTestCase{ @@ -136,4 +136,4 @@ void CoalesceTest::FeedInputs(ExecContext* ctx, const CoalesceTestCase& tc) { test::FeedInputsAsPrivate(ctx, tc.exprs); } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/concat.cc b/engine/operator/concat.cc index c9cec4c..b54e71c 100644 --- a/engine/operator/concat.cc +++ b/engine/operator/concat.cc @@ -53,7 +53,7 @@ void Concat::Execute(ExecContext* ctx) { spu::DataType output_type = spu::DataType::DT_INVALID; if (output_pb.elem_type() != pb::PrimitiveDataType::STRING) { output_type = - spu::getEncodeType(DataTypeToSpuPtType(output_pb.elem_type())); + spu::getEncodeType(spu::PtType(output_pb.elem_type())); } std::vector values; @@ -93,4 +93,4 @@ void Concat::Execute(ExecContext* ctx) { #endif // SCQL_WITH_NULL } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/if_null_test.cc b/engine/operator/if_null_test.cc index 9b759de..5437d45 100644 --- a/engine/operator/if_null_test.cc +++ b/engine/operator/if_null_test.cc @@ -37,7 +37,7 @@ class IfNullTest : public testing::TestWithParam< INSTANTIATE_TEST_SUITE_P( IfNullBatchTest, IfNullTest, testing::Combine( - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}), testing::Values( // test private status IfNullTestCase{ @@ -138,4 +138,4 @@ void IfNullTest::FeedInputs(ExecContext* ctx, const IfNullTestCase& tc) { test::FeedInputsAsPrivate(ctx, {tc.exp, tc.alt}); } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/is_null_test.cc b/engine/operator/is_null_test.cc index 8debce6..2525cef 100644 --- a/engine/operator/is_null_test.cc +++ b/engine/operator/is_null_test.cc @@ -36,7 +36,7 @@ class IsNullTest : public testing::TestWithParam< INSTANTIATE_TEST_SUITE_P( IsNullBatchTest, IsNullTest, testing::Combine( - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}), testing::Values( // test private status IsNullTestCase{ @@ -122,4 +122,4 @@ void IsNullTest::FeedInputs(ExecContext* ctx, const IsNullTestCase& tc) { test::FeedInputsAsPrivate(ctx, {tc.input}); } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/join_test.cc b/engine/operator/join_test.cc index 8436e25..51ba52e 100644 --- a/engine/operator/join_test.cc +++ b/engine/operator/join_test.cc @@ -69,7 +69,7 @@ INSTANTIATE_TEST_SUITE_P( JoinBatchTest, JoinTest, testing::Combine( // any protocol is ok - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}), // TODO: clean duplicated tests testing::Values( // ECDH PSI @@ -749,4 +749,4 @@ void JoinTest::FeedInputs(ExecContext* ctx, } } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/oblivious_group_agg_test.cc b/engine/operator/oblivious_group_agg_test.cc index 9329ebc..f17fcfc 100644 --- a/engine/operator/oblivious_group_agg_test.cc +++ b/engine/operator/oblivious_group_agg_test.cc @@ -119,10 +119,10 @@ pb::ExecNode ObliviousGroupAggTest::MakeExecNode( INSTANTIATE_TEST_SUITE_P( ObliviousGroupSumTest, ObliviousGroupAggTest, testing::Combine( - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::CHEETAH, 2}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 3}, - test::SpuRuntimeTestCase{spu::ProtocolKind::ABY3, 3}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::CHEETAH, 2}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 3}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::ABY3, 3}), testing::Values( ObliviousGroupAggTestCase{ .op_type = ObliviousGroupSum::kOpType, diff --git a/engine/operator/secret_join.cc b/engine/operator/secret_join.cc index 7149953..8a9c5eb 100644 --- a/engine/operator/secret_join.cc +++ b/engine/operator/secret_join.cc @@ -164,7 +164,7 @@ void SecretJoin::Validate(ExecContext* ctx) { util::AreTensorsStatusMatched(right_output, pb::TENSORSTATUS_SECRET)); // TODO: support ABY3/CHEETAH after spu supported YACL_ENFORCE(ctx->GetSession()->GetSpuContext()->config().protocol() == - spu::ProtocolKind::SEMI2K, + spu::pb::ProtocolKind::SEMI2K, "secret join only support SEMI2K protocol now"); } @@ -288,10 +288,10 @@ void SecretJoin::Execute(ExecContext* ctx) { spu::Value left_perm; spu::Value right_perm; - if (sctx->config().field() == spu::FM64) { + if (sctx->config().field() == spu::pb::FM64) { left_perm = BuildPerm(sctx, pl, seq_shuffled); right_perm = BuildPerm(sctx, pr, seq_shuffled); - } else if (sctx->config().field() == spu::FM128) { + } else if (sctx->config().field() == spu::pb::FM128) { left_perm = BuildPerm(sctx, pl, seq_shuffled); right_perm = BuildPerm(sctx, pr, seq_shuffled); } else { @@ -552,4 +552,4 @@ void SecretJoin::SetEmptyResult(ExecContext* ctx) { } } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/operator/secret_join_test.cc b/engine/operator/secret_join_test.cc index 94151ef..d38f844 100644 --- a/engine/operator/secret_join_test.cc +++ b/engine/operator/secret_join_test.cc @@ -43,8 +43,8 @@ class SecretJoinTest INSTANTIATE_TEST_SUITE_P( SecretJoinSecretTest, SecretJoinTest, testing::Combine( - testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 3}), + testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 3}), testing::Values( SecretJoinTestCase{ .left_key = {test::NamedTensor( diff --git a/engine/operator/test_util.cc b/engine/operator/test_util.cc index 42b536b..f90cdfd 100644 --- a/engine/operator/test_util.cc +++ b/engine/operator/test_util.cc @@ -49,12 +49,12 @@ pb::JobStartParams::Party BuildParty(const std::string& code, int32_t rank) { return party; } -spu::RuntimeConfig MakeSpuRuntimeConfigForTest( - spu::ProtocolKind protocol_kind, bool enable_colocated_optimization) { - spu::RuntimeConfig config; +spu::pb::RuntimeConfig MakeSpuRuntimeConfigForTest( + spu::pb::ProtocolKind protocol_kind, bool enable_colocated_optimization) { + spu::pb::RuntimeConfig config; config.set_protocol(protocol_kind); - config.set_field(spu::FieldType::FM64); - config.set_sigmoid_mode(spu::RuntimeConfig::SIGMOID_REAL); + config.set_field(spu::pb::FieldType::FM64); + config.set_sigmoid_mode(spu::pb::RuntimeConfig::SIGMOID_REAL); config.set_experimental_enable_colocated_optimization( enable_colocated_optimization); @@ -69,16 +69,16 @@ std::shared_ptr Make1PCSession(Router* ds_router, params.set_job_id("1PC-session"); params.set_time_zone("+08:00"); params.mutable_spu_runtime_cfg()->CopyFrom( - MakeSpuRuntimeConfigForTest(spu::ProtocolKind::REF2K)); + MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::REF2K)); SessionOptions options; auto* alice = params.add_parties(); alice->CopyFrom(BuildParty(kPartyAlice, 0)); pb::DebugOptions debug_opts; // When there is only one party involved, the protocol will not be validated, // so the related parameters are dummy. - std::vector allowed_protocols{spu::ProtocolKind::REF2K, - spu::ProtocolKind::CHEETAH, - spu::ProtocolKind::SEMI2K}; + std::vector allowed_protocols{spu::pb::ProtocolKind::REF2K, + spu::pb::ProtocolKind::CHEETAH, + spu::pb::ProtocolKind::SEMI2K}; return std::make_shared(options, params, debug_opts, &g_mem_link_factory, ds_router, ds_mgr, allowed_protocols); @@ -103,9 +103,9 @@ std::vector> MakeMultiPCSession( options.psi_config.psi_curve_type = psi::CURVE_FOURQ; auto create_session = [&](const pb::JobStartParams& params) { pb::DebugOptions debug_opts; - std::vector allowed_protocols{spu::ProtocolKind::CHEETAH, - spu::ProtocolKind::SEMI2K, - spu::ProtocolKind::ABY3}; + std::vector allowed_protocols{spu::pb::ProtocolKind::CHEETAH, + spu::pb::ProtocolKind::SEMI2K, + spu::pb::ProtocolKind::ABY3}; return std::make_shared(options, params, debug_opts, &g_mem_link_factory, nullptr, nullptr, allowed_protocols); diff --git a/engine/operator/test_util.h b/engine/operator/test_util.h index 3fc2054..20e4371 100644 --- a/engine/operator/test_util.h +++ b/engine/operator/test_util.h @@ -29,7 +29,7 @@ #define TestParamNameGenerator(TestCaseClass) \ [](const testing::TestParamInfo& info) { \ return std::to_string(info.index) + \ - spu::ProtocolKind_Name(std::get<0>(info.param).protocol) + "p" + \ + spu::pb::ProtocolKind_Name(std::get<0>(info.param).protocol) + "p" + \ std::to_string(std::get<0>(info.param).party_size) + \ (std::get<0>(info.param).enable_colocated_optimization ? "opt" \ : ""); \ @@ -44,40 +44,40 @@ constexpr char kPartyBob[] = "bob"; constexpr char kPartyCarol[] = "carol"; constexpr const char* kPartyCodes[] = {"alice", "bob", "carol"}; -spu::RuntimeConfig GetSpuRuntimeConfigForTest(); +spu::pb::RuntimeConfig GetSpuRuntimeConfigForTest(); struct SpuRuntimeTestCase { - spu::ProtocolKind protocol; + spu::pb::ProtocolKind protocol; size_t party_size; bool enable_colocated_optimization; }; static const auto SpuTestValues2PC = testing::Values( - test::SpuRuntimeTestCase{spu::ProtocolKind::CHEETAH, 2, true}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2, true}, - test::SpuRuntimeTestCase{spu::ProtocolKind::CHEETAH, 2, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2, false}); + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::CHEETAH, 2, true}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2, true}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::CHEETAH, 2, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2, false}); static const auto SpuTestValuesMultiPC = testing::Values( - test::SpuRuntimeTestCase{spu::ProtocolKind::CHEETAH, 2, true}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2, true}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 3, true}, - test::SpuRuntimeTestCase{spu::ProtocolKind::ABY3, 3, true}, - test::SpuRuntimeTestCase{spu::ProtocolKind::CHEETAH, 2, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 3, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::ABY3, 3, false}); + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::CHEETAH, 2, true}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2, true}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 3, true}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::ABY3, 3, true}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::CHEETAH, 2, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 3, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::ABY3, 3, false}); static const auto SpuTestValuesMultiPCDisableColocated = testing::Values( - test::SpuRuntimeTestCase{spu::ProtocolKind::CHEETAH, 2, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 3, false}, - test::SpuRuntimeTestCase{spu::ProtocolKind::ABY3, 3, false}); + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::CHEETAH, 2, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 3, false}, + test::SpuRuntimeTestCase{spu::pb::ProtocolKind::ABY3, 3, false}); pb::JobStartParams::Party BuildParty(const std::string& code, int32_t rank); -spu::RuntimeConfig MakeSpuRuntimeConfigForTest( - spu::ProtocolKind protocol_kind, +spu::pb::RuntimeConfig MakeSpuRuntimeConfigForTest( + spu::pb::ProtocolKind protocol_kind, bool enable_colocated_optimization = false); // make single party session diff --git a/engine/operator/unary_test.cc b/engine/operator/unary_test.cc index 17bfedd..953b4f1 100644 --- a/engine/operator/unary_test.cc +++ b/engine/operator/unary_test.cc @@ -284,7 +284,7 @@ TEST_P(UnaryTest, WorksCorrectly) { if (tc.op == Exp::kOpType) { // The exp operation tends to be less accurate, particularly when using // the CHEETAH protocol. - if (spu_tc.protocol == spu::ProtocolKind::CHEETAH) { + if (spu_tc.protocol == spu::pb::ProtocolKind::CHEETAH) { tolerance = 0.05; } else { tolerance = 0.005; @@ -340,4 +340,4 @@ TEST_P(UnaryTest, WorksCorrectly) { } } -} // namespace scql::engine::op \ No newline at end of file +} // namespace scql::engine::op diff --git a/engine/services/engine_service_impl_test.cc b/engine/services/engine_service_impl_test.cc index f098f6e..cf50e13 100644 --- a/engine/services/engine_service_impl_test.cc +++ b/engine/services/engine_service_impl_test.cc @@ -71,8 +71,8 @@ class EngineServiceImplTest : public ::testing::Test { EXPECT_NE(nullptr, factory.get()); engine_service_options.enable_authorization = true; engine_service_options.credential = "alice_credential"; - std::vector allowed_protocols = { - spu::ProtocolKind::SEMI2K}; + std::vector allowed_protocols = { + spu::pb::ProtocolKind::SEMI2K}; impl = std::make_unique( engine_service_options, std::make_unique(session_options, &listener_manager, @@ -92,7 +92,7 @@ class EngineServiceImplTest : public ::testing::Test { "Credential", fmt::format("{}_credential", op::test::kPartyAlice)); global_params.mutable_spu_runtime_cfg()->CopyFrom( - op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::SEMI2K)); + op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::SEMI2K)); } } @@ -162,7 +162,7 @@ TEST_F(EngineServiceImplTest, QueryJobStatus) { alice->CopyFrom(op::test::BuildParty(op::test::kPartyAlice, 0)); params.mutable_spu_runtime_cfg()->CopyFrom( - op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::SEMI2K)); + op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::SEMI2K)); } // When @@ -340,8 +340,8 @@ class EngineServiceImpl2PartiesTest service_options.enable_authorization = true; service_options.credential = "alice_credential"; SessionOptions session_options; - std::vector allowed_protocols = { - spu::ProtocolKind::SEMI2K}; + std::vector allowed_protocols = { + spu::pb::ProtocolKind::SEMI2K}; auto impl = std::make_unique( service_options, std::make_unique( @@ -621,7 +621,7 @@ void EngineServiceImpl2PartiesTest::AddSessionParameters( } params->mutable_spu_runtime_cfg()->CopyFrom( - op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::SEMI2K)); + op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::SEMI2K)); } void EngineServiceImpl2PartiesTest::AddRunSQLNode( @@ -739,4 +739,4 @@ void EngineServiceImpl2PartiesTest::AddPublishNode( job->add_node_ids(op::Publish::kOpType); } -} // namespace scql::engine \ No newline at end of file +} // namespace scql::engine diff --git a/engine/util/ndarray_to_arrow.cc b/engine/util/ndarray_to_arrow.cc index 3fa4ec3..cd27465 100644 --- a/engine/util/ndarray_to_arrow.cc +++ b/engine/util/ndarray_to_arrow.cc @@ -37,7 +37,7 @@ class NdArrayConverter { : pool_(arrow::default_memory_pool()), arr_(arr), validity_(validity), - pt_type_(spu::PT_INVALID), + pt_type_(spu::pb::PT_INVALID), type_(nullptr), null_bitmap_(nullptr) { length_ = arr_.numel(); @@ -100,7 +100,7 @@ class NdArrayConverter { "NdArrayConverter doesn't support strided arrays"); } - if (pt_type_ == spu::PT_I1) { + if (pt_type_ == spu::pb::PT_I1) { int64_t nbytes = arrow::bit_util::BytesForBits(length_); ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateBuffer(nbytes, pool_)); @@ -153,7 +153,7 @@ class NdArrayConverter { const spu::NdArrayRef& arr_; const spu::NdArrayRef* validity_; - spu::PtType pt_type_; + spu::pb::PtType pt_type_; int64_t length_; std::shared_ptr type_; @@ -179,7 +179,7 @@ arrow::Status NdArrayConverter::Convert() { if (type_ == nullptr) { return arrow::Status::Invalid( - fmt::format("unsupported spu::PtType {}", spu::PtType_Name(pt_type_))); + fmt::format("unsupported spu::pb::PtType {}", spu::pb::PtType_Name(pt_type_))); } // Visit the type to perform conversion diff --git a/engine/util/spu_io.cc b/engine/util/spu_io.cc index baa5967..8cd45ab 100644 --- a/engine/util/spu_io.cc +++ b/engine/util/spu_io.cc @@ -151,8 +151,8 @@ std::string SpuVarNameEncoder::GetValidityName(const std::string& name) { SpuInfeedHelper::PtView SpuInfeedHelper::ConvertArrowArrayToPtView( const std::shared_ptr& array) { - spu::PtType pt = ArrowDataTypeToSpuPtType(array->type()); - YACL_ENFORCE(pt != spu::PT_INVALID, "unsupported arrow data type: {}", + spu::pb::PtType pt = Arrowspu::PtType(array->type()); + YACL_ENFORCE(pt != spu::pb::PT_INVALID, "unsupported arrow data type: {}", array->type()->ToString()); SpuPtBufferViewConverter converter; @@ -168,7 +168,7 @@ SpuInfeedHelper::PtView SpuInfeedHelper::ConvertArrowArrayToPtView( const uint8_t* null_bitmap = array->null_bitmap_data(); auto validity = spu::PtBufferView(static_cast(null_bitmap), - spu::PT_I1, {array->length()}, {1}, true); + spu::pb::PT_I1, {array->length()}, {1}, true); return PtView(value, validity); #endif // SCQL_WITH_NULL