Skip to content

Commit

Permalink
feat: unbatch BatchMerkleProof into individual Proof
Browse files Browse the repository at this point in the history
fix: format + naming

fix: optimize by using a BTreeMap

fix: signature into_paths

fix: formatting

feat: Merkle proof unbatching

feat: Merkle proof unbatching 2

feat: Merkle proof unbatching 2
  • Loading branch information
Al-Kindi-0 committed Sep 3, 2022
1 parent 01baed2 commit 0b76b68
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 3 deletions.
2 changes: 1 addition & 1 deletion crypto/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl fmt::Display for MerkleTreeError {
Self::TooManyLeafIndexes(max_indexes, num_indexes) => {
write!(
f,
"number of leaf indexes cannot exceed {}, but was {} provided",
"number of leaf indexes cannot exceed {}, but {} was provided",
max_indexes, num_indexes
)
}
Expand Down
4 changes: 2 additions & 2 deletions crypto/src/hash/sha/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl<B: StarkField> Hasher for Sha3_256<B> {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
data[32..].copy_from_slice(&value.to_le_bytes());
ByteDigest(sha3::Sha3_256::digest(&data).into())
ByteDigest(sha3::Sha3_256::digest(data).into())
}
}

Expand Down Expand Up @@ -72,7 +72,7 @@ impl ShaHasher {

impl ByteWriter for ShaHasher {
fn write_u8(&mut self, value: u8) {
self.0.update(&[value]);
self.0.update([value]);
}

fn write_u8_slice(&mut self, values: &[u8]) {
Expand Down
179 changes: 179 additions & 0 deletions crypto/src/merkle/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,157 @@ impl<H: Hasher> BatchMerkleProof<H> {
v.remove(&1).ok_or(MerkleTreeError::InvalidProof)
}

/// Computes the uncompressed Merkle paths which aggregate to this proof.
///
/// # Errors
/// Returns an error if:
/// * No indexes were provided (i.e., `indexes` is an empty slice).
/// * Number of provided indexes is greater than 255.
/// * Number of provided indexes does not match the number of leaf nodes in the proof.
pub fn into_paths(self, indexes: &[usize]) -> Result<Vec<Vec<H::Digest>>, MerkleTreeError> {
if indexes.is_empty() {
return Err(MerkleTreeError::TooFewLeafIndexes);
}
if indexes.len() > MAX_PATHS {
return Err(MerkleTreeError::TooManyLeafIndexes(
MAX_PATHS,
indexes.len(),
));
}
if indexes.len() != self.leaves.len() {
return Err(MerkleTreeError::InvalidProof);
}

let mut partial_tree_map = BTreeMap::new();

for (&i, leaf) in indexes.iter().zip(self.leaves.iter()) {
partial_tree_map.insert(i + (1 << (self.depth)), *leaf);
}

let mut buf = [H::Digest::default(); 2];
let mut v = BTreeMap::new();

// replace odd indexes, offset, and sort in ascending order
let original_indexes = indexes;
let index_map = super::map_indexes(indexes, self.depth as usize)?;
let indexes = super::normalize_indexes(indexes);
if indexes.len() != self.nodes.len() {
return Err(MerkleTreeError::InvalidProof);
}

// for each index use values to compute parent nodes
let offset = 2usize.pow(self.depth as u32);
let mut next_indexes: Vec<usize> = Vec::new();
let mut proof_pointers: Vec<usize> = Vec::with_capacity(indexes.len());
for (i, index) in indexes.into_iter().enumerate() {
// copy values of leaf sibling leaf nodes into the buffer
match index_map.get(&index) {
Some(&index1) => {
if self.leaves.len() <= index1 {
return Err(MerkleTreeError::InvalidProof);
}
buf[0] = self.leaves[index1];
match index_map.get(&(index + 1)) {
Some(&index2) => {
if self.leaves.len() <= index2 {
return Err(MerkleTreeError::InvalidProof);
}
buf[1] = self.leaves[index2];
proof_pointers.push(0);
}
None => {
if self.nodes[i].is_empty() {
return Err(MerkleTreeError::InvalidProof);
}
buf[1] = self.nodes[i][0];
proof_pointers.push(1);
}
}
}
None => {
if self.nodes[i].is_empty() {
return Err(MerkleTreeError::InvalidProof);
}
buf[0] = self.nodes[i][0];
match index_map.get(&(index + 1)) {
Some(&index2) => {
if self.leaves.len() <= index2 {
return Err(MerkleTreeError::InvalidProof);
}
buf[1] = self.leaves[index2];
}
None => return Err(MerkleTreeError::InvalidProof),
}
proof_pointers.push(1);
}
}

// hash sibling nodes into their parent and add it to partial_tree
let parent = H::merge(&buf);
partial_tree_map.insert(offset + index, buf[0]);
partial_tree_map.insert((offset + index) ^ 1, buf[1]);
let parent_index = (offset + index) >> 1;
v.insert(parent_index, parent);
next_indexes.push(parent_index);
partial_tree_map.insert(parent_index, parent);
}

// iteratively move up, until we get to the root
for _ in 1..self.depth {
let indexes = next_indexes.clone();
next_indexes.clear();

let mut i = 0;
while i < indexes.len() {
let node_index = indexes[i];
let sibling_index = node_index ^ 1;

// determine the sibling
let sibling = if i + 1 < indexes.len() && indexes[i + 1] == sibling_index {
i += 1;
match v.get(&sibling_index) {
Some(sibling) => *sibling,
None => return Err(MerkleTreeError::InvalidProof),
}
} else {
let pointer = proof_pointers[i];
if self.nodes[i].len() <= pointer {
return Err(MerkleTreeError::InvalidProof);
}
proof_pointers[i] += 1;
self.nodes[i][pointer]
};

// get the node from the map of hashed nodes
let node = match v.get(&node_index) {
Some(node) => node,
None => return Err(MerkleTreeError::InvalidProof),
};

// compute parent node from node and sibling
partial_tree_map.insert(node_index ^ 1, sibling);
let parent = if node_index & 1 != 0 {
H::merge(&[sibling, *node])
} else {
H::merge(&[*node, sibling])
};

// add the parent node to the next set of nodes and partial_tree
let parent_index = node_index >> 1;
v.insert(parent_index, parent);
next_indexes.push(parent_index);
partial_tree_map.insert(parent_index, parent);

i += 1;
}
}

original_indexes
.iter()
.map(|&i| get_path::<H>(i, &partial_tree_map, self.depth as usize))
.collect()
}

// SERIALIZATION / DESERIALIZATION
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -343,3 +494,31 @@ impl<H: Hasher> BatchMerkleProof<H> {
fn are_siblings(left: usize, right: usize) -> bool {
left & 1 == 0 && right - 1 == left
}

/// Computes the Merkle path from the computed (partial) tree.
pub fn get_path<H: Hasher>(
index: usize,
tree: &BTreeMap<usize, <H as Hasher>::Digest>,
depth: usize,
) -> Result<Vec<H::Digest>, MerkleTreeError> {
let mut index = index + (1 << depth);
let leaf = if let Some(leaf) = tree.get(&index) {
*leaf
} else {
return Err(MerkleTreeError::InvalidProof);
};

let mut proof = vec![leaf];
while index > 1 {
let leaf = if let Some(leaf) = tree.get(&(index ^ 1)) {
*leaf
} else {
return Err(MerkleTreeError::InvalidProof);
};

proof.push(leaf);
index >>= 1;
}

Ok(proof)
}
42 changes: 42 additions & 0 deletions crypto/src/merkle/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,30 @@ fn verify_batch() {
assert!(MerkleTree::verify_batch(tree.root(), &[0, 1, 2, 3, 4, 5, 6, 7], &proof).is_ok());
}

#[test]
fn verify_into_paths() {
let leaves = Digest256::bytes_as_digests(&LEAVES8).to_vec();
let tree = MerkleTree::<Blake3_256>::new(leaves).unwrap();

let proof1 = tree.prove(1).unwrap();
let proof2 = tree.prove(2).unwrap();
let proof1_2 = tree.prove_batch(&[1, 2]).unwrap();
let result = proof1_2.into_paths(&[1, 2]).unwrap();

assert_eq!(proof1, result[0]);
assert_eq!(proof2, result[1]);

let proof3 = tree.prove(3).unwrap();
let proof4 = tree.prove(4).unwrap();
let proof6 = tree.prove(5).unwrap();
let proof3_4_6 = tree.prove_batch(&[3, 4, 5]).unwrap();
let result = proof3_4_6.into_paths(&[3, 4, 5]).unwrap();

assert_eq!(proof3, result[0]);
assert_eq!(proof4, result[1]);
assert_eq!(proof6, result[2]);
}

proptest! {
#[test]
fn prove_n_verify(tree in random_blake3_merkle_tree(128),
Expand Down Expand Up @@ -269,6 +293,24 @@ proptest! {

prop_assert!(proof1 == proof2);
}

#[test]
fn into_paths(tree in random_blake3_merkle_tree(32),
proof_indices in prop::collection::vec(any::<prop::sample::Index>(), 1..30)
) {
let mut indices: Vec<usize> = proof_indices.iter().map(|idx| idx.index(32)).collect();
indices.sort_unstable(); indices.dedup();
let proof1 = tree.prove_batch(&indices[..]).unwrap();

let mut paths_expected = Vec::new();
for &idx in indices.iter() {
paths_expected.push(tree.prove(idx).unwrap());
}

let paths = proof1.into_paths(&indices);

prop_assert!(paths_expected == paths.unwrap());
}
}

// HELPER FUNCTIONS
Expand Down

0 comments on commit 0b76b68

Please sign in to comment.