diff --git a/Cargo.lock b/Cargo.lock index ed4ffde9304d..5f20451f15f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9669,6 +9669,7 @@ dependencies = [ "proptest", "proptest-arbitrary-interop", "rand 0.8.5", + "rayon", "reth-execution-errors", "reth-primitives-traits", "reth-testing-utils", diff --git a/crates/trie/sparse/Cargo.toml b/crates/trie/sparse/Cargo.toml index 205451ef72a8..11cea8be5b56 100644 --- a/crates/trie/sparse/Cargo.toml +++ b/crates/trie/sparse/Cargo.toml @@ -23,6 +23,7 @@ alloy-primitives.workspace = true alloy-rlp.workspace = true # misc +rayon.workspace = true smallvec = { workspace = true, features = ["const_new"] } thiserror.workspace = true diff --git a/crates/trie/sparse/src/blinded.rs b/crates/trie/sparse/src/blinded.rs index a9f0e89c29c1..43db1267bc0d 100644 --- a/crates/trie/sparse/src/blinded.rs +++ b/crates/trie/sparse/src/blinded.rs @@ -7,9 +7,9 @@ use reth_trie_common::Nibbles; /// Factory for instantiating blinded node providers. pub trait BlindedProviderFactory { /// Type capable of fetching blinded account nodes. - type AccountNodeProvider: BlindedProvider; + type AccountNodeProvider: BlindedProvider + Send + Sync; /// Type capable of fetching blinded storage nodes. - type StorageNodeProvider: BlindedProvider; + type StorageNodeProvider: BlindedProvider + Send + Sync; /// Returns blinded account node provider. fn account_node_provider(&self) -> Self::AccountNodeProvider; diff --git a/crates/trie/sparse/src/state.rs b/crates/trie/sparse/src/state.rs index c6c741e002be..1526e8ef073b 100644 --- a/crates/trie/sparse/src/state.rs +++ b/crates/trie/sparse/src/state.rs @@ -301,11 +301,6 @@ impl SparseStateTrie { Ok(()) } - /// Calculates the hashes of the nodes below the provided level. - pub fn calculate_below_level(&mut self, level: usize) { - self.state.calculate_below_level(level); - } - /// Returns storage sparse trie root if the trie has been revealed. pub fn storage_root(&mut self, account: B256) -> Option { self.storages.get_mut(&account).and_then(|trie| trie.root()) @@ -346,7 +341,7 @@ impl SparseStateTrie { } impl SparseStateTrie where - F: BlindedProviderFactory, + F: BlindedProviderFactory + Send + Sync, SparseTrieError: From<::Error> + From<::Error>, { @@ -423,6 +418,11 @@ where storage_trie.remove_leaf(slot)?; Ok(()) } + + /// Calculates the hashes of the nodes below the provided level. + pub fn calculate_below_level(&mut self, level: usize) { + self.state.calculate_below_level(level); + } } #[cfg(test)] diff --git a/crates/trie/sparse/src/trie.rs b/crates/trie/sparse/src/trie.rs index 3f2d6f58bd2e..b4d1062c0a0d 100644 --- a/crates/trie/sparse/src/trie.rs +++ b/crates/trie/sparse/src/trie.rs @@ -1,10 +1,11 @@ use crate::blinded::{BlindedProvider, DefaultBlindedProvider}; use alloy_primitives::{ - hex, keccak256, + keccak256, map::{Entry, HashMap, HashSet}, B256, }; use alloy_rlp::Decodable; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind, SparseTrieResult}; use reth_tracing::tracing::trace; use reth_trie_common::{ @@ -13,7 +14,7 @@ use reth_trie_common::{ TrieNode, CHILD_INDEX_RANGE, EMPTY_ROOT_HASH, }; use smallvec::SmallVec; -use std::{borrow::Cow, fmt}; +use std::{borrow::Cow, fmt, sync::mpsc}; /// Inner representation of the sparse trie. /// Sparse trie is blind by default until nodes are revealed. @@ -115,16 +116,11 @@ impl

SparseTrie

{ pub fn root(&mut self) -> Option { Some(self.as_revealed_mut()?.root()) } - - /// Calculates the hashes of the nodes below the provided level. - pub fn calculate_below_level(&mut self, level: usize) { - self.as_revealed_mut().unwrap().update_rlp_node_level(level); - } } impl

SparseTrie

where - P: BlindedProvider, + P: BlindedProvider + Send + Sync, SparseTrieError: From, { /// Update the leaf node. @@ -140,6 +136,11 @@ where revealed.remove_leaf(path)?; Ok(()) } + + /// Calculates the hashes of the nodes below the provided level. + pub fn calculate_below_level(&mut self, level: usize) { + self.as_revealed_mut().unwrap().update_rlp_node_level(level); + } } /// The representation of revealed sparse trie. @@ -164,8 +165,6 @@ pub struct RevealedSparseTrie

{ prefix_set: PrefixSetMut, /// Retained trie updates. updates: Option, - /// Reusable buffer for RLP encoding of nodes. - rlp_buf: Vec, } impl

fmt::Debug for RevealedSparseTrie

{ @@ -176,7 +175,6 @@ impl

fmt::Debug for RevealedSparseTrie

{ .field("values", &self.values) .field("prefix_set", &self.prefix_set) .field("updates", &self.updates) - .field("rlp_buf", &hex::encode(&self.rlp_buf)) .finish_non_exhaustive() } } @@ -190,7 +188,6 @@ impl Default for RevealedSparseTrie { values: HashMap::default(), prefix_set: PrefixSetMut::default(), updates: None, - rlp_buf: Vec::new(), } } } @@ -208,7 +205,6 @@ impl RevealedSparseTrie { branch_node_hash_masks: HashMap::default(), values: HashMap::default(), prefix_set: PrefixSetMut::default(), - rlp_buf: Vec::new(), updates: None, } .with_updates(retain_updates); @@ -231,7 +227,6 @@ impl

RevealedSparseTrie

{ branch_node_hash_masks: HashMap::default(), values: HashMap::default(), prefix_set: PrefixSetMut::default(), - rlp_buf: Vec::new(), updates: None, } .with_updates(retain_updates); @@ -248,7 +243,6 @@ impl

RevealedSparseTrie

{ values: self.values, prefix_set: self.prefix_set, updates: self.updates, - rlp_buf: self.rlp_buf, } } @@ -523,19 +517,6 @@ impl

RevealedSparseTrie

{ } } - /// Update hashes of the nodes that are located at a level deeper than or equal to the provided - /// depth. Root node has a level of 0. - pub fn update_rlp_node_level(&mut self, depth: usize) { - let mut prefix_set = self.prefix_set.clone().freeze(); - let mut buffers = RlpNodeBuffers::default(); - - let targets = self.get_changed_nodes_at_depth(&mut prefix_set, depth); - for target in targets { - buffers.path_stack.push((target, Some(true))); - self.rlp_node(&mut prefix_set, &mut buffers); - } - } - /// Returns a list of paths to the nodes that were changed according to the prefix set and are /// located at the provided depth when counting from the root node. If there's a leaf at a /// depth less than the provided depth, it will be included in the result. @@ -590,10 +571,19 @@ impl

RevealedSparseTrie

{ fn rlp_node_allocate(&mut self, path: Nibbles, prefix_set: &mut PrefixSet) -> RlpNode { let mut buffers = RlpNodeBuffers::new_with_path(path); - self.rlp_node(prefix_set, &mut buffers) + let (root, updates) = self.rlp_node(prefix_set, &mut buffers, &mut Vec::new()); + self.apply_rlp_node_updates(updates); + root } - fn rlp_node(&mut self, prefix_set: &mut PrefixSet, buffers: &mut RlpNodeBuffers) -> RlpNode { + fn rlp_node( + &self, + prefix_set: &mut PrefixSet, + buffers: &mut RlpNodeBuffers, + rlp_buf: &mut Vec, + ) -> (RlpNode, RlpNodeUpdates) { + let mut rlp_node_updates = RlpNodeUpdates::default(); + 'main: while let Some((path, mut is_in_prefix_set)) = buffers.path_stack.pop() { // Check if the path is in the prefix set. // First, check the cached value. If it's `None`, then check the prefix set, and update @@ -601,7 +591,9 @@ impl

RevealedSparseTrie

{ let mut prefix_set_contains = |path: &Nibbles| *is_in_prefix_set.get_or_insert_with(|| prefix_set.contains(path)); - let (rlp_node, calculated, node_type) = match self.nodes.get_mut(&path).unwrap() { + let mut rlp_node_update = RlpNodeUpdate::default(); + + let (rlp_node, calculated, node_type) = match self.nodes.get(&path).unwrap() { SparseNode::Empty => { (RlpNode::word_rlp(&EMPTY_ROOT_HASH), false, SparseNodeType::Empty) } @@ -613,9 +605,9 @@ impl

RevealedSparseTrie

{ (RlpNode::word_rlp(&hash), false, SparseNodeType::Leaf) } else { let value = self.values.get(&path).unwrap(); - self.rlp_buf.clear(); - let rlp_node = LeafNodeRef { key, value }.rlp(&mut self.rlp_buf); - *hash = rlp_node.as_hash(); + rlp_buf.clear(); + let rlp_node = LeafNodeRef { key, value }.rlp(rlp_buf); + rlp_node_update.hash = rlp_node.as_hash(); (rlp_node, true, SparseNodeType::Leaf) } } @@ -630,9 +622,9 @@ impl

RevealedSparseTrie

{ ) } else if buffers.rlp_node_stack.last().is_some_and(|e| e.0 == child_path) { let (_, child, _, node_type) = buffers.rlp_node_stack.pop().unwrap(); - self.rlp_buf.clear(); - let rlp_node = ExtensionNodeRef::new(key, &child).rlp(&mut self.rlp_buf); - *hash = rlp_node.as_hash(); + rlp_buf.clear(); + let rlp_node = ExtensionNodeRef::new(key, &child).rlp(rlp_buf); + rlp_node_update.hash = rlp_node.as_hash(); ( rlp_node, @@ -746,17 +738,14 @@ impl

RevealedSparseTrie

{ } } - self.rlp_buf.clear(); + rlp_buf.clear(); let branch_node_ref = BranchNodeRef::new(&buffers.branch_value_stack_buf, *state_mask); - let rlp_node = branch_node_ref.rlp(&mut self.rlp_buf); - *hash = rlp_node.as_hash(); + let rlp_node = branch_node_ref.rlp(rlp_buf); // Save a branch node update only if it's not a root node, and we need to // persist updates. - let store_in_db_trie_value = if let Some(updates) = - self.updates.as_mut().filter(|_| retain_updates && !path.is_empty()) - { + let store_in_db_trie_value = if retain_updates && !path.is_empty() { let mut tree_mask_values = tree_mask_values.into_iter().rev(); let mut hash_mask_values = hash_mask_values.into_iter().rev(); let mut tree_mask = TrieMask::default(); @@ -784,14 +773,16 @@ impl

RevealedSparseTrie

{ hashes, hash.filter(|_| path.len() == 0), ); - updates.updated_nodes.insert(path.clone(), branch_node); + rlp_node_update.branch_node = Some(branch_node); } store_in_db_trie } else { false }; - *store_in_db_trie = Some(store_in_db_trie_value); + + rlp_node_update.hash = rlp_node.as_hash(); + rlp_node_update.store_in_db_trie = Some(store_in_db_trie_value); ( rlp_node, @@ -800,17 +791,45 @@ impl

RevealedSparseTrie

{ ) } }; + + if !rlp_node_update.is_empty() { + rlp_node_updates.insert(path.clone(), rlp_node_update); + } + buffers.rlp_node_stack.push((path, rlp_node, calculated, node_type)); } debug_assert_eq!(buffers.rlp_node_stack.len(), 1); - buffers.rlp_node_stack.pop().unwrap().1 + (buffers.rlp_node_stack.pop().unwrap().1, rlp_node_updates) + } + + fn apply_rlp_node_updates(&mut self, rlp_node_updates: RlpNodeUpdates) { + for (path, update) in rlp_node_updates { + if let Some(node) = self.nodes.get_mut(&path) { + match node { + SparseNode::Leaf { hash, .. } | SparseNode::Extension { hash, .. } => { + *hash = update.hash + } + SparseNode::Branch { hash, store_in_db_trie, .. } => { + *hash = update.hash; + *store_in_db_trie = update.store_in_db_trie + } + SparseNode::Empty | SparseNode::Hash(_) => unreachable!(), + } + } + + if let Some(branch_node) = update.branch_node { + if let Some(updates) = self.updates.as_mut() { + updates.updated_nodes.insert(path, branch_node); + } + } + } } } impl

RevealedSparseTrie

where - P: BlindedProvider, + P: BlindedProvider + Send + Sync, SparseTrieError: From, { /// Update the leaf node with provided value. @@ -1110,6 +1129,53 @@ where Ok(()) } + + /// Update hashes of the nodes that are located at a level deeper than or equal to the provided + /// depth. Root node has a level of 0. + pub fn update_rlp_node_level(&mut self, depth: usize) { + let mut prefix_set = self.prefix_set.clone().freeze(); + + let targets = self.get_changed_nodes_at_depth(&mut prefix_set, depth); + let (tx, rx) = mpsc::channel(); + targets + .into_par_iter() + .map_init( + || (prefix_set.clone(), RlpNodeBuffers::default(), Vec::new()), + |(prefix_set, buffers, rlp_node), target| { + buffers.path_stack.push((target, Some(true))); + let (_, updates) = self.rlp_node(prefix_set, buffers, rlp_node); + updates + }, + ) + .for_each_init( + || tx.clone(), + |tx, updates| { + tx.send(updates).unwrap(); + }, + ); + drop(tx); + + for updates in rx { + self.apply_rlp_node_updates(updates); + } + } +} + +/// Updates that [`RevealedSparseTrie::rlp_node`] produced. +type RlpNodeUpdates = HashMap; + +/// An update that [`RevealedSparseTrie::rlp_node`] produced after processing one node. +#[derive(Debug, Default)] +struct RlpNodeUpdate { + hash: Option, + store_in_db_trie: Option, + branch_node: Option, +} + +impl RlpNodeUpdate { + const fn is_empty(&self) -> bool { + self.hash.is_none() && self.store_in_db_trie.is_none() && self.branch_node.is_none() + } } /// Enum representing sparse trie node type. diff --git a/crates/trie/trie/src/witness.rs b/crates/trie/trie/src/witness.rs index e29792146856..482e7f6ee3cc 100644 --- a/crates/trie/trie/src/witness.rs +++ b/crates/trie/trie/src/witness.rs @@ -211,9 +211,9 @@ impl WitnessBlindedProviderFactory { impl BlindedProviderFactory for WitnessBlindedProviderFactory where - F: BlindedProviderFactory, - F::AccountNodeProvider: BlindedProvider, - F::StorageNodeProvider: BlindedProvider, + F: BlindedProviderFactory + Send + Sync, + F::AccountNodeProvider: BlindedProvider + Send + Sync, + F::StorageNodeProvider: BlindedProvider + Send + Sync, { type AccountNodeProvider = WitnessBlindedProvider; type StorageNodeProvider = WitnessBlindedProvider;