From c0271d7f08a1c4a0cedfa638c370874f0dec1c85 Mon Sep 17 00:00:00 2001 From: zeroknots Date: Thu, 14 Mar 2024 13:19:11 +0700 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Installation=20and=20deinstallation?= =?UTF-8?q?=20now=20works=20with=20calldata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- accounts/safe7579/src/SafeERC7579.sol | 56 ++++++++++++++----- accounts/safe7579/src/core/ModuleManager.sol | 8 +-- .../safe7579/src/interfaces/ISafe7579Init.sol | 20 +++++++ accounts/safe7579/src/utils/Launchpad.sol | 32 +++++++---- accounts/safe7579/test/Base.t.sol | 17 ++++-- accounts/safe7579/test/SafeERC7579.t.sol | 5 +- 6 files changed, 102 insertions(+), 36 deletions(-) create mode 100644 accounts/safe7579/src/interfaces/ISafe7579Init.sol diff --git a/accounts/safe7579/src/SafeERC7579.sol b/accounts/safe7579/src/SafeERC7579.sol index fc9742e0..2d9f1d71 100644 --- a/accounts/safe7579/src/SafeERC7579.sol +++ b/accounts/safe7579/src/SafeERC7579.sol @@ -29,6 +29,7 @@ import { } from "@ERC4337/account-abstraction/contracts/core/UserOperationLib.sol"; import { _packValidationData } from "@ERC4337/account-abstraction/contracts/core/Helpers.sol"; import { IEntryPoint } from "@ERC4337/account-abstraction/contracts/interfaces/IEntryPoint.sol"; +import { ISafe7579Init } from "./interfaces/ISafe7579Init.sol"; /** * @title ERC7579 Adapter for Safe accounts. @@ -36,7 +37,14 @@ import { IEntryPoint } from "@ERC4337/account-abstraction/contracts/interfaces/I * this contract creates full ERC7579 compliance to Safe accounts * @author zeroknots.eth | rhinestone.wtf */ -contract SafeERC7579 is ISafeOp, IERC7579Account, AccessControl, IMSA, HookManager { +contract SafeERC7579 is + ISafeOp, + IERC7579Account, + ISafe7579Init, + AccessControl, + IMSA, + HookManager +{ using UserOperationLib for PackedUserOperation; using ModeLib for ModeCode; using ExecutionLib for bytes; @@ -390,26 +398,46 @@ contract SafeERC7579 is ISafeOp, IERC7579Account, AccessControl, IMSA, HookManag return keccak256(abi.encode(DOMAIN_SEPARATOR_TYPEHASH, block.chainid, this)); } - function initializeAccount(bytes calldata initCode) external payable { + function initializeAccount(bytes calldata callData) external payable override { + // TODO: destructuring callData + } + + function initializeAccount( + ModuleInit[] calldata validators, + ModuleInit[] calldata executors, + ModuleInit[] calldata fallbacks, + ModuleInit[] calldata hooks + ) + public + payable + override + { _initModuleManager(); - ( - address[] memory validator, - bytes[] memory validatorInitcode, - address[] memory executors, - bytes[] memory executorsInitcode - ) = abi.decode(initCode, (address[], bytes[], address[], bytes[])); - - uint256 length = validator.length; - if (length != validatorInitcode.length) revert("Invalid input"); + // InitData memory initDatas = abi.decode(initCode, (InitData)); + + uint256 length = validators.length; for (uint256 i; i < length; i++) { - _installValidator(validator[i], validatorInitcode[i]); + ModuleInit calldata validator = validators[i]; + _installValidator(validator.module, validator.initData); } length = executors.length; - if (length != executorsInitcode.length) revert("Invalid input"); for (uint256 i; i < length; i++) { - _installExecutor(executors[i], executorsInitcode[i]); + ModuleInit calldata executor = executors[i]; + _installExecutor(executor.module, executor.initData); + } + + length = fallbacks.length; + for (uint256 i; i < length; i++) { + ModuleInit calldata fallBack = fallbacks[i]; + _installFallbackHandler(fallBack.module, fallBack.initData); + } + + length = hooks.length; + for (uint256 i; i < length; i++) { + ModuleInit calldata hook = hooks[i]; + _installFallbackHandler(hook.module, hook.initData); } emit Safe7579Initialized(msg.sender); diff --git a/accounts/safe7579/src/core/ModuleManager.sol b/accounts/safe7579/src/core/ModuleManager.sol index de6a1b73..6fa710cd 100644 --- a/accounts/safe7579/src/core/ModuleManager.sol +++ b/accounts/safe7579/src/core/ModuleManager.sol @@ -65,7 +65,7 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper { /** * install and initialize validator module */ - function _installValidator(address validator, bytes memory data) internal virtual { + function _installValidator(address validator, bytes calldata data) internal virtual { $validators.push({ account: msg.sender, newEntry: validator }); // Initialize Validator Module via Safe @@ -80,7 +80,7 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper { /** * Uninstall and de-initialize validator module */ - function _uninstallValidator(address validator, bytes memory data) internal { + function _uninstallValidator(address validator, bytes calldata data) internal { (address prev, bytes memory disableModuleData) = abi.decode(data, (address, bytes)); $validators.pop({ account: msg.sender, prevEntry: prev, popEntry: validator }); @@ -130,7 +130,7 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper { // Manage Executors //////////////////////////////////////////////////// - function _installExecutor(address executor, bytes memory data) internal { + function _installExecutor(address executor, bytes calldata data) internal { SentinelListLib.SentinelList storage $executors = $moduleManager[msg.sender]._executors; $executors.push(executor); // Initialize Executor Module via Safe @@ -268,7 +268,7 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper { // Add 20 bytes for the address appended add the end let success := - staticcall(gas(), handler, calldataPtr, add(calldatasize(), 20), 0, 0) + staticcall(gas(), handler, calldataPtr, add(calldatasize(), 20), 0, 0) let returnDataPtr := allocate(returndatasize()) returndatacopy(returnDataPtr, 0, returndatasize()) diff --git a/accounts/safe7579/src/interfaces/ISafe7579Init.sol b/accounts/safe7579/src/interfaces/ISafe7579Init.sol new file mode 100644 index 00000000..eb86459a --- /dev/null +++ b/accounts/safe7579/src/interfaces/ISafe7579Init.sol @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { CallType } from "erc7579/lib/ModeLib.sol"; + +interface ISafe7579Init { + struct ModuleInit { + address module; + bytes initData; + } + + function initializeAccount( + ModuleInit[] calldata validators, + ModuleInit[] calldata executors, + ModuleInit[] calldata fallbacks, + ModuleInit[] calldata hooks + ) + external + payable; +} diff --git a/accounts/safe7579/src/utils/Launchpad.sol b/accounts/safe7579/src/utils/Launchpad.sol index 1abc5030..b53725e0 100644 --- a/accounts/safe7579/src/utils/Launchpad.sol +++ b/accounts/safe7579/src/utils/Launchpad.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; import { ISafe, SafeERC7579 } from "../SafeERC7579.sol"; +import { ISafe7579Init } from "../interfaces/ISafe7579Init.sol"; /** * Helper contract that gets delegatecalled byt SafeProxy.setup() to setup safe7579 as a module @@ -15,9 +16,22 @@ contract Safe7579Launchpad { SAFE7579Singleton = _safe7579Singleton; } - function initSafe7579(address safe7579, bytes calldata safe7579InitCode) public { + // function initSafe7579(address safe7579, bytes calldata safe7579InitCode) public { + // ISafe(address(this)).enableModule(safe7579); + // SafeERC7579(payable(safe7579)).initializeAccount(safe7579InitCode); + // } + + function initSafe7579( + address safe7579, + ISafe7579Init.ModuleInit[] calldata validators, + ISafe7579Init.ModuleInit[] calldata executors, + ISafe7579Init.ModuleInit[] calldata fallbacks, + ISafe7579Init.ModuleInit[] calldata hooks + ) + public + { ISafe(address(this)).enableModule(safe7579); - SafeERC7579(payable(safe7579)).initializeAccount(safe7579InitCode); + SafeERC7579(payable(safe7579)).initializeAccount(validators, executors, fallbacks, hooks); } function predictSafeAddress( @@ -54,21 +68,17 @@ contract Safe7579Launchpad { function getInitCode( address[] memory signers, uint256 threshold, - address[] calldata validators, - bytes[] calldata validatorsInitCode, - address[] calldata executors, - bytes[] calldata executorsInitCode + ISafe7579Init.ModuleInit[] calldata validators, + ISafe7579Init.ModuleInit[] calldata executors, + ISafe7579Init.ModuleInit[] calldata fallbacks, + ISafe7579Init.ModuleInit[] calldata hooks ) external view returns (bytes memory initCode) { bytes memory safeLaunchPadSetup = abi.encodeCall( - this.initSafe7579, - ( - address(SAFE7579Singleton), - abi.encode(validators, validatorsInitCode, executors, executorsInitCode) - ) + this.initSafe7579, (address(SAFE7579Singleton), validators, executors, fallbacks, hooks) ); // SETUP SAFE initCode = abi.encodeCall( diff --git a/accounts/safe7579/test/Base.t.sol b/accounts/safe7579/test/Base.t.sol index 7e886b48..d1f18061 100644 --- a/accounts/safe7579/test/Base.t.sol +++ b/accounts/safe7579/test/Base.t.sol @@ -48,13 +48,22 @@ contract TestBaseUtil is Test { bytes32 salt; + ISafe7579Init.ModuleInit[] memory validators = new ISafe7579Init.ModuleInit[](1); + validators[0] = + ISafe7579Init.ModuleInit({ module: address(defaultValidator), initData: bytes("") }); + ISafe7579Init.ModuleInit[] memory executors = new ISafe7579Init.ModuleInit[](1); + executors[0] = + ISafe7579Init.ModuleInit({ module: address(defaultExecutor), initData: bytes("") }); + ISafe7579Init.ModuleInit[] memory fallbacks = new ISafe7579Init.ModuleInit[](0); + ISafe7579Init.ModuleInit[] memory hooks = new ISafe7579Init.ModuleInit[](0); + bytes memory initializer = launchpad.getInitCode({ signers: Solarray.addresses(signer1.addr, signer2.addr), threshold: 2, - validators: Solarray.addresses(address(defaultValidator)), - validatorsInitCode: Solarray.bytess(""), - executors: Solarray.addresses(address(defaultExecutor)), - executorsInitCode: Solarray.bytess("") + validators: validators, + executors: executors, + fallbacks: fallbacks, + hooks: hooks }); // computer counterfactual address for SafeProxy safe = Safe( diff --git a/accounts/safe7579/test/SafeERC7579.t.sol b/accounts/safe7579/test/SafeERC7579.t.sol index e1304392..95bdedc2 100644 --- a/accounts/safe7579/test/SafeERC7579.t.sol +++ b/accounts/safe7579/test/SafeERC7579.t.sol @@ -7,6 +7,7 @@ import "erc7579/lib/ExecutionLib.sol"; import { TestBaseUtil, MockTarget, MockFallback } from "./Base.t.sol"; import "forge-std/console2.sol"; + CallType constant CALLTYPE_STATIC = CallType.wrap(0xFE); contract Safe7579Test is TestBaseUtil { @@ -169,13 +170,11 @@ contract Safe7579Test is TestBaseUtil { IERC7579Account(address(safe)).installModule( 3, address(_fallback), abi.encode(MockFallback.target.selector, CALLTYPE_STATIC, "") ); - ( ret, msgSender, context) = MockFallback(address(safe)).target(1337); + (ret, msgSender, context) = MockFallback(address(safe)).target(1337); assertEq(ret, 1337); assertEq(msgSender, address(safe7579)); assertEq(context, address(safe)); - - vm.prank(address(safe)); IERC7579Account(address(safe)).installModule( 3,