diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index eea236fdaaec..86b6b9f60365 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -18,6 +18,7 @@ use reth_trie_sparse::{ use revm_primitives::{keccak256, EvmState, B256}; use std::{ collections::BTreeMap, + ops::Deref, sync::{ mpsc::{self, Receiver, Sender}, Arc, @@ -84,6 +85,8 @@ pub(crate) enum StateRootMessage { /// Time taken to calculate the root elapsed: Duration, }, + /// Signals state update stream end. + FinishedStateUpdates, } /// Handle to track proof calculation ordering @@ -152,6 +155,25 @@ impl ProofSequencer { } } +/// A wrapper for the sender that signals completion when dropped +#[allow(dead_code)] +pub(crate) struct StateHookSender(Sender); + +impl Deref for StateHookSender { + type Target = Sender; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Drop for StateHookSender { + fn drop(&mut self) { + // Send completion signal when the sender is dropped + let _ = self.0.send(StateRootMessage::FinishedStateUpdates); + } +} + /// Standalone task that receives a transaction state stream and updates relevant /// data structures to calculate state root. /// @@ -354,6 +376,7 @@ where let mut updates_received = 0; let mut proofs_processed = 0; let mut roots_calculated = 0; + let mut updates_finished = false; loop { match self.rx.recv() { @@ -375,6 +398,9 @@ where self.tx.clone(), ); } + StateRootMessage::FinishedStateUpdates => { + updates_finished = true; + } StateRootMessage::ProofCalculated { proof, state_update, sequence_number } => { proofs_processed += 1; trace!( @@ -434,7 +460,7 @@ where std::mem::take(&mut current_state_update), std::mem::take(&mut current_multiproof), ); - } else if all_proofs_received && no_pending { + } else if all_proofs_received && no_pending && updates_finished { debug!( target: "engine::root", total_updates = updates_received, @@ -710,10 +736,13 @@ mod tests { let task = StateRootTask::new(config, tx.clone(), rx); let handle = task.spawn(); + let state_hook_sender = StateHookSender(tx); for update in state_updates { - tx.send(StateRootMessage::StateUpdate(update)).expect("failed to send state"); + state_hook_sender + .send(StateRootMessage::StateUpdate(update)) + .expect("failed to send state"); } - drop(tx); + drop(state_hook_sender); let (root_from_task, _) = handle.wait_for_result().expect("task failed"); let root_from_base = state_root(accumulated_state);