Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RunnerParams and RunnerConfig, add test_artifacts_path to ContextData #2032

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/cheatnet/src/forking/cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{Context, Result};
use blockifier::blockifier::block::BlockInfo;
use camino::Utf8PathBuf;
use camino::{Utf8Path, Utf8PathBuf};
use fs2::FileExt;
use regex::Regex;
use runtime::starknet::context::SerializableBlockInfo;
Expand Down Expand Up @@ -97,7 +97,7 @@ impl ForkCache {
pub(crate) fn load_or_new(
url: &Url,
block_number: BlockNumber,
cache_dir: &str,
cache_dir: &Utf8Path,
) -> Result<Self> {
let cache_file = cache_file_path_from_fork_config(url, block_number, cache_dir)?;
let mut file = OpenOptions::new()
Expand Down Expand Up @@ -239,14 +239,14 @@ impl ForkCache {
fn cache_file_path_from_fork_config(
url: &Url,
BlockNumber(block_number): BlockNumber,
cache_dir: &str,
cache_dir: &Utf8Path,
) -> Result<Utf8PathBuf> {
let re = Regex::new(r"[^a-zA-Z0-9]").unwrap();

// replace non-alphanumeric characters with underscores
let sanitized_path = re.replace_all(url.as_str(), "_");

let cache_file_path = Utf8PathBuf::from(cache_dir).join(format!(
let cache_file_path = cache_dir.join(format!(
"{sanitized_path}_{block_number}_v{CACHE_VERSION}.json"
));

Expand Down
3 changes: 2 additions & 1 deletion crates/cheatnet/src/forking/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use blockifier::state::errors::StateError::{self, StateReadError, UndeclaredClas
use blockifier::state::state_api::{StateReader, StateResult};
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
use cairo_lang_utils::bigint::BigUintAsHex;
use camino::Utf8Path;
use conversions::{FromConv, IntoConv};
use flate2::read::GzDecoder;
use num_bigint::BigUint;
Expand Down Expand Up @@ -42,7 +43,7 @@ pub struct ForkStateReader {
}

impl ForkStateReader {
pub fn new(url: Url, block_number: BlockNumber, cache_dir: &str) -> Result<Self> {
pub fn new(url: Url, block_number: BlockNumber, cache_dir: &Utf8Path) -> Result<Self> {
Ok(ForkStateReader {
cache: RefCell::new(
ForkCache::load_or_new(&url, block_number, cache_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn declare(
contract_name: &str,
contracts_data: &ContractsData,
) -> Result<ClassHash, CheatcodeError> {
let contract_artifact = contracts_data.get_artifacts_for_contract(contract_name).with_context(|| {
let contract_artifact = contracts_data.get_artifacts(contract_name).with_context(|| {
format!("Failed to get contract artifact for name = {contract_name}. Make sure starknet target is correctly defined in Scarb.toml file.")
}).map_err::<EnhancedHintError, _>(From::from)?;

Expand All @@ -27,7 +27,7 @@ pub fn declare(
let contract_class = BlockifierContractClass::V1(contract_class);

let class_hash = *contracts_data
.get_class_hash_for_contract(contract_name)
.get_class_hash(contract_name)
.expect("Failed to get class hash");

match state.get_compiled_contract_class(class_hash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,50 @@ use starknet::core::utils::get_selector_from_name;
use starknet_api::core::{ClassHash, EntryPointSelector};
use std::collections::HashMap;

#[derive(Debug, Clone)]
type ContractName = String;
type FunctionName = String;

#[derive(Debug, Clone, PartialEq, Default)]
pub struct ContractsData {
contracts: HashMap<String, Contract>,
class_hash_index: BiMap<String, ClassHash>,
selectors: HashMap<EntryPointSelector, String>,
contracts: HashMap<ContractName, ContractData>,
class_hashes: BiMap<ContractName, ClassHash>,
selectors: HashMap<EntryPointSelector, FunctionName>,
}

#[derive(Debug, Clone)]
struct Contract {
#[derive(Debug, Clone, PartialEq)]
struct ContractData {
artifacts: StarknetContractArtifacts,
class_hash: ClassHash,
_sierra_source_path: Utf8PathBuf,
_source_sierra_path: Utf8PathBuf,
}

impl ContractsData {
pub fn try_from(
contracts: HashMap<String, (StarknetContractArtifacts, Utf8PathBuf)>,
contracts: HashMap<ContractName, (StarknetContractArtifacts, Utf8PathBuf)>,
) -> Result<Self> {
let parsed_contracts: HashMap<String, SierraClass> = contracts
let parsed_contracts: HashMap<ContractName, SierraClass> = contracts
.par_iter()
.map(|(name, (artifact, _))| {
Ok((name.clone(), serde_json::from_str(&artifact.sierra)?))
})
.collect::<Result<_>>()?;

let class_hashes: Vec<(String, ClassHash)> = parsed_contracts
let class_hashes: Vec<(ContractName, ClassHash)> = parsed_contracts
.par_iter()
.map(|(name, sierra_class)| Ok((name.clone(), get_class_hash(sierra_class)?)))
.collect::<Result<_>>()?;
let class_hash_index = BiMap::from_iter(class_hashes);
let class_hashes = BiMap::from_iter(class_hashes);

let contracts = contracts
.into_iter()
.map(|(name, (artifacts, sierra_source_path))| {
let class_hash = *class_hash_index.get_by_left(&name).unwrap();
let class_hash = *class_hashes.get_by_left(&name).unwrap();
(
name,
Contract {
ContractData {
artifacts,
class_hash,
_sierra_source_path: sierra_source_path,
_source_sierra_path: sierra_source_path,
},
)
})
Expand All @@ -64,38 +67,40 @@ impl ContractsData {

Ok(ContractsData {
contracts,
class_hash_index,
class_hashes,
selectors,
})
}

#[must_use]
pub fn get_artifacts_for_contract(&self, name: &str) -> Option<&StarknetContractArtifacts> {
self.contracts.get(name).map(|contract| &contract.artifacts)
pub fn get_artifacts(&self, contract_name: &str) -> Option<&StarknetContractArtifacts> {
self.contracts
.get(contract_name)
.map(|contract| &contract.artifacts)
}

#[must_use]
pub fn get_class_hash_for_contract(&self, name: &str) -> Option<&ClassHash> {
pub fn get_class_hash(&self, contract_name: &str) -> Option<&ClassHash> {
self.contracts
.get(name)
.get(contract_name)
.map(|contract| &contract.class_hash)
}

#[must_use]
pub fn get_contract_name_from_class_hash(&self, class_hash: &ClassHash) -> Option<&String> {
self.class_hash_index.get_by_right(class_hash)
pub fn get_contract_name(&self, class_hash: &ClassHash) -> Option<&ContractName> {
self.class_hashes.get_by_right(class_hash)
}

#[must_use]
pub fn get_function_name_from_entry_point_selector(
pub fn get_function_name(
&self,
entry_point_selector: &EntryPointSelector,
) -> Option<&String> {
) -> Option<&FunctionName> {
self.selectors.get(entry_point_selector)
}
}

fn build_name_selector_map(abi: Vec<AbiEntry>) -> HashMap<EntryPointSelector, String> {
fn build_name_selector_map(abi: Vec<AbiEntry>) -> HashMap<EntryPointSelector, FunctionName> {
let mut selector_map = HashMap::new();
for abi_entry in abi {
match abi_entry {
Expand All @@ -112,7 +117,7 @@ fn build_name_selector_map(abi: Vec<AbiEntry>) -> HashMap<EntryPointSelector, St

fn add_simple_abi_entry_to_mapping(
abi_entry: AbiEntry,
selector_map: &mut HashMap<EntryPointSelector, String>,
selector_map: &mut HashMap<EntryPointSelector, FunctionName>,
) {
match abi_entry {
AbiEntry::Function(abi_function) | AbiEntry::L1Handler(abi_function) => {
Expand Down
12 changes: 3 additions & 9 deletions crates/cheatnet/tests/cheatcodes/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ fn declare_simple() {
let contracts_data = get_contracts();

let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap();
let expected_class_hash = contracts_data
.get_class_hash_for_contract(contract_name)
.unwrap();
let expected_class_hash = contracts_data.get_class_hash(contract_name).unwrap();

assert_eq!(class_hash, *expected_class_hash);
}
Expand All @@ -29,9 +27,7 @@ fn declare_multiple() {

for contract_name in contract_names {
let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap();
let expected_class_hash = contracts_data
.get_class_hash_for_contract(contract_name)
.unwrap();
let expected_class_hash = contracts_data.get_class_hash(contract_name).unwrap();
assert_eq!(class_hash, *expected_class_hash);
}
}
Expand All @@ -45,9 +41,7 @@ fn declare_same_contract() {
let contracts_data = get_contracts();

let class_hash = declare(&mut cached_state, contract_name, &contracts_data).unwrap();
let expected_class_hash = contracts_data
.get_class_hash_for_contract(contract_name)
.unwrap();
let expected_class_hash = contracts_data.get_class_hash(contract_name).unwrap();
assert_eq!(class_hash, *expected_class_hash);

let output = declare(&mut cached_state, contract_name, &contracts_data);
Expand Down
3 changes: 2 additions & 1 deletion crates/cheatnet/tests/common/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ pub fn create_fork_cached_state_at(
ExtendedStateReader {
dict_state_reader: build_testing_state(),
fork_state_reader: Some(
ForkStateReader::new(node_url, BlockNumber(block_number), cache_dir).unwrap(),
ForkStateReader::new(node_url, BlockNumber(block_number), cache_dir.into())
.unwrap(),
),
},
GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST),
Expand Down
3 changes: 2 additions & 1 deletion crates/cheatnet/tests/starknet/forking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use blockifier::state::cached_state::{
};
use cairo_felt::Felt252;
use cairo_vm::vm::errors::hint_errors::HintError;
use camino::Utf8Path;
use cheatnet::constants::build_testing_state;
use cheatnet::forking::{cache::CACHE_VERSION, state::ForkStateReader};
use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::CheatcodeError;
Expand Down Expand Up @@ -652,7 +653,7 @@ fn test_calling_nonexistent_url() {
ForkStateReader::new(
nonexistent_url,
BlockNumber(1),
temp_dir.path().to_str().unwrap(),
Utf8Path::from_path(temp_dir.path()).unwrap(),
)
.unwrap(),
),
Expand Down
4 changes: 2 additions & 2 deletions crates/forge-runner/src/build_trace_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ pub fn build_profiler_call_entry_point(
} = value;

let mut contract_name = class_hash
.and_then(|c| contracts_data.get_contract_name_from_class_hash(&c))
.and_then(|c| contracts_data.get_contract_name(&c))
.cloned();
let mut function_name = contracts_data
.get_function_name_from_entry_point_selector(&entry_point_selector)
.get_function_name(&entry_point_selector)
.cloned();

if entry_point_selector.0
Expand Down
92 changes: 92 additions & 0 deletions crates/forge-runner/src/forge_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use camino::Utf8PathBuf;
use cheatnet::runtime_extensions::forge_runtime_extension::contracts_data::ContractsData;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;

#[derive(Debug, PartialEq)]
pub struct ForgeConfig {
pub test_runner_config: Arc<TestRunnerConfig>,
pub output_config: Arc<OutputConfig>,
}

#[derive(Debug, PartialEq)]
pub struct TestRunnerConfig {
pub exit_first: bool,
pub fuzzer_runs: NonZeroU32,
pub fuzzer_seed: u64,
pub max_n_steps: Option<u32>,
pub is_vm_trace_needed: bool,
pub cache_dir: Utf8PathBuf,
pub contracts_data: ContractsData,
pub environment_variables: HashMap<String, String>,
pub test_artifacts_path: Utf8PathBuf,
}

#[derive(Debug, PartialEq)]
pub struct OutputConfig {
pub detailed_resources: bool,
pub execution_data_to_save: ExecutionDataToSave,
}

impl OutputConfig {
#[must_use]
pub fn new(detailed_resources: bool, save_trace_data: bool, build_profile: bool) -> Self {
Self {
detailed_resources,
execution_data_to_save: ExecutionDataToSave::from_flags(save_trace_data, build_profile),
}
}
}

#[derive(Debug, PartialEq, Clone, Copy)]
pub enum ExecutionDataToSave {
None,
Trace,
/// Profile data requires saved trace data
TraceAndProfile,
}

impl ExecutionDataToSave {
#[must_use]
pub fn from_flags(save_trace_data: bool, build_profile: bool) -> Self {
if build_profile {
return ExecutionDataToSave::TraceAndProfile;
}
if save_trace_data {
return ExecutionDataToSave::Trace;
}
ExecutionDataToSave::None
}
}

#[must_use]
pub fn is_vm_trace_needed(execution_data_to_save: ExecutionDataToSave) -> bool {
match execution_data_to_save {
ExecutionDataToSave::Trace | ExecutionDataToSave::TraceAndProfile => true,
ExecutionDataToSave::None => false,
}
}

/// This struct should be constructed on demand to pass only relevant information from
/// [`TestRunnerConfig`] to another function.
pub struct RuntimeConfig<'a> {
pub max_n_steps: Option<u32>,
pub is_vm_trace_needed: bool,
pub cache_dir: &'a Utf8PathBuf,
pub contracts_data: &'a ContractsData,
pub environment_variables: &'a HashMap<String, String>,
}

impl<'a> RuntimeConfig<'a> {
#[must_use]
pub fn from(value: &'a TestRunnerConfig) -> RuntimeConfig<'a> {
Self {
max_n_steps: value.max_n_steps,
is_vm_trace_needed: value.is_vm_trace_needed,
cache_dir: &value.cache_dir,
contracts_data: &value.contracts_data,
environment_variables: &value.environment_variables,
}
}
}
Loading
Loading