Skip to content

Commit

Permalink
repo-sync-2024-02-08T11:13:20+0800 (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
usafchn authored Feb 8, 2024
1 parent 16753ea commit 7e52754
Show file tree
Hide file tree
Showing 26 changed files with 421 additions and 159 deletions.
2 changes: 1 addition & 1 deletion yacl/base/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package(default_visibility = ["//visibility:public"])

yacl_cc_library(
name = "exception",
srcs = [],
srcs = ["exception.cc"],
hdrs = ["exception.h"],
deps = [
"@com_github_fmtlib_fmt//:fmtlib",
Expand Down
35 changes: 35 additions & 0 deletions yacl/base/exception.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2024 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "yacl/base/exception.h"

namespace yacl {

std::string GetStacktraceString() {
::yacl::stacktrace_t stacks;
const int dep =
absl::GetStackTrace(stacks.data(), internal::kMaxStackTraceDep, 1);
std::string res;
for (int i = 0; i < dep; ++i) {
std::array<char, 2048> tmp;
const char* symbol = "(unknown)";
if (absl::Symbolize(stacks[i], tmp.data(), tmp.size())) {
symbol = tmp.data();
}
res.append(fmt::format("#{} {}+{}\n", i, symbol, stacks[i]));
}
return res;
}

} // namespace yacl
6 changes: 3 additions & 3 deletions yacl/base/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ inline std::string Format() { return ""; }
// |- logic_error
// |- runtime_error
// |- io_error

class Exception : public std::exception {
public:
Exception() = default;
Expand Down Expand Up @@ -161,7 +160,7 @@ class LinkError : public NetworkError {
fmt::format("[{}:{}] {}", __FILE__, __LINE__, fmt::format(__VA_ARGS__))

using stacktrace_t = std::array<void*, ::yacl::internal::kMaxStackTraceDep>;
//

// add absl::InitializeSymbolizer to main function to get
// human-readable names stack trace
//
Expand All @@ -170,7 +169,8 @@ using stacktrace_t = std::array<void*, ::yacl::internal::kMaxStackTraceDep>;
// absl::InitializeSymbolizer(argv[0]);
// ...
// }
//

std::string GetStacktraceString();

#define YACL_THROW_HELPER(ExceptionName, AppendStack, ...) \
do { \
Expand Down
1 change: 1 addition & 0 deletions yacl/base/exception_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void CheckExceptionContains(const std::exception& e,

TEST(Exception, StackTrace) {
try {
EXPECT_FALSE(GetStacktraceString().empty());
YACL_THROW("test");
} catch (const Exception& e) {
// e.g.
Expand Down
12 changes: 8 additions & 4 deletions yacl/crypto/primitives/ot/ferret_ote.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ OtSendStore FerretOtExtSend(const std::shared_ptr<link::Context>& ctx,
// simple_map = MakeSimpleMap(option, lpn_param.n);
// }

auto spcot_size = lpn_param.n / lpn_param.t;
for (uint64_t i = 0; i < batch_num; ++i) {
// the ot generated by this batch (including the seeds for next batch if
// necessary)
Expand All @@ -118,7 +119,7 @@ OtSendStore FerretOtExtSend(const std::shared_ptr<link::Context>& ctx,
auto idx_num = lpn_param.t;
auto idx_range = batch_ot_num;
if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) {
MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, working_s);
MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s);
} else {
YACL_THROW("Not Implemented!");
// MpCotUNSend(ctx, cot_mpcot, simple_map, option, working_s);
Expand Down Expand Up @@ -197,6 +198,7 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr<link::Context>& ctx,
// simple_map = MakeSimpleMap(option, lpn_param.n);
// }

auto spcot_size = lpn_param.n / lpn_param.t;
for (uint64_t i = 0; i < batch_num; ++i) {
// the ot generated by this batch (including the seeds for next batch if
// necessary)
Expand All @@ -208,7 +210,7 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr<link::Context>& ctx,
auto idx_range = batch_ot_num;

if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) {
MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, working_r);
MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r);
} else {
YACL_THROW("Not Implemented!");
// MpCotUNRecv(ctx, cot_mpcot, simple_map, option, e, working_r);
Expand Down Expand Up @@ -279,6 +281,7 @@ void FerretOtExtSend_cheetah(const std::shared_ptr<link::Context>& ctx,
// simple_map = MakeSimpleMap(option, lpn_param.n);
// }

auto spcot_size = lpn_param.n / lpn_param.t;
for (uint64_t i = 0; i < batch_num; ++i) {
// the ot generated by this batch (including the seeds for next batch if
// necessary)
Expand All @@ -288,7 +291,7 @@ void FerretOtExtSend_cheetah(const std::shared_ptr<link::Context>& ctx,
auto idx_num = lpn_param.t;
auto idx_range = batch_ot_num;
if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) {
MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, working_s);
MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s);
} else {
YACL_THROW("Not Implemented!");
// MpCotUNSend(ctx, cot_mpcot, simple_map, option, working_s);
Expand Down Expand Up @@ -364,6 +367,7 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr<link::Context>& ctx,
// simple_map = MakeSimpleMap(option, lpn_param.n);
// }

auto spcot_size = lpn_param.n / lpn_param.t;
for (uint64_t i = 0; i < batch_num; ++i) {
// the ot generated by this batch (including the seeds for next batch if
// necessary)
Expand All @@ -375,7 +379,7 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr<link::Context>& ctx,
auto idx_range = batch_ot_num;

if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) {
MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, working_r);
MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r);
} else {
YACL_THROW("Not Implemented!");
// MpCotUNRecv(ctx, cot_mpcot, simple_map, option, e, working_r);
Expand Down
1 change: 1 addition & 0 deletions yacl/crypto/primitives/ot/ferret_ote.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr<link::Context>& ctx,
const OtRecvStore& base_cot,
const LpnParam& lpn_param, uint64_t ot_num,
absl::Span<uint128_t> out);

} // namespace yacl::crypto
71 changes: 49 additions & 22 deletions yacl/crypto/primitives/ot/ferret_ote_rn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,74 @@ YACL_MODULE_DECLARE("ferret_ote_rn", SecParam::C::k128, SecParam::S::INF);
namespace yacl::crypto {

inline uint64_t MpCotRNHelper(uint64_t idx_num, uint64_t idx_range) {
const auto batch_size = (idx_range + idx_num - 1) / idx_num;
const auto batch_size = idx_range / idx_num;
const auto last_size = idx_range - batch_size * (idx_num - 1);
return math::Log2Ceil(batch_size) * (idx_num - 1) + math::Log2Ceil(last_size);
}

inline void MpCotRNSend(const std::shared_ptr<link::Context>& ctx,
const OtSendStore& cot, uint64_t idx_range,
uint64_t idx_num, absl::Span<uint128_t> out) {
const auto full_size = idx_range;
const auto batch_num = idx_num;
const auto batch_size = full_size / batch_num;
const auto last_size = full_size - (batch_num - 1) * batch_size;
uint64_t idx_num, uint64_t spcot_size,
absl::Span<uint128_t> out) {
const uint64_t full_size = idx_range;
const uint64_t batch_size = spcot_size;
const uint64_t batch_num = math::DivCeil(full_size, batch_size);
YACL_ENFORCE(batch_num <= idx_num);

const uint64_t last_size = full_size - (batch_num - 1) * batch_size;

// for each bin, call single-point cot
for (uint64_t i = 0; i < batch_num; ++i) {
const uint64_t this_size = (i == batch_num - 1) ? last_size : batch_size;
for (uint64_t i = 0; i < batch_num - 1; ++i) {
const auto& cot_slice =
cot.Slice(i * math::Log2Ceil(batch_size),
i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size));

GywzOtExtSend_ferret(ctx, cot_slice, this_size,
out.subspan(i * batch_size, this_size));
i * math::Log2Ceil(batch_size) + math::Log2Ceil(batch_size));
GywzOtExtSend_ferret(ctx, cot_slice, batch_size,
out.subspan(i * batch_size, batch_size));
}
// deal with last batch
if (last_size == 1) {
out[(batch_num - 1) * batch_size] =
cot.GetBlock((batch_num - 1) * math::Log2Ceil(batch_size), 0);
} else {
const auto& cot_slice =
cot.Slice((batch_num - 1) * math::Log2Ceil(batch_size),
(batch_num - 1) * math::Log2Ceil(batch_size) +
math::Log2Ceil(last_size));
GywzOtExtSend_ferret(ctx, cot_slice, last_size,
out.subspan((batch_num - 1) * batch_size, last_size));
}
}

inline void MpCotRNRecv(const std::shared_ptr<link::Context>& ctx,
const OtRecvStore& cot, uint64_t idx_range,
uint64_t idx_num, absl::Span<uint128_t> out) {
const auto full_size = idx_range;
const auto batch_num = idx_num;
const auto batch_size = full_size / batch_num;
const auto last_size = full_size - (batch_num - 1) * batch_size;
uint64_t idx_num, uint64_t spcot_size,
absl::Span<uint128_t> out) {
const uint64_t full_size = idx_range;
const uint64_t batch_size = spcot_size;
const uint64_t batch_num = math::DivCeil(full_size, batch_size);
YACL_ENFORCE(batch_num <= idx_num);

const uint64_t last_size = full_size - (batch_num - 1) * batch_size;

// for each bin, call single-point cot
for (uint64_t i = 0; i < batch_num; ++i) {
const uint64_t this_size = (i == batch_num - 1) ? last_size : batch_size;
for (uint64_t i = 0; i < batch_num - 1; ++i) {
const auto cot_slice =
cot.Slice(i * math::Log2Ceil(batch_size),
i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size));
GywzOtExtRecv_ferret(ctx, cot_slice, this_size,
out.subspan(i * batch_size, this_size));
i * math::Log2Ceil(batch_size) + math::Log2Ceil(batch_size));
GywzOtExtRecv_ferret(ctx, cot_slice, batch_size,
out.subspan(i * batch_size, batch_size));
}
// deal with last batch
if (last_size == 1) {
out[(batch_num - 1) * batch_size] =
cot.GetBlock((batch_num - 1) * math::Log2Ceil(batch_size));
} else {
const auto& cot_slice =
cot.Slice((batch_num - 1) * math::Log2Ceil(batch_size),
(batch_num - 1) * math::Log2Ceil(batch_size) +
math::Log2Ceil(last_size));
GywzOtExtRecv_ferret(ctx, cot_slice, last_size,
out.subspan((batch_num - 1) * batch_size, last_size));
}
}

Expand Down
3 changes: 2 additions & 1 deletion yacl/crypto/primitives/ot/ferret_ote_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ TEST_P(FerretOtExtTest, CheetahWorks) {

INSTANTIATE_TEST_SUITE_P(
Works_Instances, FerretOtExtTest,
testing::Values(FerretParams{10485760, LpnNoiseAsm::RegularNoise},
testing::Values(FerretParams{81921, LpnNoiseAsm::RegularNoise},
FerretParams{10485760, LpnNoiseAsm::RegularNoise},
FerretParams{10485761, LpnNoiseAsm::RegularNoise},
FerretParams{1 << 20, LpnNoiseAsm::RegularNoise}
// FerretParams{1 << 21, LpnNoiseAsm::RegularNoise},
Expand Down
16 changes: 8 additions & 8 deletions yacl/crypto/primitives/ot/gywz_ote.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void GywzOtExtRecv(const std::shared_ptr<link::Context>& ctx,
absl::Span<uint128_t> output) {
const uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);
YACL_ENFORCE_GT(n, index);

// Convert index into ot choices
Expand Down Expand Up @@ -174,7 +174,7 @@ void GywzOtExtSend(const std::shared_ptr<link::Context>& ctx,
absl::Span<uint128_t> output) {
const uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);

// get delta from cot
uint128_t delta = cot.GetDelta();
Expand All @@ -201,7 +201,7 @@ void GywzOtExtRecv_ferret(const std::shared_ptr<link::Context>& ctx,
absl::Span<uint128_t> output) {
uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);
YACL_ENFORCE(cot.Type() == OtStoreType::Compact);

uint32_t index = 0;
Expand Down Expand Up @@ -230,7 +230,7 @@ void GywzOtExtSend_ferret(const std::shared_ptr<link::Context>& ctx,
absl::Span<uint128_t> output) {
uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);
YACL_ENFORCE(cot.Type() == OtStoreType::Compact);

// get delta from cot
Expand All @@ -256,7 +256,7 @@ void GywzOtExtRecv_fixed_index(const std::shared_ptr<link::Context>& ctx,
absl::Span<uint128_t> output) {
const uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);

auto recv_buf = ctx->Recv(ctx->NextRank(), "GYWZ_OTE: messages");
YACL_ENFORCE(recv_buf.size() >=
Expand All @@ -272,7 +272,7 @@ void GywzOtExtSend_fixed_index(const std::shared_ptr<link::Context>& ctx,
absl::Span<uint128_t> output) {
uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);

AlignedVector<uint128_t> left_sums(height);
GywzOtExtSend_fixed_index(cot, n, output, absl::MakeSpan(left_sums));
Expand All @@ -288,7 +288,7 @@ void GywzOtExtRecv_fixed_index(const OtRecvStore& cot, uint32_t n,
absl::Span<uint128_t> recv_msgs) {
const uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);
YACL_ENFORCE(recv_msgs.size() >= height);

uint32_t index = 0;
Expand All @@ -309,7 +309,7 @@ void GywzOtExtSend_fixed_index(const OtSendStore& cot, uint32_t n,
absl::Span<uint128_t> send_msgs) {
uint32_t height = math::Log2Ceil(n);
YACL_ENFORCE(cot.Size() == height);
YACL_ENFORCE_GE(n, (uint32_t)1);
YACL_ENFORCE_GT(n, (uint32_t)1);
YACL_ENFORCE(send_msgs.size() >= height);

uint128_t delta = cot.GetDelta();
Expand Down
26 changes: 26 additions & 0 deletions yacl/crypto/primitives/ot/gywz_ote_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,30 @@ INSTANTIATE_TEST_SUITE_P(TestWork, GywzParamTest,
TestParams{1 << 10}, //
TestParams{1 << 15}));

// Edge Case
// n should be greater than 1
TEST(GywzEdgeTest, Work) {
size_t n = 1;

auto index = RandInRange(n);
auto lctxs = link::test::SetupWorld(2);
uint128_t delta = SecureRandSeed();
auto base_ot = MockCots(math::Log2Ceil(n), delta); // mock many base OTs

std::vector<uint128_t> send_out(n);
std::vector<uint128_t> recv_out(n);

std::future<void> sender = std::async([&] {
ASSERT_THROW(GywzOtExtRecv(lctxs[0], base_ot.recv, n, index,
absl::MakeSpan(recv_out)),
::yacl::Exception);
});
std::future<void> receiver = std::async([&] {
ASSERT_THROW(
GywzOtExtSend(lctxs[1], base_ot.send, n, absl::MakeSpan(send_out)),
::yacl::Exception);
});
sender.get();
receiver.get();
}
} // namespace yacl::crypto
Loading

0 comments on commit 7e52754

Please sign in to comment.