diff --git a/.bazelrc b/.bazelrc index 708b01c9..04806c9d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,9 @@ build:macos --host_copt=-Wa,--noexecstack # platform specific config # Bazel will automatic pick platform config since we have enable_platform_specific_config set build:macos --features=-supports_dynamic_linker +build:macos --linkopt="-Wl,-no_warn_duplicate_libraries" +build:macos --copt=-Wno-unused-command-line-argument +build:macos --host_copt=-Wno-unused-command-line-argument build:asan --features=asan build:ubsan --features=ubsan diff --git a/.bazelversion b/.bazelversion index 0df17dd0..c0be8a79 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -6.2.1 \ No newline at end of file +6.4.0 \ No newline at end of file diff --git a/.circleci/continue-config.yml b/.circleci/continue-config.yml index 449958b8..c9547e46 100644 --- a/.circleci/continue-config.yml +++ b/.circleci/continue-config.yml @@ -58,7 +58,7 @@ jobs: path: test_logs.tar.gz macOS_ut_arm64: macos: - xcode: 14.2 + xcode: 15.1 environment: HOMEBREW_NO_AUTO_UPDATE: 1 resource_class: macos.m1.medium.gen1 diff --git a/ALGORITHMS.md b/ALGORITHMS.md new file mode 100644 index 00000000..9758ef4e --- /dev/null +++ b/ALGORITHMS.md @@ -0,0 +1,33 @@ +# Supported Crypto Algorithms + +TODO + +## Primitives +- OT + - Simplest OT : https://eprint.iacr.org/2015/267.pdf + - INKP OT Extension : https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf + - KOS OT Extension : https://eprint.iacr.org/2015/546.pdf + - KKRT OT Extension : https://eprint.iacr.org/2016/799.pdf + - SGRR OT Extension: https://eprint.iacr.org/2019/1084.pdf + - GYWZ OT Extension : https://eprint.iacr.org/2022/1431.pdf + - Ferret OT Extension : https://eprint.iacr.org/2020/924.pdf + - Softspoken OT Extension : https://eprint.iacr.org/2022/192.pdf +- VOLE(over f2k) + - base VOLE : https://eprint.iacr.org/2016/505.pdf + - Silent VOLE : https://eprint.iacr.org/2019/1159.pdf, https://eprint.iacr.org/2021/1150.pdf, https://eprint.iacr.org/2022/1014.pdf + +## Theoretical Tools + +- Random Oracle +- Random Permutation +- Local Linear Code : https://eprint.iacr.org/2020/924.pdf +- Low Density Parity Check Code (Silver Code) : https://eprint.iacr.org/2021/1150.pdf +- Expanding Accumulation Code : https://eprint.iacr.org/2022/1014.pdf +- Correlation-Robust Hash Function : https://eprint.iacr.org/2019/074.pdf +- Circular Correlation-Robust Hash Function : https://eprint.iacr.org/2019/074.pdf + +## Basic (Traditional) algorithms + +- AES +- Hash: SHA2, SM2 +- RSA diff --git a/README.md b/README.md index 0a589c5b..c97c20e7 100644 --- a/README.md +++ b/README.md @@ -15,26 +15,24 @@ Repo layout: - [io](yacl/io/): a simple streaming-based io library. - [link](yacl/link/): a simple rpc-based MPI framework, providing the [SPMD](https://en.wikipedia.org/wiki/SPMD) parallel programming capability. -## Supported Crypto Primitives +## Supported Crypto Algorithms -Oblivious Transfer (and extensions) +See **Full List** of supported algorithms: [ALGORITHMS.md](ALGORITHMS.md) -- [Simplest OT](https://eprint.iacr.org/2015/267.pdf): 1-out-of-2 OT -- [IKNP OTe](https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf): 1-out-of-2 OT extension -- [Ferret OTe](https://eprint.iacr.org/2020/924): 1-out-of-2 OT extension -- [KKRT OTe](https://eprint.iacr.org/2016/799.pdf): 1-out-of-n OT (a.k.a OPRF) -- [SGRR OTe](https://eprint.iacr.org/2019/1084.pdf): (n-1)-out-of-n OTe -- [GYWZ+ OTe](https://eprint.iacr.org/2022/1431.pdf): (n-1)-out-of-n OTe with correlated GGM tree optimizations +**Selected algorithms**: -Distributed Point Function +- Oblivious Transfer (and extensions): [Simplest OT](https://eprint.iacr.org/2015/267.pdf), [IKNP OTe](https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf), [Ferret OTe](https://eprint.iacr.org/2020/924), [KKRT OTe](https://eprint.iacr.org/2016/799.pdf), [SGRR OTe](https://eprint.iacr.org/2019/1084.pdf). +- VOLE: [Silent VOLE](https://eprint.iacr.org/2019/1159.pdf), [Sparse VOLE (GF128)](https://eprint.iacr.org/2019/1084.pdf) +- Distributed Point Function: [BGI16](https://eprint.iacr.org/2018/707.pdf) +- Threshold Proxy-Re-encryption: [umbral with GM](https://github.com/nucypher/umbral-doc/blob/master/umbral-doc.pdf). -- [BGI16](https://eprint.iacr.org/2018/707.pdf) - -Threshold Proxy-Re-encryption +## Build -- A substitute of [umbral](https://github.com/nucypher/umbral-doc/blob/master/umbral-doc.pdf). Our implementation supports SM2, SM3 and SM4. +### Supported platforms -## Build +| | Linux x86_64 | Linux aarch64 | macOS x86_64 | macOS Apple Silicon | Windows x86_64 | Windows WSL2 x86_64 | +|-----|--------------|---------------|--------------|---------------------|----------------|---------------------| +| CPU | yes | yes | yes | yes | no | yes | ### Prerequisite diff --git a/bazel/patches/libtommath.patch b/bazel/patches/libtommath.patch index 9a80d071..8c8c88a9 100644 --- a/bazel/patches/libtommath.patch +++ b/bazel/patches/libtommath.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 8f85249..53e0365 100644 +index dfbcb0f..72f9a46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,7 @@ set_target_properties(${PROJECT_NAME} PROPERTIES @@ -9,28 +9,18 @@ index 8f85249..53e0365 100644 - PUBLIC_HEADER "${PUBLIC_HEADERS}" + PUBLIC_HEADER "${HEADERS}" ) - + option(COMPILE_LTO "Build with LTO enabled") diff --git a/tommath_private.h b/tommath_private.h -index d88d263..46caa96 100644 +index d319a1d..5f4446e 100644 --- a/tommath_private.h +++ b/tommath_private.h -@@ -188,14 +188,14 @@ MP_STATIC_ASSERT(prec_geq_min_prec, MP_DEFAULT_DIGIT_COUNT >= MP_MIN_DIGIT_COUNT +@@ -17,7 +17,7 @@ + * On Win32 a .def file must be used to specify the exported symbols. + */ + #if defined(__GNUC__) && __GNUC__ >= 4 && !defined(_WIN32) && !defined(__CYGWIN__) +-# define MP_PRIVATE __attribute__ ((visibility ("hidden"))) ++# define MP_PRIVATE + #else + # define MP_PRIVATE #endif - - /* random number source */ --extern MP_PRIVATE mp_err(*s_mp_rand_source)(void *out, size_t size); -+extern mp_err(*s_mp_rand_source)(void *out, size_t size); - - /* lowlevel functions, do not call! */ - MP_PRIVATE bool s_mp_get_bit(const mp_int *a, int b) MP_WUR; - MP_PRIVATE int s_mp_log_2expt(const mp_int *a, mp_digit base) MP_WUR; - - MP_PRIVATE mp_err s_mp_add(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; --MP_PRIVATE mp_err s_mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) MP_WUR; -+mp_err s_mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) MP_WUR; - MP_PRIVATE mp_err s_mp_div_recursive(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r) MP_WUR; - MP_PRIVATE mp_err s_mp_div_school(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d) MP_WUR; - MP_PRIVATE mp_err s_mp_div_small(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d) MP_WUR; - - \ No newline at end of file diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index d32665e3..f42bd61f 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -312,15 +312,15 @@ def _com_github_libtom_libtommath(): maybe( http_archive, name = "com_github_libtom_libtommath", - sha256 = "da0759723645d974b82f134a26a1933a08fee887580132f55482c606ec688188", + sha256 = "dbfdafbaeb51ff92fdd3f2505ec0490f8a9badc2a71b378219856b68d470f0aa", type = "tar.gz", - strip_prefix = "libtommath-7f96509df1a6b44867bbda56bbf2cb92524be8ef", + strip_prefix = "libtommath-8ce69f7b5e2f34620633f4fb5c231045a8dc2f54", patch_args = ["-p1"], patches = [ "@yacl//bazel:patches/libtommath.patch", ], urls = [ - "https://github.com/libtom/libtommath/archive/7f96509df1a6b44867bbda56bbf2cb92524be8ef.tar.gz", + "https://github.com/libtom/libtommath/archive/8ce69f7b5e2f34620633f4fb5c231045a8dc2f54.tar.gz", ], build_file = "@yacl//bazel:libtommath.BUILD", ) diff --git a/bazel/yacl.bzl b/bazel/yacl.bzl index 14b3fd97..6876deb4 100644 --- a/bazel/yacl.bzl +++ b/bazel/yacl.bzl @@ -39,12 +39,22 @@ AES_COPT_FLAGS = select({ ], }) -OMP_LINK_FLAGS = select({ +OMP_DEPS = select({ "@bazel_tools//src/conditions:darwin_x86_64": ["@macos_omp_x64//:openmp"], "@bazel_tools//src/conditions:darwin_arm64": ["@macos_omp_arm64//:openmp"], "//conditions:default": [], }) +OMP_CFLAGS = select({ + "@platforms//os:macos": ["-Xclang", "-fopenmp"], + "//conditions:default": ["-fopenmp"], +}) + +OMP_LINKFLAGS = select({ + "@platforms//os:macos": [], + "//conditions:default": ["-fopenmp"], +}) + def _yacl_copts(): return select({ "@yacl//bazel:yacl_build_as_release": RELEASE_FLAGS, @@ -86,14 +96,11 @@ def yacl_configure_make(**attrs): def yacl_cc_test( copts = [], deps = [], - linkstatic = True, **kwargs): cc_test( copts = _yacl_copts() + copts, deps = deps + [ "@com_google_googletest//:gtest_main", ], - # static link for tcmalloc - linkstatic = True, **kwargs ) diff --git a/yacl/base/buffer_test.cc b/yacl/base/buffer_test.cc index 470ff314..12def9b4 100644 --- a/yacl/base/buffer_test.cc +++ b/yacl/base/buffer_test.cc @@ -14,6 +14,7 @@ #include "yacl/base/buffer.h" +#include #include #include "gtest/gtest.h" @@ -25,8 +26,8 @@ namespace yacl::test { TEST(BufferTest, ParallelWorks) { std::vector v; v.resize(100000); - parallel_for(0, v.size(), 1, [&](int64_t beg, int64_t end) { - for (int64_t i = beg; i < end; ++i) { + parallel_for(0, v.size(), [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { v[i] = Buffer(fmt::format("hello_{}", i)); } }); diff --git a/yacl/base/exception.h b/yacl/base/exception.h index 1a769a03..6a63e7ad 100644 --- a/yacl/base/exception.h +++ b/yacl/base/exception.h @@ -283,7 +283,7 @@ class EnforceNotMet : public Exception { do { \ if (!(condition)) { \ ::yacl::stacktrace_t __stacks__; \ - int __dep__ = absl::GetStackTrace( \ + const int __dep__ = absl::GetStackTrace( \ __stacks__.data(), ::yacl::internal::kMaxStackTraceDep, 0); \ throw ::yacl::EnforceNotMet(__FILE__, __LINE__, #condition, \ ::yacl::internal::Format(__VA_ARGS__), \ @@ -414,4 +414,10 @@ T CheckNotNull(T t) { return t; } +#ifdef NDEBUG +#define WEAK_ENFORCE(condition, ...) ((void)0) +#else +#define WEAK_ENFORCE(condition, ...) YACL_ENFORCE(condition, __VA_ARGS__) +#endif + } // namespace yacl diff --git a/yacl/crypto/base/ecc/ecc_test.cc b/yacl/crypto/base/ecc/ecc_test.cc index 007d36e6..5a999d39 100644 --- a/yacl/crypto/base/ecc/ecc_test.cc +++ b/yacl/crypto/base/ecc/ecc_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -289,7 +290,7 @@ class EcCurveTest : public ::testing::TestWithParam { constexpr int64_t ts = 1 << 15; std::array buf; auto g = ec_->GetGenerator(); - yacl::parallel_for(0, ts, 1, [&](int64_t beg, int64_t end) { + yacl::parallel_for(0, ts, [&](int64_t beg, int64_t end) { auto point = ec_->MulBase(MPInt(beg)); buf[beg] = point; for (int64_t i = beg + 1; i < end; ++i) { @@ -392,7 +393,7 @@ TEST(OpensslMemLeakTest, DISABLED_MulBaseLeaks) { EcGroupFactory::Instance().Create("sm2", ArgLib = "openssl"); std::mutex mutex; - yacl::parallel_for(0, 2, 1, [&](int64_t, int64_t) { + yacl::parallel_for(0, 2, [&](int64_t, int64_t) { std::lock_guard guard(mutex); // memory leaks here even with serial calls. ec->MulBase(0_mp); diff --git a/yacl/crypto/base/ecc/openssl/openssl_test.cc b/yacl/crypto/base/ecc/openssl/openssl_test.cc index 6a66e613..1c9f236a 100644 --- a/yacl/crypto/base/ecc/openssl/openssl_test.cc +++ b/yacl/crypto/base/ecc/openssl/openssl_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "gtest/gtest.h" #include "yacl/crypto/base/ecc/openssl/openssl_group.h" @@ -103,7 +105,7 @@ TEST(OpensslMemLeakTest, MulBaseLeaks) { yacl::crypto::EcGroupFactory::Instance().Create("sm2", ArgLib = "openssl"); - yacl::parallel_for(0, 2, 1, [&](int64_t, int64_t) { + yacl::parallel_for(0, 2, [&](int64_t, int64_t) { // no memory leak here, but the same code in ecc_test.cc leaks. ec->MulBase(0_mp); }); diff --git a/yacl/crypto/base/ecc/pairing_test.cc b/yacl/crypto/base/ecc/pairing_test.cc index e120e0de..bb125642 100644 --- a/yacl/crypto/base/ecc/pairing_test.cc +++ b/yacl/crypto/base/ecc/pairing_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "fmt/ranges.h" #include "gtest/gtest.h" @@ -230,12 +232,14 @@ class PairingCurveTest : public ::testing::TestWithParam { constexpr int64_t ts = 1 << 15; std::array buf; auto g = ec->GetGenerator(); - yacl::parallel_for(0, ts, 1, [&](int64_t beg, int64_t end) { - auto point = ec->MulBase(MPInt(beg)); - buf[beg] = point; - for (int64_t i = beg + 1; i < end; ++i) { - point = ec->Add(point, g); - buf[i] = point; + yacl::parallel_for(0, ts, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + auto point = ec->MulBase(MPInt(beg)); + buf[beg] = point; + for (int64_t i = beg + 1; i < end; ++i) { + point = ec->Add(point, g); + buf[i] = point; + } } }); @@ -326,14 +330,12 @@ TEST(Pairing_Multi_Instance_Test, Works) { // TODO: temporarily disable mcl pairing-related test, since its weird error // on Intel Mac if (lib_name != "libmcl") { - yacl::parallel_for(0, 10, 1, [&](int64_t x, int64_t y) { - for (int64_t i = x; i < y; i++) { - std::shared_ptr pairing = - PairingGroupFactory::Instance().Create(pairing_name, - ArgLib = lib_name); - pairing->Pairing(pairing->GetG1()->GetGenerator(), - pairing->GetG2()->GetGenerator()); - } + yacl::parallel_for(0, 10, [&](int64_t, int64_t) { + std::shared_ptr pairing = + PairingGroupFactory::Instance().Create(pairing_name, + ArgLib = lib_name); + pairing->Pairing(pairing->GetG1()->GetGenerator(), + pairing->GetG2()->GetGenerator()); }); } } diff --git a/yacl/crypto/tools/BUILD.bazel b/yacl/crypto/tools/BUILD.bazel index b8453ef9..6c07e6a7 100644 --- a/yacl/crypto/tools/BUILD.bazel +++ b/yacl/crypto/tools/BUILD.bazel @@ -84,7 +84,6 @@ yacl_cc_library( ":code_interface", "//yacl/crypto/tools:random_permutation", "//yacl/math:gadget", - "//yacl/utils:thread_pool", ] + select({ "@platforms//cpu:aarch64": [ "@com_github_dltcollab_sse2neon//:sse2neon", @@ -109,8 +108,8 @@ yacl_cc_library( deps = [ ":code_interface", "//yacl/base:block", + "//yacl/base:exception", "//yacl/base:int128", - "//yacl/utils:thread_pool", ] + select({ "@platforms//cpu:aarch64": [ "@com_github_dltcollab_sse2neon//:sse2neon", @@ -136,7 +135,6 @@ yacl_cc_library( ":linear_code", "//yacl/base:block", "//yacl/base:int128", - "//yacl/utils:thread_pool", ] + select({ "@platforms//cpu:aarch64": [ "@com_github_dltcollab_sse2neon//:sse2neon", diff --git a/yacl/crypto/tools/linear_code.h b/yacl/crypto/tools/linear_code.h index d068e2b1..37943201 100644 --- a/yacl/crypto/tools/linear_code.h +++ b/yacl/crypto/tools/linear_code.h @@ -16,7 +16,6 @@ #include #include -#include #include "absl/types/span.h" diff --git a/yacl/math/galois_field/BUILD.bazel b/yacl/math/galois_field/BUILD.bazel new file mode 100644 index 00000000..2d9a3393 --- /dev/null +++ b/yacl/math/galois_field/BUILD.bazel @@ -0,0 +1,53 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") + +package(default_visibility = ["//visibility:public"]) + +yacl_cc_library( + name = "galois_field", + deps = [ + "//yacl/math/galois_field/mpint_field", + ], +) + +yacl_cc_library( + name = "sketch", + hdrs = [ + "gf_scalar.h", + "gf_vector.h", + ], + deps = [ + ":spi", + "//yacl/io/msgpack:buffer", + "//yacl/io/msgpack:spec_traits", + "//yacl/utils:parallel", + "@com_google_absl//absl/types:span", + ], +) + +yacl_cc_library( + name = "spi", + srcs = [ + "gf_spi.cc", + ], + hdrs = [ + "gf_spi.h", + ], + deps = [ + "//yacl/math/mpint", + "//yacl/utils/spi", + ], +) diff --git a/yacl/math/galois_field/gf_scalar.h b/yacl/math/galois_field/gf_scalar.h new file mode 100644 index 00000000..f966cb98 --- /dev/null +++ b/yacl/math/galois_field/gf_scalar.h @@ -0,0 +1,459 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/io/msgpack/buffer.h" +#include "yacl/io/msgpack/spec_traits.h" +#include "yacl/math/galois_field/gf_spi.h" +#include "yacl/utils/parallel.h" + +namespace yacl::math { + +// Scalar means that Lib based on this class can only process one data at a +// time, but each method can be called concurrently. +// Scalar 表示基于此类实现的库一次只能处理一个数据,但是每个接口都能被并发调用 +template +class GFScalarSketch : public GaloisField { + public: + // if x is scalar, returns bool + // if x is vectored, returns std::vector + virtual bool IsIdentityOne(const T& x) const = 0; + virtual bool IsIdentityZero(const T& x) const = 0; + virtual bool IsInField(const T& x) const = 0; + + virtual bool Equal(const T& x, const T& y) const = 0; + + //==================================// + // operations defined on field // + //==================================// + + // get the additive inverse −a for all elements in set + virtual T Neg(const T& x) const = 0; + virtual void NegInplace(T* x) const = 0; + + // get the multiplicative inverse 1/b for every nonzero element in set + virtual T Inv(const T& x) const = 0; + virtual void InvInplace(T* x) const = 0; + + virtual T Add(const T& x, const T& y) const = 0; + virtual void AddInplace(T* x, const T& y) const = 0; + + virtual T Sub(const T& x, const T& y) const = 0; + virtual void SubInplace(T* x, const T& y) const = 0; + + virtual T Mul(const T& x, const T& y) const = 0; + virtual void MulInplace(T* x, const T& y) const = 0; + + virtual T Div(const T& x, const T& y) const = 0; + virtual void DivInplace(T* x, const T& y) const = 0; + + virtual T Pow(const T& x, const MPInt& y) const = 0; + virtual void PowInplace(T* x, const MPInt& y) const = 0; + + // scalar version: return a random scalar element + virtual T RandomT() const = 0; + + //==================================// + // operations defined on field // + //==================================// + + virtual T DeepCopy(const T& x) const = 0; + + // To human-readable string + virtual std::string ToString(const T& x) const = 0; + + virtual Buffer Serialize(const T& x) const = 0; + // serialize field element(s) to already allocated buffer. + // if buf is nullptr, then calc serialize size only + // @return: the actual size of serialized buffer + virtual size_t Serialize(const T& x, uint8_t* buf, size_t buf_len) const = 0; + + virtual T DeserializeT(ByteContainerView buffer) const = 0; + + private: +#define DefineBoolUnaryFunc(FuncName) \ + Item FuncName(const Item& x) const override { \ + if (x.IsArray()) { \ + auto xsp = x.AsSpan(); \ + /* std::vector cannot write in parallel */ \ + std::vector res; \ + res.resize(xsp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + res[i] = FuncName(xsp[i]); \ + } \ + }); \ + /* convert std::vector to std::vector */ \ + std::vector bv; \ + bv.resize(res.size()); \ + std::copy(res.begin(), res.end(), bv.begin()); \ + return Item::Take(std::move(bv)); \ + } else { \ + return FuncName(x.As()); \ + } \ + } + + // if x is scalar, returns bool + // if x is vectored, returns std::vector + DefineBoolUnaryFunc(IsIdentityOne); + DefineBoolUnaryFunc(IsIdentityZero); + DefineBoolUnaryFunc(IsInField); + + bool Equal(const Item& x, const Item& y) const override { + switch (x, y) { + case OperandType::Scalar2Scalar: { + return Equal(x.As(), y.As()); + } + case OperandType::Vector2Vector: { + auto xsp = x.AsSpan(); + auto ysp = y.AsSpan(); + if (xsp.length() != ysp.length()) { + return false; + } + + std::atomic res = true; + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + if (!res) { + return; + } + + if (!Equal(xsp[i], ysp[i])) { + res.store(false); + return; + } + } + }); + return res.load(); + } + case OperandType::Scalar2Vector: + case OperandType::Vector2Scalar: + return false; + } + YACL_THROW("Bug: please add more case branch"); + } + + //================================// + // operations defined on set // + //================================// + +#define DefineUnaryFunc(FuncName) \ + Item FuncName(const Item& x) const override { \ + using RES_T = decltype(FuncName(std::declval())); \ + \ + if (x.IsArray()) { \ + auto xsp = x.AsSpan(); \ + std::vector res; \ + res.resize(xsp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + res[i] = FuncName(xsp[i]); \ + } \ + }); \ + return Item::Take(std::move(res)); \ + } else { \ + return FuncName(x.As()); \ + } \ + } + +#define DefineUnaryInplaceFunc(FuncName) \ + void FuncName(Item* x) const override { \ + if (x->IsArray()) { \ + auto xsp = x->AsSpan(); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + FuncName(&xsp[i]); \ + } \ + }); \ + } else { \ + FuncName(x->As()); \ + } \ + } + + // get the additive inverse −a for all elements in set + DefineUnaryFunc(Neg); + DefineUnaryInplaceFunc(NegInplace); + + // get the multiplicative inverse 1/b for every nonzero element in set + DefineUnaryFunc(Inv); + DefineUnaryInplaceFunc(InvInplace); + +#define DefineBinaryFunc(FuncName) \ + Item FuncName(const Item& x, const Item& y) const override { \ + using RES_T = \ + decltype(FuncName(std::declval(), std::declval())); \ + \ + switch (x, y) { \ + case OperandType::Scalar2Scalar: { \ + return FuncName(x.As(), y.As()); \ + } \ + case OperandType::Vector2Vector: { \ + auto xsp = x.AsSpan(); \ + auto ysp = y.AsSpan(); \ + YACL_ENFORCE_EQ( \ + xsp.length(), ysp.length(), \ + "operands must have the same length, x.len={}, y.len={}", \ + xsp.length(), ysp.length()); \ + \ + std::vector res; \ + res.resize(xsp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + res[i] = FuncName(xsp[i], ysp[i]); \ + } \ + }); \ + return Item::Take(std::move(res)); \ + } \ + default: \ + YACL_THROW("GFScalarSketch method [{}] doesn't support broadcast now", \ + #FuncName); \ + } \ + } + +#define DefineBinaryInplaceFunc(FuncName) \ + void FuncName(yacl::Item* x, const yacl::Item& y) const override { \ + switch (*x, y) { \ + case OperandType::Scalar2Scalar: { \ + FuncName(x->As(), y.As()); \ + return; \ + } \ + case OperandType::Vector2Vector: { \ + auto xsp = x->AsSpan(); \ + auto ysp = y.AsSpan(); \ + YACL_ENFORCE_EQ( \ + xsp.length(), ysp.length(), \ + "operands must have the same length, x.len={}, y.len={}", \ + xsp.length(), ysp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + FuncName(&xsp[i], ysp[i]); \ + } \ + }); \ + return; \ + } \ + default: \ + YACL_THROW("GFScalarSketch method [{}] doesn't support broadcast now", \ + #FuncName); \ + } \ + } + + DefineBinaryFunc(Add); + DefineBinaryInplaceFunc(AddInplace); + + DefineBinaryFunc(Sub); + DefineBinaryInplaceFunc(SubInplace); + + DefineBinaryFunc(Mul); + DefineBinaryInplaceFunc(MulInplace); + + DefineBinaryFunc(Div); + DefineBinaryInplaceFunc(DivInplace); + + Item Pow(const Item& x, const MPInt& y) const override { + if (x.IsArray()) { + auto xsp = x.AsSpan(); + std::vector res; + res.resize(xsp.length()); + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + res[i] = Pow(xsp[i], y); + } + }); + return Item::Take(std::move(res)); + } else { + return Pow(x.As(), y); + } + } + + void PowInplace(Item* x, const MPInt& y) const override { + if (x->IsArray()) { + auto xsp = x->AsSpan(); + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + Pow(&xsp[i], y); + } + }); + return; + } else { + Pow(x->As(), y); + return; + } + } + + Item Random() const override { return RandomT(); } + + Item Random(size_t count) const override { + std::vector res; + res.resize(count); + yacl::parallel_for(0, count, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + res[i] = RandomT(); + } + }); + return Item::Take(std::move(res)); + } + + //================================// + // I/O // + //================================// + + DefineUnaryFunc(DeepCopy); + + // To human-readable string + std::string ToString(const Item& x) const override { + if (x.IsArray()) { + auto xsp = x.AsSpan(); + std::string res = "["; + if (!xsp.empty()) { + std::string str = ToString(xsp[0]); + res.reserve(str.size() * xsp.length() * 1.1); + res += str; + } + + for (size_t i = 1; i < xsp.length(); ++i) { + res += ", "; + res += ToString(xsp[i]); + } + res += "]"; + return res; + } else { + return ToString(x.As()); + } + } + + Buffer Serialize(const Item& x) const override { + Buffer buf; + io::StreamBuffer sbuf(&buf); + msgpack::packer packer(sbuf); + + if (x.IsArray()) { + auto xsp = x.AsSpan(); + if (xsp.empty()) { + packer.pack_array(0); + return buf; + } + + // reserve space and pack header + auto item_size = Serialize(xsp[0], nullptr, 0) + 5; + sbuf.Expand(item_size * xsp.length() * 1.1); + packer.pack_array(xsp.length()); + + // todo: need parallel + size_t total_sz = 0; + for (size_t i = 0; i < xsp.length(); ++i) { + auto body_sz = Serialize(xsp[i], nullptr, 0); + total_sz += body_sz; + packer.pack_str(body_sz); + if (sbuf.FreeSize() < body_sz) { + size_t exp_size = (total_sz / (i + 1) + 5) * (xsp.length() - i); + sbuf.Expand(std::max(exp_size, body_sz)); + } + body_sz = Serialize(xsp[i], reinterpret_cast(sbuf.PosLoc()), + sbuf.FreeSize()); + sbuf.IncPos(body_sz); + } + return buf; + } else { + auto& xt = x.As(); + auto sz = Serialize(xt, nullptr, 0); + sbuf.Expand(sz + 5); + + packer.pack_str(sz); // pack header and size + // write payload + auto body_sz = Serialize(xt, reinterpret_cast(sbuf.PosLoc()), + sbuf.FreeSize()); + sbuf.IncPos(body_sz); + return buf; + } + } + + // serialize field element(s) to already allocated buffer. + // if buf is nullptr, then calc approximate serialize size only + // @return: the actual size of serialized buffer + size_t Serialize(const Item& x, uint8_t* buf, size_t buf_len) const override { + if (x.IsArray()) { + auto xsp = x.AsSpan(); + if (buf == nullptr) { // just calc size + buf_len = io::msgpack_traits::HeadSizeOfArray(xsp.length()); + for (size_t i = 0; i < xsp.length(); ++i) { + auto body_sz = Serialize(xsp[i], nullptr, 0); + buf_len += (body_sz + io::msgpack_traits::HeadSizeOfStr(body_sz)); + } + return buf_len; + } + + // actual pack + io::FixedBuffer sbuf(reinterpret_cast(buf), buf_len); + msgpack::packer packer(sbuf); + packer.pack_array(xsp.length()); + for (size_t i = 0; i < xsp.length(); ++i) { + packer.pack_str(Serialize(xsp[i], nullptr, 0)); + auto body_sz = Serialize( + xsp[i], reinterpret_cast(sbuf.PosLoc()), sbuf.FreeSize()); + sbuf.IncPos(body_sz); + } + return sbuf.WrittenSize(); + } else { + auto& xt = x.As(); + if (buf == nullptr) { + auto body_sz = Serialize(xt, nullptr, 0); + return body_sz + io::msgpack_traits::HeadSizeOfStr(body_sz); + } + + io::FixedBuffer sbuf(reinterpret_cast(buf), buf_len); + msgpack::packer packer(sbuf); + + auto sz = Serialize(xt, nullptr, 0); + packer.pack_str(sz); // pack header and size + auto body_sz = Serialize(xt, reinterpret_cast(sbuf.PosLoc()), + sbuf.FreeSize()); + return sbuf.WrittenSize() + body_sz; + } + } + + Item Deserialize(ByteContainerView buffer) const override { + msgpack::object_handle msg = msgpack::unpack( + reinterpret_cast(buffer.data()), buffer.size()); + + auto obj = msg.get(); + switch (obj.type) { + case msgpack::type::STR: + // scalar case + return DeserializeT({obj.via.str.ptr, obj.via.str.size}); + case msgpack::type::ARRAY: { + // vector case + std::vector res; + res.resize(obj.via.array.size); + yacl::parallel_for( + 0, obj.via.array.size, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + auto str_obj = obj.via.array.ptr[i]; + YACL_ENFORCE(str_obj.type == msgpack::type::STR, + "Deserialize: illegal format"); + res[i] = + DeserializeT({str_obj.via.str.ptr, str_obj.via.str.size}); + } + }); + return Item::Take(std::move(res)); + } + default: + YACL_THROW("Deserialize: unexpected type"); + } + } +}; + +} // namespace yacl::math diff --git a/yacl/math/galois_field/gf_spi.cc b/yacl/math/galois_field/gf_spi.cc new file mode 100644 index 00000000..1b2ae6fc --- /dev/null +++ b/yacl/math/galois_field/gf_spi.cc @@ -0,0 +1,24 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/math/galois_field/gf_spi.h" + +namespace yacl::math { + +GaloisFieldFactory& GaloisFieldFactory::Instance() { + static GaloisFieldFactory factory; + return factory; +} + +} // namespace yacl::math diff --git a/yacl/math/galois_field/gf_spi.h b/yacl/math/galois_field/gf_spi.h new file mode 100644 index 00000000..92c85053 --- /dev/null +++ b/yacl/math/galois_field/gf_spi.h @@ -0,0 +1,123 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "yacl/math/mpint/mp_int.h" +#include "yacl/utils/spi/item.h" +#include "yacl/utils/spi/spi_factory.h" + +namespace yacl::math { + +class GaloisField { + public: + virtual ~GaloisField() = default; + + //================================// + // meta info query // + //================================// + + virtual std::string GetLibraryName() const = 0; + virtual std::string GetFieldName() const = 0; + + // The order of Finite Field will always be k-th power of a prime number p. + // And in extension field, field order and field modulus are different and not + // directly related, which is unlike in normal prime field that field order is + // just field modulus. + // Note, the origin order(p^k) of extension field(degree k>1) is actually + // useless for field computation. So we usually disable `GetOrder` for + // extension field and set it to be 0, except we are dealing within a subfield + // from the upper extension field. + virtual MPInt GetOrder() const = 0; + virtual MPInt GetExtensionDegree() const = 0; // the k of GF(p^k) + virtual MPInt GetBaseFieldOrder() const = 0; // the p of GF(p^k) + + // get the additive identity + virtual Item GetIdentityZero() const = 0; + // get the multiplicative identity + virtual Item GetIdentityOne() const = 0; + + // Below functions: + // - if x is scalar, returns bool + // - if x is vectored, returns std::vector + virtual Item IsIdentityOne(const Item& x) const = 0; + virtual Item IsIdentityZero(const Item& x) const = 0; + virtual Item IsInField(const Item& x) const = 0; + + virtual bool Equal(const Item& x, const Item& y) const = 0; + + //==================================// + // operations defined on field // + //==================================// + + // get the additive inverse −a for all elements in set + virtual Item Neg(const Item& x) const = 0; + virtual void NegInplace(Item* x) const = 0; + + // get the multiplicative inverse 1/b for every nonzero element in set + virtual Item Inv(const Item& x) const = 0; + virtual void InvInplace(Item* x) const = 0; + + virtual Item Add(const Item& x, const Item& y) const = 0; + virtual void AddInplace(Item* x, const Item& y) const = 0; + + virtual Item Sub(const Item& x, const Item& y) const = 0; + virtual void SubInplace(Item* x, const Item& y) const = 0; + + virtual Item Mul(const Item& x, const Item& y) const = 0; + virtual void MulInplace(Item* x, const Item& y) const = 0; + + virtual Item Div(const Item& x, const Item& y) const = 0; + virtual void DivInplace(Item* x, const Item& y) const = 0; + + virtual Item Pow(const Item& x, const MPInt& y) const = 0; + virtual void PowInplace(Item* x, const MPInt& y) const = 0; + + // scalar version: return a random scalar element + virtual Item Random() const = 0; + // vector version: return a vector of 'count' elements + virtual Item Random(size_t count) const = 0; + + //================================// + // I/O // + //================================// + + virtual Item DeepCopy(const Item& x) const = 0; + + // To human-readable string + virtual std::string ToString(const Item& x) const = 0; + + virtual Buffer Serialize(const Item& x) const = 0; + // serialize field element(s) to already allocated buffer. + // if buf is nullptr, then calc serialize size only + // @return: the actual size of serialized buffer + virtual size_t Serialize(const Item& x, uint8_t* buf, + size_t buf_len) const = 0; + + virtual Item Deserialize(ByteContainerView buffer) const = 0; +}; + +class GaloisFieldFactory final : public SpiFactoryBase { + public: + static GaloisFieldFactory& Instance(); +}; + +#define REGISTER_GF_LIBRARY(lib_name, performance, checker, creator) \ + REGISTER_SPI_LIBRARY_HELPER(GaloisFieldFactory, lib_name, performance, \ + checker, creator) + +} // namespace yacl::math diff --git a/yacl/math/galois_field/gf_vector.h b/yacl/math/galois_field/gf_vector.h new file mode 100644 index 00000000..3b348a76 --- /dev/null +++ b/yacl/math/galois_field/gf_vector.h @@ -0,0 +1,169 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/types/span.h" + +#include "yacl/math/galois_field/gf_spi.h" + +namespace yacl::math { + +template +class GFVectorizedSketch : public GaloisField { + public: + // if x is scalar, returns bool + // if x is vectored, returns std::vector + virtual std::vector IsIdentityOne(absl::Span x) const = 0; + virtual std::vector IsIdentityZero(absl::Span x) const = 0; + virtual std::vector IsInField(absl::Span x) const = 0; + + virtual bool Equal(absl::Span x, absl::Span y) const = 0; + + //================================// + // operations defined on set // + //================================// + + // get the additive inverse −a for all elements in set + virtual std::vector Neg(absl::Span x) const = 0; + virtual void NegInplace(absl::Span x) const = 0; + + // get the multiplicative inverse 1/b for every nonzero element in set + virtual std::vector Inv(absl::Span x) const = 0; + virtual void InvInplace(absl::Span x) const = 0; + + virtual std::vector Add(absl::Span x, + absl::Span y) const = 0; + virtual void AddInplace(absl::Span x, absl::Span y) const = 0; + + virtual std::vector Sub(absl::Span x, + absl::Span y) const = 0; + virtual void SubInplace(absl::Span x, absl::Span y) const = 0; + + virtual std::vector Mul(absl::Span x, + absl::Span y) const = 0; + virtual void MulInplace(absl::Span x, absl::Span y) const = 0; + + virtual std::vector Div(absl::Span x, + absl::Span y) const = 0; + virtual void DivInplace(absl::Span x, absl::Span y) const = 0; + + virtual std::vector Pow(absl::Span x, const MPInt& y) const = 0; + virtual void PowInplace(absl::Span x, const MPInt& y) const = 0; + + virtual std::vector RandomT(size_t count) const = 0; + + //================================// + // I/O // + //================================// + + virtual std::vector DeepCopy(absl::Span x) const = 0; + + // To human-readable string + virtual std::string ToString(absl::Span x) const = 0; + + virtual Buffer Serialize(absl::Span x) const = 0; + // serialize field element(s) to already allocated buffer. + // if buf is nullptr, then calc serialize size only + // @return: the actual size of serialized buffer + virtual size_t Serialize(absl::Span x, uint8_t* buf, + size_t buf_len) const = 0; + + virtual std::vector DeserializeT(ByteContainerView buffer) const = 0; + + private: +#define DefineUnaryFunc(FuncName) \ + auto FuncName(const Item& x) const override { \ + return FuncName(x.AsSpan()); \ + } + +#define DefineUnaryInplaceFunc(FuncName) \ + void FuncName(Item* x) const override { return FuncName(x->AsSpan()); } + +#define DefineBinaryFunc(FuncName) \ + auto FuncName(const Item& x, const Item& y) const override { \ + return FuncName(x.AsSpan(), y.AsSpan()); \ + } + +#define DefineBinaryInplaceFunc(FuncName) \ + void FuncName(Item* x, const Item& y) const override { \ + FuncName(x->AsSpan(), y.AsSpan()); \ + } + + // if x is scalar, returns bool + // if x is vectored, returns std::vector + DefineUnaryFunc(IsIdentityOne); + DefineUnaryFunc(IsIdentityZero); + DefineUnaryFunc(IsInField); + DefineBinaryFunc(Equal); + + //==================================// + // operations defined on field // + //==================================// + + // get the additive inverse −a for all elements in set + DefineUnaryFunc(Neg); + DefineUnaryInplaceFunc(NegInplace); + + // get the multiplicative inverse 1/b for every nonzero element in set + DefineUnaryFunc(Inv); + DefineUnaryInplaceFunc(InvInplace); + + DefineBinaryFunc(Add); + DefineBinaryInplaceFunc(AddInplace); + + DefineBinaryFunc(Sub); + DefineBinaryInplaceFunc(SubInplace); + + DefineBinaryFunc(Mul); + DefineBinaryInplaceFunc(MulInplace); + + DefineBinaryFunc(Div); + DefineBinaryInplaceFunc(DivInplace); + + virtual Item Pow(const Item& x, const MPInt& y) const { + return Pow(x.AsSpan(), y); + } + + virtual void PowInplace(Item* x, const MPInt& y) const { + PowInplace(x->AsSpan(), y); + } + + Item Random() const override { return RandomT(1)[0]; } + + Item Random(size_t count) const override { return RandomT(count); } + + //================================// + // I/O // + //================================// + + DefineUnaryFunc(DeepCopy); + + // To human-readable string + DefineUnaryFunc(ToString); + DefineUnaryFunc(Serialize); + + // serialize field element(s) to already allocated buffer. + // if buf is nullptr, then calc serialize size only + // @return: the actual size of serialized buffer + virtual size_t Serialize(const Item& x, uint8_t* buf, size_t buf_len) const { + return Serialize(x.AsSpan(), buf, buf_len); + } + + virtual Item Deserialize(ByteContainerView buffer) const { + return DeserializeT(buffer); + } +}; + +} // namespace yacl::math diff --git a/yacl/math/galois_field/mpint_field/BUILD.bazel b/yacl/math/galois_field/mpint_field/BUILD.bazel new file mode 100644 index 00000000..5798dc84 --- /dev/null +++ b/yacl/math/galois_field/mpint_field/BUILD.bazel @@ -0,0 +1,55 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:yacl.bzl", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test") + +package(default_visibility = ["//visibility:public"]) + +yacl_cc_library( + name = "configs", + hdrs = ["configs.h"], + deps = [ + "//yacl/math/mpint", + "//yacl/utils/spi/argument", + ], +) + +yacl_cc_library( + name = "mpint_field", + srcs = ["mpint_field.cc"], + hdrs = ["mpint_field.h"], + deps = [ + ":configs", + "//yacl/math/galois_field:sketch", + ], + alwayslink = 1, +) + +yacl_cc_test( + name = "mpint_field_test", + srcs = ["mpint_field_test.cc"], + deps = [ + ":mpint_field", + ], +) + +yacl_cc_binary( + name = "bench", + srcs = ["mpint_field_bench.cc"], + deps = [ + ":configs", + "//yacl/math/galois_field", + "@com_github_google_benchmark//:benchmark", + ], +) diff --git a/yacl/math/galois_field/mpint_field/configs.h b/yacl/math/galois_field/mpint_field/configs.h new file mode 100644 index 00000000..796b6d19 --- /dev/null +++ b/yacl/math/galois_field/mpint_field/configs.h @@ -0,0 +1,41 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/math/mpint/mp_int.h" +#include "yacl/utils/spi/argument/argument.h" + +// How to use mpint field? +// +// > #include "yacl/math/galois_field/gf_spi.h" +// > #include "yacl/math/galois_field/mpint_field/configs.h" +// > +// > void foo() { +// > auto gf = GaloisFieldFactory::Instance().Create("Zn", ArgMod = 13_mp); +// > auto sum = gf->Add(10_mp, 5_mp); // output 2 +// > } +// +// Note 1: Do not include 'mpint_field.h', include 'configs.h' instead. +// Note 2: Get mpint field instance by `GaloisFieldFactory::Instance().Create()` + +namespace yacl::math::mpf { + +inline const std::string kFieldName = "Zp"; +inline const std::string kLibName = "mpint"; + +// Prd-defined options... +DECLARE_ARG(MPInt, Mod); + +} // namespace yacl::math::mpf diff --git a/yacl/math/galois_field/mpint_field/mpint_field.cc b/yacl/math/galois_field/mpint_field/mpint_field.cc new file mode 100644 index 00000000..fae3dc9d --- /dev/null +++ b/yacl/math/galois_field/mpint_field/mpint_field.cc @@ -0,0 +1,152 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/math/galois_field/mpint_field/mpint_field.h" + +#include "yacl/math/galois_field/mpint_field/configs.h" + +namespace yacl::math::mpf { + +DEFINE_ARG(MPInt, Mod); + +REGISTER_GF_LIBRARY(kLibName, 100, MPIntField::Check, MPIntField::Create); + +std::unique_ptr MPIntField::Create(const std::string &field_name, + const SpiArgs &args) { + YACL_ENFORCE(field_name == kFieldName); + auto mod = args.GetRequired(ArgMod); + YACL_ENFORCE(mod.IsPrime(), "ArgMod must be a prime"); + return std::unique_ptr(new MPIntField(std::move(mod))); +} + +bool MPIntField::Check(const std::string &field_name, const SpiArgs &) { + return field_name == kFieldName; +} + +std::string MPIntField::GetLibraryName() const { return kLibName; } + +std::string MPIntField::GetFieldName() const { return kFieldName; } + +MPInt MPIntField::GetOrder() const { return mod_; } + +MPInt MPIntField::GetExtensionDegree() const { return MPInt::_1_; } + +MPInt MPIntField::GetBaseFieldOrder() const { return mod_; } + +Item MPIntField::GetIdentityZero() const { return MPInt::_0_; } + +Item MPIntField::GetIdentityOne() const { return MPInt::_1_; } + +bool MPIntField::IsIdentityOne(const MPInt &x) const { return x == MPInt::_1_; } + +bool MPIntField::IsIdentityZero(const MPInt &x) const { return x.IsZero(); } + +bool MPIntField::IsInField(const MPInt &x) const { + return x.IsNatural() && x < mod_; +} + +bool MPIntField::Equal(const MPInt &x, const MPInt &y) const { return x == y; } + +//==================================// +// operations defined on field // +//==================================// + +MPInt MPIntField::Add(const MPInt &x, const MPInt &y) const { + return x.AddMod(y, mod_); +} + +void MPIntField::AddInplace(MPInt *x, const MPInt &y) const { + MPInt::AddMod(*x, y, mod_, x); +} + +MPInt MPIntField::Neg(const MPInt &x) const { + if (x.IsZero()) { + return x; + } + + WEAK_ENFORCE(IsInField(x), "x is not a valid field element, x={}", x); + return mod_ - x; +} + +void MPIntField::NegInplace(MPInt *x) const { + if (x->IsZero()) { + return; + } + + WEAK_ENFORCE(IsInField(*x), "x is not a valid field element, x={}", *x); + x->NegateInplace(); + AddInplace(x, mod_); + x->DecrOne(); +} + +MPInt MPIntField::Inv(const MPInt &x) const { return x.InvertMod(mod_); } + +void MPIntField::InvInplace(MPInt *x) const { MPInt::InvertMod(*x, mod_, x); } + +MPInt MPIntField::Sub(const MPInt &x, const MPInt &y) const { + return x.SubMod(y, mod_); +} + +void MPIntField::SubInplace(MPInt *x, const MPInt &y) const { + MPInt::SubMod(*x, y, mod_, x); +} + +MPInt MPIntField::Mul(const MPInt &x, const MPInt &y) const { + return x.MulMod(y, mod_); +} + +void MPIntField::MulInplace(MPInt *x, const MPInt &y) const { + MPInt::MulMod(*x, y, mod_, x); +} + +MPInt MPIntField::Div(const MPInt &x, const MPInt &y) const { + return x.MulMod(y.InvertMod(mod_), mod_); +} + +void MPIntField::DivInplace(MPInt *x, const MPInt &y) const { + MPInt::MulMod(*x, y.InvertMod(mod_), mod_, x); +} + +MPInt MPIntField::Pow(const MPInt &x, const MPInt &y) const { + return x.PowMod(y, mod_); +} + +void MPIntField::PowInplace(MPInt *x, const MPInt &y) const { + MPInt::PowMod(*x, y, mod_, x); +} + +MPInt MPIntField::RandomT() const { + MPInt res; + MPInt::RandomLtN(mod_, &res); + return res; +} + +MPInt MPIntField::DeepCopy(const MPInt &x) const { return x; } + +std::string MPIntField::ToString(const MPInt &x) const { return x.ToString(); } + +Buffer MPIntField::Serialize(const MPInt &x) const { return x.Serialize(); } + +size_t MPIntField::Serialize(const MPInt &x, uint8_t *buf, + size_t buf_len) const { + return x.Serialize(buf, buf_len); +} + +MPInt MPIntField::DeserializeT(ByteContainerView buffer) const { + MPInt res; + res.Deserialize(buffer); + return res; +} + +} // namespace yacl::math::mpf diff --git a/yacl/math/galois_field/mpint_field/mpint_field.h b/yacl/math/galois_field/mpint_field/mpint_field.h new file mode 100644 index 00000000..bc0feee0 --- /dev/null +++ b/yacl/math/galois_field/mpint_field/mpint_field.h @@ -0,0 +1,82 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/math/galois_field/gf_scalar.h" +#include "yacl/math/mpint/mp_int.h" + +namespace yacl::math::mpf { + +class MPIntField : public GFScalarSketch { + public: + static std::unique_ptr Create(const std::string &field_name, + const SpiArgs &args); + static bool Check(const std::string &field_name, const SpiArgs &); + ~MPIntField() override = default; + + std::string GetLibraryName() const override; + std::string GetFieldName() const override; + + MPInt GetOrder() const override; + MPInt GetExtensionDegree() const override; + MPInt GetBaseFieldOrder() const override; + + Item GetIdentityZero() const override; + Item GetIdentityOne() const override; + + bool IsIdentityOne(const MPInt &x) const override; + bool IsIdentityZero(const MPInt &x) const override; + bool IsInField(const MPInt &x) const override; + bool Equal(const MPInt &x, const MPInt &y) const override; + + //==================================// + // operations defined on field // + //==================================// + + MPInt Neg(const MPInt &x) const override; + void NegInplace(MPInt *x) const override; + MPInt Inv(const MPInt &x) const override; + void InvInplace(MPInt *x) const override; + + MPInt Add(const MPInt &x, const MPInt &y) const override; + void AddInplace(MPInt *x, const MPInt &y) const override; + + MPInt Sub(const MPInt &x, const MPInt &y) const override; + void SubInplace(MPInt *x, const MPInt &y) const override; + MPInt Mul(const MPInt &x, const MPInt &y) const override; + void MulInplace(MPInt *x, const MPInt &y) const override; + MPInt Div(const MPInt &x, const MPInt &y) const override; + void DivInplace(MPInt *x, const MPInt &y) const override; + MPInt Pow(const MPInt &x, const MPInt &y) const override; + void PowInplace(MPInt *x, const MPInt &y) const override; + + MPInt RandomT() const override; + + MPInt DeepCopy(const MPInt &x) const override; + std::string ToString(const MPInt &x) const override; + + Buffer Serialize(const MPInt &x) const override; + size_t Serialize(const MPInt &x, uint8_t *buf, size_t buf_len) const override; + MPInt DeserializeT(ByteContainerView buffer) const override; + + private: + explicit MPIntField(MPInt mod) : mod_(std::move(mod)) {} + + MPInt mod_; +}; + +} // namespace yacl::math::mpf diff --git a/yacl/math/galois_field/mpint_field/mpint_field_bench.cc b/yacl/math/galois_field/mpint_field/mpint_field_bench.cc new file mode 100644 index 00000000..6b7818f3 --- /dev/null +++ b/yacl/math/galois_field/mpint_field/mpint_field_bench.cc @@ -0,0 +1,55 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +#include "yacl/math/galois_field/gf_spi.h" +#include "yacl/math/galois_field/mpint_field/configs.h" +#include "yacl/math/mpint/mp_int.h" + +using yacl::math::MPInt; + +// state.range(0): bits of number +static void BM_MPIntAddMod(benchmark::State& state) { + MPInt m1, m2, mod; + MPInt::RandomExactBits(state.range(0), &m1); + MPInt::RandomExactBits(state.range(0), &m2); + MPInt::RandomExactBits(state.range(0) - 1, &mod); + for (auto _ : state) { + benchmark::DoNotOptimize(m1.AddMod(m2, mod)); + } +} + +// state.range(0): bits of number +static void BM_MpfAdd(benchmark::State& state) { + MPInt m1, m2, mod; + MPInt::RandomExactBits(state.range(0), &m1); + MPInt::RandomExactBits(state.range(0), &m2); + MPInt::RandomExactBits(state.range(0) - 1, &mod); + + auto spi = yacl::math::GaloisFieldFactory::Instance().Create( + "Zp", yacl::ArgLib = "mpint", yacl::math::mpf::ArgMod = mod); + + for (auto _ : state) { + benchmark::DoNotOptimize(spi->Add(m1, m2)); + } +} + +BENCHMARK(BM_MPIntAddMod)->Arg(64)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_MpfAdd)->Arg(64)->Arg(1024)->Arg(2048)->Arg(4096); + +int main() { + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/yacl/math/galois_field/mpint_field/mpint_field_test.cc b/yacl/math/galois_field/mpint_field/mpint_field_test.cc new file mode 100644 index 00000000..bd0f8544 --- /dev/null +++ b/yacl/math/galois_field/mpint_field/mpint_field_test.cc @@ -0,0 +1,252 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include "yacl/math/galois_field/gf_spi.h" +#include "yacl/math/galois_field/mpint_field/configs.h" + +namespace yacl::math::mpf::test { + +class MPIntFieldTest : public testing::Test {}; + +TEST_F(MPIntFieldTest, AddWorks) { + auto gf = GaloisFieldFactory::Instance().Create(kFieldName, ArgLib = kLibName, + ArgMod = 13_mp); + + EXPECT_EQ(gf->GetLibraryName(), kLibName); + EXPECT_EQ(gf->GetFieldName(), kFieldName); + + EXPECT_EQ(gf->GetOrder(), 13_mp); + EXPECT_TRUE(gf->GetExtensionDegree().IsOne()); + EXPECT_EQ(gf->GetBaseFieldOrder(), 13_mp); + + EXPECT_EQ(gf->GetIdentityZero(), 0_mp); + EXPECT_EQ(gf->GetIdentityOne(), 1_mp); +} + +TEST_F(MPIntFieldTest, ScalarWorks) { + auto gf = GaloisFieldFactory::Instance().Create(kFieldName, ArgLib = kLibName, + ArgMod = 13_mp); + + EXPECT_TRUE((bool)gf->IsIdentityZero(0_mp)); + EXPECT_FALSE((bool)gf->IsIdentityZero(1_mp)); + EXPECT_FALSE((bool)gf->IsIdentityOne(0_mp)); + EXPECT_TRUE((bool)gf->IsIdentityOne(1_mp)); + + EXPECT_TRUE((bool)gf->IsInField(0_mp)); + EXPECT_TRUE((bool)gf->IsInField(1_mp)); + EXPECT_TRUE((bool)gf->IsInField(12_mp)); + EXPECT_FALSE((bool)gf->IsInField(13_mp)); + EXPECT_FALSE((bool)gf->IsInField(-1_mp)); + + EXPECT_TRUE((bool)gf->Equal(0_mp, 0_mp)); + EXPECT_FALSE((bool)gf->Equal(1_mp, 0_mp)); + EXPECT_TRUE((bool)gf->Equal(12_mp, 12_mp)); + + // operands // + EXPECT_EQ(gf->Neg(0_mp), 0_mp); + EXPECT_EQ(gf->Neg(1_mp), 12_mp); + EXPECT_EQ(gf->Neg(6_mp), 7_mp); + EXPECT_EQ(gf->Neg(7_mp), 6_mp); + EXPECT_EQ(gf->Neg(12_mp), 1_mp); + + EXPECT_EQ(gf->Inv(1_mp), 1_mp); + EXPECT_EQ(gf->Inv(2_mp), 7_mp); + EXPECT_EQ(gf->Inv(3_mp), 9_mp); + EXPECT_EQ(gf->Inv(7_mp), 2_mp); + EXPECT_EQ(gf->Inv(9_mp), 3_mp); + EXPECT_ANY_THROW(gf->Inv(0_mp)); // error + + EXPECT_EQ(gf->Add(10_mp, 5_mp), 2_mp); + EXPECT_NE(gf->Add(10_mp, 5_mp), 3_mp); // test item not equal + + EXPECT_EQ(gf->Sub(10_mp, 12_mp), 11_mp); // 23 - 12 + EXPECT_EQ(gf->Sub(0_mp, 0_mp), 0_mp); + EXPECT_EQ(gf->Sub(10_mp, 1_mp), 9_mp); + + EXPECT_EQ(gf->Mul(10_mp, 12_mp), 3_mp); + EXPECT_EQ(gf->Mul(1_mp, 12_mp), 12_mp); + EXPECT_EQ(gf->Mul(0_mp, 12_mp), 0_mp); + + EXPECT_EQ(gf->Div(10_mp, 1_mp), 10_mp); + EXPECT_EQ(gf->Div(10_mp, 2_mp), 5_mp); + EXPECT_EQ(gf->Div(12_mp, 1_mp), 12_mp); + EXPECT_EQ(gf->Div(12_mp, 12_mp), 1_mp); + EXPECT_EQ(gf->Div(3_mp, 10_mp), 12_mp); + EXPECT_EQ(gf->Div(3_mp, 12_mp), 10_mp); + EXPECT_EQ(gf->Div(10_mp, 5_mp), 2_mp); + EXPECT_ANY_THROW(gf->Div(10_mp, 0_mp)); // error + EXPECT_EQ(gf->Div(0_mp, 1_mp), 0_mp); + EXPECT_EQ(gf->Div(0_mp, 11_mp), 0_mp); + + EXPECT_EQ(gf->Pow(10_mp, 0_mp), 1_mp); + EXPECT_EQ(gf->Pow(10_mp, 1_mp), 10_mp); + EXPECT_EQ(gf->Pow(10_mp, 2_mp), 9_mp); + + auto r1 = gf->Random(); + auto r2 = gf->Random(); + EXPECT_TRUE((bool)gf->IsInField(r1)); + EXPECT_TRUE(gf->IsInField(r2).IsAll(true)); + + // I/O // + + MPInt mp1 = 12_mp; + Item mp2 = gf->DeepCopy(mp1); + mp1.DecrOne(); + EXPECT_EQ(mp1, 11_mp); + EXPECT_EQ(mp2, 12_mp); + + EXPECT_EQ(gf->ToString(mp2), "12"); + + // serialize + Buffer buf = gf->Serialize(mp2); + auto mp3 = gf->Deserialize(buf); + EXPECT_EQ(mp3, 12_mp); + + MPInt::RandomExactBits(4096, &mp1); + buf.reset(); + buf.resize(gf->Serialize(mp1, nullptr, 0)); + auto real_sz = gf->Serialize(mp1, buf.data(), buf.size()); + EXPECT_EQ(gf->Deserialize(buf), mp1); + buf.resize(real_sz); + EXPECT_EQ(gf->Deserialize(buf), mp1); +} + +TEST_F(MPIntFieldTest, VectorWorks) { + auto gf = GaloisFieldFactory::Instance().Create(kFieldName, ArgLib = kLibName, + ArgMod = 13_mp); + + // test item format + std::vector a = {1_mp, 2_mp, 3_mp}; + std::vector b = {11_mp, 12_mp, 13_mp}; + auto sum_v = gf->Add(Item::Ref(a), Item::Ref(b)); + ASSERT_TRUE(sum_v.IsArray()); + ASSERT_FALSE(sum_v.IsView()); + + // test item (not) equal + auto sum_sp = sum_v.AsSpan(); + EXPECT_EQ(sum_sp.length(), 3); + EXPECT_EQ(sum_sp, absl::MakeConstSpan({12_mp, 1_mp, 3_mp})); + EXPECT_NE(sum_sp, absl::MakeConstSpan({12_mp, 1_mp})); + EXPECT_NE(sum_sp, absl::MakeConstSpan({12_mp, 1_mp, 3_mp, 12_mp})); + EXPECT_NE(sum_sp, absl::MakeConstSpan({12_mp, 1_mp, 4_mp})); + + EXPECT_EQ(gf->IsIdentityZero(Item::Take({0_mp, 1_mp, 2_mp, 0_mp})), + std::vector({true, false, false, true})); + EXPECT_NE(gf->IsIdentityZero(Item::Take({0_mp, 1_mp, 2_mp, 0_mp})), + std::vector({true, false, false, false})); + EXPECT_EQ(gf->IsIdentityOne(Item::Take({0_mp, 1_mp, 2_mp, 0_mp})), + std::vector({false, true, false, false})); + EXPECT_EQ(gf->IsInField( + Item::Take({0_mp, -1_mp, 2_mp, 12_mp, 13_mp, 50_mp})), + std::vector({true, false, true, true, false, false})); + + // test gf->Equal + EXPECT_TRUE((bool)gf->Equal(Item::Take({}), Item::Take({}))); + EXPECT_TRUE( + (bool)gf->Equal(Item::Take({0_mp}), Item::Take({0_mp}))); + EXPECT_TRUE((bool)gf->Equal(Item::Take({0_mp, 10_mp}), + Item::Take({0_mp, 10_mp}))); + EXPECT_TRUE((bool)gf->Equal(Item::Take({0_mp, 10_mp, 5_mp, 7_mp}), + Item::Take({0_mp, 10_mp, 5_mp, 7_mp}))); + EXPECT_FALSE((bool)gf->Equal(Item::Take({0_mp, 10_mp, 5_mp, 7_mp}), + Item::Take({0_mp, 10_mp, 5_mp}))); + EXPECT_FALSE((bool)gf->Equal(Item::Take({0_mp, 10_mp, 5_mp}), + Item::Take({0_mp, 10_mp, 5_mp, 7_mp}))); + EXPECT_FALSE((bool)gf->Equal(Item::Take({0_mp, 10_mp, 5_mp, 6_mp}), + Item::Take({0_mp, 10_mp, 5_mp, 7_mp}))); + EXPECT_FALSE((bool)gf->Equal(Item::Take({0_mp, 10_mp, 5_mp, 7_mp}), + Item::Take({1_mp, 10_mp, 5_mp, 7_mp}))); + EXPECT_FALSE((bool)gf->Equal(Item::Take({0_mp, 1_mp, 2_mp, 3_mp}), + Item::Take({3_mp, 2_mp, 1_mp, 0_mp}))); + + // operands // + + EXPECT_EQ(gf->Neg(Item::Take({0_mp, 1_mp, 2_mp, 3_mp})), + std::vector({0_mp, 12_mp, 11_mp, 10_mp})); + EXPECT_EQ(gf->Inv(Item::Take({1_mp, 2_mp, 3_mp})), + std::vector({1_mp, 7_mp, 9_mp})); + EXPECT_ANY_THROW(gf->Inv(Item::Take({0_mp, 2_mp, 3_mp}))); // error + + EXPECT_EQ(gf->Add(Item::Take({0_mp, 1_mp, 2_mp, 3_mp}), + Item::Take({7_mp, 6_mp, 5_mp, 4_mp})), + std::vector({7_mp, 7_mp, 7_mp, 7_mp})); + EXPECT_EQ(gf->Sub(Item::Take({0_mp, 1_mp, 2_mp, 3_mp}), + Item::Take({7_mp, 6_mp, 5_mp, 4_mp})), + std::vector({6_mp, 8_mp, 10_mp, 12_mp})); + EXPECT_EQ(gf->Mul(Item::Take({0_mp, 1_mp, 2_mp, 3_mp}), + Item::Take({7_mp, 6_mp, 5_mp, 4_mp})), + std::vector({0_mp, 6_mp, 10_mp, 12_mp})); + EXPECT_EQ(gf->Div(Item::Take({0_mp, 1_mp, 3_mp, 3_mp}), + Item::Take({7_mp, 1_mp, 12_mp, 10_mp})), + std::vector({0_mp, 1_mp, 10_mp, 12_mp})); + EXPECT_EQ(gf->Pow(Item::Take({0_mp, 1_mp, 3_mp, 4_mp}), 2_mp), + std::vector({0_mp, 1_mp, 9_mp, 3_mp})); + + auto r1 = gf->Random(1000); + auto check1 = gf->IsInField(r1); + EXPECT_TRUE(check1.IsAll(true)); + + auto r2 = gf->Random(1000); + EXPECT_TRUE(gf->IsInField(r2).IsAll(true)); + EXPECT_FALSE(gf->Equal(r1, r2)); + EXPECT_FALSE(gf->Sub(r1, r2).IsAll(0_mp)); +} + +TEST_F(MPIntFieldTest, VectorIoWorks) { + MPInt mod; + MPInt::RandPrimeOver(1024, &mod, PrimeType::Normal); + auto gf = GaloisFieldFactory::Instance().Create(kFieldName, ArgLib = kLibName, + ArgMod = mod); + + // subspan + auto item1 = Item::Take({0_mp, 1_mp, 2_mp, 3_mp}); + auto item2 = item1.SubSpan(0, 1); + ASSERT_TRUE(item2.IsView()); + ASSERT_FALSE(item2.IsReadOnly()); + ASSERT_TRUE(item2.IsHoldType>()); + ASSERT_TRUE(gf->Equal(item2, Item::Take({0_mp}))); + + // deepcopy + Item item3 = gf->DeepCopy(item1); + ASSERT_TRUE(gf->Equal(item1, item3)); + + item2.AsSpan()[0] = 10_mp; + ASSERT_TRUE(gf->Equal(item1, Item::Take({10_mp, 1_mp, 2_mp, 3_mp}))); + ASSERT_TRUE(gf->Equal(item3, Item::Take({0_mp, 1_mp, 2_mp, 3_mp}))); + + // to string + EXPECT_EQ(gf->ToString(item2), "[10]"); + EXPECT_EQ(gf->ToString(item3), "[0, 1, 2, 3]"); + + // serialize + Buffer buf = gf->Serialize(item3); + auto item4 = gf->Deserialize(buf); + EXPECT_TRUE(gf->Equal(item3, item4)); + + std::vector vt; + vt.resize(1024); + for (int i = 0; i < 1024; ++i) { + MPInt::RandomExactBits(i, &vt[i]); + } + + item1 = Item::Ref(vt); + buf.resize(gf->Serialize(item1, nullptr, 0)); + auto real_sz = gf->Serialize(item1, buf.data(), buf.size()); + EXPECT_EQ(real_sz, buf.size()); + EXPECT_EQ(gf->Deserialize(buf), vt); +} + +} // namespace yacl::math::mpf::test diff --git a/yacl/math/mpint/BUILD.bazel b/yacl/math/mpint/BUILD.bazel index 2d19d3aa..1ccfd42e 100644 --- a/yacl/math/mpint/BUILD.bazel +++ b/yacl/math/mpint/BUILD.bazel @@ -24,6 +24,7 @@ yacl_cc_library( name = "mp_int_enforce", hdrs = ["mp_int_enforce.h"], deps = [ + "//yacl/base:exception", "@com_github_fmtlib_fmt//:fmtlib", "@com_github_libtom_libtommath//:libtommath", ], @@ -49,6 +50,7 @@ yacl_cc_library( srcs = ["tommath_ext_types.cc"], hdrs = ["tommath_ext_types.h"], deps = [ + ":mp_int_enforce", "//yacl/base:int128", "@com_github_libtom_libtommath//:libtommath", ], @@ -88,7 +90,7 @@ yacl_cc_test( ) yacl_cc_test( - name = "mp_ext_test", + name = "mpx_test", srcs = ["tommath_ext_test.cc"], deps = [ ":tommath_ext_features", diff --git a/yacl/math/mpint/montgomery_math.cc b/yacl/math/mpint/montgomery_math.cc index 79c84847..c9073992 100644 --- a/yacl/math/mpint/montgomery_math.cc +++ b/yacl/math/mpint/montgomery_math.cc @@ -16,7 +16,8 @@ namespace yacl::math { -MontgomerySpace::MontgomerySpace(const MPInt &mod) { +MontgomerySpace::MontgomerySpace(const MPInt &mod) : identity_(0) { + // init identity_ to 0 to make sure memory is allocated YACL_ENFORCE(!mod.IsNegative() && mod.IsOdd(), "modulus must be a positive odd number"); mod_ = mod; diff --git a/yacl/math/mpint/mp_int.cc b/yacl/math/mpint/mp_int.cc index bb4ff71d..6adf7903 100644 --- a/yacl/math/mpint/mp_int.cc +++ b/yacl/math/mpint/mp_int.cc @@ -35,7 +35,12 @@ const MPInt MPInt::_0_(0); const MPInt MPInt::_1_(1); const MPInt MPInt::_2_(2); -MPInt::MPInt() { MPINT_ENFORCE_OK(mp_init(&n_)); } +MPInt::MPInt() { + // Use mpx_init instead of mp_init is hazardous. + // Therefore, any changes to MPInt’s public interface must be tested + // carefully. + mpx_init(&n_); +} MPInt::MPInt(const std::string &num, size_t radix) { MPINT_ENFORCE_OK(mp_init(&n_)); @@ -70,12 +75,12 @@ MPInt &MPInt::operator=(MPInt &&other) noexcept { template <> int8_t MPInt::Get() const { - return mp_get_i8(&n_); + return mpx_get_i8(&n_); } template <> int16_t MPInt::Get() const { - return mp_get_i16(&n_); + return mpx_get_i16(&n_); } template <> @@ -90,17 +95,17 @@ int64_t MPInt::Get() const { template <> int128_t MPInt::Get() const { - return mp_get_i128(&n_); + return mpx_get_i128(&n_); } template <> uint8_t MPInt::Get() const { - return mp_get_mag_u8(&n_); + return mpx_get_mag_u8(&n_); } template <> uint16_t MPInt::Get() const { - return mp_get_mag_u16(&n_); + return mpx_get_mag_u16(&n_); } template <> @@ -123,7 +128,7 @@ unsigned long MPInt::Get() const { // NOLINT: macOS uint64_t is ull template <> uint128_t MPInt::Get() const { - return mp_get_mag_u128(&n_); + return mpx_get_mag_u128(&n_); } template <> @@ -143,77 +148,73 @@ MPInt MPInt::Get() const { template <> void MPInt::Set(int8_t value) { - mp_set_i8(&n_, value); + mpx_set_i8(&n_, value); } template <> void MPInt::Set(int16_t value) { - mp_set_i16(&n_, value); + mpx_set_i16(&n_, value); } template <> void MPInt::Set(int32_t value) { - mp_set_i32(&n_, value); + mpx_set_i32(&n_, value); } template <> void MPInt::Set(int64_t value) { - MPINT_ENFORCE_OK(mp_grow(&n_, 2)); - mp_set_i64(&n_, value); + mpx_set_i64(&n_, value); } #ifdef __APPLE__ template <> void MPInt::Set(long value) { // NOLINT: macOS int64_t is ll static_assert(sizeof(long) == 8); - mp_set_i64(&n_, value); + mpx_set_i64(&n_, value); } #endif template <> void MPInt::Set(int128_t value) { MPINT_ENFORCE_OK(mp_grow(&n_, 3)); - mp_set_i128(&n_, value); + mpx_set_i128(&n_, value); } template <> void MPInt::Set(uint8_t value) { - mp_set_u8(&n_, value); + mpx_set_u8(&n_, value); } template <> void MPInt::Set(uint16_t value) { - mp_set_u16(&n_, value); + mpx_set_u16(&n_, value); } template <> void MPInt::Set(uint32_t value) { - mp_set_u32(&n_, value); + mpx_set_u32(&n_, value); } template <> void MPInt::Set(uint64_t value) { - MPINT_ENFORCE_OK(mp_grow(&n_, 2)); - mp_set_u64(&n_, value); + mpx_set_u64(&n_, value); } #ifdef __APPLE__ template <> void MPInt::Set(unsigned long value) { // NOLINT: macOS uint64_t is ull static_assert(sizeof(unsigned long) == 8); - mp_set_u64(&n_, value); + mpx_set_u64(&n_, value); } #endif template <> void MPInt::Set(uint128_t value) { - MPINT_ENFORCE_OK(mp_grow(&n_, 3)); - mp_set_u128(&n_, value); + mpx_set_u128(&n_, value); } template <> void MPInt::Set(float value) { - MPINT_ENFORCE_OK(mp_grow(&n_, 2)); MPINT_ENFORCE_OK(mp_set_double(&n_, value)); } @@ -281,11 +282,11 @@ void MPInt::SetZero() { mp_zero(&n_); } uint8_t MPInt::operator[](int idx) const { return GetBit(idx); } -uint8_t MPInt::GetBit(int idx) const { return mp_ext_get_bit(n_, idx); } +uint8_t MPInt::GetBit(int idx) const { return mpx_get_bit(n_, idx); } -void MPInt::SetBit(int idx, uint8_t bit) { mp_ext_set_bit(&n_, idx, bit); } +void MPInt::SetBit(int idx, uint8_t bit) { mpx_set_bit(&n_, idx, bit); } -size_t MPInt::BitCount() const { return mp_ext_count_bits_fast(n_); } +size_t MPInt::BitCount() const { return mpx_count_bits_fast(n_); } bool MPInt::operator>=(const MPInt &other) const { return Compare(other) >= 0; } bool MPInt::operator<=(const MPInt &other) const { return Compare(other) <= 0; } @@ -421,23 +422,27 @@ std::ostream &operator<<(std::ostream &os, const MPInt &an_int) { } MPInt &MPInt::DecrOne() & { + mpx_reserve(&n_, 1); MPINT_ENFORCE_OK(mp_decr(&n_)); return *this; } MPInt &MPInt::IncrOne() & { + mpx_reserve(&n_, 1); MPINT_ENFORCE_OK(mp_incr(&n_)); return *this; } -MPInt MPInt::DecrOne() && { +MPInt &&MPInt::DecrOne() && { + mpx_reserve(&n_, 1); MPINT_ENFORCE_OK(mp_decr(&n_)); - return *this; + return std::move(*this); } -MPInt MPInt::IncrOne() && { +MPInt &&MPInt::IncrOne() && { + mpx_reserve(&n_, 1); MPINT_ENFORCE_OK(mp_incr(&n_)); - return *this; + return std::move(*this); } MPInt MPInt::Abs() const { @@ -498,11 +503,13 @@ void MPInt::MulMod(const MPInt &a, const MPInt &b, const MPInt &mod, MPInt *d) { } void MPInt::Pow(const MPInt &a, uint32_t b, MPInt *c) { + mpx_reserve(&c->n_, MP_BITS_TO_DIGITS(mpx_count_bits_fast(a.n_) * b)); MPINT_ENFORCE_OK(mp_expt_n(&a.n_, b, &c->n_)); } MPInt MPInt::Pow(uint32_t b) const { MPInt res; + mpx_reserve(&res.n_, mpx_count_bits_fast(n_) * b); MPINT_ENFORCE_OK(mp_expt_n(&n_, b, &res.n_)); return res; } @@ -571,7 +578,7 @@ void MPInt::RandomRoundUp(size_t bit_size, MPInt *r) { } void MPInt::RandomExactBits(size_t bit_size, MPInt *r) { - mp_ext_rand_bits(&r->n_, bit_size); + mpx_rand_bits(&r->n_, bit_size); } void MPInt::RandomMonicExactBits(size_t bit_size, MPInt *r) { @@ -592,7 +599,7 @@ void MPInt::RandPrimeOver(size_t bit_size, MPInt *out, PrimeType prime_type) { int trials = mp_prime_rabin_miller_trials(bit_size); if (prime_type == PrimeType::FastSafe) { - mp_ext_safe_prime_rand(&out->n_, trials, bit_size); + mpx_safe_prime_rand(&out->n_, trials, bit_size); } else { MPINT_ENFORCE_OK(mp_prime_rand(&out->n_, trials, bit_size, static_cast(prime_type))); @@ -623,22 +630,22 @@ std::string MPInt::ToHexString() const { return ToRadixString(16); } std::string MPInt::ToString() const { return ToRadixString(10); } yacl::Buffer MPInt::Serialize() const { - size_t size = mp_ext_serialize_size(n_); + size_t size = mpx_serialize_size(n_); yacl::Buffer buffer(size); - mp_ext_serialize(n_, buffer.data(), size); + mpx_serialize(n_, buffer.data(), size); return buffer; } size_t MPInt::Serialize(uint8_t *buf, size_t buf_len) const { if (buf == nullptr) { - return mp_ext_serialize_size(n_); + return mpx_serialize_size(n_); } - return mp_ext_serialize(n_, buf, buf_len); + return mpx_serialize(n_, buf, buf_len); } void MPInt::Deserialize(yacl::ByteContainerView buffer) { - mp_ext_deserialize(&n_, buffer.data(), buffer.size()); + mpx_deserialize(&n_, buffer.data(), buffer.size()); } yacl::Buffer MPInt::ToBytes(size_t byte_len, Endian endian) const { @@ -648,13 +655,13 @@ yacl::Buffer MPInt::ToBytes(size_t byte_len, Endian endian) const { } void MPInt::ToBytes(unsigned char *buf, size_t buf_len, Endian endian) const { - mp_ext_to_bytes(n_, buf, buf_len, endian); + mpx_to_bytes(n_, buf, buf_len, endian); } yacl::Buffer MPInt::ToMagBytes(Endian endian) const { - size_t size = mp_ext_mag_bytes_size(n_); + size_t size = mpx_mag_bytes_size(n_); yacl::Buffer buffer(size); - mp_ext_to_mag_bytes(n_, buffer.data(), size, endian); + mpx_to_mag_bytes(n_, buffer.data(), size, endian); return buffer; } @@ -664,14 +671,14 @@ yacl::Buffer MPInt::ToMagBytes(Endian endian) const { size_t MPInt::ToMagBytes(unsigned char *buf, size_t buf_len, Endian endian) const { if (buf == nullptr) { - return mp_ext_mag_bytes_size(n_); + return mpx_mag_bytes_size(n_); } - return mp_ext_to_mag_bytes(n_, buf, buf_len, endian); + return mpx_to_mag_bytes(n_, buf, buf_len, endian); } void MPInt::FromMagBytes(yacl::ByteContainerView buffer, Endian endian) { - mp_ext_from_mag_bytes(&n_, buffer.data(), buffer.size(), endian); + mpx_from_mag_bytes(&n_, buffer.data(), buffer.size(), endian); } } // namespace yacl::math diff --git a/yacl/math/mpint/mp_int.h b/yacl/math/mpint/mp_int.h index fc67db63..6d25633e 100644 --- a/yacl/math/mpint/mp_int.h +++ b/yacl/math/mpint/mp_int.h @@ -180,8 +180,8 @@ class MPInt { MPInt &DecrOne() &; MPInt &IncrOne() &; - [[nodiscard]] MPInt DecrOne() &&; - [[nodiscard]] MPInt IncrOne() &&; + [[nodiscard]] MPInt &&DecrOne() &&; + [[nodiscard]] MPInt &&IncrOne() &&; [[nodiscard]] MPInt Abs() const; @@ -392,6 +392,7 @@ class MPInt { friend class MontgomerySpace; }; +// for fmtlib inline auto format_as(const MPInt &i) { return fmt::streamed(i); } } // namespace yacl::math diff --git a/yacl/math/mpint/mp_int_enforce.h b/yacl/math/mpint/mp_int_enforce.h index 31d0634e..773c46af 100644 --- a/yacl/math/mpint/mp_int_enforce.h +++ b/yacl/math/mpint/mp_int_enforce.h @@ -17,10 +17,12 @@ #include "fmt/ostream.h" #include "libtommath/tommath.h" +#include "yacl/base/exception.h" + namespace fmt { template <> struct formatter : ostream_formatter {}; } // namespace fmt #define MPINT_ENFORCE_OK(MP_ERR, ...) \ - YACL_ENFORCE_EQ((MP_ERR), MP_OKAY, __VA_ARGS__) + YACL_ENFORCE_EQ((MP_ERR), MP_OKAY, ##__VA_ARGS__) diff --git a/yacl/math/mpint/mp_int_test.cc b/yacl/math/mpint/mp_int_test.cc index e2f38ca2..eb540bfd 100644 --- a/yacl/math/mpint/mp_int_test.cc +++ b/yacl/math/mpint/mp_int_test.cc @@ -29,6 +29,35 @@ TEST_F(MPIntTest, CompareWorks) { EXPECT_TRUE(x1 <= x2); EXPECT_TRUE(x2 > x1); EXPECT_TRUE(x2 >= x1); + + EXPECT_EQ(x1.CompareAbs(-256_mp), 0); + EXPECT_EQ(x1.CompareAbs(-257_mp), -1); + EXPECT_EQ(x1.CompareAbs(-255_mp), 1); +} + +TEST_F(MPIntTest, BitOpsWorks) { + MPInt x; + x.IncrOne(); + ASSERT_TRUE(x.IsOne()); + ASSERT_EQ(x << 2, 4_mp); + x <<= 3; + ASSERT_EQ(x, 8_mp); + + EXPECT_EQ(x >> 1, 4_mp); + x >>= 3; // 8 >> 3 + ASSERT_EQ(x, MPInt::_1_); + x >>= 2; + ASSERT_TRUE(x.IsZero()); + + x.Set("0011", 2); + MPInt y("0101", 2); + EXPECT_EQ(x & y, MPInt(0b0101 & 0b0011)); + EXPECT_EQ(x | y, MPInt(0b0101 | 0b0011)); + EXPECT_EQ(x ^ y, MPInt(0b0101 ^ 0b0011)); + + ASSERT_EQ(x &= y, MPInt(0b0001)); + ASSERT_EQ(x |= y, MPInt(0b0101)); + ASSERT_EQ(x ^= y, MPInt(0)); } TEST_F(MPIntTest, ArithmeticWorks) { @@ -44,10 +73,6 @@ TEST_F(MPIntTest, ArithmeticWorks) { EXPECT_TRUE(x1 / x2 == MPInt(23 / 37)); EXPECT_TRUE(x2 / x1 == MPInt(37 / 23)); - EXPECT_EQ(x1.AddMod(x2, MPInt(5)), MPInt((23 + 37) % 5)); - EXPECT_EQ(x2.SubMod(x1, MPInt(5)), MPInt((37 - 23) % 5)); - EXPECT_EQ(x1.MulMod(x2, MPInt(5)), MPInt((23 * 37) % 5)); - MPInt c; MPInt::Add(x1, x2, &c); EXPECT_TRUE(c == MPInt(23 + 37)); @@ -55,8 +80,32 @@ TEST_F(MPIntTest, ArithmeticWorks) { EXPECT_TRUE(c == MPInt(23 - 37)); MPInt::Mul(x1, x2, &c); EXPECT_TRUE(c == MPInt(23 * 37)); - EXPECT_TRUE(x1.Mul(3) == MPInt(23 * 3)); + MPInt::Div3(x1, &c); + EXPECT_EQ(c, MPInt(23 / 3)); + + EXPECT_EQ(x1.AddMod(x2, MPInt(5)), MPInt((23 + 37) % 5)); + EXPECT_EQ(x2.SubMod(x1, MPInt(5)), MPInt((37 - 23) % 5)); + EXPECT_EQ(x1.MulMod(x2, MPInt(5)), MPInt((23 * 37) % 5)); + + MPInt::AddMod(x1, x2, 7_mp, &c); + EXPECT_EQ(c, MPInt((23 + 37) % 7)); + MPInt::SubMod(x1, x2, 7_mp, &c); + EXPECT_EQ(c, MPInt((37 - 23) % 7)); + MPInt::MulMod(x1, x2, 7_mp, &c); + EXPECT_EQ(c, MPInt((23 * 37) % 7)); + + // Test inplace version + x1.Set(1234); + ASSERT_EQ(x1 += 10_mp, 1244_mp); + ASSERT_EQ(x1 -= 10_mp, 1234_mp); + ASSERT_EQ(x1 *= 10_mp, 12340_mp); + ASSERT_EQ(x1 /= 10_mp, 1234_mp); + ASSERT_EQ(x1 %= 10_mp, 4_mp); + x1.IncrOne(); + ASSERT_EQ(x1, 5_mp); + + x1 = 23_mp; x1.MulInplace(3); EXPECT_TRUE(x1 == MPInt(23 * 3)); } @@ -72,22 +121,74 @@ TEST_F(MPIntTest, PowWorks) { MPInt::_2_.Pow(255) - 19_mp, "0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed"_mp); + MPInt out; + MPInt::Pow(MPInt::_2_, 1111, &out); + EXPECT_EQ(out, 1_mp << 1111); + + out.Set(123); + out.PowInplace(3); + EXPECT_EQ(out, MPInt(123 * 123 * 123)); + // PowMod EXPECT_EQ(MPInt::_2_.PowMod(0_mp, 5_mp), MPInt::_1_); EXPECT_EQ(MPInt::_2_.PowMod(1_mp, 5_mp), MPInt::_2_); EXPECT_EQ(MPInt::_2_.PowMod(2_mp, 5_mp), 4_mp); EXPECT_EQ(MPInt::_2_.PowMod(3_mp, 5_mp), 3_mp); + + MPInt::PowMod(-324_mp, MPInt::_2_, 13_mp, &out); + EXPECT_EQ(out.Get(), 324 * 324 % 13); +} + +TEST_F(MPIntTest, InvertModWorks) { + MPInt a(667); + MPInt::InvertMod(a, MPInt(561613), &a); + EXPECT_EQ(842, a.Get()); + EXPECT_EQ(842, a.Get()); + EXPECT_EQ(842, a.Get()); + EXPECT_EQ(842, a.Get()); + EXPECT_EQ(842_mp, a.Get()); +} + +TEST_F(MPIntTest, LcmGcdWorks) { + MPInt x1, x2, x3; + MPInt::RandPrimeOver(82, &x1, PrimeType::Normal); + MPInt::RandPrimeOver(100, &x2, PrimeType::Normal); + MPInt::RandPrimeOver(150, &x3, PrimeType::Normal); + + MPInt out; + MPInt::Gcd(x1, x1, &out); + EXPECT_EQ(out, x1); + MPInt::Gcd(x1, x2, &out); + EXPECT_EQ(out, 1_mp); + MPInt::Gcd(x1 * x2, x2 * x3, &out); + EXPECT_EQ(out, x2); + + MPInt::Lcm(x1, x1, &out); + EXPECT_EQ(out, x1); + MPInt::Lcm(x1, x2, &out); + EXPECT_EQ(out, x1 * x2); + MPInt::Lcm(x1 * x2, x2 * x3, &out); + EXPECT_EQ(out, x1 * x2 * x3); } TEST_F(MPIntTest, CtorZeroWorks) { MPInt x; EXPECT_TRUE(x.IsZero()); + EXPECT_EQ(x.Get(), 0); + EXPECT_EQ(x.Get(), 0); MPInt x2(0); EXPECT_TRUE(x2.IsZero()); + EXPECT_EQ(x.Get(), 0); + EXPECT_EQ(x.Get(), 0); MPInt x3(0, 2048); EXPECT_TRUE(x3.IsZero()); + EXPECT_EQ(x.Get(), 0); + EXPECT_EQ(x.Get(), 0); + + EXPECT_EQ(MPInt().IncrOne(), MPInt::_1_); + EXPECT_EQ(MPInt().DecrOne(), -MPInt::_1_); } TEST_F(MPIntTest, CtorWorks) { @@ -99,8 +200,8 @@ TEST_F(MPIntTest, CtorWorks) { EXPECT_TRUE(x1.Compare(x2) < 0); EXPECT_EQ(MPInt("0").Get(), 0); - EXPECT_EQ(MPInt("0777").Get(), 511); - EXPECT_EQ(MPInt("520").Get(), 520); + EXPECT_EQ(MPInt("0777").Get(), 511); + EXPECT_EQ(MPInt("520").Get(), 520); EXPECT_EQ(MPInt("0xabc").Get(), 2748); EXPECT_EQ(MPInt("0xABC").Get(), 2748); EXPECT_EQ(MPInt("0Xabc").Get(), 2748); @@ -118,10 +219,35 @@ TEST_F(MPIntTest, CtorWorks) { EXPECT_EQ(MPInt("+0Xabc").Get(), 2748); } -TEST_F(MPIntTest, InvertModWorks) { - MPInt a(667); - MPInt::InvertMod(a, MPInt(561613), &a); - EXPECT_EQ(842, a.Get()); +TEST_F(MPIntTest, SetWorks) { + MPInt a; + a.Set(static_cast(-100)); + EXPECT_EQ(a, -100_mp); + a.Set(static_cast(-1000)); + EXPECT_EQ(a, -1000_mp); + a.Set(static_cast(-10000)); + EXPECT_EQ(a, -10000_mp); + a.Set(static_cast(-100000)); + EXPECT_EQ(a, -100000_mp); + a.Set(static_cast(-1000000)); + EXPECT_EQ(a, -1000000_mp); + // for macOS + a.Set(static_cast(-123456)); + EXPECT_EQ(a, -123456_mp); + + a.Set(static_cast(100)); + EXPECT_EQ(a, 100_mp); + a.Set(static_cast(1000)); + EXPECT_EQ(a, 1000_mp); + a.Set(static_cast(10000)); + EXPECT_EQ(a, 10000_mp); + a.Set(static_cast(100000)); + EXPECT_EQ(a, 100000_mp); + a.Set(static_cast(1000000)); + EXPECT_EQ(a, 1000000_mp); + // for macOS + a.Set(static_cast(123456)); + EXPECT_EQ(a, 123456_mp); } TEST_F(MPIntTest, ToStringWorks) { @@ -176,6 +302,7 @@ TEST_F(MPIntTest, MagBytesWorks) { MPInt x2(-1234567890); yacl::Buffer x2_buf = x2.ToMagBytes(); ASSERT_TRUE(x2_buf.size() > 0); + ASSERT_EQ(x2_buf.size(), x2.ToMagBytes(nullptr, 0)); MPInt x2_value; x2_value.FromMagBytes(x2_buf); @@ -192,7 +319,7 @@ TEST_F(MPIntTest, MagBytesWorks) { EXPECT_EQ(buf.data()[0], 0x34); EXPECT_EQ(buf.data()[1], 0x12); - buf = a.ToMagBytes(Endian::big); + a.ToMagBytes(buf.data(), buf.size(), Endian::big); ASSERT_EQ(buf.size(), 2); EXPECT_EQ(buf.data()[0], 0x12); EXPECT_EQ(buf.data()[1], 0x34); diff --git a/yacl/math/mpint/tommath_ext_features.cc b/yacl/math/mpint/tommath_ext_features.cc index a078c69b..cd9c0195 100644 --- a/yacl/math/mpint/tommath_ext_features.cc +++ b/yacl/math/mpint/tommath_ext_features.cc @@ -131,7 +131,7 @@ bool is_pocklington_criterion_satisfied(const mp_int *p) { // Miller-Rabin and Baillie-PSW for `p`. // If `q` and `p` are found to be prime, return them as a result. If not, go // back to the point 1. -void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { +void mpx_safe_prime_rand(mp_int *p, int t, int psize) { uint8_t maskAND, maskOR_msb, maskOR_lsb; int maskOR_msb_offset; bool res; @@ -179,7 +179,7 @@ void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { /* read it in */ /* TODO: casting only for now until all lengths have been changed to the * type "size_t"*/ - mp_ext_from_mag_bytes(&q, tmp, (size_t)bsize, Endian::big); + mpx_from_mag_bytes(&q, tmp, (size_t)bsize, Endian::big); // Find a odd number `q` among q, q+2, .... , (1 << 20) satisfy: // 1. co-prime to `small_primes`. @@ -224,7 +224,7 @@ void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { } while (true); } -void mp_ext_rand_bits(mp_int *out, int64_t bits) { +void mpx_rand_bits(mp_int *out, int64_t bits) { if (bits <= 0) { mp_zero(out); return; @@ -286,7 +286,7 @@ int count_bits_debruijn(uint64_t v) { return bitPatternToLog2[(v * 0x6c04f118e9966f6bULL) >> 57]; } -int mp_ext_count_bits_fast(const mp_int &a) { +int mpx_count_bits_fast(const mp_int &a) { if (a.used == 0) { return 0; } @@ -294,8 +294,8 @@ int mp_ext_count_bits_fast(const mp_int &a) { return (a.used - 1) * MP_DIGIT_BIT + count_bits_debruijn(a.dp[a.used - 1]); } -void mp_ext_to_bytes(const mp_int &num, unsigned char *buf, int64_t byte_len, - Endian endian) { +void mpx_to_bytes(const mp_int &num, unsigned char *buf, int64_t byte_len, + Endian endian) { YACL_ENFORCE(MP_DIGIT_BIT % 4 == 0, "Unsupported MP_DIGIT_BIT {}", MP_DIGIT_BIT); @@ -340,18 +340,18 @@ void mp_ext_to_bytes(const mp_int &num, unsigned char *buf, int64_t byte_len, } } -size_t mp_ext_mag_bytes_size(const mp_int &num) { - return (mp_ext_count_bits_fast(num) + CHAR_BIT - 1) / CHAR_BIT; +size_t mpx_mag_bytes_size(const mp_int &num) { + return (mpx_count_bits_fast(num) + CHAR_BIT - 1) / CHAR_BIT; } -size_t mp_ext_to_mag_bytes(const mp_int &num, uint8_t *buf, size_t buf_len, - Endian endian) { +size_t mpx_to_mag_bytes(const mp_int &num, uint8_t *buf, size_t buf_len, + Endian endian) { static_assert(MP_DIGIT_BIT % 4 == 0, "Unsupported MP_DIGIT_BIT"); if (num.used == 0) { return 0; } - auto min_bytes = mp_ext_mag_bytes_size(num); + auto min_bytes = mpx_mag_bytes_size(num); YACL_ENFORCE(buf_len >= min_bytes, "buf is too small to store mp_int, buf_size={}, required={}", buf_len, min_bytes); @@ -391,8 +391,8 @@ size_t mp_ext_to_mag_bytes(const mp_int &num, uint8_t *buf, size_t buf_len, return pos; } -void mp_ext_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, - Endian endian) { +void mpx_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, + Endian endian) { if (buf_len == 0) { mp_zero(num); } @@ -441,19 +441,19 @@ void mp_ext_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, // │ // sign bit // D = data/payload; S = sign bit -size_t mp_ext_serialize_size(const mp_int &num) { - return mp_ext_count_bits_fast(num) / CHAR_BIT + 1; +size_t mpx_serialize_size(const mp_int &num) { + return mpx_count_bits_fast(num) / CHAR_BIT + 1; } -size_t mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len) { - auto total_buf = mp_ext_serialize_size(num); +size_t mpx_serialize(const mp_int &num, uint8_t *buf, size_t buf_len) { + auto total_buf = mpx_serialize_size(num); YACL_ENFORCE(buf_len >= total_buf, "buf is too small, min required={}, actual={}", total_buf, buf_len); // store num in Little-Endian buf[total_buf - 1] = 0; - auto value_buf = mp_ext_to_mag_bytes(num, buf, buf_len, Endian::little); + auto value_buf = mpx_to_mag_bytes(num, buf, buf_len, Endian::little); YACL_ENFORCE(total_buf == value_buf || total_buf == value_buf + 1, "bug: buf len mismatch, {} vs {}", total_buf, value_buf); // write sign bit @@ -462,15 +462,15 @@ size_t mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len) { return total_buf; } -void mp_ext_deserialize(mp_int *num, const uint8_t *buf, size_t buf_len) { +void mpx_deserialize(mp_int *num, const uint8_t *buf, size_t buf_len) { YACL_ENFORCE(buf_len > 0, "mp_int deserialize: empty buffer"); // since buf is const, we cannot clear the sign bit - mp_ext_from_mag_bytes(num, buf, buf_len, Endian::little); + mpx_from_mag_bytes(num, buf, buf_len, Endian::little); num->sign = ((buf[buf_len - 1] >> 7) == 1 ? MP_NEG : MP_ZPOS); - mp_ext_set_bit(num, buf_len * CHAR_BIT - 1, 0); // clear sign bit + mpx_set_bit(num, buf_len * CHAR_BIT - 1, 0); // clear sign bit } -uint8_t mp_ext_get_bit(const mp_int &a, int index) { +uint8_t mpx_get_bit(const mp_int &a, int index) { int limb = index / MP_DIGIT_BIT; if (limb >= a.used) { return 0; @@ -479,7 +479,7 @@ uint8_t mp_ext_get_bit(const mp_int &a, int index) { return (a.dp[limb] >> (index % MP_DIGIT_BIT)) & 1; } -void mp_ext_set_bit(mp_int *a, int index, uint8_t value) { +void mpx_set_bit(mp_int *a, int index, uint8_t value) { int limb = index / MP_DIGIT_BIT; if (limb >= a->alloc) { MPINT_ENFORCE_OK(mp_grow(a, limb + 1)); diff --git a/yacl/math/mpint/tommath_ext_features.h b/yacl/math/mpint/tommath_ext_features.h index cdba5963..a50d4ef7 100644 --- a/yacl/math/mpint/tommath_ext_features.h +++ b/yacl/math/mpint/tommath_ext_features.h @@ -22,30 +22,30 @@ namespace yacl::math { // Reference: https://eprint.iacr.org/2003/186.pdf // libtommath style -void mp_ext_safe_prime_rand(mp_int *out, int t, int size); +void mpx_safe_prime_rand(mp_int *out, int t, int size); -void mp_ext_rand_bits(mp_int *out, int64_t bits); +void mpx_rand_bits(mp_int *out, int64_t bits); // Convert num to bytes and output to buf -void mp_ext_to_bytes(const mp_int &num, unsigned char *buf, int64_t byte_len, - Endian endian = Endian::native); +void mpx_to_bytes(const mp_int &num, unsigned char *buf, int64_t byte_len, + Endian endian = Endian::native); -size_t mp_ext_mag_bytes_size(const mp_int &num); -size_t mp_ext_to_mag_bytes(const mp_int &num, uint8_t *buf, size_t buf_len, - Endian endian = Endian::native); -void mp_ext_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, - Endian endian = Endian::native); +size_t mpx_mag_bytes_size(const mp_int &num); +size_t mpx_to_mag_bytes(const mp_int &num, uint8_t *buf, size_t buf_len, + Endian endian = Endian::native); +void mpx_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, + Endian endian = Endian::native); // returns the number of bits in an int // Faster than tommath's native mp_count_bits() method -int mp_ext_count_bits_fast(const mp_int &a); +int mpx_count_bits_fast(const mp_int &a); -size_t mp_ext_serialize_size(const mp_int &num); -size_t mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len); -void mp_ext_deserialize(mp_int *num, const uint8_t *buf, size_t buf_len); +size_t mpx_serialize_size(const mp_int &num); +size_t mpx_serialize(const mp_int &num, uint8_t *buf, size_t buf_len); +void mpx_deserialize(mp_int *num, const uint8_t *buf, size_t buf_len); // return 0 or 1 -uint8_t mp_ext_get_bit(const mp_int &a, int index); -void mp_ext_set_bit(mp_int *a, int index, uint8_t value); +uint8_t mpx_get_bit(const mp_int &a, int index); +void mpx_set_bit(mp_int *a, int index, uint8_t value); } // namespace yacl::math diff --git a/yacl/math/mpint/tommath_ext_test.cc b/yacl/math/mpint/tommath_ext_test.cc index 05112041..c1369b8c 100644 --- a/yacl/math/mpint/tommath_ext_test.cc +++ b/yacl/math/mpint/tommath_ext_test.cc @@ -40,20 +40,20 @@ TEST(TommathExtTest, CountBits) { mp_int n; MP_ASSERT_OK(mp_init_i32(&n, 0)); ON_SCOPE_EXIT([&] { mp_clear(&n); }); - EXPECT_EQ(mp_ext_count_bits_fast(n), 0); + EXPECT_EQ(mpx_count_bits_fast(n), 0); MP_ASSERT_OK(mp_incr(&n)); - EXPECT_EQ(mp_ext_count_bits_fast(n), 1); + EXPECT_EQ(mpx_count_bits_fast(n), 1); for (int i = 0; i < 4096; ++i) { MP_ASSERT_OK(mp_mul_2(&n, &n)); - EXPECT_EQ(mp_ext_count_bits_fast(n), mp_count_bits(&n)); + EXPECT_EQ(mpx_count_bits_fast(n), mp_count_bits(&n)); } mp_zero(&n); MP_ASSERT_OK(mp_incr(&n)); for (int i = 0; i < 128; ++i) { - EXPECT_EQ(mp_ext_count_bits_fast(n), i + 1) << Info(n); + EXPECT_EQ(mpx_count_bits_fast(n), i + 1) << Info(n); MP_ASSERT_OK(mp_mul_2(&n, &n)); if (i % 2 == 1) { MP_ASSERT_OK(mp_incr(&n)); @@ -65,7 +65,7 @@ TEST(TommathExtTest, CountBitsRandom) { mp_int n; MP_ASSERT_OK(mp_init_i64(&n, 0)); ON_SCOPE_EXIT([&] { mp_clear(&n); }); - EXPECT_EQ(mp_ext_count_bits_fast(n), 0); + EXPECT_EQ(mpx_count_bits_fast(n), 0); std::random_device rd; std::mt19937_64 gen(rd()); @@ -79,7 +79,7 @@ TEST(TommathExtTest, CountBitsRandom) { a >>= 1; } - EXPECT_EQ(mp_ext_count_bits_fast(n), bits) << Info(n); + EXPECT_EQ(mpx_count_bits_fast(n), bits) << Info(n); } for (int64_t i = 0; i < 1000000; ++i) { @@ -91,7 +91,7 @@ TEST(TommathExtTest, CountBitsRandom) { a >>= 1; } - EXPECT_EQ(mp_ext_count_bits_fast(n), bits) << Info(n); + EXPECT_EQ(mpx_count_bits_fast(n), bits) << Info(n); } } @@ -106,14 +106,14 @@ TEST(TommathExtTest, Serialize) { for (int64_t bits = 0; bits < 4097; ++bits) { for (int64_t i = 0; i < 100; ++i) { - mp_ext_rand_bits(&a, bits); + mpx_rand_bits(&a, bits); if (i % 2 == 0) { MP_ASSERT_OK(mp_neg(&a, &a)); } - auto sz = mp_ext_serialize_size(a); - mp_ext_serialize(a, buf, sz); + auto sz = mpx_serialize_size(a); + mpx_serialize(a, buf, sz); - mp_ext_deserialize(&b, buf, sz); + mpx_deserialize(&b, buf, sz); ASSERT_EQ(mp_cmp(&a, &b), 0) << "a is " << Info(a) << "\nb is " << Info(b); } @@ -132,25 +132,25 @@ TEST(TommathExtTest, GetBit) { int idx = 0; while (s != 0) { - EXPECT_EQ(s & 1, mp_ext_get_bit(n, idx)); - mp_ext_set_bit(&new_n, idx, s & 1); + EXPECT_EQ(s & 1, mpx_get_bit(n, idx)); + mpx_set_bit(&new_n, idx, s & 1); s >>= 1; ++idx; } EXPECT_TRUE(mp_cmp(&n, &new_n) == 0); - mp_ext_set_bit(&new_n, 666, 0); + mpx_set_bit(&new_n, 666, 0); EXPECT_TRUE(mp_cmp(&n, &new_n) == 0); - mp_ext_set_bit(&new_n, 1000, 1); - EXPECT_EQ(mp_ext_get_bit(new_n, 999), 0); - EXPECT_EQ(mp_ext_get_bit(new_n, 1000), 1); - EXPECT_EQ(mp_ext_get_bit(new_n, 1001), 0); - EXPECT_EQ(mp_ext_count_bits_fast(new_n), 1001); + mpx_set_bit(&new_n, 1000, 1); + EXPECT_EQ(mpx_get_bit(new_n, 999), 0); + EXPECT_EQ(mpx_get_bit(new_n, 1000), 1); + EXPECT_EQ(mpx_get_bit(new_n, 1001), 0); + EXPECT_EQ(mpx_count_bits_fast(new_n), 1001); - mp_ext_set_bit(&new_n, 1000, 0); - EXPECT_EQ(mp_ext_get_bit(new_n, 1000), 0); - EXPECT_LE(mp_ext_count_bits_fast(new_n), 64); + mpx_set_bit(&new_n, 1000, 0); + EXPECT_EQ(mpx_get_bit(new_n, 1000), 0); + EXPECT_LE(mpx_count_bits_fast(new_n), 64); } TEST(TommathExtTest, MpDivd) { @@ -168,7 +168,7 @@ TEST(TommathExtTest, MpDivd) { ON_SCOPE_EXIT([&] { mp_clear(&mp_res); }); for (uint64_t i = 1000; i < 100000; ++i) { - mp_ext_rand_bits(&a, i / 10); + mpx_rand_bits(&a, i / 10); mp_digit res; MP_ASSERT_OK(mp_div_d(&a, d, nullptr, &res)); diff --git a/yacl/math/mpint/tommath_ext_types.cc b/yacl/math/mpint/tommath_ext_types.cc index d91e65da..0d4a33cd 100644 --- a/yacl/math/mpint/tommath_ext_types.cc +++ b/yacl/math/mpint/tommath_ext_types.cc @@ -15,32 +15,46 @@ #include "yacl/math/mpint/tommath_ext_types.h" #include -#include // memset + +#include "yacl/math/mpint/mp_int_enforce.h" + +extern "C" { +#include "libtommath/tommath_private.h" +} // Following macros are copied from tommath_private.h #define MP_MIN(x, y) (((x) < (y)) ? (x) : (y)) #define MP_MAX(x, y) (((x) > (y)) ? (x) : (y)) #define MP_SIZEOF_BITS(type) ((size_t)CHAR_BIT * sizeof(type)) -#define MP_ZERO_DIGITS(mem, digits) \ - do { \ - int zd_ = (digits); \ - if (zd_ > 0) { \ - memset((mem), 0, sizeof(mp_digit) * (size_t)zd_); \ - } \ - } while (0) - -#define MP_INIT_INT(name, set, type) \ - mp_err name(mp_int *a, type b) { \ - mp_err err; \ - if ((err = mp_init(a)) != MP_OKAY) { \ - return err; \ - } \ - set(a, b); \ - return MP_OKAY; \ + +void mpx_init(mp_int *a) { + a->dp = nullptr; + a->used = 0; + a->alloc = 0; + a->sign = MP_ZPOS; +} + +void mpx_reserve(mp_int *a, size_t n_digits) { + if (a->dp == nullptr) { + a->dp = static_cast(MP_CALLOC(n_digits, sizeof(mp_digit))); + YACL_ENFORCE(a->dp != nullptr); + a->alloc = n_digits; + return; } -#define MP_SET_UNSIGNED(name, type) \ + MPINT_ENFORCE_OK(mp_grow(a, 1)); +} + +#define MPX_INIT_INT(name, set, type) \ + mp_err name(mp_int *a, type b) { \ + mpx_init(a); \ + set(a, b); \ + return MP_OKAY; \ + } + +#define MPX_SET_UNSIGNED(name, type) \ void name(mp_int *a, type b) { \ + MPINT_ENFORCE_OK(mp_grow(a, MP_BYTES_TO_DIGITS(sizeof(type)))); \ int i = 0; \ while (b != 0u) { \ a->dp[i++] = ((mp_digit)b & MP_MASK); \ @@ -51,18 +65,18 @@ } \ a->used = i; \ a->sign = MP_ZPOS; \ - MP_ZERO_DIGITS(a->dp + a->used, a->alloc - a->used); \ + s_mp_zero_digs(a->dp + a->used, a->alloc - a->used); \ } -#define MP_SET_SIGNED(name, uname, type, utype) \ - void name(mp_int *a, type b) { \ - uname(a, (b < 0) ? -(utype)b : (utype)b); \ - if (b < 0) { \ - a->sign = MP_NEG; \ - } \ +#define MPX_SET_SIGNED(name, uname, type, utype) \ + void name(mp_int *a, type b) { \ + uname(a, (b < 0) ? -(utype)b : (utype)b); \ + if (b < 0) { \ + a->sign = MP_NEG; \ + } \ } -#define MP_GET_MAG(name, type) \ +#define MPX_GET_MAG(name, type) \ type name(const mp_int *a) { \ unsigned i = MP_MIN( \ (unsigned)a->used, \ @@ -78,28 +92,40 @@ return res; \ } -#define MP_GET_SIGNED(name, mag, type, utype) \ +#define MPX_GET_SIGNED(name, mag, type, utype) \ type name(const mp_int *a) { \ utype res = mag(a); \ return (a->sign == MP_NEG) ? -(type)res : (type)res; \ } // define int8 related functions. -MP_SET_UNSIGNED(mp_set_u8, uint8_t) -MP_SET_SIGNED(mp_set_i8, mp_set_u8, int8_t, uint8_t) -MP_GET_MAG(mp_get_mag_u8, uint8_t) -MP_GET_SIGNED(mp_get_i8, mp_get_mag_u8, int8_t, uint8_t) +MPX_SET_UNSIGNED(mpx_set_u8, uint8_t) +MPX_SET_SIGNED(mpx_set_i8, mpx_set_u8, int8_t, uint8_t) +MPX_GET_MAG(mpx_get_mag_u8, uint8_t) +MPX_GET_SIGNED(mpx_get_i8, mpx_get_mag_u8, int8_t, uint8_t) // define int16 related functions. -MP_SET_UNSIGNED(mp_set_u16, uint16_t) -MP_SET_SIGNED(mp_set_i16, mp_set_u16, int16_t, uint16_t) -MP_GET_MAG(mp_get_mag_u16, uint16_t) -MP_GET_SIGNED(mp_get_i16, mp_get_mag_u16, int16_t, uint16_t) +MPX_SET_UNSIGNED(mpx_set_u16, uint16_t) +MPX_SET_SIGNED(mpx_set_i16, mpx_set_u16, int16_t, uint16_t) +MPX_GET_MAG(mpx_get_mag_u16, uint16_t) +MPX_GET_SIGNED(mpx_get_i16, mpx_get_mag_u16, int16_t, uint16_t) + +// define int32 related functions. +MPX_SET_UNSIGNED(mpx_set_u32, uint32_t) +MPX_SET_SIGNED(mpx_set_i32, mpx_set_u32, int32_t, uint32_t) +MPX_GET_MAG(mpx_get_mag_u32, uint32_t) +MPX_GET_SIGNED(mpx_get_i32, mpx_get_mag_u32, int32_t, uint32_t) + +// define int64 related functions. +MPX_SET_UNSIGNED(mpx_set_u64, uint64_t) +MPX_SET_SIGNED(mpx_set_i64, mpx_set_u64, int64_t, uint64_t) +MPX_GET_MAG(mpx_get_mag_u64, uint64_t) +MPX_GET_SIGNED(mpx_get_i64, mpx_get_mag_u64, int64_t, uint64_t) // define int128 related functions. -MP_INIT_INT(mp_init_i128, mp_set_i128, int128_t) -MP_INIT_INT(mp_init_u128, mp_set_u128, uint128_t) -MP_SET_UNSIGNED(mp_set_u128, uint128_t) -MP_SET_SIGNED(mp_set_i128, mp_set_u128, int128_t, uint128_t) -MP_GET_MAG(mp_get_mag_u128, uint128_t) -MP_GET_SIGNED(mp_get_i128, mp_get_mag_u128, int128_t, uint128_t) +MPX_INIT_INT(mpx_init_i128, mpx_set_i128, int128_t) +MPX_INIT_INT(mpx_init_u128, mpx_set_u128, uint128_t) +MPX_SET_UNSIGNED(mpx_set_u128, uint128_t) +MPX_SET_SIGNED(mpx_set_i128, mpx_set_u128, int128_t, uint128_t) +MPX_GET_MAG(mpx_get_mag_u128, uint128_t) +MPX_GET_SIGNED(mpx_get_i128, mpx_get_mag_u128, int128_t, uint128_t) diff --git a/yacl/math/mpint/tommath_ext_types.h b/yacl/math/mpint/tommath_ext_types.h index 5ea98f96..d3a0fad4 100644 --- a/yacl/math/mpint/tommath_ext_types.h +++ b/yacl/math/mpint/tommath_ext_types.h @@ -18,29 +18,46 @@ #include "yacl/base/int128.h" -// define int8 related functions. -void mp_set_u8(mp_int *a, uint8_t b); -void mp_set_i8(mp_int *a, int8_t b); +#define MP_BITS_TO_DIGITS(bits) ((bits) + MP_DIGIT_BIT - 1) / MP_DIGIT_BIT +#define MP_BYTES_TO_DIGITS(bytes) MP_BITS_TO_DIGITS((bytes)*CHAR_BIT) -uint8_t mp_get_mag_u8(const mp_int *a); -int8_t mp_get_i8(const mp_int *a); -#define mp_get_u8(a) ((uint8_t)mp_get_i8(a)) +void mpx_init(mp_int *a); +void mpx_reserve(mp_int *a, size_t n_digits); -// define int16 related functions. -void mp_set_u16(mp_int *a, uint16_t b); -void mp_set_i16(mp_int *a, int16_t b); +// define int8 related functions. +void mpx_set_u8(mp_int *a, uint8_t b); +void mpx_set_i8(mp_int *a, int8_t b); +uint8_t mpx_get_mag_u8(const mp_int *a); +int8_t mpx_get_i8(const mp_int *a); +#define mpx_get_u8(a) ((uint8_t)mpx_get_i8(a)) -uint16_t mp_get_mag_u16(const mp_int *a); -int16_t mp_get_i16(const mp_int *a); -#define mp_get_u16(a) ((uint16_t)mp_get_i16(a)) +// define int16 related functions. +void mpx_set_u16(mp_int *a, uint16_t b); +void mpx_set_i16(mp_int *a, int16_t b); +uint16_t mpx_get_mag_u16(const mp_int *a); +int16_t mpx_get_i16(const mp_int *a); +#define mpx_get_u16(a) ((uint16_t)mpx_get_i16(a)) + +// define int32 related functions. +void mpx_set_u32(mp_int *a, uint32_t b); +void mpx_set_i32(mp_int *a, int32_t b); +uint32_t mpx_get_mag_u32(const mp_int *a); +int32_t mpx_get_i32(const mp_int *a); +#define mpx_get_u32(a) ((uint32_t)mpx_get_i32(a)) + +// define int64 related functions. +void mpx_set_u64(mp_int *a, uint64_t b); +void mpx_set_i64(mp_int *a, int64_t b); +uint64_t mpx_get_mag_u64(const mp_int *a); +int64_t mpx_get_i64(const mp_int *a); +#define mpx_get_u64(a) ((uint64_t)mpx_get_i64(a)) // define int128 related functions. -mp_err mp_init_i128(mp_int *a, int128_t b) MP_WUR; -mp_err mp_init_u128(mp_int *a, uint128_t b) MP_WUR; - -void mp_set_u128(mp_int *a, uint128_t b); -void mp_set_i128(mp_int *a, int128_t b); - -uint128_t mp_get_mag_u128(const mp_int *a); -int128_t mp_get_i128(const mp_int *a); -#define mp_get_u128(a) ((uint128_t)mp_get_i128(a)) +mp_err mpx_init_i128(mp_int *a, int128_t b) MP_WUR; +mp_err mpx_init_u128(mp_int *a, uint128_t b) MP_WUR; + +void mpx_set_u128(mp_int *a, uint128_t b); +void mpx_set_i128(mp_int *a, int128_t b); +uint128_t mpx_get_mag_u128(const mp_int *a); +int128_t mpx_get_i128(const mp_int *a); +#define mpx_get_u128(a) ((uint128_t)mpx_get_i128(a)) diff --git a/yacl/utils/BUILD.bazel b/yacl/utils/BUILD.bazel index 0def6e89..ebe4e1a9 100644 --- a/yacl/utils/BUILD.bazel +++ b/yacl/utils/BUILD.bazel @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:yacl.bzl", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test") -load("@rules_proto//proto:defs.bzl", "proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//bazel:yacl.bzl", "OMP_CFLAGS", "OMP_DEPS", "OMP_LINKFLAGS", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test") package(default_visibility = ["//visibility:public"]) @@ -64,45 +64,16 @@ yacl_cc_test( ], ) -yacl_cc_library( - name = "thread_pool", - srcs = ["thread_pool.cc"], - hdrs = ["thread_pool.h"], - deps = [ - "//yacl/base:exception", - ], -) - -yacl_cc_test( - name = "thread_pool_test", - srcs = ["thread_pool_test.cc"], - deps = [ - ":thread_pool", - ], -) - yacl_cc_library( name = "parallel", srcs = [ - "parallel_common.cc", - "parallel_native.cc", + "parallel.cc", ], hdrs = [ "parallel.h", - "parallel_native.h", ], - visibility = ["//visibility:public"], deps = [ ":thread_pool", - "//yacl/base:exception", - ], -) - -yacl_cc_test( - name = "parallel_test", - srcs = ["parallel_test.cc"], - deps = [ - ":parallel", ], ) @@ -257,3 +228,32 @@ yacl_cc_test( ":platform_utils", ], ) + +yacl_cc_test( + name = "parallel_test", + srcs = ["parallel_test.cc"], + deps = [ + ":parallel", + "//yacl/base:exception", + ], +) + +yacl_cc_binary( + name = "parallel_bench", + srcs = ["parallel_bench.cc"], + copts = OMP_CFLAGS, + linkopts = OMP_LINKFLAGS, + deps = [ + ":parallel", + "@com_github_google_benchmark//:benchmark", + ] + OMP_DEPS, +) + +yacl_cc_library( + name = "thread_pool", + srcs = ["thread_pool.cc"], + hdrs = ["thread_pool.h"], + deps = [ + "//yacl/base:exception", + ], +) diff --git a/yacl/utils/parallel_native.cc b/yacl/utils/parallel.cc similarity index 79% rename from yacl/utils/parallel_native.cc rename to yacl/utils/parallel.cc index 7700a76a..e280ddfa 100644 --- a/yacl/utils/parallel_native.cc +++ b/yacl/utils/parallel.cc @@ -1,11 +1,51 @@ -// Copyright (c) 2016 Facebook Inc. +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/utils/parallel.h" + #include #include +#include -#include "yacl/utils/parallel.h" #include "yacl/utils/thread_pool.h" namespace yacl { +namespace { + +size_t get_env_num_threads(const char* var_name, size_t def_value = 0) { + try { + if (auto* value = std::getenv(var_name)) { + int nthreads = std::stoi(value); + YACL_ENFORCE(nthreads > 0); + return nthreads; + } + } catch (const std::exception& e) { + YACL_THROW("Invalid {} variable value: {}", var_name, e.what()); + } + return def_value; +} + +} // namespace + +int intraop_default_num_threads() { + size_t nthreads = get_env_num_threads("YACL_NUM_THREADS", 0); + if (nthreads == 0) { + nthreads = ThreadPool::DefaultNumThreads(); + } + return nthreads; +} + namespace { // used with _set_in_parallel_region to mark master thread // as in parallel region while executing parallel primitives diff --git a/yacl/utils/parallel.h b/yacl/utils/parallel.h index a94319e8..4aaaf1a2 100644 --- a/yacl/utils/parallel.h +++ b/yacl/utils/parallel.h @@ -1,17 +1,22 @@ -// Copyright (c) 2016 Facebook Inc. +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "yacl/base/exception.h" namespace yacl { -namespace internal { -// This parameter is heuristically chosen to determine the minimum number of -// work that warrants parallelism. For example, when summing an array, it is -// deemed inefficient to parallelise over arrays shorter than 32768. Further, -// no parallel algorithm (such as parallel_reduce) should split work into -// smaller than GRAIN_SIZE chunks. -constexpr int64_t GRAIN_SIZE = 32768; -} // namespace internal inline int64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; } @@ -23,13 +28,31 @@ void set_num_threads(int); // Returns the number of threads used in parallel region int get_num_threads(); +int get_thread_id(); +bool in_parallel_region(); + +// Returns number of intra-op threads used by default +int intraop_default_num_threads(); -// Returns the current thread number (starting from 0) -// in the current parallel region, or 0 in the sequential region -int get_thread_num(); +namespace internal { -// Checks whether the code runs in parallel region -bool in_parallel_region(); +inline std::tuple calc_num_tasks_and_chunk_size( + int64_t begin, int64_t end, int64_t grain_size) { + if ((end - begin) < grain_size) { + return std::make_tuple(1, std::max(static_cast(0), end - begin)); + } + // Choose number of tasks based on grain size and number of threads. + size_t chunk_size = divup((end - begin), get_num_threads()); + // Make sure each task is at least grain_size size. + chunk_size = std::max(static_cast(grain_size), chunk_size); + size_t num_tasks = divup((end - begin), chunk_size); + return std::make_tuple(num_tasks, chunk_size); +} + +void _parallel_run(int64_t begin, int64_t end, int64_t grain_size, + const std::function& f); + +} // namespace internal /* parallel_for @@ -48,8 +71,21 @@ states from the current thread to the worker threads. This means for example that Tensor operations CANNOT be used in the body of your function, only data pointers. */ +template inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, - const std::function& f); + F&& f) { + YACL_ENFORCE(grain_size > 0); + if (begin >= end) { + return; + } + if ((end - begin) < grain_size || in_parallel_region()) { + f(begin, end); + return; + } + internal::_parallel_run(begin, end, grain_size, + [f](int64_t fstart, int64_t fend, + size_t /* unused */) { f(fstart, fend); }); +} inline void parallel_for(int64_t begin, int64_t end, const std::function& f) { @@ -88,15 +124,31 @@ body of your function, only data pointers. [1] https://software.intel.com/en-us/node/506154 */ -template -inline RES_T parallel_reduce( - int64_t begin, int64_t end, int64_t grain_size, - const std::function& reduce_f, - const std::function& combine_f); - -// Returns number of intra-op threads used by default -int intraop_default_num_threads(); +template +inline scalar_t parallel_reduce(const int64_t begin, const int64_t end, + const int64_t grain_size, const F& reduce_f, + const SF& combine_f) { + YACL_ENFORCE(grain_size > 0); + YACL_ENFORCE(begin < end, "begin={}, end={}", begin, end); + + if ((end - begin) < grain_size || in_parallel_region()) { + return reduce_f(begin, end); + } + + size_t num_tasks; + size_t chunk_size; + std::tie(num_tasks, chunk_size) = + internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); + std::vector results(num_tasks); + internal::_parallel_run(begin, end, grain_size, + [&](int64_t fstart, int64_t fend, size_t task_id) { + results[task_id] = reduce_f(fstart, fend); + }); + auto result = results[0]; + for (size_t i = 1; i < results.size(); ++i) { + result = combine_f(result, results[i]); + } + return result; +} } // namespace yacl - -#include "yacl/utils/parallel_native.h" diff --git a/yacl/utils/parallel_bench.cc b/yacl/utils/parallel_bench.cc new file mode 100644 index 00000000..f1f073ef --- /dev/null +++ b/yacl/utils/parallel_bench.cc @@ -0,0 +1,75 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" +#include "omp.h" + +#include "yacl/utils/parallel.h" + +namespace yacl::bench { + +constexpr int64_t kTestSize = 100000; + +static void BM_OpenMp(benchmark::State& state) { + [[maybe_unused]] int64_t sum = 0; + for (auto _ : state) { +#pragma omp parallel for + for (int64_t i = 0; i < kTestSize; ++i) { + sum ^= i; + } + } +} +BENCHMARK(BM_OpenMp); + +static void BM_BatchFor(benchmark::State& state) { + [[maybe_unused]] int64_t sum = 0; + for (auto _ : state) { + parallel_for(1, kTestSize, state.range(0), [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + sum ^= i; + } + }); + } +} +BENCHMARK(BM_BatchFor) + ->Arg(1) + ->Arg(10) + ->Arg(100) + ->Arg(kTestSize / omp_get_max_threads()) + ->Arg(kTestSize / omp_get_max_threads() + 1); + +static void BM_AutoBatchSizeFor(benchmark::State& state) { + [[maybe_unused]] int64_t sum = 0; + for (auto _ : state) { + parallel_for(1, kTestSize, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + sum ^= i; + } + }); + } +} +BENCHMARK(BM_AutoBatchSizeFor); + +} // namespace yacl::bench + +int main(int argc, char** argv) { + ::benchmark::Initialize(&argc, argv); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) { + return 1; + } + + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/yacl/utils/parallel_common.cc b/yacl/utils/parallel_common.cc deleted file mode 100644 index d4eb242d..00000000 --- a/yacl/utils/parallel_common.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2016 Facebook Inc. -#include - -#include "yacl/utils/parallel.h" -#include "yacl/utils/thread_pool.h" - -namespace yacl { -namespace { - -size_t get_env_num_threads(const char* var_name, size_t def_value = 0) { - try { - if (auto* value = std::getenv(var_name)) { - int nthreads = std::stoi(value); - YACL_ENFORCE(nthreads > 0); - return nthreads; - } - } catch (const std::exception& e) { - YACL_THROW("Invalid {} variable value: {}", var_name, e.what()); - } - return def_value; -} - -} // namespace - -int intraop_default_num_threads() { - size_t nthreads = get_env_num_threads("YACL_NUM_THREADS", 0); - if (nthreads == 0) { - nthreads = ThreadPool::DefaultNumThreads(); - } - return nthreads; -} - -} // namespace yacl diff --git a/yacl/utils/parallel_native.h b/yacl/utils/parallel_native.h deleted file mode 100644 index 30743628..00000000 --- a/yacl/utils/parallel_native.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2016 Facebook Inc. -#pragma once - -#include -#include -#include -#include -#include - -#include "yacl/utils/parallel.h" - -namespace yacl { -namespace internal { - -inline std::tuple calc_num_tasks_and_chunk_size( - int64_t begin, int64_t end, int64_t grain_size) { - if ((end - begin) < grain_size) { - return std::make_tuple(1, std::max(static_cast(0), end - begin)); - } - // Choose number of tasks based on grain size and number of threads. - size_t chunk_size = divup((end - begin), get_num_threads()); - // Make sure each task is at least grain_size size. - chunk_size = std::max(static_cast(grain_size), chunk_size); - size_t num_tasks = divup((end - begin), chunk_size); - return std::make_tuple(num_tasks, chunk_size); -} - -void _parallel_run(int64_t begin, int64_t end, int64_t grain_size, - const std::function& f); - -} // namespace internal - -inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, - const std::function& f) { - YACL_ENFORCE(grain_size > 0); - if (begin >= end) { - return; - } - if ((end - begin) < grain_size || in_parallel_region()) { - f(begin, end); - return; - } - internal::_parallel_run(begin, end, grain_size, - [f](int64_t fstart, int64_t fend, - size_t /* unused */) { f(fstart, fend); }); -} - -template -inline RES_T parallel_reduce( - int64_t begin, int64_t end, int64_t grain_size, - const std::function& reduce_f, - const std::function& combine_f) { - YACL_ENFORCE(grain_size > 0); - YACL_ENFORCE(begin < end, "begin={}, end={}", begin, end); - - if ((end - begin) < grain_size || in_parallel_region()) { - return reduce_f(begin, end); - } - - size_t num_tasks; - size_t chunk_size; - std::tie(num_tasks, chunk_size) = - internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); - std::vector results(num_tasks); - RES_T* results_data = results.data(); - internal::_parallel_run( - begin, end, grain_size, - [&reduce_f, results_data](int64_t fstart, int64_t fend, size_t task_id) { - results_data[task_id] = reduce_f(fstart, fend); - }); - RES_T result = results[0]; - for (size_t i = 1; i < results.size(); ++i) { - result = combine_f(result, results[i]); - } - return result; -} - -} // namespace yacl diff --git a/yacl/utils/parallel_test.cc b/yacl/utils/parallel_test.cc index 54267a0f..a7a8211a 100644 --- a/yacl/utils/parallel_test.cc +++ b/yacl/utils/parallel_test.cc @@ -14,35 +14,54 @@ #include "yacl/utils/parallel.h" +#include #include #include "gtest/gtest.h" +#include "yacl/base/exception.h" + namespace yacl { -struct Param { - int num_threads; - int data_size; - int grain_size; -}; +TEST(ParallelTest, ParallelForTest) { + std::vector data(200); + std::iota(data.begin(), data.end(), 0); -class ParallelTest : public testing::TestWithParam {}; + parallel_for(0, data.size(), [&data](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + data[i] *= 2; + } + }); -TEST_P(ParallelTest, ParallelForTest) { - auto param = GetParam(); + for (size_t i = 0; i < data.size(); ++i) { + ASSERT_EQ(i * 2, data[i]); + } +} - init_num_threads(); - set_num_threads(param.num_threads); +TEST(ParallelTest, ParallelForBatchedTest) { + std::vector data(200); + std::iota(data.begin(), data.end(), 0); + + parallel_for(0, data.size(), 50, [&data](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + data[i] *= 2; + } + }); + + for (size_t i = 0; i < data.size(); ++i) { + ASSERT_EQ(i * 2, data[i]); + } +} - std::vector data(param.data_size); +TEST(ParallelTest, ParallelForBatchedWithTrailingTest) { + std::vector data(210); std::iota(data.begin(), data.end(), 0); - parallel_for(0, data.size(), param.grain_size, - [&data](int64_t beg, int64_t end) { - for (int64_t i = beg; i < end; ++i) { - data[i] *= 2; - } - }); + parallel_for(0, data.size(), 50, [&data](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + data[i] *= 2; + } + }); for (size_t i = 0; i < data.size(); ++i) { ASSERT_EQ(i * 2, data[i]); @@ -50,28 +69,23 @@ TEST_P(ParallelTest, ParallelForTest) { } TEST(ParallelTest, ParallelWithExceptionTest) { - init_num_threads(); - set_num_threads(4); - + EXPECT_THROW( + parallel_for(0, 1000, + [](int64_t, int64_t) { throw RuntimeError("surprise"); }), + RuntimeError); EXPECT_THROW( parallel_for(0, 1000, 1, [](int64_t, int64_t) { throw RuntimeError("surprise"); }), RuntimeError); } -TEST_P(ParallelTest, ParallelReduceTest) { - auto param = GetParam(); - - init_num_threads(); - set_num_threads(param.num_threads); - - std::vector data(param.data_size); +TEST(ParallelTest, ParallelReduceTest) { + std::vector data(500); std::iota(data.begin(), data.end(), 0); int expect_sum = std::accumulate(data.begin(), data.end(), 0); - int total_sum = parallel_reduce( - 0, data.size(), param.grain_size, - [&data](int64_t beg, int64_t end) { + 0, data.size(), 1, + [&data](int64_t beg, int64_t end) -> int { int partial_sum = data[beg]; for (int64_t i = beg + 1; i < end; ++i) { partial_sum += data[i]; @@ -79,12 +93,7 @@ TEST_P(ParallelTest, ParallelReduceTest) { return partial_sum; }, [](int a, int b) { return a + b; }); - ASSERT_EQ(expect_sum, total_sum); } -INSTANTIATE_TEST_SUITE_P(ParallelTestSuit, ParallelTest, - testing::Values(Param{4, 123, 10}, Param{4, 123, 50}, - Param{4, 123, 200})); - } // namespace yacl diff --git a/yacl/utils/spi/BUILD.bazel b/yacl/utils/spi/BUILD.bazel index 41c88356..96705766 100644 --- a/yacl/utils/spi/BUILD.bazel +++ b/yacl/utils/spi/BUILD.bazel @@ -32,6 +32,7 @@ yacl_cc_library( deps = [ "//yacl/base:exception", "//yacl/math/mpint", + "//yacl/utils:parallel", "@com_google_absl//absl/types:span", ], ) diff --git a/yacl/utils/spi/argument/arg_set.h b/yacl/utils/spi/argument/arg_set.h index 8113c833..bbf5c403 100644 --- a/yacl/utils/spi/argument/arg_set.h +++ b/yacl/utils/spi/argument/arg_set.h @@ -27,7 +27,8 @@ class SpiArgs : public std::map { // If the user sets this parameter, but the type is not T, then an exception // is thrown template - T Get(const SpiArgKey &key, const T &default_value) const { + T Get(const SpiArgKey &key, + const typename SpiArgKey::ValueType &default_value) const { auto it = find((key.Key())); if (it == end()) { return default_value; diff --git a/yacl/utils/spi/item.cc b/yacl/utils/spi/item.cc index d96d6991..50811e0f 100644 --- a/yacl/utils/spi/item.cc +++ b/yacl/utils/spi/item.cc @@ -20,12 +20,24 @@ namespace yacl { namespace { -std::string TryRead(const std::any &v) { -#define TRY_TYPE(type) \ - if (t == typeid(type)) { \ - return fmt::to_string(std::any_cast(v)); \ +#define TRY_TYPE(type) \ + if (t == typeid(type)) { \ + return fmt::to_string(std::any_cast(v)); \ + } \ + if (t == typeid(absl::Span)) { \ + const auto &c = std::any_cast>(v); \ + return fmt::to_string(fmt::join(c, ", ")); \ + } \ + if (t == typeid(absl::Span)) { \ + const auto &c = std::any_cast>(v); \ + return fmt::to_string(fmt::join(c, ", ")); \ + } \ + if (t == typeid(std::vector)) { \ + const auto &c = std::any_cast>(v); \ + return fmt::to_string(fmt::join(c, ", ")); \ } +std::string TryRead(const std::any &v) { const auto &t = v.type(); TRY_TYPE(bool); TRY_TYPE(int8_t); @@ -46,11 +58,36 @@ std::string TryRead(const std::any &v) { } // namespace +template <> +bool Item::IsAll(const bool &element) const { + if (!HasValue()) { + return false; + } + + if (!IsArray()) { + return As() == element; + } + + if (IsView()) { + absl::Span real = + IsReadOnly() ? As>() : As>(); + return IsAllSameTo(real, element); + } + + auto &real = As>(); + for (const auto &item : real) { + if (item != element) { + return false; + } + } + return true; +} + std::string Item::ToString() const { if (IsArray()) { - return fmt::format("{} Item, element_type={}, RO={}", + return fmt::format("{} Item, element_type={}, RO={}, Content={}", IsView() ? "Span" : "Vector", v_.type().name(), - IsReadOnly()); + IsReadOnly(), TryRead(v_)); } else { return fmt::format("Scalar item, type={}, RO={}, Content={}", v_.type().name(), IsReadOnly(), TryRead(v_)); diff --git a/yacl/utils/spi/item.h b/yacl/utils/spi/item.h index 055180ae..f3c30b14 100644 --- a/yacl/utils/spi/item.h +++ b/yacl/utils/spi/item.h @@ -15,10 +15,14 @@ #pragma once #include +#include +#include +#include #include "absl/types/span.h" #include "yacl/base/exception.h" +#include "yacl/utils/parallel.h" namespace yacl { @@ -141,7 +145,7 @@ class Item { } template - std::enable_if_t, T> As() { + std::enable_if_t, T> As() { // As a pointer try { return std::any_cast>(&v_); } catch (const std::bad_any_cast& e) { @@ -206,6 +210,9 @@ class Item { // non-const value -> const T return As>(); } else { + static_assert(!std::is_same_v, + "Call AsSpan on a vector item is not allowed"); + // vector // non-const value -> const T return absl::MakeConstSpan(As>()); @@ -236,6 +243,10 @@ class Item { return !operator==(other); } + // is every element in item equals "element" + template + bool IsAll(const T& element) const; + OperandType operator,(const Item& other) const { return static_cast(((meta_ & 1) << 1) | (other.meta_ & 1)); }; @@ -309,6 +320,24 @@ class Item { return item; } + template + bool IsAllSameTo(absl::Span real, const T& expected) const { + std::atomic res = true; + yacl::parallel_for(0, real.length(), [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + if (!res) { + return; + } + + if (real[i] != expected) { + res.store(false); + return; + } + } + }); + return res.load(); + } + // The format of meta: // bit 0 -> is array? 0 - scalar; 1 - array // bit 1 -> is view ? 0 - hold value; 1 - ref/view @@ -319,4 +348,15 @@ class Item { std::ostream& operator<<(std::ostream& os, const Item& a); +template <> +bool Item::IsAll(const bool& element) const; + +template +bool Item::IsAll(const T& element) const { + if (!HasValue()) { + return false; + } + return IsAllSameTo(AsSpan(), element); +} + } // namespace yacl diff --git a/yacl/utils/spi/spi_factory_test.cc b/yacl/utils/spi/spi_factory_test.cc index 57afa863..0d020ebc 100644 --- a/yacl/utils/spi/spi_factory_test.cc +++ b/yacl/utils/spi/spi_factory_test.cc @@ -83,13 +83,11 @@ class MockQuantumLib : public MockPheSpi { static std::unique_ptr Create(const std::string &phe_name, const SpiArgs &args) { YACL_ENFORCE(phe_name == "elgamal"); - return std::make_unique( - args.Get(Curve, "ed25519")); + return std::make_unique(args.Get(Curve, "ed25519")); } static bool Check(const std::string &phe_name, const SpiArgs &args) { - return phe_name == "elgamal" && - args.Get(Curve, "ed25519") == "ed25519"; + return phe_name == "elgamal" && args.Get(Curve, "ed25519") == "ed25519"; } std::string ToString() override { diff --git a/yacl/utils/thread_pool_test.cc b/yacl/utils/thread_pool_test.cc deleted file mode 100644 index 33438503..00000000 --- a/yacl/utils/thread_pool_test.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) 2020 Ant Financial Inc. All rights reserved. - -#include "yacl/utils/thread_pool.h" - -#include -#include - -#include "gtest/gtest.h" - -namespace yacl { -namespace test { - -constexpr static size_t kThreadPoolSize = 3; - -namespace { -class Timer { - public: - Timer() { begin_point_ = std::chrono::steady_clock::now(); } - - double GetElapsedTimeInMs() const { - auto end_point = std::chrono::steady_clock::now(); - double span = std::chrono::duration_cast( - end_point - begin_point_) - .count(); - return span / 1000.0; - } - - private: - std::chrono::steady_clock::time_point begin_point_; -}; -} // namespace - -class ThreadPoolTest : public ::testing::Test { - public: - ThreadPoolTest() : thread_pool_(kThreadPoolSize) {} - - protected: - ThreadPool thread_pool_; -}; - -TEST_F(ThreadPoolTest, InThreadPoolTest) { - ASSERT_EQ(thread_pool_.NumThreads(), kThreadPoolSize); - ASSERT_FALSE(thread_pool_.InThreadPool()); - - auto caller_id = std::this_thread::get_id(); - auto ret = thread_pool_.Submit([&caller_id]() { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - return std::this_thread::get_id() == caller_id; - }); - ASSERT_FALSE(ret.get()); - - ret = thread_pool_.Submit([this]() { return thread_pool_.InThreadPool(); }); - ASSERT_TRUE(ret.get()); -} - -TEST_F(ThreadPoolTest, DISABLED_ParallelTest) { - Timer timer; - - std::future futures[kThreadPoolSize]; - for (auto& future : futures) { - future = thread_pool_.Submit( - []() { std::this_thread::sleep_for(std::chrono::milliseconds(100)); }); - } - - EXPECT_LT(timer.GetElapsedTimeInMs(), 80); - for (auto& future : futures) { - future.wait(); - } - EXPECT_GE(timer.GetElapsedTimeInMs(), 100); - EXPECT_LT(timer.GetElapsedTimeInMs(), 200); -} - -TEST_F(ThreadPoolTest, MoreTasksTest) { - std::atomic sum(0); - - std::future futures[kThreadPoolSize * 10]; - for (auto& future : futures) { - future = thread_pool_.Submit([&sum]() { - for (int32_t i = 0; i < 10000; ++i) { - ++sum; - } - }); - } - - // wait all - for (auto& feature : futures) { - feature.get(); - } - - EXPECT_EQ(sum.load(), 10000 * kThreadPoolSize * 10); -} - -TEST_F(ThreadPoolTest, ParamsTest) { - auto func1 = [](int a) { return a; }; - auto func2 = [](int a, long b) -> int { return a + b; }; - auto func3 = [](int a, int b, const uint32_t& c) -> int { return a + b + c; }; - - std::vector> futures; - for (int i = 0; i < 600; i += 6) { - futures.push_back(thread_pool_.Submit(func1, i)); - futures.push_back(thread_pool_.Submit(func2, i + 1, i + 2)); - futures.push_back(thread_pool_.Submit(func3, i + 3, i + 4, i + 5)); - } - - // get all - int sum = 0; - for (auto& feature : futures) { - sum += feature.get(); - } - - EXPECT_EQ(sum, 600 * 599 / 2); // 即 0..599 之和 -} - -TEST_F(ThreadPoolTest, ExceptionTest) { - std::future futures[7]; - futures[0] = thread_pool_.Submit([]() { throw RuntimeError(); }); - futures[1] = thread_pool_.Submit([]() { throw IoError(); }); - futures[2] = thread_pool_.Submit([]() { throw LogicError(); }); - futures[3] = thread_pool_.Submit([]() { throw std::exception(); }); - futures[4] = thread_pool_.Submit([]() { throw 1L; }); - futures[5] = thread_pool_.Submit([]() { throw "hello"; }); - futures[6] = - thread_pool_.Submit([]() { throw std::string("is anybody here"); }); - - // wait() always no throw - for (auto& future : futures) { - EXPECT_NO_THROW(future.wait()); - } - - EXPECT_THROW(futures[0].get(), RuntimeError); - EXPECT_THROW(futures[1].get(), IoError); - EXPECT_THROW(futures[2].get(), LogicError); - EXPECT_THROW(futures[3].get(), std::exception); - EXPECT_THROW(futures[4].get(), long); - EXPECT_THROW(futures[5].get(), const char*); - EXPECT_THROW(futures[6].get(), std::string); - - auto one_more_future = - thread_pool_.Submit([]() { return "no throw, just return"; }); - EXPECT_NO_THROW(one_more_future.get()); -} - -} // namespace test -} // namespace yacl