diff --git a/src/tracing/builder/geth.rs b/src/tracing/builder/geth.rs index 96ad2fb..9eac9b7 100644 --- a/src/tracing/builder/geth.rs +++ b/src/tracing/builder/geth.rs @@ -10,7 +10,10 @@ use alloy_rpc_types_trace::geth::{ AccountChangeKind, AccountState, CallConfig, CallFrame, DefaultFrame, DiffMode, GethDefaultTracingOptions, PreStateConfig, PreStateFrame, PreStateMode, StructLog, }; -use revm::{db::DatabaseRef, primitives::ResultAndState}; +use revm::{ + db::DatabaseRef, + primitives::{EvmState, ResultAndState}, +}; use std::{ borrow::Cow, collections::{BTreeMap, HashMap, VecDeque}, @@ -225,75 +228,101 @@ impl<'a> GethTraceBuilder<'a> { ResultAndState { state, .. }: &ResultAndState, prestate_config: &PreStateConfig, db: DB, + ) -> Result { + let code_enabled = prestate_config.code_enabled(); + let storage_enabled = prestate_config.storage_enabled(); + if prestate_config.is_diff_mode() { + self.geth_prestate_diff_traces(state, db, code_enabled, storage_enabled) + } else { + self.geth_prestate_pre_traces(state, db, code_enabled, storage_enabled) + } + } + + fn geth_prestate_pre_traces( + &self, + state: &EvmState, + db: DB, + code_enabled: bool, + storage_enabled: bool, ) -> Result { let account_diffs = state.iter().map(|(addr, acc)| (*addr, acc)); + let mut prestate = PreStateMode::default(); - if prestate_config.is_default_mode() { - let mut prestate = PreStateMode::default(); - // we only want changed accounts for things like balance changes etc - for (addr, changed_acc) in account_diffs { - let db_acc = db.basic_ref(addr)?.unwrap_or_default(); - let code = load_account_code(&db, &db_acc); - let mut acc_state = - AccountState::from_account_info(db_acc.nonce, db_acc.balance, code); + // we only want changed accounts for things like balance changes etc + for (addr, changed_acc) in account_diffs { + let db_acc = db.basic_ref(addr)?.unwrap_or_default(); + let code = code_enabled.then(|| load_account_code(&db, &db_acc)).flatten(); + let mut acc_state = AccountState::from_account_info(db_acc.nonce, db_acc.balance, code); - // insert the original value of all modified storage slots + // insert the original value of all modified storage slots + if storage_enabled { for (key, slot) in changed_acc.storage.iter() { acc_state.storage.insert((*key).into(), slot.original_value.into()); } - - prestate.0.insert(addr, acc_state); } - Ok(PreStateFrame::Default(prestate)) - } else { - let mut state_diff = DiffMode::default(); - let mut account_change_kinds = HashMap::with_capacity(account_diffs.len()); - for (addr, changed_acc) in account_diffs { - let db_acc = db.basic_ref(addr)?.unwrap_or_default(); + prestate.0.insert(addr, acc_state); + } + + Ok(PreStateFrame::Default(prestate)) + } + + fn geth_prestate_diff_traces( + &self, + state: &EvmState, + db: DB, + code_enabled: bool, + storage_enabled: bool, + ) -> Result { + let account_diffs = state.iter().map(|(addr, acc)| (*addr, acc)); + let mut state_diff = DiffMode::default(); + let mut account_change_kinds = HashMap::with_capacity(account_diffs.len()); + for (addr, changed_acc) in account_diffs { + let db_acc = db.basic_ref(addr)?.unwrap_or_default(); - let pre_code = load_account_code(&db, &db_acc); + let pre_code = code_enabled.then(|| load_account_code(&db, &db_acc)).flatten(); - let mut pre_state = - AccountState::from_account_info(db_acc.nonce, db_acc.balance, pre_code); + let mut pre_state = + AccountState::from_account_info(db_acc.nonce, db_acc.balance, pre_code); - let mut post_state = AccountState::from_account_info( - changed_acc.info.nonce, - changed_acc.info.balance, - changed_acc.info.code.as_ref().map(|code| code.original_bytes()), - ); + let mut post_state = AccountState::from_account_info( + changed_acc.info.nonce, + changed_acc.info.balance, + changed_acc.info.code.as_ref().map(|code| code.original_bytes()), + ); - // handle storage changes + // handle storage changes + if storage_enabled { for (key, slot) in changed_acc.storage.iter().filter(|(_, slot)| slot.is_changed()) { pre_state.storage.insert((*key).into(), slot.original_value.into()); post_state.storage.insert((*key).into(), slot.present_value.into()); } - - state_diff.pre.insert(addr, pre_state); - state_diff.post.insert(addr, post_state); - - // determine the change type - let pre_change = if changed_acc.is_created() { - AccountChangeKind::Create - } else { - AccountChangeKind::Modify - }; - let post_change = if changed_acc.is_selfdestructed() { - AccountChangeKind::SelfDestruct - } else { - AccountChangeKind::Modify - }; - - account_change_kinds.insert(addr, (pre_change, post_change)); } - // ensure we're only keeping changed entries - state_diff.retain_changed().remove_zero_storage_values(); + state_diff.pre.insert(addr, pre_state); + state_diff.post.insert(addr, post_state); + + // determine the change type + let pre_change = if changed_acc.is_created() { + AccountChangeKind::Create + } else { + AccountChangeKind::Modify + }; + let post_change = if changed_acc.is_selfdestructed() { + AccountChangeKind::SelfDestruct + } else { + AccountChangeKind::Modify + }; - self.diff_traces(&mut state_diff.pre, &mut state_diff.post, account_change_kinds); - Ok(PreStateFrame::Diff(state_diff)) + account_change_kinds.insert(addr, (pre_change, post_change)); } + + // ensure we're only keeping changed entries + state_diff.retain_changed().remove_zero_storage_values(); + + self.diff_traces(&mut state_diff.pre, &mut state_diff.post, account_change_kinds); + Ok(PreStateFrame::Diff(state_diff)) } /// Returns the difference between the pre and post state of the transaction depending on the diff --git a/tests/it/geth.rs b/tests/it/geth.rs index 959ebbd..03db41f 100644 --- a/tests/it/geth.rs +++ b/tests/it/geth.rs @@ -137,8 +137,7 @@ fn test_geth_mux_tracer() { let (addr, mut evm) = deploy_contract(code.into(), deployer, SpecId::LONDON); let call_config = CallConfig { only_top_call: Some(false), with_log: Some(true) }; - let prestate_config = - PreStateConfig { diff_mode: Some(false), disable_code: None, disable_storage: None }; + let prestate_config = PreStateConfig { diff_mode: Some(false), ..Default::default() }; let nested_call_config = CallConfig { only_top_call: Some(true), with_log: Some(false) }; let nested_mux_config = MuxConfig(HashMap::from_iter([(