Skip to content

Commit

Permalink
fix: use sozo manifest when possible to avoid recomputing diff (dojoe…
Browse files Browse the repository at this point in the history
…ngine#2649)

* fix: ensure the local manifest is used when possible

* fix: add missing tag for the models and events

* fix: regenerate db with new migration spawn and move

* tests: add missing tests
  • Loading branch information
glihm authored Nov 7, 2024
1 parent a4adeb3 commit 86365fd
Show file tree
Hide file tree
Showing 14 changed files with 610 additions and 171 deletions.
31 changes: 26 additions & 5 deletions bin/sozo/src/commands/call.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use anyhow::{anyhow, Result};
use clap::Args;
use dojo_types::naming;
use dojo_world::config::calldata_decoder;
use dojo_world::contracts::ContractInfo;
use scarb::core::Config;
use sozo_ops::resource_descriptor::ResourceDescriptor;
use sozo_scarbext::WorkspaceExt;
Expand Down Expand Up @@ -39,6 +41,11 @@ pub struct CallArgs {
#[arg(help = "The block ID (could be a hash, a number, 'pending' or 'latest')")]
pub block_id: Option<String>,

#[arg(long)]
#[arg(help = "If true, sozo will compute the diff of the world from the chain to translate \
tags to addresses.")]
pub diff: bool,

#[command(flatten)]
pub starknet: StarknetOptions,

Expand All @@ -57,8 +64,7 @@ impl CallArgs {
let descriptor = self.tag_or_address.ensure_namespace(&profile_config.namespace.default);

config.tokio_handle().block_on(async {
let (world_diff, provider, _) =
utils::get_world_diff_and_provider(self.starknet.clone(), self.world, &ws).await?;
let local_manifest = ws.read_manifest_profile()?;

let calldata = if let Some(cd) = self.calldata {
calldata_decoder::decode_calldata(&cd)?
Expand All @@ -69,8 +75,21 @@ impl CallArgs {
let contract_address = match &descriptor {
ResourceDescriptor::Address(address) => Some(*address),
ResourceDescriptor::Tag(tag) => {
let selector = naming::compute_selector_from_tag(tag);
world_diff.get_contract_address(selector)
let contracts: HashMap<String, ContractInfo> =
if self.diff || local_manifest.is_none() {
let (world_diff, _, _) = utils::get_world_diff_and_provider(
self.starknet.clone(),
self.world,
&ws,
)
.await?;

(&world_diff).into()
} else {
(&local_manifest.unwrap()).into()
};

contracts.get(tag).map(|c| c.address)
}
ResourceDescriptor::Name(_) => {
unimplemented!("Expected to be a resolved tag with default namespace.")
Expand All @@ -84,6 +103,8 @@ impl CallArgs {
BlockId::Tag(BlockTag::Pending)
};

let (provider, _) = self.starknet.provider(profile_config.env.as_ref())?;

let res = provider
.call(
FunctionCall {
Expand Down
12 changes: 2 additions & 10 deletions bin/sozo/src/commands/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,12 @@ async fn match_event<P: Provider + Send + Sync>(
format!("{:#066x}", e.system_address.0)
};

let (record, _, _) = model::model_get(
tag.clone(),
e.keys.clone(),
world_diff.world_info.address,
provider,
block_id,
)
.await?;
// TODO: for events, we need to pull the schema and print the values accordingly.

(
format!("Event emitted ({})", tag),
format!(
"Selector: {:#066x}\nContract: {}\nKeys: {}\nValues: {}\nData:\n{}",
"Selector: {:#066x}\nContract: {}\nKeys: {}\nValues: {}",
e.selector,
contract_tag,
e.keys
Expand All @@ -358,7 +351,6 @@ async fn match_event<P: Provider + Send + Sync>(
.map(|v| format!("{:#066x}", v))
.collect::<Vec<String>>()
.join(", "),
record
),
)
}
Expand Down
66 changes: 48 additions & 18 deletions bin/sozo/src/commands/execute.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::HashMap;

use anyhow::{anyhow, Result};
use clap::Args;
use dojo_types::naming;
use dojo_utils::{Invoker, TxnConfig};
use dojo_world::config::calldata_decoder;
use dojo_world::contracts::ContractInfo;
use scarb::core::Config;
use sozo_ops::resource_descriptor::ResourceDescriptor;
use sozo_scarbext::WorkspaceExt;
Expand Down Expand Up @@ -39,6 +41,11 @@ pub struct ExecuteArgs {
- no prefix: A cairo felt or any type that fit into one felt.")]
pub calldata: Option<String>,

#[arg(long)]
#[arg(help = "If true, sozo will compute the diff of the world from the chain to translate \
tags to addresses.")]
pub diff: bool,

#[command(flatten)]
pub starknet: StarknetOptions,

Expand Down Expand Up @@ -71,28 +78,44 @@ impl ExecuteArgs {
let txn_config: TxnConfig = self.transaction.into();

config.tokio_handle().block_on(async {
// We could save the world diff computation extracting the account directly from the
// options.
let (world_diff, account, _) = utils::get_world_diff_and_account(
self.account,
self.starknet.clone(),
self.world,
&ws,
&mut None,
)
.await?;

let contract_address = match &descriptor {
ResourceDescriptor::Address(address) => Some(*address),
let local_manifest = ws.read_manifest_profile()?;
let use_diff = self.diff || local_manifest.is_none();

let (contract_address, contracts) = match &descriptor {
ResourceDescriptor::Address(address) => (Some(*address), Default::default()),
ResourceDescriptor::Tag(tag) => {
let selector = naming::compute_selector_from_tag(tag);
world_diff.get_contract_address(selector)
let contracts: HashMap<String, ContractInfo> = if use_diff {
let (world_diff, _, _) = utils::get_world_diff_and_account(
self.account.clone(),
self.starknet.clone(),
self.world,
&ws,
&mut None,
)
.await?;

(&world_diff).into()
} else {
(&local_manifest.unwrap()).into()
};

(contracts.get(tag).map(|c| c.address), contracts)
}
ResourceDescriptor::Name(_) => {
unimplemented!("Expected to be a resolved tag with default namespace.")
}
}
.ok_or_else(|| anyhow!("Contract {descriptor} not found in the world diff."))?;
};

let contract_address = contract_address.ok_or_else(|| {
let mut message = format!("Contract {descriptor} not found in the manifest.");
if self.diff {
message.push_str(
" Run the command again with `--diff` to force the fetch of data from the \
chain.",
);
}
anyhow!(message)
})?;

trace!(
contract=?descriptor,
Expand All @@ -113,6 +136,13 @@ impl ExecuteArgs {
selector: snutils::get_selector_from_name(&self.entrypoint)?,
};

let (provider, _) = self.starknet.provider(profile_config.env.as_ref())?;

let account = self
.account
.account(provider, profile_config.env.as_ref(), &self.starknet, &contracts)
.await?;

let invoker = Invoker::new(&account, txn_config);
// TODO: add walnut back, perhaps at the invoker level.
let tx_result = invoker.invoke(call).await?;
Expand Down
82 changes: 20 additions & 62 deletions bin/sozo/src/commands/options/account/controller.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;

use anyhow::{bail, Result};
use dojo_world::diff::WorldDiff;
use dojo_world::ResourceType;
use dojo_world::contracts::contract_info::ContractInfo;
use slot::account_sdk::account::session::hash::{Policy, ProvedPolicy};
use slot::account_sdk::account::session::merkle::MerkleTree;
use slot::account_sdk::account::session::SessionAccount;
use slot::session::{FullSessionInfo, PolicyMethod};
use starknet::core::types::contract::{AbiEntry, StateMutability};
use starknet::core::types::Felt;
use starknet::core::utils::get_selector_from_name;
use starknet::macros::felt;
Expand Down Expand Up @@ -35,16 +34,12 @@ pub type ControllerSessionAccount<P> = SessionAccount<Arc<P>>;
/// * Starknet mainnet
/// * Starknet sepolia
/// * Slot hosted networks
#[tracing::instrument(
name = "create_controller",
skip(rpc_url, provider, world_address, world_diff)
)]
#[tracing::instrument(name = "create_controller", skip(rpc_url, provider, contracts))]
pub async fn create_controller<P>(
// Ideally we can get the url from the provider so we dont have to pass an extra url param here
rpc_url: Url,
provider: P,
world_address: Felt,
world_diff: &WorldDiff,
contracts: &HashMap<String, ContractInfo>,
) -> Result<ControllerSessionAccount<P>>
where
P: Provider,
Expand All @@ -62,7 +57,7 @@ where
bail!("No Controller is associated with this account.");
};

let policies = collect_policies(world_address, contract_address, world_diff)?;
let policies = collect_policies(contract_address, contracts)?;

// Check if the session exists, if not create a new one
let session_details = match slot::session::get(chain_id)? {
Expand Down Expand Up @@ -132,37 +127,28 @@ fn is_equal_to_existing(new_policies: &[PolicyMethod], session_info: &FullSessio
/// This function collect all the contracts' methods in the current project according to the
/// project's base manifest ( `/manifests/<profile>/base` ) and convert them into policies.
fn collect_policies(
world_address: Felt,
user_address: Felt,
world_diff: &WorldDiff,
contracts: &HashMap<String, ContractInfo>,
) -> Result<Vec<PolicyMethod>> {
let policies = collect_policies_from_local_world(world_address, user_address, world_diff)?;
let policies = collect_policies_from_contracts(user_address, contracts)?;
trace!(target: "account::controller", policies_count = policies.len(), "Extracted policies from project.");
Ok(policies)
}

fn collect_policies_from_local_world(
world_address: Felt,
fn collect_policies_from_contracts(
user_address: Felt,
world_diff: &WorldDiff,
contracts: &HashMap<String, ContractInfo>,
) -> Result<Vec<PolicyMethod>> {
let mut policies: Vec<PolicyMethod> = Vec::new();

// get methods from all project contracts
for (selector, resource) in world_diff.resources.iter() {
if resource.resource_type() == ResourceType::Contract {
// Safe to unwrap the two methods since the selector comes from the resources registry
// in the local world.
let contract_address = world_diff.get_contract_address(*selector).unwrap();
let sierra_class = world_diff.get_class(*selector).unwrap();

policies_from_abis(&mut policies, &resource.tag(), contract_address, &sierra_class.abi);
for (tag, info) in contracts {
for e in &info.entrypoints {
let policy = PolicyMethod { target: info.address, method: e.clone() };
trace!(target: "account::controller", tag, target = format!("{:#x}", policy.target), method = %policy.method, "Adding policy");
policies.push(policy);
}
}

// get method from world contract
policies_from_abis(&mut policies, "world", world_address, &world_diff.world_info.class.abi);

// special policy for sending declare tx
// corresponds to [account_sdk::account::DECLARATION_SELECTOR]
let method = "__declare_transaction__".to_string();
Expand All @@ -179,39 +165,12 @@ fn collect_policies_from_local_world(
Ok(policies)
}

/// Recursively extract methods and convert them into policies from the all the
/// ABIs in the project.
fn policies_from_abis(
policies: &mut Vec<PolicyMethod>,
contract_tag: &str,
contract_address: Felt,
entries: &[AbiEntry],
) {
for entry in entries {
match entry {
AbiEntry::Function(f) => {
// we only create policies for non-view functions
if let StateMutability::External = f.state_mutability {
let policy =
PolicyMethod { target: contract_address, method: f.name.to_string() };
trace!(target: "account::controller", tag = contract_tag, target = format!("{:#x}", policy.target), method = %policy.method, "Adding policy");
policies.push(policy);
}
}

AbiEntry::Interface(i) => {
policies_from_abis(policies, contract_tag, contract_address, &i.items)
}

_ => {}
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use dojo_test_utils::compiler::CompilerTestSetup;
use dojo_world::diff::WorldDiff;
use dojo_world::contracts::ContractInfo;
use scarb::compiler::Profile;
use sozo_scarbext::WorkspaceExt;
use starknet::macros::felt;
Expand All @@ -228,13 +187,12 @@ mod tests {
let ws = scarb::ops::read_workspace(config.manifest_path(), &config)
.unwrap_or_else(|op| panic!("Error building workspace: {op:?}"));

let world_local = ws.load_world_local().unwrap();
let world_diff = WorldDiff::from_local(world_local).unwrap();
let manifest = ws.read_manifest_profile().expect("Failed to read manifest").unwrap();
let contracts: HashMap<String, ContractInfo> = (&manifest).into();

let user_addr = felt!("0x2af9427c5a277474c079a1283c880ee8a6f0f8fbf73ce969c08d88befec1bba");

let policies =
collect_policies(world_diff.world_info.address, user_addr, &world_diff).unwrap();
let policies = collect_policies(user_addr, &contracts).unwrap();

if std::env::var("POLICIES_FIX").is_ok() {
let policies_json = serde_json::to_string_pretty(&policies).unwrap();
Expand Down
Loading

0 comments on commit 86365fd

Please sign in to comment.