Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
piotmag769 committed Apr 17, 2024
1 parent 1d0837b commit fe0d5a3
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn declare(
contract_name: &str,
contracts_data: &ContractsData,
) -> Result<ClassHash, CheatcodeError> {
let contract_artifact = contracts_data.contracts.get(contract_name).with_context(|| {
let contract_artifact = contracts_data.get_artifacts_for_contract(contract_name).with_context(|| {
format!("Failed to get contract artifact for name = {contract_name}. Make sure starknet target is correctly defined in Scarb.toml file.")
}).map_err::<EnhancedHintError, _>(From::from)?;

Expand All @@ -27,8 +27,7 @@ pub fn declare(
let contract_class = BlockifierContractClass::V1(contract_class);

let class_hash = *contracts_data
.class_hashes
.get_by_left(contract_name)
.get_class_hash_for_contract(contract_name)
.expect("Failed to get class hash");

match state.get_compiled_contract_class(class_hash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,49 @@ use std::collections::HashMap;

#[derive(Debug, Clone)]
pub struct ContractsData {
pub contracts: HashMap<String, StarknetContractArtifacts>,
pub source_sierra_paths: HashMap<String, Utf8PathBuf>,
pub class_hashes: BiMap<String, ClassHash>,
pub selectors: HashMap<EntryPointSelector, String>,
contracts: HashMap<String, Contract>,
class_hash_index: BiMap<String, ClassHash>,
selectors: HashMap<EntryPointSelector, String>,
}

#[derive(Debug, Clone)]
struct Contract {
artifacts: StarknetContractArtifacts,
class_hash: ClassHash,
_sierra_source_path: Utf8PathBuf,
}

impl ContractsData {
pub fn try_from(
contracts_artifacts: HashMap<String, StarknetContractArtifacts>,
contracts_sierra_paths: HashMap<String, Utf8PathBuf>,
contracts: HashMap<String, (StarknetContractArtifacts, Utf8PathBuf)>,
) -> Result<Self> {
let parsed_contracts: HashMap<String, SierraClass> = contracts_artifacts
let parsed_contracts: HashMap<String, SierraClass> = contracts
.par_iter()
.map(|(name, artifact)| Ok((name.clone(), serde_json::from_str(&artifact.sierra)?)))
.map(|(name, (artifact, _))| {
Ok((name.clone(), serde_json::from_str(&artifact.sierra)?))
})
.collect::<Result<_>>()?;

let class_hashes: Vec<(String, ClassHash)> = parsed_contracts
.par_iter()
.map(|(name, sierra_class)| Ok((name.clone(), get_class_hash(sierra_class)?)))
.collect::<Result<_>>()?;
let class_hash_index = BiMap::from_iter(class_hashes);

let contracts = contracts
.into_iter()
.map(|(name, (artifacts, sierra_source_path))| {
let class_hash = *class_hash_index.get_by_left(&name).unwrap();
(
name,
Contract {
artifacts,
class_hash,
_sierra_source_path: sierra_source_path,
},
)
})
.collect();

let selectors = parsed_contracts
.into_par_iter()
Expand All @@ -40,12 +63,36 @@ impl ContractsData {
.collect();

Ok(ContractsData {
contracts: contracts_artifacts,
source_sierra_paths: contracts_sierra_paths,
class_hashes: BiMap::from_iter(class_hashes),
contracts,
class_hash_index,
selectors,
})
}

#[must_use]
pub fn get_artifacts_for_contract(&self, name: &str) -> Option<&StarknetContractArtifacts> {
self.contracts.get(name).map(|contract| &contract.artifacts)
}

#[must_use]
pub fn get_class_hash_for_contract(&self, name: &str) -> Option<&ClassHash> {
self.contracts
.get(name)
.map(|contract| &contract.class_hash)
}

#[must_use]
pub fn get_contract_name_from_class_hash(&self, class_hash: &ClassHash) -> Option<&String> {
self.class_hash_index.get_by_right(class_hash)
}

#[must_use]
pub fn get_function_name_from_entry_point_selector(
&self,
entry_point_selector: &EntryPointSelector,
) -> Option<&String> {
self.selectors.get(entry_point_selector)
}
}

fn build_name_selector_map(abi: Vec<AbiEntry>) -> HashMap<EntryPointSelector, String> {
Expand Down
34 changes: 13 additions & 21 deletions crates/cheatnet/tests/cheatcodes/declare.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
use crate::common::{get_contracts, state::create_cached_state};
use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::{
declare, get_class_hash,
};
use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::declare;
use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::CheatcodeError;
use runtime::EnhancedHintError;
use scarb_api::StarknetContractArtifacts;
use starknet_api::core::ClassHash;
use std::collections::HashMap;

fn get_contract_class_hash(
contract_name: &str,
contracts: &HashMap<String, StarknetContractArtifacts>,
) -> ClassHash {
let contract = contracts.get(contract_name).unwrap();
let sierra_class = serde_json::from_str(&contract.sierra).unwrap();
get_class_hash(&sierra_class).unwrap()
}

#[test]
fn declare_simple() {
Expand All @@ -26,9 +12,11 @@ fn declare_simple() {
let contracts_data = get_contracts();

let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap();
let expected_class_hash = get_contract_class_hash(contract_name, &contracts_data.contracts);
let expected_class_hash = contracts_data
.get_class_hash_for_contract(contract_name)
.unwrap();

assert_eq!(class_hash, expected_class_hash);
assert_eq!(class_hash, *expected_class_hash);
}

#[test]
Expand All @@ -41,8 +29,10 @@ fn declare_multiple() {

for contract_name in contract_names {
let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap();
let expected_class_hash = get_contract_class_hash(contract_name, &contracts_data.contracts);
assert_eq!(class_hash, expected_class_hash);
let expected_class_hash = contracts_data
.get_class_hash_for_contract(contract_name)
.unwrap();
assert_eq!(class_hash, *expected_class_hash);
}
}

Expand All @@ -55,8 +45,10 @@ fn declare_same_contract() {
let contracts_data = get_contracts();

let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap();
let expected_class_hash = get_contract_class_hash(contract_name, &contracts_data.contracts);
assert_eq!(class_hash, expected_class_hash);
let expected_class_hash = contracts_data
.get_class_hash_for_contract(contract_name)
.unwrap();
assert_eq!(class_hash, *expected_class_hash);

let output = declare(&mut cached_state, contract_name, &contracts_data);

Expand Down
9 changes: 5 additions & 4 deletions crates/cheatnet/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use cheatnet::state::CheatnetState;
use conversions::IntoConv;
use runtime::starknet::context::build_context;
use scarb_api::metadata::MetadataCommandExt;
use scarb_api::{get_contracts_artifacts_and_sierra_paths, ScarbCommand};
use scarb_api::{get_contracts_artifacts_and_source_sierra_paths, ScarbCommand};
use starknet::core::utils::get_selector_from_name;
use starknet_api::core::PatriciaKey;
use starknet_api::core::{ClassHash, ContractAddress};
Expand Down Expand Up @@ -76,9 +76,10 @@ pub fn get_contracts() -> ContractsData {

let package = scarb_metadata.packages.first().unwrap();

let (contracts_artifacts, contracts_sierra_paths) =
get_contracts_artifacts_and_sierra_paths(&scarb_metadata, &package.id, None).unwrap();
ContractsData::try_from(contracts_artifacts, contracts_sierra_paths).unwrap()
let contracts =
get_contracts_artifacts_and_source_sierra_paths(&scarb_metadata, &package.id, None)
.unwrap();
ContractsData::try_from(contracts).unwrap()
}

pub fn deploy_contract(
Expand Down
6 changes: 4 additions & 2 deletions crates/forge-runner/src/build_trace_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ pub fn build_profiler_call_entry_point(
} = value;

let mut contract_name = class_hash
.and_then(|c| contracts_data.class_hashes.get_by_right(&c))
.and_then(|c| contracts_data.get_contract_name_from_class_hash(&c))
.cloned();
let mut function_name = contracts_data
.get_function_name_from_entry_point_selector(&entry_point_selector)
.cloned();
let mut function_name = contracts_data.selectors.get(&entry_point_selector).cloned();

if entry_point_selector.0
== get_selector_from_name(TEST_ENTRY_POINT_SELECTOR)
Expand Down
18 changes: 10 additions & 8 deletions crates/forge/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use clap::{Parser, Subcommand, ValueEnum};
use configuration::load_package_config;
use forge::scarb::config::ForgeConfig;
use forge::scarb::{
build_contracts_with_scarb, build_test_artifacts_with_scarb, load_test_artifacts,
test_artifacts_path,
build_contracts_with_scarb, build_test_artifacts_with_scarb, get_test_artifacts_path,
load_test_artifacts,
};
use forge::shared_cache::{clean_cache, set_cached_failed_tests_names};
use forge::test_filter::TestsFilter;
Expand All @@ -15,7 +15,7 @@ use forge_runner::test_crate_summary::TestCrateSummary;
use forge_runner::{RunnerConfig, RunnerParams, CACHE_DIR};
use rand::{thread_rng, RngCore};
use scarb_api::{
get_contracts_artifacts_and_sierra_paths,
get_contracts_artifacts_and_source_sierra_paths,
metadata::{Metadata, MetadataCommandExt, PackageMetadata},
package_matches_version_requirement, target_dir_for_workspace, ScarbCommand,
};
Expand Down Expand Up @@ -269,13 +269,15 @@ fn test_workspace(args: TestArgs) -> Result<bool> {
env::set_current_dir(&package.root)?;

let test_artifacts_path =
test_artifacts_path(&snforge_target_dir_path, &package.name);
get_test_artifacts_path(&snforge_target_dir_path, &package.name);
let compiled_test_crates = load_test_artifacts(&test_artifacts_path)?;

let (contracts_artifacts, contracts_sierra_paths) =
get_contracts_artifacts_and_sierra_paths(&scarb_metadata, &package.id, None)?;
let contracts_data =
ContractsData::try_from(contracts_artifacts, contracts_sierra_paths)?;
let contracts = get_contracts_artifacts_and_source_sierra_paths(
&scarb_metadata,
&package.id,
None,
)?;
let contracts_data = ContractsData::try_from(contracts)?;

let forge_config =
load_package_config::<ForgeConfig>(&scarb_metadata, &package.id)?;
Expand Down
5 changes: 4 additions & 1 deletion crates/forge/src/scarb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ pub fn build_test_artifacts_with_scarb(filter: PackagesFilter) -> Result<()> {
}

#[must_use]
pub fn test_artifacts_path(snforge_target_dir_path: &Utf8Path, package_name: &str) -> Utf8PathBuf {
pub fn get_test_artifacts_path(
snforge_target_dir_path: &Utf8Path,
package_name: &str,
) -> Utf8PathBuf {
snforge_target_dir_path.join(format!("{package_name}.snforge_sierra.json"))
}

Expand Down
23 changes: 15 additions & 8 deletions crates/forge/test_utils/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use forge_runner::{
};
use indoc::formatdoc;
use scarb_api::{
get_contracts_artifacts_and_sierra_paths, metadata::MetadataCommandExt, ScarbCommand,
get_contracts_artifacts_and_source_sierra_paths, metadata::MetadataCommandExt, ScarbCommand,
StarknetContractArtifacts,
};
use shared::command::CommandExt;
Expand Down Expand Up @@ -96,11 +96,12 @@ impl Contract {
.find(|package| package.name == "contract")
.unwrap();

let contract = get_contracts_artifacts_and_sierra_paths(&scarb_metadata, &package.id, None)
.unwrap()
.0
.remove(&self.name)
.ok_or(anyhow!("there is no contract with name {}", self.name))?;
let contract =
get_contracts_artifacts_and_source_sierra_paths(&scarb_metadata, &package.id, None)
.unwrap()
.remove(&self.name)
.ok_or(anyhow!("there is no contract with name {}", self.name))?
.0;

Ok((contract.sierra, contract.casm))
}
Expand Down Expand Up @@ -191,15 +192,21 @@ impl<'a> TestCase {
]
}

pub fn contracts(&self) -> Result<HashMap<String, StarknetContractArtifacts>> {
pub fn contracts(&self) -> Result<HashMap<String, (StarknetContractArtifacts, Utf8PathBuf)>> {
self.contracts
.clone()
.into_iter()
.map(|contract| {
let name = contract.name.clone();
let (sierra, casm) = contract.generate_sierra_and_casm()?;

Ok((name, StarknetContractArtifacts { sierra, casm }))
Ok((
name,
(
StarknetContractArtifacts { sierra, casm },
Default::default(),
),
))
})
.collect()
}
Expand Down
7 changes: 3 additions & 4 deletions crates/forge/test_utils/src/running_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use camino::Utf8PathBuf;
use cheatnet::runtime_extensions::forge_runtime_extension::contracts_data::ContractsData;
use forge::block_number_map::BlockNumberMap;
use forge::run;
use forge::scarb::{load_test_artifacts, test_artifacts_path};
use forge::scarb::{get_test_artifacts_path, load_test_artifacts};
use forge::test_filter::TestsFilter;
use forge_runner::test_crate_summary::TestCrateSummary;
use forge_runner::{RunnerConfig, RunnerParams};
use shared::command::CommandExt;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::process::Command;
Expand All @@ -28,7 +27,7 @@ pub fn run_test_case(test: &TestCase) -> Vec<TestCrateSummary> {
.unwrap();

let rt = Runtime::new().expect("Could not instantiate Runtime");
let test_artifacts_path = test_artifacts_path(
let test_artifacts_path = get_test_artifacts_path(
&test.path().unwrap().join("target/dev/snforge"),
"test_package",
);
Expand All @@ -49,7 +48,7 @@ pub fn run_test_case(test: &TestCase) -> Vec<TestCrateSummary> {
None,
)),
Arc::new(RunnerParams::new(
ContractsData::try_from(test.contracts().unwrap(), HashMap::new()).unwrap(),
ContractsData::try_from(test.contracts().unwrap()).unwrap(),
test.env().clone(),
)),
&[],
Expand Down
7 changes: 3 additions & 4 deletions crates/forge/tests/integration/setup_fork.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use indoc::formatdoc;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
Expand All @@ -16,7 +15,7 @@ use tokio::runtime::Runtime;

use cheatnet::runtime_extensions::forge_runtime_extension::contracts_data::ContractsData;
use forge::compiled_raw::RawForkParams;
use forge::scarb::{load_test_artifacts, test_artifacts_path};
use forge::scarb::{get_test_artifacts_path, load_test_artifacts};
use forge_runner::{RunnerConfig, RunnerParams};
use shared::command::CommandExt;
use test_utils::runner::{assert_case_output_contains, assert_failed, assert_passed, Contract};
Expand Down Expand Up @@ -114,7 +113,7 @@ fn fork_aliased_decorator() {
.output_checked()
.unwrap();

let test_artifacts_path = test_artifacts_path(
let test_artifacts_path = get_test_artifacts_path(
&test.path().unwrap().join("target/dev/snforge"),
"test_package",
);
Expand All @@ -136,7 +135,7 @@ fn fork_aliased_decorator() {
None,
)),
Arc::new(RunnerParams::new(
ContractsData::try_from(test.contracts().unwrap(), HashMap::new()).unwrap(),
ContractsData::try_from(test.contracts().unwrap()).unwrap(),
test.env().clone(),
)),
&[ForkTarget::new(
Expand Down
Loading

0 comments on commit fe0d5a3

Please sign in to comment.