diff --git a/contracts/account/ArgentAccount.cairo b/contracts/account/ArgentAccount.cairo index 24ea5de2..7721e271 100644 --- a/contracts/account/ArgentAccount.cairo +++ b/contracts/account/ArgentAccount.cairo @@ -21,6 +21,12 @@ from contracts.account.library import ( assert_initialized, assert_no_self_call, ) + +// +// @title ArgentAccount +// @author Argent Labs +// @notice Main account for Argent on StarkNet +// ///////////////////// // CONSTANTS @@ -196,6 +202,10 @@ func supportsInterface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_che // EXTERNAL FUNCTIONS ///////////////////// +// @dev Initialises the account with the signer and an optional guardian. +// Must be called immediately after the account is deployed. +// @param signer The signer public key +// @param guardian The guardian public key @external func initialize{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( signer: felt, guardian: felt @@ -206,6 +216,11 @@ func initialize{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr} return (); } +// @dev Upgrades the implementation of the account and delegate calls {execute_after_upgrade} if additional data is provided. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param implementation The class hash of the new implementation +// @param calldata The calldata to pass to {execute_after_upgrade} +// @return retdata The return of the library call to {execute_after_upgrade} @external func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( implementation: felt, calldata_len: felt, calldata: felt* @@ -226,6 +241,11 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( } } +// @dev Logic or multicall to execute after an upgrade. +// Can only be called by the account after a call to {upgrade}. +// @param call_array The multicall to execute +// @param calldata The calldata associated to the multicall +// @return retdata An array containing the output of the calls @external func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* @@ -241,6 +261,9 @@ func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range return (retdata_len=retdata_len, retdata=retdata); } +// @dev Changes the signer. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param newSigner The public key of the new signer @external func changeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newSigner: felt @@ -249,6 +272,9 @@ func changeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt return (); } +// @dev Changes the guardian. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param newGuardian The public key of the new guardian @external func changeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newGuardian: felt @@ -257,6 +283,9 @@ func changeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ return (); } +// @dev Changes the guardian backup. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param newGuardian The public key of the new guardian backup @external func changeGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newGuardian: felt @@ -265,24 +294,35 @@ func changeGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_ return (); } +// @dev Triggers the escape of the guardian when it is lost or compromised. +// Must be called via {__execute__} and authorised by the signer alone. +// Can override an ongoing escape of the signer. @external func triggerEscapeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { ArgentModel.trigger_escape_guardian(); return (); } +// @dev Triggers the escape of the signer when it is lost or compromised. +// Must be called via {__execute__} and authorised by a guardian alone. +// Cannot override an ongoing escape of the guardian. @external func triggerEscapeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { ArgentModel.trigger_escape_signer(); return (); } +// @dev Cancels an ongoing escape if any. +// Must be called via {__execute__} and authorised by the signer and a guardian. @external func cancelEscape{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { ArgentModel.cancel_escape(); return (); } +// @dev Escapes the guardian after the escape period of 7 days. +// Must be called via {__execute__} and authorised by the signer alone. +// @param newGuardian The public key of the new guardian @external func escapeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newGuardian: felt @@ -291,6 +331,9 @@ func escapeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ return (); } +// @dev Escapes the signer after the escape period of 7 days. +// Must be called via {__execute__} and authorised by a guardian alone. +// @param newSigner The public key of the new signer @external func escapeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newSigner: felt @@ -303,6 +346,8 @@ func escapeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt // VIEW FUNCTIONS ///////////////////// +// @dev Gets the current signer +// @return signer The public key of the signer @view func getSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( signer: felt @@ -311,6 +356,8 @@ func getSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( return (signer=res); } +// @dev Gets the current guardian +// @return guardian The public key of the guardian @view func getGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( guardian: felt @@ -319,6 +366,8 @@ func getGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr return (guardian=res); } +// @dev Gets the current guardian backup +// @return guardianBackup The public key of the guardian backup @view func getGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( guardianBackup: felt @@ -327,6 +376,9 @@ func getGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_che return (guardianBackup=res); } +// @dev Gets the details of the ongoing escape +// @return activeAt The timestamp at which the escape can be executed +// @return type The type of the ongoing escape: 0=no escape, 1=guardian escape, 2=signer escape @view func getEscape{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( activeAt: felt, type: felt @@ -335,17 +387,21 @@ func getEscape{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( return (activeAt=activeAt, type=type); } +// @dev Gets the version of the account implementation +// @return version The current version as a short string @view func getVersion() -> (version: felt) { return (version=VERSION); } +// @dev Gets the name of the account implementation +// @return name The name as a short string @view func getName() -> (name: felt) { return (name=NAME); } -// TMP: Remove when isValidSignature() is widely used +// @dev DEPRECATED: Remove when isValidSignature() is widely used @view func is_valid_signature{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr diff --git a/contracts/account/ArgentPluginAccount.cairo b/contracts/account/ArgentPluginAccount.cairo index 4395a874..6e14943e 100644 --- a/contracts/account/ArgentPluginAccount.cairo +++ b/contracts/account/ArgentPluginAccount.cairo @@ -2,7 +2,6 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.common.alloc import alloc -from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.math import assert_not_zero from starkware.starknet.common.syscalls import ( library_call, @@ -25,6 +24,12 @@ from contracts.account.library import ( assert_only_self ) +// +// @title ArgentPluginAccount +// @author Argent Labs +// @notice Experimental Argent account supporting plugins +// + /////////////////////// // CONSTANTS /////////////////////// @@ -212,6 +217,9 @@ func supportsInterface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_che // PLUGIN ////////////////////// +// @dev Adds a new plugin. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param plugin The class hash of the plugin @external func addPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt) { // only called via execute @@ -225,6 +233,9 @@ func addPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( return (); } +// @dev Removes an existing plugin. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param plugin The class hash of the plugin @external func removePlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt) { // only called via execute @@ -239,6 +250,11 @@ func removePlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt return (); } +// @dev Executes a library call on one of the enabled plugin of the account. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param plugin The class hash of the plugin +// @param selector The method to execute on the plugin +// @param calldata The call data of the call @external func executeOnPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( plugin: felt, selector: felt, calldata_len: felt, calldata: felt* @@ -257,6 +273,9 @@ func executeOnPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check return (); } +// @dev Checks if a plugin is enabled on the account. +// @param plugin The class hash of the plugin +// @return success True if the plugin is enabled @view func isPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt) -> ( success: felt @@ -292,6 +311,10 @@ func validate_with_plugin{ // EXTERNAL FUNCTIONS ////////////////////// +// @dev Initialises the account with the signer and an optional guardian. +// Must be called immediately after the account is deployed. +// @param signer The signer public key +// @param guardian The guardian public key @external func initialize{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( signer: felt, guardian: felt @@ -302,6 +325,11 @@ func initialize{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr} return (); } +// @dev Upgrades the implementation of the account and delegate calls {execute_after_upgrade} if additional data is provided. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param implementation The class hash of the new implementation +// @param calldata The calldata to pass to {execute_after_upgrade} +// @return retdata The return of the library call to {execute_after_upgrade} @external func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( implementation: felt, calldata_len: felt, calldata: felt* @@ -323,6 +351,11 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( } } +// @dev Logic or multicall to execute after an upgrade. +// Can only be called by the account after a call to {upgrade}. +// @param call_array The multicall to execute +// @param calldata The calldata associated to the multicall +// @return retdata An array containing the output of the calls @external func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* @@ -338,6 +371,9 @@ func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range return (retdata_len=retdata_len, retdata=retdata); } +// @dev Changes the signer. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param newSigner The public key of the new signer @external func changeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newSigner: felt @@ -346,6 +382,9 @@ func changeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt return (); } +// @dev Changes the guardian. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param newGuardian The public key of the new guardian @external func changeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newGuardian: felt @@ -354,6 +393,9 @@ func changeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ return (); } +// @dev Changes the guardian backup. +// Must be called via {__execute__} and authorised by the signer and a guardian. +// @param newGuardian The public key of the new guardian backup @external func changeGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newGuardian: felt @@ -362,24 +404,35 @@ func changeGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_ return (); } +// @dev Triggers the escape of the guardian when it is lost or compromised. +// Must be called via {__execute__} and authorised by the signer alone. +// Can override an ongoing escape of the signer. @external func triggerEscapeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { ArgentModel.trigger_escape_guardian(); return (); } +// @dev Triggers the escape of the signer when it is lost or compromised. +// Must be called via {__execute__} and authorised by a guardian alone. +// Cannot override an ongoing escape of the guardian. @external func triggerEscapeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { ArgentModel.trigger_escape_signer(); return (); } +// @dev Cancels an ongoing escape if any. +// Must be called via {__execute__} and authorised by the signer and a guardian. @external func cancelEscape{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { ArgentModel.cancel_escape(); return (); } +// @dev Escapes the guardian after the escape period of 7 days. +// Must be called via {__execute__} and authorised by the signer alone. +// @param newGuardian The public key of the new guardian @external func escapeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newGuardian: felt @@ -388,6 +441,9 @@ func escapeGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ return (); } +// @dev Escapes the signer after the escape period of 7 days. +// Must be called via {__execute__} and authorised by a guardian alone. +// @param newSigner The public key of the new signer @external func escapeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( newSigner: felt @@ -400,6 +456,8 @@ func escapeSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt // VIEW FUNCTIONS ///////////////////// +// @dev Gets the current signer +// @return signer The public key of the signer @view func getSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( signer: felt @@ -408,6 +466,8 @@ func getSigner{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( return (signer=res); } +// @dev Gets the current guardian +// @return guardian The public key of the guardian @view func getGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( guardian: felt @@ -416,6 +476,8 @@ func getGuardian{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr return (guardian=res); } +// @dev Gets the current guardian backup +// @return guardian The public key of the guardian backup @view func getGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( guardianBackup: felt @@ -424,6 +486,9 @@ func getGuardianBackup{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_che return (guardianBackup=res); } +// @dev Gets the details of the ongoing escape +// @return activeAt The timestamp at which the escape can be executed +// @return type The type of the ongoing escape: 0=no escape, 1=guardian escape, 2=signer escape @view func getEscape{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( activeAt: felt, type: felt @@ -432,17 +497,21 @@ func getEscape{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( return (activeAt=activeAt, type=type); } +// @dev Gets the version of the account implementation +// @return version The current version as a short string @view func getVersion() -> (version: felt) { return (version=VERSION); } +// @dev Gets the name of the account implementation +// @return name The name as a short string @view func getName() -> (name: felt) { return (name=NAME); } -// TMP: Remove when isValidSignature() is widely used +// @dev DEPRECATED: Remove when isValidSignature() is widely used @view func is_valid_signature{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr diff --git a/requirements.txt b/requirements.txt index d08478cf..0bd67835 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -cairo-lang>=0.10.1 -cairo-nile>=0.9.1 -nile-coverage>=0.2.0 +cairo-lang==0.10.1 +cairo-nile==0.11.0 +nile-coverage==0.2.5.1 pytest>=7.1.2 pytest-asyncio>=0.19.0 \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..9d4e5eb2 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,56 @@ +from typing import Tuple + +import pytest +import asyncio + +from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starknet.testing.contract import DeclaredClass +from starkware.starknet.testing.starknet import Starknet +from utils.utilities import compile + + +@pytest.fixture(scope='module') +def event_loop(): + return asyncio.new_event_loop() + + +@pytest.fixture(scope='module') +async def starknet(): + return await Starknet.empty() + + +@pytest.fixture(scope='module') +def account_cls() -> ContractClass: + return compile('contracts/account/ArgentAccount.cairo') + + +@pytest.fixture(scope='module') +def test_dapp_cls() -> ContractClass: + return compile("contracts/test/TestDapp.cairo") + + +@pytest.fixture(scope='module') +async def proxy_cls() -> ContractClass: + return compile("contracts/upgrade/Proxy.cairo") + + +@pytest.fixture(scope='module') +async def declared_account(starknet: Starknet, account_cls: ContractClass) -> DeclaredClass: + return await starknet.declare(contract_class=account_cls) + + +@pytest.fixture(scope='module') +async def declared_proxy(starknet: Starknet, proxy_cls: ContractClass): + return await starknet.declare(contract_class=proxy_cls) + + +@pytest.fixture(scope='module') +async def deploy_env(starknet: Starknet, declared_account: DeclaredClass, declared_proxy: DeclaredClass, proxy_cls: ContractClass, account_cls: ContractClass): + return starknet, declared_account, account_cls, proxy_cls, declared_proxy + + +@pytest.fixture +def deploy_env_copy(deploy_env: Tuple[Starknet, DeclaredClass, ContractClass, ContractClass, DeclaredClass]): + starknet, account_decl, account_cls, proxy_cls, proxy_decl = deploy_env + starknet_copy = Starknet(starknet.state.copy()) + return starknet_copy, account_decl, account_cls, proxy_cls, proxy_decl \ No newline at end of file diff --git a/test/test_argent_account.py b/test/test_argent_account.py index ff67705b..917b66d9 100644 --- a/test/test_argent_account.py +++ b/test/test_argent_account.py @@ -1,5 +1,6 @@ import pytest -import asyncio + +from starkware.starknet.services.api.contract_class import ContractClass from starkware.starknet.testing.starknet import Starknet from starkware.starknet.definitions.error_codes import StarknetErrorCode from utils.Signer import Signer @@ -27,20 +28,9 @@ ESCAPE_TYPE_GUARDIAN = 1 ESCAPE_TYPE_SIGNER = 2 -@pytest.fixture(scope='module') -def event_loop(): - return asyncio.new_event_loop() @pytest.fixture(scope='module') -def contract_classes(): - account_cls = compile('contracts/account/ArgentAccount.cairo') - dapp_cls = compile("contracts/test/TestDapp.cairo") - - return account_cls, dapp_cls - -@pytest.fixture(scope='module') -async def contract_init(contract_classes): - account_cls, dapp_cls = contract_classes +async def contract_init(account_cls: ContractClass, test_dapp_cls: ContractClass): starknet = await Starknet.empty() account = await starknet.deploy( @@ -56,25 +46,25 @@ async def contract_init(contract_classes): await account_no_guardian.initialize(signer.public_key, 0).execute() dapp = await starknet.deploy( - contract_class=dapp_cls, + contract_class=test_dapp_cls, constructor_calldata=[], ) return starknet.state, account, account_no_guardian, dapp + @pytest.fixture -async def contract_factory(contract_classes, contract_init): - account_cls, dapp_cls = contract_classes +async def contract_factory(account_cls: ContractClass, test_dapp_cls: ContractClass, contract_init): state, account, account_no_guardian, dapp = contract_init _state = state.copy() account = cached_contract(_state, account_cls, account) account_no_guardian = cached_contract(_state, account_cls, account_no_guardian) - dapp = cached_contract(_state, dapp_cls, dapp) + dapp = cached_contract(_state, test_dapp_cls, dapp) return account, account_no_guardian, dapp -@pytest.mark.asyncio + async def test_initializer(contract_factory): account, _, _ = contract_factory # should be configured correctly @@ -89,7 +79,7 @@ async def test_initializer(contract_factory): "argent: already initialized" ) -@pytest.mark.asyncio + async def test_declare(contract_factory): account, _, _ = contract_factory sender = TransactionSender(account) @@ -116,7 +106,8 @@ async def test_declare(contract_factory): tx_exec_info = await sender.declare_class(test_cls, [signer, guardian]) -@pytest.mark.asyncio + + async def test_call_dapp_with_guardian(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -166,7 +157,7 @@ async def test_call_dapp_with_guardian(contract_factory): assert (await dapp.get_number(account.contract_address).call()).result.number == 47 -@pytest.mark.asyncio + async def test_call_dapp_guardian_backup(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -201,7 +192,7 @@ async def test_call_dapp_guardian_backup(contract_factory): assert (await dapp.get_number(account.contract_address).call()).result.number == 47 -@pytest.mark.asyncio + async def test_call_dapp_no_guardian(contract_factory): _, account_no_guardian, dapp = contract_factory sender = TransactionSender(account_no_guardian) @@ -216,7 +207,7 @@ async def test_call_dapp_no_guardian(contract_factory): await sender.send_transaction([(account_no_guardian.contract_address, 'changeSigner', [new_signer.public_key])], [signer]) assert (await account_no_guardian.getSigner().call()).result.signer == (new_signer.public_key) - # should reverts calls that require the guardian to be set + # should revert calls that require the guardian to be set await assert_revert( sender.send_transaction([(account_no_guardian.contract_address, 'triggerEscapeGuardian', [])], [new_signer]), "argent: guardian required" @@ -227,7 +218,7 @@ async def test_call_dapp_no_guardian(contract_factory): await sender.send_transaction([(account_no_guardian.contract_address, 'changeGuardian', [new_guardian.public_key])], [new_signer]) assert (await account_no_guardian.getGuardian().call()).result.guardian == (new_guardian.public_key) -@pytest.mark.asyncio + async def test_multicall(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -255,12 +246,12 @@ async def test_multicall(contract_factory): await sender.send_transaction([(dapp.contract_address, 'set_number', [47]), (dapp.contract_address, 'increase_number', [10])], [signer, guardian]) assert (await dapp.get_number(account.contract_address).call()).result.number == 57 -@pytest.mark.asyncio + async def test_change_signer(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) - assert (await account.getSigner().call()).result.signer == (signer.public_key) + assert (await account.getSigner().call()).result.signer == signer.public_key # should revert with the wrong signer await assert_revert( @@ -286,7 +277,7 @@ async def test_change_signer(contract_factory): assert (await account.getSigner().call()).result.signer == (new_signer.public_key) -@pytest.mark.asyncio + async def test_change_guardian(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -317,7 +308,7 @@ async def test_change_guardian(contract_factory): assert (await account.getGuardian().call()).result.guardian == (new_guardian.public_key) -@pytest.mark.asyncio + async def test_change_guardian_backup(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -346,7 +337,7 @@ async def test_change_guardian_backup(contract_factory): assert (await account.getGuardianBackup().call()).result.guardianBackup == (new_guardian_backup.public_key) -@pytest.mark.asyncio + async def test_change_guardian_backup_when_no_guardian(contract_factory): _, account_no_guardian, dapp = contract_factory sender = TransactionSender(account_no_guardian) @@ -355,7 +346,7 @@ async def test_change_guardian_backup_when_no_guardian(contract_factory): sender.send_transaction([(account_no_guardian.contract_address, 'changeGuardianBackup', [new_guardian_backup.public_key])], [signer]) ) -@pytest.mark.asyncio + async def test_change_guardian_when_guardian_backup(contract_factory): account, _, _ = contract_factory sender = TransactionSender(account) @@ -368,7 +359,7 @@ async def test_change_guardian_when_guardian_backup(contract_factory): "argent: new guardian invalid" ) -@pytest.mark.asyncio + async def test_trigger_escape_guardian_by_signer(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -391,7 +382,7 @@ async def test_trigger_escape_guardian_by_signer(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == (DEFAULT_TIMESTAMP + ESCAPE_SECURITY_PERIOD) and escape.type == ESCAPE_TYPE_GUARDIAN) -@pytest.mark.asyncio + async def test_trigger_escape_signer_by_guardian(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -414,7 +405,7 @@ async def test_trigger_escape_signer_by_guardian(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == (DEFAULT_TIMESTAMP + ESCAPE_SECURITY_PERIOD) and escape.type == ESCAPE_TYPE_SIGNER) -@pytest.mark.asyncio + async def test_trigger_escape_signer_by_guardian_backup(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -440,7 +431,7 @@ async def test_trigger_escape_signer_by_guardian_backup(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == (DEFAULT_TIMESTAMP + ESCAPE_SECURITY_PERIOD) and escape.type == ESCAPE_TYPE_SIGNER) -@pytest.mark.asyncio + async def test_escape_guardian(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -487,7 +478,7 @@ async def test_escape_guardian(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == 0 and escape.type == 0) -@pytest.mark.asyncio + async def test_escape_signer(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -532,7 +523,7 @@ async def test_escape_signer(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == 0 and escape.type == 0) -@pytest.mark.asyncio + async def test_signer_overrides_trigger_escape_signer(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -553,7 +544,7 @@ async def test_signer_overrides_trigger_escape_signer(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == (DEFAULT_TIMESTAMP + 100 + ESCAPE_SECURITY_PERIOD) and escape.type == ESCAPE_TYPE_GUARDIAN) -@pytest.mark.asyncio + async def test_guardian_overrides_trigger_escape_guardian(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -576,7 +567,6 @@ async def test_guardian_overrides_trigger_escape_guardian(contract_factory): ) -@pytest.mark.asyncio async def test_cancel_escape(contract_factory): account, _, dapp = contract_factory sender = TransactionSender(account) @@ -609,7 +599,7 @@ async def test_cancel_escape(contract_factory): escape = (await account.getEscape().call()).result assert (escape.activeAt == 0 and escape.type == 0) -@pytest.mark.asyncio + async def test_is_valid_signature(contract_factory): account, _, dapp = contract_factory hash = 1283225199545181604979924458180358646374088657288769423115053097913173815464 @@ -624,7 +614,7 @@ async def test_is_valid_signature(contract_factory): res = (await account.is_valid_signature(hash, signatures).call()).result assert (res.is_valid == 1) -@pytest.mark.asyncio + async def test_support_interface(contract_factory): account, _, _ = contract_factory diff --git a/test/test_argent_deploy.py b/test/test_argent_deploy.py new file mode 100644 index 00000000..e66d4132 --- /dev/null +++ b/test/test_argent_deploy.py @@ -0,0 +1,29 @@ +from typing import Tuple + + +from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starknet.testing.starknet import Starknet +from starkware.starknet.testing.contract import StarknetContract, DeclaredClass, StarknetContractFunctionInvocation + +from utils.TransactionSender import TransactionSender +from utils.utilities import signer_key_1, signer_key_2, signer_key_3, signer_key_4, assert_revert + + +async def test_validate_deploy(deploy_env_copy: Tuple[Starknet, DeclaredClass, ContractClass, ContractClass, DeclaredClass]): + starknet, account_decl, account_cls, proxy_cls, proxy_decl = deploy_env_copy + + unsigned_tx = await TransactionSender.get_unsigned_deploy_transaction( + proxy_decl=proxy_decl, + account_decl=account_decl, + signer_pub_key=signer_key_1.public_key + ) + signature = TransactionSender.get_signature( + unsigned_tx.calculate_hash(starknet.state.general_config), + signer_keys=signer_key_1 + ) + await TransactionSender.send_deploy_tx( + starknet=starknet, + unsigned_tx=unsigned_tx, + contract_cls=proxy_cls, + signature=signature + ) diff --git a/test/test_proxy.py b/test/test_proxy.py index ed7eb03f..5d1ec67d 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -15,10 +15,6 @@ wrong_guardian = Signer(4) -@pytest.fixture(scope='module') -def event_loop(): - return asyncio.new_event_loop() - @pytest.fixture(scope='module', params=[ "ArgentAccount", "ArgentPluginAccount", @@ -26,6 +22,7 @@ def event_loop(): def account_class(request): return compile(f"contracts/account/{request.param}.cairo") + @pytest.fixture(scope='module') def contract_classes(account_class): proxy_cls = compile("contracts/upgrade/Proxy.cairo") diff --git a/test/utils/Signer.py b/test/utils/Signer.py index 3986054b..66768583 100644 --- a/test/utils/Signer.py +++ b/test/utils/Signer.py @@ -1,7 +1,8 @@ from typing import Tuple from starkware.crypto.signature.signature import private_to_stark_key, sign -class Signer(): + +class Signer: def __init__(self, private_key: int): self.private_key = private_key self.public_key = private_to_stark_key(private_key) diff --git a/test/utils/TransactionSender.py b/test/utils/TransactionSender.py index 3d744452..611b7f47 100644 --- a/test/utils/TransactionSender.py +++ b/test/utils/TransactionSender.py @@ -1,14 +1,18 @@ from typing import Optional, List, Tuple from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.definitions.general_config import StarknetChainId -from starkware.starknet.testing.contract import StarknetContract +from starkware.starknet.testing.contract import StarknetContract, DeclaredClass from starkware.starknet.core.os.transaction_hash.transaction_hash import calculate_transaction_hash_common, TransactionHashPrefix -from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Declare +from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Declare, DeployAccount from starkware.starknet.business_logic.transaction.objects import InternalTransaction, TransactionExecutionInfo from starkware.starknet.services.api.contract_class import ContractClass from starkware.starknet.core.os.class_hash import compute_class_hash +from starkware.starknet.testing.objects import StarknetCallInfo +from starkware.starknet.testing.starknet import Starknet from utils.Signer import Signer +from utils.utilities import build_contract_with_proxy + TRANSACTION_VERSION = 1 @@ -20,6 +24,28 @@ class TransactionSender: def __init__(self, account: StarknetContract): self.account = account + @staticmethod + def get_signature( + message_hash: int, + signer_keys: Signer, + guardian_keys: Optional[Signer] = None, + backup_guardian_keys:Optional[Signer] = None, + ) -> List[int]: + signers = [signer_keys] + if guardian_keys is not None or backup_guardian_keys is not None: + signers.append(guardian_keys) + if backup_guardian_keys is not None: + signers.append(backup_guardian_keys) + + signatures = [] + for signer in signers: + if signer is None: + signatures += [0, 0] + else: + signatures += list(signer.sign(message_hash)) + + return signatures + async def send_transaction( self, calls: List[Call], @@ -100,6 +126,98 @@ async def declare_class( execution_info = await state.execute_tx(tx=tx) return execution_info + @staticmethod + async def deploy( + starknet: Starknet, + proxy_cls: ContractClass, + proxy_decl: DeclaredClass, + account_decl: DeclaredClass, + account_cls: ContractClass, + signer_keys: Signer, + guardian_keys: Optional[Signer] = None + ): + + unsigned_tx = await TransactionSender.get_unsigned_deploy_transaction( + proxy_decl=proxy_decl, + account_decl=account_decl, + signer_pub_key=signer_keys.public_key, + guardian_pub_key= None if guardian_keys is None else guardian_keys.public_key + ) + + signature = TransactionSender.get_signature( + message_hash=unsigned_tx.calculate_hash(starknet.state.general_config), + signer_keys=signer_keys, + guardian_keys=guardian_keys + ) + + + proxy = await TransactionSender.send_deploy_tx( + starknet=starknet, + unsigned_tx=unsigned_tx, + contract_cls=proxy_cls, + signature=signature + ) + + account = build_contract_with_proxy(proxy=proxy, implementation_abi=account_cls.abi), + + return TransactionSender(account) + + @staticmethod + async def get_unsigned_deploy_transaction( + proxy_decl: DeclaredClass, + account_decl: DeclaredClass, + signer_pub_key: int, + guardian_pub_key: Optional[int] = None, + salt: Optional[int] = None, + ) -> DeployAccount: + initialize_params = [signer_pub_key, 0 if guardian_pub_key is None else guardian_pub_key] + proxy_call_data = [ + account_decl.class_hash, # implementation, + get_selector_from_name('initialize'), # selector + len(initialize_params), # arguments to initialize method + *initialize_params + ] + nonce = 0 + max_fee = 0 + + external_tx = DeployAccount( + class_hash=proxy_decl.class_hash, + contract_address_salt=0 if salt is None else salt, + constructor_calldata=proxy_call_data, + version=TRANSACTION_VERSION, + nonce=nonce, + max_fee=max_fee, + signature=[] + ) + return external_tx + + @staticmethod + async def send_deploy_tx(starknet: Starknet, unsigned_tx: DeployAccount, contract_cls: ContractClass, signature: List[int]) -> StarknetContract: + external_tx = DeployAccount( + class_hash=unsigned_tx.class_hash, + contract_address_salt=unsigned_tx.contract_address_salt, + constructor_calldata=unsigned_tx.constructor_calldata, + version=unsigned_tx.version, + nonce=unsigned_tx.nonce, + max_fee=unsigned_tx.max_fee, + signature=signature + ) + tx_exec_info = await starknet.state.execute_tx(tx=InternalTransaction.from_external( + external_tx=external_tx, + general_config=starknet.state.general_config + )) + + return StarknetContract( + state=starknet.state, + abi=contract_cls.abi, + contract_address=tx_exec_info.call_info.contract_address, + deploy_call_info=StarknetCallInfo.from_internal( + call_info=tx_exec_info.call_info, + result=(), + main_call_events=tx_exec_info.call_info.events + ) + ) + def from_call_to_call_array(calls: List[Call]): call_array = [] diff --git a/test/utils/utilities.py b/test/utils/utilities.py index 364727d6..8a87ee66 100644 --- a/test/utils/utilities.py +++ b/test/utils/utilities.py @@ -1,5 +1,7 @@ import os from typing import Optional, List, Tuple + +from starkware.starknet.public.abi import AbiType from starkware.starknet.testing.contract import StarknetContract from starkware.starknet.testing.state import StarknetState from starkware.starknet.business_logic.state.state import BlockInfo @@ -9,9 +11,17 @@ from starkware.starknet.business_logic.execution.objects import Event, TransactionExecutionInfo from starkware.starknet.compiler.compile import get_selector_from_name from starkware.starknet.services.api.contract_class import ContractClass +from utils.Signer import Signer DEFAULT_TIMESTAMP = 1640991600 + +signer_key_1 = Signer(1) +signer_key_2 = Signer(2) +signer_key_3 = Signer(3) +signer_key_4 = Signer(4) + + def str_to_felt(text: str) -> int: b_text = bytes(text, 'UTF-8') return int.from_bytes(b_text, "big") @@ -68,14 +78,13 @@ def compile(path: str) -> ContractClass: contract_cls = compile_starknet_files([path], debug_info=True) return contract_cls + def cached_contract(state: StarknetState, _class: ContractClass, deployed: StarknetContract) -> StarknetContract: - contract = StarknetContract( + return build_contract( + contract=deployed, state=state, - abi=_class.abi, - contract_address=deployed.contract_address, - deploy_call_info=deployed.deploy_call_info + custom_abi=_class.abi ) - return contract def get_execute_data(tx_exec_info: TransactionExecutionInfo) -> List[int]: @@ -83,3 +92,24 @@ def get_execute_data(tx_exec_info: TransactionExecutionInfo) -> List[int]: ret_execute_size, *ret_execute = raw_data assert ret_execute_size == len(ret_execute), "Unexpected return size" return ret_execute + + +def copy_contract_state(contract: StarknetContract) -> StarknetContract: + return build_contract(contract=contract, state=contract.state.copy()) + + +def build_contract_with_proxy(proxy: StarknetContract, implementation_abi: AbiType) -> StarknetContract: + return build_contract(state=proxy.state, contract=proxy, custom_abi=implementation_abi) + + +def build_contract_with_state(contract: StarknetContract, state: StarknetState) -> StarknetContract: + return build_contract(state=state, contract=contract, custom_abi=contract.abi) + + +def build_contract(contract: StarknetContract, state: StarknetState = None, custom_abi: AbiType = None) -> StarknetContract: + return StarknetContract( + state=contract.state if state is None else state, + abi=contract.abi if custom_abi is None else custom_abi, + contract_address=contract.contract_address, + deploy_call_info=contract.deploy_call_info + )