diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/declare.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/declare.rs index 8a90d35252..ca00f0d92d 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/declare.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/declare.rs @@ -18,7 +18,7 @@ pub fn declare( contract_name: &str, contracts_data: &ContractsData, ) -> Result { - let contract_artifact = contracts_data.contracts.get(contract_name).with_context(|| { + let contract_artifact = contracts_data.get_artifacts_for_contract(contract_name).with_context(|| { format!("Failed to get contract artifact for name = {contract_name}. Make sure starknet target is correctly defined in Scarb.toml file.") }).map_err::(From::from)?; @@ -27,8 +27,7 @@ pub fn declare( let contract_class = BlockifierContractClass::V1(contract_class); let class_hash = *contracts_data - .class_hashes - .get_by_left(contract_name) + .get_class_hash_for_contract(contract_name) .expect("Failed to get class hash"); match state.get_compiled_contract_class(class_hash) { diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/contracts_data.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/contracts_data.rs index 86c0c3a555..28aa61362f 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/contracts_data.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/contracts_data.rs @@ -12,26 +12,49 @@ use std::collections::HashMap; #[derive(Debug, Clone)] pub struct ContractsData { - pub contracts: HashMap, - pub source_sierra_paths: HashMap, - pub class_hashes: BiMap, - pub selectors: HashMap, + contracts: HashMap, + class_hash_index: BiMap, + selectors: HashMap, +} + +#[derive(Debug, Clone)] +struct Contract { + artifacts: StarknetContractArtifacts, + class_hash: ClassHash, + _sierra_source_path: Utf8PathBuf, } impl ContractsData { pub fn try_from( - contracts_artifacts: HashMap, - contracts_sierra_paths: HashMap, + contracts: HashMap, ) -> Result { - let parsed_contracts: HashMap = contracts_artifacts + let parsed_contracts: HashMap = contracts .par_iter() - .map(|(name, artifact)| Ok((name.clone(), serde_json::from_str(&artifact.sierra)?))) + .map(|(name, (artifact, _))| { + Ok((name.clone(), serde_json::from_str(&artifact.sierra)?)) + }) .collect::>()?; let class_hashes: Vec<(String, ClassHash)> = parsed_contracts .par_iter() .map(|(name, sierra_class)| Ok((name.clone(), get_class_hash(sierra_class)?))) .collect::>()?; + let class_hash_index = BiMap::from_iter(class_hashes); + + let contracts = contracts + .into_iter() + .map(|(name, (artifacts, sierra_source_path))| { + let class_hash = *class_hash_index.get_by_left(&name).unwrap(); + ( + name, + Contract { + artifacts, + class_hash, + _sierra_source_path: sierra_source_path, + }, + ) + }) + .collect(); let selectors = parsed_contracts .into_par_iter() @@ -40,12 +63,36 @@ impl ContractsData { .collect(); Ok(ContractsData { - contracts: contracts_artifacts, - source_sierra_paths: contracts_sierra_paths, - class_hashes: BiMap::from_iter(class_hashes), + contracts, + class_hash_index, selectors, }) } + + #[must_use] + pub fn get_artifacts_for_contract(&self, name: &str) -> Option<&StarknetContractArtifacts> { + self.contracts.get(name).map(|contract| &contract.artifacts) + } + + #[must_use] + pub fn get_class_hash_for_contract(&self, name: &str) -> Option<&ClassHash> { + self.contracts + .get(name) + .map(|contract| &contract.class_hash) + } + + #[must_use] + pub fn get_contract_name_from_class_hash(&self, class_hash: &ClassHash) -> Option<&String> { + self.class_hash_index.get_by_right(class_hash) + } + + #[must_use] + pub fn get_function_name_from_entry_point_selector( + &self, + entry_point_selector: &EntryPointSelector, + ) -> Option<&String> { + self.selectors.get(entry_point_selector) + } } fn build_name_selector_map(abi: Vec) -> HashMap { diff --git a/crates/cheatnet/tests/cheatcodes/declare.rs b/crates/cheatnet/tests/cheatcodes/declare.rs index 95eb5e742a..1e8bee25dd 100644 --- a/crates/cheatnet/tests/cheatcodes/declare.rs +++ b/crates/cheatnet/tests/cheatcodes/declare.rs @@ -1,21 +1,7 @@ use crate::common::{get_contracts, state::create_cached_state}; -use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::{ - declare, get_class_hash, -}; +use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::declare; use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::CheatcodeError; use runtime::EnhancedHintError; -use scarb_api::StarknetContractArtifacts; -use starknet_api::core::ClassHash; -use std::collections::HashMap; - -fn get_contract_class_hash( - contract_name: &str, - contracts: &HashMap, -) -> ClassHash { - let contract = contracts.get(contract_name).unwrap(); - let sierra_class = serde_json::from_str(&contract.sierra).unwrap(); - get_class_hash(&sierra_class).unwrap() -} #[test] fn declare_simple() { @@ -26,9 +12,11 @@ fn declare_simple() { let contracts_data = get_contracts(); let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap(); - let expected_class_hash = get_contract_class_hash(contract_name, &contracts_data.contracts); + let expected_class_hash = contracts_data + .get_class_hash_for_contract(contract_name) + .unwrap(); - assert_eq!(class_hash, expected_class_hash); + assert_eq!(class_hash, *expected_class_hash); } #[test] @@ -41,8 +29,10 @@ fn declare_multiple() { for contract_name in contract_names { let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap(); - let expected_class_hash = get_contract_class_hash(contract_name, &contracts_data.contracts); - assert_eq!(class_hash, expected_class_hash); + let expected_class_hash = contracts_data + .get_class_hash_for_contract(contract_name) + .unwrap(); + assert_eq!(class_hash, *expected_class_hash); } } @@ -55,8 +45,10 @@ fn declare_same_contract() { let contracts_data = get_contracts(); let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap(); - let expected_class_hash = get_contract_class_hash(contract_name, &contracts_data.contracts); - assert_eq!(class_hash, expected_class_hash); + let expected_class_hash = contracts_data + .get_class_hash_for_contract(contract_name) + .unwrap(); + assert_eq!(class_hash, *expected_class_hash); let output = declare(&mut cached_state, contract_name, &contracts_data); diff --git a/crates/cheatnet/tests/common/mod.rs b/crates/cheatnet/tests/common/mod.rs index 7b1dffd666..1b1d8d7588 100644 --- a/crates/cheatnet/tests/common/mod.rs +++ b/crates/cheatnet/tests/common/mod.rs @@ -24,7 +24,7 @@ use cheatnet::state::CheatnetState; use conversions::IntoConv; use runtime::starknet::context::build_context; use scarb_api::metadata::MetadataCommandExt; -use scarb_api::{get_contracts_artifacts_and_sierra_paths, ScarbCommand}; +use scarb_api::{get_contracts_artifacts_and_source_sierra_paths, ScarbCommand}; use starknet::core::utils::get_selector_from_name; use starknet_api::core::PatriciaKey; use starknet_api::core::{ClassHash, ContractAddress}; @@ -76,9 +76,10 @@ pub fn get_contracts() -> ContractsData { let package = scarb_metadata.packages.first().unwrap(); - let (contracts_artifacts, contracts_sierra_paths) = - get_contracts_artifacts_and_sierra_paths(&scarb_metadata, &package.id, None).unwrap(); - ContractsData::try_from(contracts_artifacts, contracts_sierra_paths).unwrap() + let contracts = + get_contracts_artifacts_and_source_sierra_paths(&scarb_metadata, &package.id, None) + .unwrap(); + ContractsData::try_from(contracts).unwrap() } pub fn deploy_contract( diff --git a/crates/forge-runner/src/build_trace_data.rs b/crates/forge-runner/src/build_trace_data.rs index 7456ec41db..60fceca56a 100644 --- a/crates/forge-runner/src/build_trace_data.rs +++ b/crates/forge-runner/src/build_trace_data.rs @@ -108,9 +108,11 @@ pub fn build_profiler_call_entry_point( } = value; let mut contract_name = class_hash - .and_then(|c| contracts_data.class_hashes.get_by_right(&c)) + .and_then(|c| contracts_data.get_contract_name_from_class_hash(&c)) + .cloned(); + let mut function_name = contracts_data + .get_function_name_from_entry_point_selector(&entry_point_selector) .cloned(); - let mut function_name = contracts_data.selectors.get(&entry_point_selector).cloned(); if entry_point_selector.0 == get_selector_from_name(TEST_ENTRY_POINT_SELECTOR) diff --git a/crates/forge/src/main.rs b/crates/forge/src/main.rs index e7fb4d1726..3b43ceb467 100644 --- a/crates/forge/src/main.rs +++ b/crates/forge/src/main.rs @@ -4,8 +4,8 @@ use clap::{Parser, Subcommand, ValueEnum}; use configuration::load_package_config; use forge::scarb::config::ForgeConfig; use forge::scarb::{ - build_contracts_with_scarb, build_test_artifacts_with_scarb, load_test_artifacts, - test_artifacts_path, + build_contracts_with_scarb, build_test_artifacts_with_scarb, get_test_artifacts_path, + load_test_artifacts, }; use forge::shared_cache::{clean_cache, set_cached_failed_tests_names}; use forge::test_filter::TestsFilter; @@ -15,7 +15,7 @@ use forge_runner::test_crate_summary::TestCrateSummary; use forge_runner::{RunnerConfig, RunnerParams, CACHE_DIR}; use rand::{thread_rng, RngCore}; use scarb_api::{ - get_contracts_artifacts_and_sierra_paths, + get_contracts_artifacts_and_source_sierra_paths, metadata::{Metadata, MetadataCommandExt, PackageMetadata}, package_matches_version_requirement, target_dir_for_workspace, ScarbCommand, }; @@ -240,13 +240,15 @@ fn test_workspace(args: TestArgs) -> Result { env::set_current_dir(&package.root)?; let test_artifacts_path = - test_artifacts_path(&snforge_target_dir_path, &package.name); + get_test_artifacts_path(&snforge_target_dir_path, &package.name); let compiled_test_crates = load_test_artifacts(&test_artifacts_path)?; - let (contracts_artifacts, contracts_sierra_paths) = - get_contracts_artifacts_and_sierra_paths(&scarb_metadata, &package.id, None)?; - let contracts_data = - ContractsData::try_from(contracts_artifacts, contracts_sierra_paths)?; + let contracts = get_contracts_artifacts_and_source_sierra_paths( + &scarb_metadata, + &package.id, + None, + )?; + let contracts_data = ContractsData::try_from(contracts)?; let forge_config = load_package_config::(&scarb_metadata, &package.id)?; diff --git a/crates/forge/src/scarb.rs b/crates/forge/src/scarb.rs index f31119665e..4dc9abd474 100644 --- a/crates/forge/src/scarb.rs +++ b/crates/forge/src/scarb.rs @@ -44,7 +44,10 @@ pub fn build_test_artifacts_with_scarb(filter: PackagesFilter) -> Result<()> { } #[must_use] -pub fn test_artifacts_path(snforge_target_dir_path: &Utf8Path, package_name: &str) -> Utf8PathBuf { +pub fn get_test_artifacts_path( + snforge_target_dir_path: &Utf8Path, + package_name: &str, +) -> Utf8PathBuf { snforge_target_dir_path.join(format!("{package_name}.snforge_sierra.json")) } diff --git a/crates/forge/test_utils/src/runner.rs b/crates/forge/test_utils/src/runner.rs index 07bde95b99..6124c1ce29 100644 --- a/crates/forge/test_utils/src/runner.rs +++ b/crates/forge/test_utils/src/runner.rs @@ -12,7 +12,7 @@ use forge_runner::{ }; use indoc::formatdoc; use scarb_api::{ - get_contracts_artifacts_and_sierra_paths, metadata::MetadataCommandExt, ScarbCommand, + get_contracts_artifacts_and_source_sierra_paths, metadata::MetadataCommandExt, ScarbCommand, StarknetContractArtifacts, }; use shared::command::CommandExt; @@ -96,11 +96,12 @@ impl Contract { .find(|package| package.name == "contract") .unwrap(); - let contract = get_contracts_artifacts_and_sierra_paths(&scarb_metadata, &package.id, None) - .unwrap() - .0 - .remove(&self.name) - .ok_or(anyhow!("there is no contract with name {}", self.name))?; + let contract = + get_contracts_artifacts_and_source_sierra_paths(&scarb_metadata, &package.id, None) + .unwrap() + .remove(&self.name) + .ok_or(anyhow!("there is no contract with name {}", self.name))? + .0; Ok((contract.sierra, contract.casm)) } @@ -191,7 +192,7 @@ impl<'a> TestCase { ] } - pub fn contracts(&self) -> Result> { + pub fn contracts(&self) -> Result> { self.contracts .clone() .into_iter() @@ -199,7 +200,13 @@ impl<'a> TestCase { let name = contract.name.clone(); let (sierra, casm) = contract.generate_sierra_and_casm()?; - Ok((name, StarknetContractArtifacts { sierra, casm })) + Ok(( + name, + ( + StarknetContractArtifacts { sierra, casm }, + Default::default(), + ), + )) }) .collect() } diff --git a/crates/forge/test_utils/src/running_tests.rs b/crates/forge/test_utils/src/running_tests.rs index fad0b56b4b..44e75c5512 100644 --- a/crates/forge/test_utils/src/running_tests.rs +++ b/crates/forge/test_utils/src/running_tests.rs @@ -3,12 +3,11 @@ use camino::Utf8PathBuf; use cheatnet::runtime_extensions::forge_runtime_extension::contracts_data::ContractsData; use forge::block_number_map::BlockNumberMap; use forge::run; -use forge::scarb::{load_test_artifacts, test_artifacts_path}; +use forge::scarb::{get_test_artifacts_path, load_test_artifacts}; use forge::test_filter::TestsFilter; use forge_runner::test_crate_summary::TestCrateSummary; use forge_runner::{RunnerConfig, RunnerParams}; use shared::command::CommandExt; -use std::collections::HashMap; use std::num::NonZeroU32; use std::path::PathBuf; use std::process::Command; @@ -28,7 +27,7 @@ pub fn run_test_case(test: &TestCase) -> Vec { .unwrap(); let rt = Runtime::new().expect("Could not instantiate Runtime"); - let test_artifacts_path = test_artifacts_path( + let test_artifacts_path = get_test_artifacts_path( &test.path().unwrap().join("target/dev/snforge"), "test_package", ); @@ -49,7 +48,7 @@ pub fn run_test_case(test: &TestCase) -> Vec { None, )), Arc::new(RunnerParams::new( - ContractsData::try_from(test.contracts().unwrap(), HashMap::new()).unwrap(), + ContractsData::try_from(test.contracts().unwrap()).unwrap(), test.env().clone(), )), &[], diff --git a/crates/forge/tests/integration/setup_fork.rs b/crates/forge/tests/integration/setup_fork.rs index d220e513c4..ed2b0a6c52 100644 --- a/crates/forge/tests/integration/setup_fork.rs +++ b/crates/forge/tests/integration/setup_fork.rs @@ -1,5 +1,4 @@ use indoc::formatdoc; -use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; use std::process::Command; @@ -16,7 +15,7 @@ use tokio::runtime::Runtime; use cheatnet::runtime_extensions::forge_runtime_extension::contracts_data::ContractsData; use forge::compiled_raw::RawForkParams; -use forge::scarb::{load_test_artifacts, test_artifacts_path}; +use forge::scarb::{get_test_artifacts_path, load_test_artifacts}; use forge_runner::{RunnerConfig, RunnerParams}; use shared::command::CommandExt; use test_utils::runner::{assert_case_output_contains, assert_failed, assert_passed, Contract}; @@ -114,7 +113,7 @@ fn fork_aliased_decorator() { .output_checked() .unwrap(); - let test_artifacts_path = test_artifacts_path( + let test_artifacts_path = get_test_artifacts_path( &test.path().unwrap().join("target/dev/snforge"), "test_package", ); @@ -136,7 +135,7 @@ fn fork_aliased_decorator() { None, )), Arc::new(RunnerParams::new( - ContractsData::try_from(test.contracts().unwrap(), HashMap::new()).unwrap(), + ContractsData::try_from(test.contracts().unwrap()).unwrap(), test.env().clone(), )), &[ForkTarget::new( diff --git a/crates/scarb-api/src/lib.rs b/crates/scarb-api/src/lib.rs index 80589c9094..6c5ebdfb47 100644 --- a/crates/scarb-api/src/lib.rs +++ b/crates/scarb-api/src/lib.rs @@ -113,14 +113,11 @@ fn try_get_starknet_artifacts_path( } /// Get the map with `StarknetContractArtifacts` for the given package -pub fn get_contracts_artifacts_and_sierra_paths( +pub fn get_contracts_artifacts_and_source_sierra_paths( metadata: &Metadata, package: &PackageId, profile: Option<&str>, -) -> Result<( - HashMap, - HashMap, -)> { +) -> Result> { let target_name = target_name_for_package(metadata, package)?; let target_dir = target_dir_for_workspace(metadata); let maybe_contracts_path = try_get_starknet_artifacts_path( @@ -129,37 +126,33 @@ pub fn get_contracts_artifacts_and_sierra_paths( profile.unwrap_or(metadata.current_profile.as_str()), )?; - let maps = match maybe_contracts_path { - Some(contracts_path) => load_contracts_artifacts_and_sierra_paths(&contracts_path)?, - None => (HashMap::default(), HashMap::default()), + let map = match maybe_contracts_path { + Some(contracts_path) => load_contracts_artifacts_and_source_sierra_paths(&contracts_path)?, + None => HashMap::default(), }; - Ok(maps) + Ok(map) } -fn load_contracts_artifacts_and_sierra_paths( +fn load_contracts_artifacts_and_source_sierra_paths( contracts_path: &Utf8PathBuf, -) -> Result<( - HashMap, - HashMap, -)> { +) -> Result> { let base_path = contracts_path .parent() .ok_or_else(|| anyhow!("Failed to get parent for path = {}", &contracts_path))?; let artifacts = artifacts_for_package(contracts_path)?; - let mut artifacts_map = HashMap::new(); - let mut sierra_paths_map = HashMap::new(); + let mut map = HashMap::new(); for ref contract in artifacts.contracts { let name = contract.contract_name.clone(); let contract_artifacts = StarknetContractArtifacts::from_scarb_contract_artifact(contract, base_path)?; - artifacts_map.insert(name.clone(), contract_artifacts); let sierra_path = base_path.join(contract.artifacts.sierra.clone()); - sierra_paths_map.insert(name, sierra_path); + + map.insert(name.clone(), (contract_artifacts, sierra_path)); } - Ok((artifacts_map, sierra_paths_map)) + Ok(map) } fn compilation_unit_for_package<'a>( @@ -500,9 +493,8 @@ mod tests { .unwrap(); let package = metadata.packages.first().unwrap(); - let contracts = get_contracts_artifacts_and_sierra_paths(&metadata, &package.id, None) - .unwrap() - .0; + let contracts = + get_contracts_artifacts_and_source_sierra_paths(&metadata, &package.id, None).unwrap(); assert!(contracts.contains_key("ERC20")); assert!(contracts.contains_key("HelloStarknet")); @@ -515,8 +507,8 @@ mod tests { ) .unwrap(); let contract = contracts.get("ERC20").unwrap(); - assert_eq!(&sierra_contents_erc20, &contract.sierra); - assert_eq!(&casm_contents_erc20, &contract.casm); + assert_eq!(&sierra_contents_erc20, &contract.0.sierra); + assert_eq!(&casm_contents_erc20, &contract.0.casm); let sierra_contents_erc20 = fs::read_to_string( temp.join("target/dev/basic_package_HelloStarknet.contract_class.json"), @@ -527,8 +519,8 @@ mod tests { ) .unwrap(); let contract = contracts.get("HelloStarknet").unwrap(); - assert_eq!(&sierra_contents_erc20, &contract.sierra); - assert_eq!(&casm_contents_erc20, &contract.casm); + assert_eq!(&sierra_contents_erc20, &contract.0.sierra); + assert_eq!(&casm_contents_erc20, &contract.0.casm); } #[test] diff --git a/crates/sncast/src/helpers/scarb_utils.rs b/crates/sncast/src/helpers/scarb_utils.rs index 9232d6f645..8f13fa8196 100644 --- a/crates/sncast/src/helpers/scarb_utils.rs +++ b/crates/sncast/src/helpers/scarb_utils.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Context, Result}; use camino::{Utf8Path, Utf8PathBuf}; use scarb_api::{ - get_contracts_artifacts_and_sierra_paths, + get_contracts_artifacts_and_source_sierra_paths, metadata::{Metadata, MetadataCommand, PackageMetadata}, ScarbCommand, ScarbCommandError, StarknetContractArtifacts, }; @@ -154,20 +154,25 @@ pub fn build_and_load_artifacts( let metadata = get_scarb_metadata_with_deps(&config.scarb_toml_path)?; if metadata.profiles.contains(&config.profile) { - Ok( - get_contracts_artifacts_and_sierra_paths( - &metadata, - &package.id, - Some(&config.profile), - )? - .0, - ) + Ok(get_contracts_artifacts_and_source_sierra_paths( + &metadata, + &package.id, + Some(&config.profile), + )? + .into_iter() + .map(|(name, (artifacts, _))| (name, artifacts)) + .collect()) } else { let profile = &config.profile; print_as_warning(&anyhow!( "Profile {profile} does not exist in scarb, using default 'dev' profile." )); - Ok(get_contracts_artifacts_and_sierra_paths(&metadata, &package.id, None)?.0) + Ok( + get_contracts_artifacts_and_source_sierra_paths(&metadata, &package.id, None)? + .into_iter() + .map(|(name, (artifacts, _))| (name, artifacts)) + .collect(), + ) } }