From 39c515ff93d21575573fea6c74691034c8555f57 Mon Sep 17 00:00:00 2001 From: x-mass <36629999+x-mass@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:09:22 +0000 Subject: [PATCH] feat: implement MPT --- flake.lock | 4 +- libs/nil_core/CMakeLists.txt | 10 +- .../include/zkevm_framework/core/mpt/mpt.hpp | 56 +- .../include/zkevm_framework/core/mpt/node.hpp | 105 ++++ .../include/zkevm_framework/core/mpt/path.hpp | 83 +++ libs/nil_core/src/CMakeLists.txt | 19 + libs/nil_core/src/mpt/mpt.cpp | 512 ++++++++++++++++++ libs/nil_core/src/mpt/node.cpp | 97 ++++ libs/nil_core/src/mpt/path.cpp | 116 ++++ tests/libs/nil_core/CMakeLists.txt | 1 + tests/libs/nil_core/test_nil_core_mpt.cpp | 122 +++++ 11 files changed, 1118 insertions(+), 7 deletions(-) create mode 100644 libs/nil_core/include/zkevm_framework/core/mpt/node.hpp create mode 100644 libs/nil_core/include/zkevm_framework/core/mpt/path.hpp create mode 100644 libs/nil_core/src/CMakeLists.txt create mode 100644 libs/nil_core/src/mpt/mpt.cpp create mode 100644 libs/nil_core/src/mpt/node.cpp create mode 100644 libs/nil_core/src/mpt/path.cpp create mode 100644 tests/libs/nil_core/test_nil_core_mpt.cpp diff --git a/flake.lock b/flake.lock index b87da97..42705b0 100644 --- a/flake.lock +++ b/flake.lock @@ -34,11 +34,11 @@ "rev": "9f4c875d95a208043f0569d1bba23d6310c89974", "revCount": 969, "type": "git", - "url": "https://github.com/NilFoundation/nil" + "url": "ssh://git@github.com/NilFoundation/nil" }, "original": { "type": "git", - "url": "https://github.com/NilFoundation/nil" + "url": "ssh://git@github.com/NilFoundation/nil" } }, "nil-crypto3": { diff --git a/libs/nil_core/CMakeLists.txt b/libs/nil_core/CMakeLists.txt index f3390a4..125e2cf 100644 --- a/libs/nil_core/CMakeLists.txt +++ b/libs/nil_core/CMakeLists.txt @@ -12,13 +12,15 @@ if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.0) message(FATAL_ERROR "NilCore library can be built only with GCC 13+") endif() -add_library(${LIBRARY_NAME} INTERFACE) +add_library(${LIBRARY_NAME} SHARED) -target_compile_features(${LIBRARY_NAME} INTERFACE cxx_std_23) +add_subdirectory(src) -target_include_directories(${LIBRARY_NAME} INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_23) -target_link_libraries(${LIBRARY_NAME} INTERFACE intx::intx sszpp::sszpp) +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) + +target_link_libraries(${LIBRARY_NAME} PRIVATE intx::intx sszpp::sszpp) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) diff --git a/libs/nil_core/include/zkevm_framework/core/mpt/mpt.hpp b/libs/nil_core/include/zkevm_framework/core/mpt/mpt.hpp index f2f096a..4f5e9ec 100644 --- a/libs/nil_core/include/zkevm_framework/core/mpt/mpt.hpp +++ b/libs/nil_core/include/zkevm_framework/core/mpt/mpt.hpp @@ -1,9 +1,63 @@ #ifndef ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_MPT_HPP_ #define ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_MPT_HPP_ +#include +#include +#include +#include +#include + +#include "zkevm_framework/core/mpt/node.hpp" + +namespace std { + // To allow using vector of bytes as key in unordered_map + template<> + struct hash> { + size_t operator()(const vector& v) const { + return hash()( + string_view(reinterpret_cast(v.data()), v.size())); + } + }; +} // namespace std + namespace core { namespace mpt { - class Mpt; + + namespace details { + class GetHandler; + class SetHandler; + class DeleteHandler; + struct DeleteResult; + } // namespace details + + class MerklePatriciaTrie { + public: + MerklePatriciaTrie(); + std::vector get(const std::vector& key) const; + void set(const std::vector& key, const std::vector& value); + void remove(const std::vector& key); + + protected: + Path pathFromKey(const std::vector& key) const; + Node getNode(const Reference& ref) const; + Reference storeNode(const Node& node); + std::optional> get(const Reference& nodeRef, Path& path) const; + Reference set(const Reference& nodeRef, Path& path, + const std::vector& value); + details::DeleteResult delete_node(const Reference& nodeRef, const Path& path); + + private: + static constexpr size_t kMaxRawKeyLen = 32; + + Reference root_; + std::unordered_map> + nodes_; // Better design is to use other external class for storage + + friend class details::GetHandler; + friend class details::SetHandler; + friend class details::DeleteHandler; + }; + } // namespace mpt } // namespace core diff --git a/libs/nil_core/include/zkevm_framework/core/mpt/node.hpp b/libs/nil_core/include/zkevm_framework/core/mpt/node.hpp new file mode 100644 index 0000000..0fe511d --- /dev/null +++ b/libs/nil_core/include/zkevm_framework/core/mpt/node.hpp @@ -0,0 +1,105 @@ +#ifndef ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_NODE_HPP_ +#define ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_NODE_HPP_ + +#include +#include +#include + +#include "zkevm_framework/core/mpt/path.hpp" + +namespace core { + namespace mpt { + + constexpr size_t kBranchesNum = 16; + + using Bytes = std::vector; + using Reference = Bytes; + + enum class NodeTypeFlag : std::uint8_t { + kLeafNode = 0, + kExtensionNode = 1, + kBranchNode = 2 + }; + + class Serializable { + public: + virtual ~Serializable() = default; + virtual Bytes Encode() const = 0; + }; + + class PathHolder { + public: + PathHolder() = default; + explicit PathHolder(Path path); + + Path path; + }; + + class ValueHolder { + public: + ValueHolder() = default; + explicit ValueHolder(const std::vector& value); + + const std::vector& value() const; + void set_value(const std::vector& new_value); + + protected: + ssz::list value_; + }; + + class LeafNode : public ValueHolder, + public PathHolder, + public ssz::ssz_variable_size_container, + public Serializable { + public: + LeafNode() = default; + LeafNode(const Path& path, const std::vector& new_value); + + Bytes Encode() const override; + + SSZ_CONT(path, value_) + }; + + class ExtensionNode : public PathHolder, + public ssz::ssz_variable_size_container, + public Serializable { + public: + ExtensionNode() = default; + ExtensionNode(const Path& path, const Reference& next); + + Bytes Encode() const override; + const Reference& get_next_ref() const; + + SSZ_CONT(path, next_ref_) + + private: + ssz::list next_ref_; + }; + + class BranchNode : public ValueHolder, + public ssz::ssz_variable_size_container, + public Serializable { + public: + BranchNode() = default; + BranchNode(const std::array& refs, + const std::vector& value); + + Bytes Encode() const override; + std::array get_branches() const; + void ClearBranch(std::byte nibble); + void SetBranch(std::byte nibble, const std::vector& value); + + SSZ_CONT(branches_, value_) + + private: + ssz::list, kBranchesNum> branches_; + }; + + using Node = std::variant; + + Node DecodeNode(const Bytes& bytes); + + } // namespace mpt +} // namespace core + +#endif // ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_NODE_HPP_ diff --git a/libs/nil_core/include/zkevm_framework/core/mpt/path.hpp b/libs/nil_core/include/zkevm_framework/core/mpt/path.hpp new file mode 100644 index 0000000..d6de352 --- /dev/null +++ b/libs/nil_core/include/zkevm_framework/core/mpt/path.hpp @@ -0,0 +1,83 @@ +#ifndef ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_PATH_HPP_ +#define ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_PATH_HPP_ + +#include +#include +#include +#include + +namespace core { + namespace mpt { + + class Path : public ssz::ssz_variable_size_container { + public: + Path(); + Path(const Path &other) = default; + Path(const std::vector &data, std::size_t offset = 0); + + int size() const; + bool empty() const; + std::byte operator[](size_t idx) const; + Path operator+(const Path &other) const; + bool operator==(const Path &other) const; + std::byte at(std::size_t idx) const; + + bool StartsWith(const Path &other) const; + Path *Consume(std::size_t amount); + Path CommonPrefix(const Path &other) const; + + // Methods used by sszpp. NO hash_tree_root METHOD PROVIDED + constexpr std::size_t ssz_size() const noexcept { return 1 + size() / 2; } + + constexpr void serialize(ssz::ssz_iterator auto result) const { + std::size_t nibblesLen = size(); + bool isOdd = (nibblesLen % 2 == 1); + + // If even size, we just insert empty byte at the beginning + auto prefix = std::byte{0x00}; + + // If odd, we put kOddFlag at the first nibble and the first element of Path after + // it, so we could insert by full bytes afterwards + if (isOdd) { + prefix = kOddFlag | at(0); + } + + std::copy(static_cast(static_cast(&prefix)), + static_cast(static_cast(&prefix)) + 1, + result); + ++result; + + for (std::size_t i = (isOdd ? 1 : 0); i < nibblesLen; i += 2) { + std::byte nextByte = (operator[](i) << 4) | operator[](i + 1); + std::copy( + static_cast(static_cast(&nextByte)), + static_cast(static_cast(&nextByte)) + 1, + result); + ++result; + } + } + + constexpr void deserialize(const std::ranges::sized_range auto &bytes) { + bool isOddLen = (bytes.front() & kOddFlag) == kOddFlag; + if (isOddLen) { + offset_ = 1; + } else { + offset_ = 2; + } + + data_.reserve(std::ranges::size(bytes)); + std::ranges::copy(bytes, std::back_inserter(data_)); + } + + private: + Path ConstructFromPrefix(std::size_t length) const; + + static constexpr std::byte kOddFlag = std::byte{0x10}; + std::vector data_; + std::size_t offset_ = 0; + }; + + } // namespace mpt +} // namespace core + +#endif // ZKEMV_FRAMEWORK_LIBS_NIL_CORE_INCLUDE_ZKEVM_FRAMEWORK_CORE_MPT_MPT_HPP_ diff --git a/libs/nil_core/src/CMakeLists.txt b/libs/nil_core/src/CMakeLists.txt new file mode 100644 index 0000000..9a6faef --- /dev/null +++ b/libs/nil_core/src/CMakeLists.txt @@ -0,0 +1,19 @@ +# SSZ++ can be compiled only with GCC 13+ +if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + message(FATAL_ERROR "Data types library can be built only with GCC 13+") +endif() +if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.0) + message(FATAL_ERROR "Data types library can be built only with GCC 13+") +endif() + +find_package(sszpp REQUIRED) + +set(SOURCES + mpt/mpt.cpp + mpt/node.cpp + mpt/path.cpp +) + +target_sources(${LIBRARY_NAME} PRIVATE ${SOURCES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE sszpp::sszpp) diff --git a/libs/nil_core/src/mpt/mpt.cpp b/libs/nil_core/src/mpt/mpt.cpp new file mode 100644 index 0000000..3b76f73 --- /dev/null +++ b/libs/nil_core/src/mpt/mpt.cpp @@ -0,0 +1,512 @@ +#include "zkevm_framework/core/mpt/mpt.hpp" + +namespace core { + namespace mpt { + namespace details { + + template + struct overloaded : Ts... { + using Ts::operator()...; + }; + template + overloaded(Ts...) -> overloaded; + + + class GetHandler { + public: + GetHandler(const MerklePatriciaTrie& mpt, Path& path) : mpt_(mpt), path_(path) {} + + std::optional> operator()(const LeafNode& leaf) const { + if (leaf.path == path_) { + return leaf.value(); + } + return std::nullopt; + } + + std::optional> operator()( + const ExtensionNode& extensionNode) const { + if (path_.size() >= extensionNode.path.size() && + path_.StartsWith(extensionNode.path)) { + path_.Consume(extensionNode.path.size()); + return mpt_.get(extensionNode.get_next_ref(), path_); + } + return std::nullopt; + } + + std::optional> operator()( + const BranchNode& branchNode) const { + if (path_.empty()) { + return branchNode.value(); + } + auto next = branchNode.get_branches()[std::to_integer(path_.at(0))]; + if (!next.empty()) { + path_.Consume(1); + return mpt_.get(next, path_); + } + return std::nullopt; + } + + private: + const MerklePatriciaTrie& mpt_; + Path& path_; + }; + + class SetHandler { + public: + SetHandler(MerklePatriciaTrie& mpt, Path& path, const std::vector& value) + : mpt_(mpt), path_(path), value_(value) {} + + Reference operator()(const LeafNode& leaf) { + if (leaf.path == path_) { + return mpt_.storeNode(LeafNode{path_, value_}); + } + + auto commonPrefix = path_.CommonPrefix(leaf.path); + + path_.Consume(commonPrefix.size()); + auto leafPath = leaf.path; + leafPath.Consume(commonPrefix.size()); + + auto branchReference = createBranchNode(path_, value_, leafPath, leaf.value()); + + if (commonPrefix.size() != 0) { + return mpt_.storeNode(ExtensionNode{commonPrefix, branchReference}); + } + + return branchReference; + } + + Reference operator()(const ExtensionNode& extensionNode) { + if (path_.StartsWith(extensionNode.path)) { + path_.Consume(extensionNode.path.size()); + auto newReference = mpt_.set(extensionNode.get_next_ref(), path_, value_); + return mpt_.storeNode(ExtensionNode{extensionNode.path, newReference}); + } + + auto commonPrefix = path_.CommonPrefix(extensionNode.path); + + path_.Consume(commonPrefix.size()); + auto extensionPath = extensionNode.path; + extensionPath.Consume(commonPrefix.size()); + + std::array branches{}; + std::vector branchValue; + if (path_.size() == 0) { + branchValue = value_; + } + + createBranchLeaf(path_, value_, branches); + createBranchExtension(extensionPath, extensionNode.get_next_ref(), branches); + + auto branchReference = mpt_.storeNode(BranchNode{branches, branchValue}); + + if (commonPrefix.size() != 0) { + return mpt_.storeNode(ExtensionNode{commonPrefix, branchReference}); + } + return branchReference; + } + + Reference operator()(const BranchNode& branchNode) { + if (path_.size() == 0) { + return mpt_.storeNode(BranchNode{branchNode.get_branches(), value_}); + } + + auto nibble = path_.at(0); + path_.Consume(1); + auto newReference = mpt_.set( + branchNode.get_branches()[std::to_integer(nibble)], path_, value_); + + auto newBranches = branchNode.get_branches(); + newBranches[std::to_integer(nibble)] = newReference; + + return mpt_.storeNode(BranchNode{newBranches, branchNode.value()}); + } + + private: + // If path isn't empty, creates leaf node and stores reference in appropriate + // branch. + void createBranchLeaf(Path path, const std::vector& val, + std::array& branches) { + if (path.size() > 0) { + const auto nibble = path.at(0); + path.Consume(1); + auto leaf = mpt_.storeNode(LeafNode(path, val)); + branches[std::to_integer(nibble)] = leaf; + } + } + + Reference createBranchNode(const Path& lhsPath, + const std::vector& lhsVal, + const Path& rhsPath, + const std::vector& rhsVal) { + if (lhsPath.size() == 0 && rhsPath.size() == 0) { + throw std::runtime_error("invalid action"); + } + + std::array branches; + + std::vector val; + if (lhsPath.size() == 0) { + val = lhsVal; + } else if (rhsPath.size() == 0) { + val = rhsVal; + } + createBranchLeaf(lhsPath, lhsVal, branches); + createBranchLeaf(rhsPath, rhsVal, branches); + + return mpt_.storeNode(BranchNode(branches, val)); + } + + void createBranchExtension(Path path, const Reference& nextRef, + std::array& branches) { + if (path.size() == 0) { + throw std::runtime_error( + "Path for extension node should contain at least one nibble"); + } + + if (path.size() == 1) { + branches[std::to_integer(path.at(0))] = nextRef; + } else { + std::byte nibble = path.at(0); + + path.Consume(1); + + auto reference = mpt_.storeNode(ExtensionNode(path, nextRef)); + branches[std::to_integer(nibble)] = reference; + } + } + + MerklePatriciaTrie& mpt_; + Path& path_; + const std::vector& value_; + }; + + enum class DeleteAction { Unknown, Deleted, Updated, UselessBranch }; + + struct DeletionInfo { + Path path; + std::optional ref; + }; + + struct DeleteResult { + DeleteAction action; + std::optional info; + }; + + class DeleteHandler { + public: + DeleteHandler(MerklePatriciaTrie& mpt, const Path& path) : mpt_(mpt), path_(path) {} + + DeleteResult operator()(const LeafNode& leaf) { + if (path_ == leaf.path) { + return {DeleteAction::Deleted, std::nullopt}; + } + throw std::runtime_error("Key not found"); + } + + DeleteResult operator()(const ExtensionNode& extensionNode) { + if (!path_.StartsWith(extensionNode.path)) { + throw std::runtime_error("Key not found"); + } + + Path remainingPath = path_; + remainingPath.Consume(extensionNode.path.size()); + auto result = mpt_.delete_node(extensionNode.get_next_ref(), remainingPath); + + switch (result.action) { + case DeleteAction::Deleted: + return {DeleteAction::Deleted, std::nullopt}; + + case DeleteAction::Updated: + if (!result.info || !result.info->ref) { + throw std::runtime_error("Invalid update info"); + } + return { + DeleteAction::Updated, + DeletionInfo{{}, + mpt_.storeNode(ExtensionNode{ + extensionNode.path, result.info->ref.value()})}}; + + case DeleteAction::UselessBranch: + if (!result.info || !result.info->ref) { + throw std::runtime_error("Invalid useless branch info"); + } + return handleUselessBranch(extensionNode, result.info.value()); + + default: + throw std::runtime_error("Invalid action"); + } + } + + DeleteResult operator()(const BranchNode& branchNode) { + DeleteAction action; + std::optional info; + std::byte idx; + auto branchCopy = branchNode; + + if (path_.empty() && branchNode.value().empty()) { + throw std::runtime_error("Key not found"); + } else if (path_.empty() && !branchNode.value().empty()) { + branchCopy.set_value({}); + action = DeleteAction::Deleted; + } else { + idx = path_.at(0); + + if (branchNode.get_branches()[std::to_integer(idx)].empty()) { + throw std::runtime_error("Key not found"); + } + + Path remainingPath = path_; + remainingPath.Consume(1); + auto result = mpt_.delete_node( + branchNode.get_branches()[std::to_integer(idx)], + remainingPath); + action = result.action; + info = result.info; + } + + return handleBranchDeleteResult(branchCopy, action, info, idx); + } + + private: + MerklePatriciaTrie& mpt_; + const Path& path_; + + DeleteResult handleUselessBranch(const ExtensionNode& extensionNode, + const DeletionInfo& info) { + auto childNode = mpt_.getNode(info.ref.value()); + + return std::visit( + overloaded{ + [&](const LeafNode& leafChild) -> DeleteResult { + auto newPath = extensionNode.path + leafChild.path; + return { + DeleteAction::Updated, + DeletionInfo{ + {}, mpt_.storeNode(LeafNode{newPath, leafChild.value()})}}; + }, + [&](const ExtensionNode& extensionChild) -> DeleteResult { + auto newPath = extensionNode.path + extensionChild.path; + return {DeleteAction::Updated, + DeletionInfo{{}, + mpt_.storeNode(ExtensionNode{ + newPath, extensionChild.get_next_ref()})}}; + }, + [&](const BranchNode& branchChild) -> DeleteResult { + auto newPath = extensionNode.path + info.path; + return {DeleteAction::Updated, + DeletionInfo{{}, + mpt_.storeNode(ExtensionNode{ + newPath, info.ref.value()})}}; + }}, + childNode); + } + + DeleteResult handleBranchDeleteResult(const BranchNode& branchNode, + DeleteAction action, + const std::optional& info, + std::byte idx) { + switch (action) { + case DeleteAction::Deleted: + return handleBranchDeletion(branchNode, idx); + + case DeleteAction::Updated: + case DeleteAction::UselessBranch: + if (info && info->ref) { + auto updatedBranch = branchNode; + updatedBranch.SetBranch(idx, info->ref.value()); + return {DeleteAction::Updated, + DeletionInfo{{}, mpt_.storeNode(updatedBranch)}}; + } + throw std::runtime_error("Invalid update info"); + + default: + throw std::runtime_error("Invalid action"); + } + } + + DeleteResult handleBranchDeletion(const BranchNode& branchNode, std::byte idx) { + auto branches = branchNode.get_branches(); + size_t validBranches = + std::count_if(branches.begin(), branches.end(), + [](const Reference& ref) { return !ref.empty(); }); + + if (validBranches == 0 && branchNode.value().empty()) { + return {DeleteAction::Deleted, std::nullopt}; + } else if (validBranches == 0 && !branchNode.value().empty()) { + Path newPath; // Empty path + return {DeleteAction::UselessBranch, + DeletionInfo{newPath, mpt_.storeNode( + LeafNode{newPath, branchNode.value()})}}; + } else if (validBranches == 1 && branchNode.value().empty()) { + return buildNewNodeFromLastBranch(branches); + } else { + auto updatedBranch = branchNode; + updatedBranch.ClearBranch(idx); + return {DeleteAction::Updated, + DeletionInfo{{}, mpt_.storeNode(updatedBranch)}}; + } + } + + DeleteResult buildNewNodeFromLastBranch( + const std::array& branches) { + // Find the index of the only stored branch. + auto it = std::find_if(branches.begin(), branches.end(), + [](const Reference& ref) { return !ref.empty(); }); + if (it == branches.end()) { + throw std::runtime_error("No valid branches found"); + } + + uint8_t idx = std::distance(branches.begin(), it); + + // Path in leaf will contain one nibble (at this step). + Path prefixNibble({std::byte{idx}}, 1); + auto child = mpt_.getNode(*it); + + return std::visit( + overloaded{ + [&](const LeafNode& leafChild) -> DeleteResult { + auto path = prefixNibble + leafChild.path; + return {DeleteAction::UselessBranch, + DeletionInfo{path, mpt_.storeNode( + LeafNode{path, leafChild.value()})}}; + }, + [&](const ExtensionNode& extensionChild) -> DeleteResult { + auto path = prefixNibble + extensionChild.path; + return { + DeleteAction::UselessBranch, + DeletionInfo{path, mpt_.storeNode(ExtensionNode{ + path, extensionChild.get_next_ref()})}}; + }, + [&](const BranchNode&) -> DeleteResult { + return {DeleteAction::UselessBranch, + DeletionInfo{prefixNibble, mpt_.storeNode(ExtensionNode{ + prefixNibble, *it})}}; + }}, + child); + } + }; + } // namespace details + + // Crypto3 compiles so slow... Use this dummy hash for testing. + // Hash we are going to use is still willing to change anyway + std::array basicHash(const std::vector& key) { + std::array result = {std::byte{0}}; + + for (size_t i = 0; i < key.size(); ++i) { + result[i % 64] ^= key[i]; + } + + for (size_t i = 0; i < 64; ++i) { + result[i] ^= result[(i + 1) % 64]; + result[i] = (result[i] << 1) | (result[i] >> 7); + } + + return result; + } + + Path MerklePatriciaTrie::pathFromKey(const std::vector& key) const { + std::vector hashResult; + if (key.size() > kMaxRawKeyLen) { + auto hashArray = basicHash(key); + hashResult = std::vector(hashArray.begin(), hashArray.end()); + } else { + hashResult = key; + } + + return Path(hashResult); + } + + MerklePatriciaTrie::MerklePatriciaTrie(){}; + + std::vector MerklePatriciaTrie::get(const std::vector& key) const { + if (root_.empty()) { + throw std::runtime_error("Not initialized MPT"); + } + + auto path = pathFromKey(key); + auto result = get(root_, path); + + if (result) { + return *result; + } + throw std::runtime_error("Key not found"); + } + + std::optional> MerklePatriciaTrie::get(const Reference& nodeRef, + Path& path) const { + auto node = getNode(nodeRef); + return std::visit(details::GetHandler(*this, path), node); + } + + void MerklePatriciaTrie::set(const std::vector& key, + const std::vector& value) { + auto path = pathFromKey(key); + root_ = set(root_, path, value); + } + + void MerklePatriciaTrie::remove(const std::vector& key) { + if (root_.empty()) { + return; + } + + auto path = pathFromKey(key); + auto result = delete_node(root_, path); + + switch (result.action) { + case details::DeleteAction::Deleted: { + root_.clear(); + } + case details::DeleteAction::Updated: { + } + case details::DeleteAction::Unknown: { + root_ = *(result.info->ref); + return; + } + default: { + throw std::runtime_error("remove error"); + } + } + } + + details::DeleteResult MerklePatriciaTrie::delete_node(const Reference& nodeRef, + const Path& path) { + auto node = getNode(nodeRef); + return std::visit(details::DeleteHandler(*this, path), node); + } + + Node MerklePatriciaTrie::getNode(const Reference& ref) const { + if (ref.size() < 32) { + return DecodeNode(ref); + } + auto it = nodes_.find(ref); + if (it == nodes_.end()) { + throw std::runtime_error("Node not found"); + } + + return DecodeNode(it->second); + } + + Reference MerklePatriciaTrie::storeNode(const Node& node) { + Bytes encoded = std::visit([](const auto& n) { return n.Encode(); }, node); + if (encoded.size() < 32) { + return encoded; + } + auto keyArr = basicHash(encoded); + Bytes key(keyArr.begin(), keyArr.end()); + nodes_[key] = encoded; + return key; + } + + Reference MerklePatriciaTrie::set(const Reference& nodeRef, Path& path, + const std::vector& value) { + if (nodeRef.empty()) { + return storeNode(LeafNode{path, value}); + } + + auto node = getNode(nodeRef); + return std::visit(details::SetHandler(*this, path, value), node); + } + + } // namespace mpt +} // namespace core diff --git a/libs/nil_core/src/mpt/node.cpp b/libs/nil_core/src/mpt/node.cpp new file mode 100644 index 0000000..05fbe38 --- /dev/null +++ b/libs/nil_core/src/mpt/node.cpp @@ -0,0 +1,97 @@ +#include "zkevm_framework/core/mpt/node.hpp" + +#include +#include + +namespace core { + namespace mpt { + + PathHolder::PathHolder(Path path) : path(std::move(path)) {} + + ValueHolder::ValueHolder(const std::vector& value) { set_value(value); } + + const std::vector& ValueHolder::value() const { return value_.data(); } + + void ValueHolder::set_value(const std::vector& new_value) { value_ = new_value; } + + LeafNode::LeafNode(const Path& path, const std::vector& new_value) + : PathHolder(path), ValueHolder(new_value) {} + + Bytes LeafNode::Encode() const { + auto serialized = std::vector(ssz_size() + 1); + serialized.front() = static_cast(NodeTypeFlag::kLeafNode); + ssz::serialize(serialized.begin() + 1, *this); + return serialized; + } + + ExtensionNode::ExtensionNode(const Path& path, const Reference& next) + : PathHolder(path), next_ref_(next) {} + + Bytes ExtensionNode::Encode() const { + auto serialized = std::vector(ssz_size() + 1); + serialized.front() = static_cast(NodeTypeFlag::kExtensionNode); + ssz::serialize(serialized.begin() + 1, *this); + return serialized; + } + + const Reference& ExtensionNode::get_next_ref() const { return next_ref_.data(); } + + BranchNode::BranchNode(const std::array& refs, + const std::vector& value) + : ValueHolder(value) { + for (const auto& branch : refs) { + branches_.push_back(branch); + } + } + + Bytes BranchNode::Encode() const { + auto serialized = std::vector(ssz_size() + 1); + serialized.front() = static_cast(NodeTypeFlag::kBranchNode); + ssz::serialize(serialized.begin() + 1, *this); + return serialized; + } + + std::array BranchNode::get_branches() const { + std::array arr; + const auto& branches_vec = branches_.data(); + for (std::size_t i = 0; i < branches_.size(); ++i) { + arr[i] = branches_vec[i].data(); + } + return arr; + } + + void BranchNode::ClearBranch(std::byte nibble) { + branches_[std::to_integer(nibble)].data().clear(); + } + + void BranchNode::SetBranch(std::byte nibble, const std::vector& value) { + branches_[std::to_integer(nibble)] = value; + } + + template + Node DeserializeNode(const std::span& bytes) { + return ssz::deserialize(bytes); + } + + Node DecodeNode(const Bytes& bytes) { + if (bytes.empty()) { + throw std::runtime_error("Empty byte array"); + } + + const auto type_flag = static_cast(bytes.front()); + std::span data_span(bytes.data() + 1, bytes.size() - 1); + + switch (type_flag) { + case NodeTypeFlag::kLeafNode: + return DeserializeNode(data_span); + case NodeTypeFlag::kExtensionNode: + return DeserializeNode(data_span); + case NodeTypeFlag::kBranchNode: + return DeserializeNode(data_span); + default: + throw std::runtime_error("Unknown node type"); + } + } + + } // namespace mpt +} // namespace core diff --git a/libs/nil_core/src/mpt/path.cpp b/libs/nil_core/src/mpt/path.cpp new file mode 100644 index 0000000..00f3e89 --- /dev/null +++ b/libs/nil_core/src/mpt/path.cpp @@ -0,0 +1,116 @@ +#include "zkevm_framework/core/mpt/path.hpp" + +#include +#include + +namespace core { + namespace mpt { + + Path::Path() = default; + + Path::Path(const std::vector& data, std::size_t offset) + : data_(data), offset_(offset) {} + + int Path::size() const { return data_.size() * 2 - offset_; } + + bool Path::empty() const { return size() == 0; } + + std::byte Path::operator[](size_t idx) const { + idx += offset_; + auto target_byte = data_[idx / 2]; + return (idx % 2 == 0) ? (target_byte >> 4) : (target_byte & std::byte{0x0F}); + } + + bool Path::operator==(const Path& other) const { + if (other.size() != size()) { + return false; + } + for (int i = 0; i < size(); ++i) { + if (operator[](i) != other[i]) { + return false; + } + } + return true; + } + + std::byte Path::at(std::size_t idx) const { + if (idx >= size()) { + throw std::out_of_range("Index out of range"); + } + return operator[](idx); + } + + bool Path::StartsWith(const Path& other) const { + if (other.size() > size()) { + return false; + } + for (int i = 0; i < other.size(); ++i) { + if (operator[](i) != other[i]) { + return false; + } + } + return true; + } + + Path* Path::Consume(std::size_t amount) { + offset_ += amount; + return this; + } + + Path Path::CommonPrefix(const Path& other) const { + int least_len = std::min(size(), other.size()); + int common_len = 0; + for (int i = 0; i < least_len; ++i) { + if (operator[](i) != other[i]) { + break; + } + common_len += 1; + } + return ConstructFromPrefix(common_len); + } + + // Path Path::Combine(const Path& other) const { return *this + other; } + + Path Path::operator+(const Path& other) const { + std::size_t final_size = size() + other.size(); + bool final_is_odd = (final_size % 2) == 1; + std::vector new_data; + new_data.reserve((final_size + 1) / 2); // Round up to nearest byte + + auto at_combined = [this, &other](std::size_t n) -> std::byte { + if (n < this->size()) { + return operator[](n); + } + return other[n - this->size()]; + }; + + if (final_is_odd) { + new_data.push_back(at_combined(0) & std::byte{0x0F}); + } + + for (std::size_t i = final_is_odd ? 1 : 0; i < final_size; i += 2) { + std::byte next_byte = + (at_combined(i) << 4) | (at_combined(i + 1) & std::byte{0x0F}); + new_data.push_back(next_byte); + } + + return Path(new_data, final_is_odd ? 1 : 0); + } + + Path Path::ConstructFromPrefix(std::size_t length) const { + std::vector new_data; + bool is_odd_len = length % 2 == 1; + std::size_t pos = 0; + if (is_odd_len) { + new_data.push_back(operator[](pos)); + pos += 1; + } + for (; pos < length; pos += 2) { + new_data.push_back((operator[](pos) << 4) | operator[](pos + 1)); + } + std::size_t new_offset = is_odd_len ? 1 : 0; + return Path(new_data, new_offset); + } + + } // namespace mpt +} // namespace core diff --git a/tests/libs/nil_core/CMakeLists.txt b/tests/libs/nil_core/CMakeLists.txt index 4f9355a..5a1ad6e 100644 --- a/tests/libs/nil_core/CMakeLists.txt +++ b/tests/libs/nil_core/CMakeLists.txt @@ -15,3 +15,4 @@ function(add_nil_core_test target) endfunction() add_nil_core_test(test_nil_core_ssz) +add_nil_core_test(test_nil_core_mpt) diff --git a/tests/libs/nil_core/test_nil_core_mpt.cpp b/tests/libs/nil_core/test_nil_core_mpt.cpp new file mode 100644 index 0000000..da52d17 --- /dev/null +++ b/tests/libs/nil_core/test_nil_core_mpt.cpp @@ -0,0 +1,122 @@ +#include +#include +#include + +#include "gtest/gtest.h" +#include "ssz++.hpp" +#include "zkevm_framework/core/mpt/mpt.hpp" + +using namespace core; +using namespace core::mpt; + +std::vector stringToByteVector(const std::string& str) { + std::vector result; + result.reserve(str.size()); + for (char c : str) { + result.push_back(static_cast(c)); + } + return result; +} + +TEST(NilCoreMerklePatriciaTrieTest, PathSerialization) { + std::string current; + for (std::size_t i = 0; i < 100; ++i) { + auto pathBytes = stringToByteVector(current); + Path path{pathBytes}; + + const auto serialized = ssz::serialize(path); + const auto deserialized = ssz::deserialize(serialized); + + ASSERT_EQ(path, deserialized); + + current += 'a' + (i % 26); + } +} + +TEST(NilCoreMerklePatriciaTrieTest, NodeEncodeDecode) { + // TODO: add hardcoded input for deserialization to check binary compatability with cluster + std::string current; + for (std::size_t i = 0; i < 10; ++i) { + auto currentBytes = stringToByteVector(current); + // LeafNode + Path path{currentBytes}; + std::vector value = currentBytes; + LeafNode leafNode{path, value}; + + auto serialized = leafNode.Encode(); + Node node = DecodeNode(serialized); + ASSERT_TRUE(std::holds_alternative(node)); + const auto& leaf = std::get(node); + + ASSERT_EQ(leaf.value(), value); + ASSERT_EQ(leaf.path, path); + + // ExtensionNode + ExtensionNode extensionNode{path, value}; + + serialized = extensionNode.Encode(); + node = DecodeNode(serialized); + ASSERT_TRUE(std::holds_alternative(node)); + const auto& extension = std::get(node); + + ASSERT_EQ(extension.path, path); + ASSERT_EQ(extension.get_next_ref(), value); + + // BranchNode + std::array branches; + for (std::size_t i = 0; i < kBranchesNum; ++i) { + branches[i] = value; + } + BranchNode branchNode(branches, value); + + serialized = branchNode.Encode(); + node = DecodeNode(serialized); + ASSERT_TRUE(std::holds_alternative(node)); + const auto& branch = std::get(node); + + ASSERT_EQ(branch.value(), value); + ASSERT_EQ(branch.get_branches(), branches); + current += 'a' + (i % 26); + } +} + +TEST(NilCoreMerklePatriciaTrieTest, InsertGetMany) { + MerklePatriciaTrie trie; + + std::vector> cases = { + {"do", "verb"}, {"dog", "puppy"}, {"doge", "coin"}, {"horse", "stallion"}}; + + for (const auto& [key, value] : cases) { + ASSERT_NO_THROW(trie.set(stringToByteVector(key), stringToByteVector(value))); + } + + for (const auto& [key, expectedValue] : cases) { + std::vector result; + result = trie.get(stringToByteVector(key)); + ASSERT_NO_THROW(result = trie.get(stringToByteVector(key))); + ASSERT_EQ(result, stringToByteVector(expectedValue)); + } +} + +TEST(NilCoreMerklePatriciaTrieTest, TestDelete) { + Path p; + MerklePatriciaTrie trie; + + std::vector> cases = { + {"do", "verb"}, {"dog", "puppy"}, {"doge", "coin"}, {"horse", "stallion"}}; + + for (const auto& [key, value] : cases) { + ASSERT_NO_THROW(trie.set(stringToByteVector(key), stringToByteVector(value))); + } + + ASSERT_NO_THROW(trie.remove(stringToByteVector("do"))); + ASSERT_ANY_THROW(trie.remove(stringToByteVector("do"))); // Can't remove twice + ASSERT_ANY_THROW(trie.remove(stringToByteVector("d"))); // Can't remove absent + ASSERT_NO_THROW(trie.remove(stringToByteVector("doge"))); + + ASSERT_ANY_THROW(trie.get(stringToByteVector("do"))); // Can't access removed + ASSERT_ANY_THROW(trie.get(stringToByteVector("doge"))); // Can't access removed + + ASSERT_NO_THROW(trie.get(stringToByteVector("dog"))); // Can access existing + ASSERT_NO_THROW(trie.get(stringToByteVector("horse"))); // Can access existing +}