Skip to content

Commit

Permalink
chore: addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
krushimir committed Jan 22, 2025
1 parent f42d597 commit a76506f
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.13.1 (2024-12-26)

- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
- Added parallel implementation of `Smt::compute_mutations` with better performance (#365).

## 0.13.0 (2024-11-24)

Expand Down
8 changes: 0 additions & 8 deletions src/merkle/smt/full/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,6 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}

/// Sequential implementation of [`Smt::compute_mutations()`].
pub fn compute_mutations_sequential(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations_sequential(self, kv_pairs)
}

/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
/// updated tree will revert the changes.
Expand Down
157 changes: 90 additions & 67 deletions src/merkle/smt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
type Leaves<T> = UnorderedMap<u64, T>;
type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
type MutatedLeavesResult<K, V> = (Vec<Vec<SubtreeLeaf>>, UnorderedMap<K, V>);
type MutatedSubtreeLeaves = Vec<Vec<SubtreeLeaf>>;

/// An abstract description of a sparse Merkle tree.
///
Expand Down Expand Up @@ -185,7 +185,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
{
#[cfg(feature = "concurrent")]
{
self.compute_mutations_subtree(kv_pairs)
self.compute_mutations_concurrent(kv_pairs)
}
#[cfg(not(feature = "concurrent"))]
{
Expand Down Expand Up @@ -287,13 +287,21 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {

/// Parallel implementation of [`SparseMerkleTree::compute_mutations()`].
///
/// This method recursively tracks mutations across 8-depth subtrees from the bottom up,
/// ultimately reconstructing the complete mutation set for the entire tree.
/// This method computes mutations by recursively processing subtrees in parallel, working from
/// the bottom up. For a tree of depth D with subtrees of depth 8, the process works as
/// follows:
///
/// The implementation is similar to [`SparseMerkleTree::build_subtrees_from_sorted_entries()`],
/// sharing the same constraint that the depth must be a multiple of 8.
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
///
/// 2. The subtrees containing modifications are then processed in parallel:
/// - For each modified subtree, compute node mutations from depth D up to depth D-8
/// - Each subtree computation yields a new root at depth D-8 and its associated mutations
///
/// 3. These subtree roots become the "leaves" for the next iteration, which processes the next
/// 8 levels up. This continues until reaching the tree's root at depth 0.
#[cfg(feature = "concurrent")]
fn compute_mutations_subtree(
fn compute_mutations_concurrent(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value>
Expand All @@ -307,7 +315,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
sorted_kv_pairs.sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());

// Convert sorted pairs into mutated leaves and capture any new pairs
let (mut subtree_leaves, new_pairs) = self.sorted_pairs_to_mutated_leaves(sorted_kv_pairs);
let (mut subtree_leaves, new_pairs) =
self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs);
let mut node_mutations = NodeMutations::default();

// Process each depth level in reverse, stepping by the subtree depth
Expand Down Expand Up @@ -521,63 +530,38 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
) -> Self::Leaf;

/// Computes leaves from a set of key-value pairs and current leaf values.
/// Deried from `sorted_pairs_to_leaves`
///
/// TODO: refactor and merge functionality with `sorted_pairs_to_leaves`?
fn sorted_pairs_to_mutated_leaves(
/// Derived from `sorted_pairs_to_leaves`
fn sorted_pairs_to_mutated_subtree_leaves(
&self,
pairs: Vec<(Self::Key, Self::Value)>,
) -> MutatedLeavesResult<Self::Key, Self::Value> {
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));

let mut accumulated_leaves = Vec::with_capacity(pairs.len() / 2);
) -> (MutatedSubtreeLeaves, UnorderedMap<Self::Key, Self::Value>) {
// Map to track new key-value pairs for mutated leaves
let mut new_pairs = UnorderedMap::new();
let mut current_leaf_buffer = Vec::new();

let mut iter = pairs.into_iter().peekable();
while let Some((key, value)) = iter.next() {
let col = Self::key_to_leaf_index(&key).index.value();

if let Some((next_key, _)) = iter.peek() {
let next_col = Self::key_to_leaf_index(next_key).index.value();
debug_assert!(next_col >= col);
}

current_leaf_buffer.push((key.clone(), value));

// If the next pair is the same column, continue accumulating
if iter
.peek()
.is_some_and(|(next_key, _)| Self::key_to_leaf_index(next_key).index.value() == col)
{
continue;
}

// Process buffered pairs
let leaf_pairs = mem::take(&mut current_leaf_buffer);
let mut leaf = self.get_leaf(&key);
let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
let mut leaf = self.get_leaf(&leaf_pairs[0].0);

for (key, value) in leaf_pairs {
match new_pairs.get(&key) {
Some(existing_value) if existing_value == &value => continue,
_ => {
leaf = self.construct_prospective_leaf(leaf, &key, &value);
new_pairs.insert(key, value);
},
}
}
// Check if the value has changed
let old_value =
new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));

let hash = Self::hash_leaf(&leaf);
accumulated_leaves.push(SubtreeLeaf { col, hash });
// Skip if the value hasn't changed
if value == old_value {
continue;
}

debug_assert!(current_leaf_buffer.is_empty());
}
// Otherwise, update the leaf and track the new key-value pair
leaf = self.construct_prospective_leaf(leaf, &key, &value);
new_pairs.insert(key, value);
}

let leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
(leaves, new_pairs)
leaf
});
(accumulator.leaves, new_pairs)
}

// Computes the node mutations and the root of a subtree
/// Computes the node mutations and the root of a subtree
fn build_subtree_mutations(
&self,
mut leaves: Vec<SubtreeLeaf>,
Expand All @@ -602,12 +586,11 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
let mut iter = leaves.drain(..).peekable();

while let Some(first_leaf) = iter.next() {
/// This constructs a valid index because next_depth will never exceed the depth of the tree.
// This constructs a valid index because next_depth will never exceed the depth of
// the tree.
let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent();
let parent_node = self.get_inner_node(parent_index);
let (left, right) = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node);

let combined_node = InnerNode { left: left.hash, right: right.hash };
let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node);
let combined_hash = combined_node.hash();

let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth);
Expand Down Expand Up @@ -636,31 +619,38 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
(node_mutations, root_leaf)
}

// Returns the sibling pair based on the first leaf and the current depth
//
// This is a helper function that is used to build the subtree mutations
// The first leaf is the leaf that we are currently processing
// The current depth is the depth of the current subtree
/// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part:
/// - If `first_leaf` is a right child, the left child is copied from the `parent_node`.
/// - If `first_leaf` is a left child, the right child is taken from `iter` if it also mutated
/// or copied from the `parent_node`.
///
/// Returns the `InnerNode` containing the hashes of the sibling pair.
fn fetch_sibling_pair(
iter: &mut core::iter::Peekable<alloc::vec::Drain<SubtreeLeaf>>,
first_leaf: SubtreeLeaf,
parent_node: InnerNode,
) -> (SubtreeLeaf, SubtreeLeaf) {
) -> InnerNode {
let is_right_node = first_leaf.col.is_odd();

if is_right_node {
let left_leaf = SubtreeLeaf {
col: first_leaf.col - 1,
hash: parent_node.left,
};
(left_leaf, first_leaf)
InnerNode {
left: left_leaf.hash,
right: first_leaf.hash,
}
} else {
let right_col = first_leaf.col + 1;
let right_leaf = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(),
_ => SubtreeLeaf { col: right_col, hash: parent_node.right },
};
(first_leaf, right_leaf)
InnerNode {
left: first_leaf.hash,
right: right_leaf.hash,
}
}
}

Expand Down Expand Up @@ -689,6 +679,39 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
fn sorted_pairs_to_leaves(
pairs: Vec<(Self::Key, Self::Value)>,
) -> PairComputations<u64, Self::Leaf> {
Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| Self::pairs_to_leaf(leaf_pairs))
}

/// Processes sorted key-value pairs to compute leaves for a subtree.
///
/// This function groups key-value pairs by their corresponding column index and processes each
/// group to construct leaves. The actual construction of the leaf is delegated to the
/// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating
/// new leaves or mutating existing ones).
///
/// # Parameters
/// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index
/// column (not simply by key). If the input is not sorted correctly, the function will
/// produce incorrect results and may panic in debug mode.
/// - `process_leaf`: A callback function used to process each group of key-value pairs
/// corresponding to the same column index. The callback takes a vector of key-value pairs for
/// a single column and returns the constructed leaf for that column.
///
/// # Returns
/// A `PairComputations<u64, Self::Leaf>` containing:
/// - `nodes`: A mapping of column indices to the constructed leaves.
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
///
/// # Panics
/// This function will panic in debug mode if the input `pairs` are not sorted by column index.
fn process_sorted_pairs_to_leaves<F>(
pairs: Vec<(Self::Key, Self::Value)>,
mut process_leaf: F,
) -> PairComputations<u64, Self::Leaf>
where
F: FnMut(Vec<(Self::Key, Self::Value)>) -> Self::Leaf,
{
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));

let mut accumulator: PairComputations<u64, Self::Leaf> = Default::default();
Expand Down Expand Up @@ -720,7 +743,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
// Otherwise, the next pair is a different column, or there is no next pair. Either way
// it's time to swap out our buffer.
let leaf_pairs = mem::take(&mut current_leaf_buffer);
let leaf = Self::pairs_to_leaf(leaf_pairs);
let leaf = process_leaf(leaf_pairs);
let hash = Self::hash_leaf(&leaf);

accumulator.nodes.insert(col, leaf);
Expand Down
18 changes: 16 additions & 2 deletions src/merkle/smt/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use alloc::{collections::BTreeMap, vec::Vec};
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};

use rand::{prelude::IteratorRandom, thread_rng, Rng};

Expand Down Expand Up @@ -115,6 +118,17 @@ fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(Rpo
const REMOVAL_PROBABILITY: f64 = 0.2;
let mut rng = thread_rng();

// Assertion to ensure input keys are unique
assert!(
entries
.iter()
.map(|(key, _)| key)
.collect::<BTreeSet<_>>()
.len()
== entries.len(),
"Input entries contain duplicate keys!"
);

let mut sorted_entries: Vec<(RpoDigest, Word)> = entries
.into_iter()
.choose_multiple(&mut rng, updates)
Expand Down Expand Up @@ -452,7 +466,7 @@ fn test_singlethreaded_subtree_mutations() {

let mut node_mutations = NodeMutations::default();

let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_leaves(updates);
let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates);

for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// There's no flat_map_unzip(), so this is the best we can do.
Expand Down

0 comments on commit a76506f

Please sign in to comment.