Skip to content

Commit

Permalink
tweak cpp namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
tongke6 committed Mar 5, 2025
1 parent b2a618f commit 186759d
Show file tree
Hide file tree
Showing 22 changed files with 138 additions and 138 deletions.
74 changes: 37 additions & 37 deletions engine/core/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,83 +101,83 @@ std::shared_ptr<arrow::DataType> ToArrowDataType(pb::PrimitiveDataType dtype) {
return dt;
}

spu::PtType ArrowDataTypeToSpuPtType(
spu::pb::PtType ArrowDataTypeToSpuPtType(
const std::shared_ptr<arrow::DataType>& dtype) {
spu::PtType pt;
spu::pb::PtType pt;
switch (dtype->id()) {
case arrow::Type::BOOL:
pt = spu::PT_I1;
pt = spu::pb::PT_I1;
break;
case arrow::Type::UINT8:
pt = spu::PT_U8;
pt = spu::pb::PT_U8;
break;
case arrow::Type::INT8:
pt = spu::PT_I8;
pt = spu::pb::PT_I8;
break;
case arrow::Type::UINT16:
pt = spu::PT_U16;
pt = spu::pb::PT_U16;
break;
case arrow::Type::INT16:
pt = spu::PT_I16;
pt = spu::pb::PT_I16;
break;
case arrow::Type::UINT32:
pt = spu::PT_U32;
pt = spu::pb::PT_U32;
break;
case arrow::Type::INT32:
pt = spu::PT_I32;
pt = spu::pb::PT_I32;
break;
case arrow::Type::UINT64:
pt = spu::PT_U64;
pt = spu::pb::PT_U64;
break;
case arrow::Type::INT64:
pt = spu::PT_I64;
pt = spu::pb::PT_I64;
break;
case arrow::Type::FLOAT:
pt = spu::PT_F32;
pt = spu::pb::PT_F32;
break;
case arrow::Type::DOUBLE:
pt = spu::PT_F64;
pt = spu::pb::PT_F64;
break;
default:
pt = spu::PT_INVALID;
pt = spu::pb::PT_INVALID;
}
return pt;
}

std::shared_ptr<arrow::DataType> SpuPtTypeToArrowDataType(spu::PtType pt_type) {
std::shared_ptr<arrow::DataType> SpuPtTypeToArrowDataType(spu::pb::PtType pt_type) {
std::shared_ptr<arrow::DataType> dt;
switch (pt_type) {
case spu::PT_I8:
case spu::pb::PT_I8:
dt = arrow::int8();
break;
case spu::PT_U8:
case spu::pb::PT_U8:
dt = arrow::uint8();
break;
case spu::PT_I16:
case spu::pb::PT_I16:
dt = arrow::int16();
break;
case spu::PT_U16:
case spu::pb::PT_U16:
dt = arrow::uint16();
break;
case spu::PT_I32:
case spu::pb::PT_I32:
dt = arrow::int32();
break;
case spu::PT_U32:
case spu::pb::PT_U32:
dt = arrow::uint32();
break;
case spu::PT_I64:
case spu::pb::PT_I64:
dt = arrow::int64();
break;
case spu::PT_U64:
case spu::pb::PT_U64:
dt = arrow::uint64();
break;
case spu::PT_F32:
case spu::pb::PT_F32:
dt = arrow::float32();
break;
case spu::PT_F64:
case spu::pb::PT_F64:
dt = arrow::float64();
break;
case spu::PT_I1:
case spu::pb::PT_I1:
dt = arrow::boolean();
break;
default:
Expand All @@ -186,36 +186,36 @@ std::shared_ptr<arrow::DataType> SpuPtTypeToArrowDataType(spu::PtType pt_type) {
return dt;
}

spu::PtType DataTypeToSpuPtType(pb::PrimitiveDataType dtype) {
spu::PtType pt;
spu::pb::PtType DataTypeToSpuPtType(pb::PrimitiveDataType dtype) {
spu::pb::PtType pt;
switch (dtype) {
case pb::PrimitiveDataType::INT8:
pt = spu::PT_I8;
pt = spu::pb::PT_I8;
break;
case pb::PrimitiveDataType::INT16:
pt = spu::PT_I16;
pt = spu::pb::PT_I16;
break;
case pb::PrimitiveDataType::INT32:
pt = spu::PT_I32;
pt = spu::pb::PT_I32;
break;
case pb::PrimitiveDataType::INT64:
case pb::PrimitiveDataType::DATETIME:
case pb::PrimitiveDataType::TIMESTAMP:
pt = spu::PT_I64;
pt = spu::pb::PT_I64;
break;
case pb::PrimitiveDataType::BOOL:
pt = spu::PT_I1;
pt = spu::pb::PT_I1;
break;
case pb::PrimitiveDataType::FLOAT32:
pt = spu::PT_F32;
pt = spu::pb::PT_F32;
break;
case pb::PrimitiveDataType::FLOAT64:
pt = spu::PT_F64;
pt = spu::pb::PT_F64;
break;
default:
pt = spu::PT_INVALID;
pt = spu::pb::PT_INVALID;
}
return pt;
}

} // namespace scql::engine
} // namespace scql::engine
12 changes: 6 additions & 6 deletions engine/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ pb::PrimitiveDataType FromArrowDataType(
std::shared_ptr<arrow::DataType> ToArrowDataType(pb::PrimitiveDataType dtype);

/// @brief convert arrow data type to spu plaintext type enum
/// @returns spu::PT_INVALID if @param[in] dtype is not supported
spu::PtType ArrowDataTypeToSpuPtType(
/// @returns spu::pb::PT_INVALID if @param[in] dtype is not supported
spu::pb::PtType ArrowDataTypeToSpuPtType(
const std::shared_ptr<arrow::DataType>& dtype);

/// @brief convert spu plaintext type enum to arrow data type
/// @returns nullptr if @param[in] pt_type is not supported
std::shared_ptr<arrow::DataType> SpuPtTypeToArrowDataType(spu::PtType pt_type);
std::shared_ptr<arrow::DataType> SpuPtTypeToArrowDataType(spu::pb::PtType pt_type);

/// @brief convert scql primitive data type to spu plaintext type enum
/// @returns spu::PT_INVALID if @param[in] dtype is not supported
spu::PtType DataTypeToSpuPtType(pb::PrimitiveDataType dtype);
/// @returns spu::pb::PT_INVALID if @param[in] dtype is not supported
spu::pb::PtType DataTypeToSpuPtType(pb::PrimitiveDataType dtype);

} // namespace scql::engine
} // namespace scql::engine
8 changes: 4 additions & 4 deletions engine/exe/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,15 @@ std::unique_ptr<scql::engine::EngineServiceImpl> BuildEngineService(
}
session_opt.log_options = opts;

std::vector<spu::ProtocolKind> allowed_protocols;
std::vector<spu::pb::ProtocolKind> allowed_protocols;

std::vector<absl::string_view> protocols_str =
absl::StrSplit(FLAGS_spu_allowed_protocols, ',');
for (auto& protocol_str : protocols_str) {
std::string stripped_str(absl::StripAsciiWhitespace(protocol_str));
spu::ProtocolKind protocol_kind;
spu::pb::ProtocolKind protocol_kind;

YACL_ENFORCE(spu::ProtocolKind_Parse(stripped_str, &protocol_kind),
YACL_ENFORCE(spu::pb::ProtocolKind_Parse(stripped_str, &protocol_kind),
fmt::format("invalid protocol provided: {}", stripped_str));
allowed_protocols.push_back(protocol_kind);
}
Expand Down Expand Up @@ -462,4 +462,4 @@ grpc::SslCredentialsOptions LoadSslCredentialsOptions(
YACL_ENFORCE(butil::ReadFileToString(butil::FilePath(cert_file), &content));
opts.pem_cert_chain = content;
return opts;
}
}
10 changes: 5 additions & 5 deletions engine/framework/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Session::Session(const SessionOptions& session_opt,
const pb::JobStartParams& params, pb::DebugOptions debug_opts,
yacl::link::ILinkFactory* link_factory, Router* router,
DatasourceAdaptorMgr* ds_mgr,
const std::vector<spu::ProtocolKind>& allowed_spu_protocols)
const std::vector<spu::pb::ProtocolKind>& allowed_spu_protocols)
: id_(params.job_id()),
session_opt_(session_opt),
time_zone_(params.time_zone()),
Expand Down Expand Up @@ -112,10 +112,10 @@ Session::Session(const SessionOptions& session_opt,
std::accumulate(
allowed_spu_protocols_.begin(), allowed_spu_protocols_.end(),
std::string{},
[](const std::string& acc, const spu::ProtocolKind& protocol) {
[](const std::string& acc, const spu::pb::ProtocolKind& protocol) {
return acc.empty()
? spu::ProtocolKind_Name(protocol)
: acc + ", " + spu::ProtocolKind_Name(protocol);
? spu::pb::ProtocolKind_Name(protocol)
: acc + ", " + spu::pb::ProtocolKind_Name(protocol);
})));
spu::mpc::Factory::RegisterProtocol(spu_ctx_.get(), lctx_);
}
Expand Down Expand Up @@ -397,4 +397,4 @@ std::shared_ptr<spdlog::logger> ActiveLogger(const Session* session) {
}
return session_logger;
}
} // namespace scql::engine
} // namespace scql::engine
6 changes: 3 additions & 3 deletions engine/framework/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Session {
pb::DebugOptions debug_opts,
yacl::link::ILinkFactory* link_factory, Router* router,
DatasourceAdaptorMgr* ds_mgr,
const std::vector<spu::ProtocolKind>& allowed_spu_protocols);
const std::vector<spu::pb::ProtocolKind>& allowed_spu_protocols);
~Session();
/// @return session id
std::string Id() const { return id_; }
Expand Down Expand Up @@ -269,7 +269,7 @@ class Session {
std::shared_ptr<util::PsiDetailLogger> psi_logger_ = nullptr;
pb::DebugOptions debug_opts_;

const std::vector<spu::ProtocolKind> allowed_spu_protocols_;
const std::vector<spu::pb::ProtocolKind> allowed_spu_protocols_;

// for progress exposure
std::atomic_int32_t nodes_count_ = -1;
Expand All @@ -285,4 +285,4 @@ class Session {
std::shared_ptr<spdlog::logger> ActiveLogger(const Session* session);

size_t CryptoHash(const std::string& str);
} // namespace scql::engine
} // namespace scql::engine
4 changes: 2 additions & 2 deletions engine/framework/session_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ SessionManager::SessionManager(
std::unique_ptr<yacl::link::ILinkFactory> link_factory,
std::unique_ptr<Router> ds_router,
std::unique_ptr<DatasourceAdaptorMgr> ds_mgr, int32_t session_timeout_s,
const std::vector<spu::ProtocolKind>& allowed_spu_protocols)
const std::vector<spu::pb::ProtocolKind>& allowed_spu_protocols)
: session_opt_(std::move(session_opt)),
listener_manager_(listener_manager),
link_factory_(std::move(link_factory)),
Expand Down Expand Up @@ -341,4 +341,4 @@ std::optional<std::string> SessionManager::GetTimeoutSession() {
return std::nullopt;
}

} // namespace scql::engine
} // namespace scql::engine
6 changes: 3 additions & 3 deletions engine/framework/session_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SessionManager {
std::unique_ptr<Router> ds_router,
std::unique_ptr<DatasourceAdaptorMgr> ds_mgr,
int32_t session_timeout_s,
const std::vector<spu::ProtocolKind>& allowed_spu_protocols);
const std::vector<spu::pb::ProtocolKind>& allowed_spu_protocols);

~SessionManager();

Expand Down Expand Up @@ -80,7 +80,7 @@ class SessionManager {
std::atomic<bool> to_stop_{false};
std::unique_ptr<std::thread> watch_thread_;
std::queue<std::string> session_timeout_queue_;
const std::vector<spu::ProtocolKind> allowed_spu_protocols_;
const std::vector<spu::pb::ProtocolKind> allowed_spu_protocols_;
};

} // namespace scql::engine
} // namespace scql::engine
16 changes: 8 additions & 8 deletions engine/framework/session_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class SessionManagerTest : public ::testing::Test {
factory = std::make_unique<TestFactory>(&listener_manager);
EXPECT_NE(nullptr, factory.get());
SessionOptions options;
std::vector<spu::ProtocolKind> allowed_spu_protocols = {
spu::ProtocolKind::SEMI2K, spu::ProtocolKind::CHEETAH};
std::vector<spu::pb::ProtocolKind> allowed_spu_protocols = {
spu::pb::ProtocolKind::SEMI2K, spu::pb::ProtocolKind::CHEETAH};
mgr = std::make_unique<SessionManager>(options, &listener_manager,
std::move(factory), nullptr, nullptr,
1, allowed_spu_protocols);
Expand All @@ -75,7 +75,7 @@ TEST_F(SessionManagerTest, Works) {
alice->CopyFrom(op::test::BuildParty(op::test::kPartyAlice, 0));

params.mutable_spu_runtime_cfg()->CopyFrom(
op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::SEMI2K));
op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::SEMI2K));
}
pb::DebugOptions debug_opts;
// When
Expand Down Expand Up @@ -121,14 +121,14 @@ TEST_F(SessionManagerTest, TestSessionCreation) {
SessionOptions options;

common_params.mutable_spu_runtime_cfg()->CopyFrom(
op::test::MakeSpuRuntimeConfigForTest(spu::ProtocolKind::REF2K));
op::test::MakeSpuRuntimeConfigForTest(spu::pb::ProtocolKind::REF2K));
auto create_session = [&](const pb::JobStartParams& params) {
pb::DebugOptions debug_opts;

// not allowed to create session with REF2K.
std::vector<spu::ProtocolKind> allowed_protocols{spu::ProtocolKind::CHEETAH,
spu::ProtocolKind::SEMI2K,
spu::ProtocolKind::ABY3};
std::vector<spu::pb::ProtocolKind> allowed_protocols{spu::pb::ProtocolKind::CHEETAH,
spu::pb::ProtocolKind::SEMI2K,
spu::pb::ProtocolKind::ABY3};
EXPECT_THROW(std::make_shared<Session>(options, params, debug_opts,
&g_mem_link_factory, nullptr,
nullptr, allowed_protocols),
Expand All @@ -151,4 +151,4 @@ TEST_F(SessionManagerTest, TestSessionCreation) {
futures[1].get();
}

} // namespace scql::engine
} // namespace scql::engine
4 changes: 2 additions & 2 deletions engine/operator/arrow_func_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ArrowFuncTest
INSTANTIATE_TEST_SUITE_P(
ArrowFuncBatchTest, ArrowFuncTest,
testing::Combine(
testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}),
testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}),
testing::Values(
ArrowFuncTestCase{
.ins = {test::NamedTensor(
Expand Down Expand Up @@ -190,4 +190,4 @@ void ArrowFuncTest::FeedInputs(ExecContext* ctx, const ArrowFuncTestCase& tc) {
test::FeedInputsAsPrivate(ctx, tc.ins);
}

} // namespace scql::engine::op
} // namespace scql::engine::op
4 changes: 2 additions & 2 deletions engine/operator/coalesce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CoalesceTest
INSTANTIATE_TEST_SUITE_P(
CoalesceBatchTest, CoalesceTest,
testing::Combine(
testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}),
testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}),
testing::Values(
// test private status
CoalesceTestCase{
Expand Down Expand Up @@ -136,4 +136,4 @@ void CoalesceTest::FeedInputs(ExecContext* ctx, const CoalesceTestCase& tc) {
test::FeedInputsAsPrivate(ctx, tc.exprs);
}

} // namespace scql::engine::op
} // namespace scql::engine::op
4 changes: 2 additions & 2 deletions engine/operator/if_null_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IfNullTest : public testing::TestWithParam<
INSTANTIATE_TEST_SUITE_P(
IfNullBatchTest, IfNullTest,
testing::Combine(
testing::Values(test::SpuRuntimeTestCase{spu::ProtocolKind::SEMI2K, 2}),
testing::Values(test::SpuRuntimeTestCase{spu::pb::ProtocolKind::SEMI2K, 2}),
testing::Values(
// test private status
IfNullTestCase{
Expand Down Expand Up @@ -138,4 +138,4 @@ void IfNullTest::FeedInputs(ExecContext* ctx, const IfNullTestCase& tc) {
test::FeedInputsAsPrivate(ctx, {tc.exp, tc.alt});
}

} // namespace scql::engine::op
} // namespace scql::engine::op
Loading

0 comments on commit 186759d

Please sign in to comment.