Skip to content

Commit

Permalink
✨ Safe7579 now supports complex fallbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroknots committed Mar 14, 2024
1 parent f1ad273 commit 0774d2b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 66 deletions.
26 changes: 26 additions & 0 deletions accounts/safe7579/src/core/ExecutionHelper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,30 @@ abstract contract ExecutionHelper {
_executeReturnData(safe, execution.target, execution.value, execution.callData);
}
}

/**
* Execute staticcall on Safe, get return value from call
* @dev This function will revert if the call fails
* @param safe address of the safe
* @param target address of the contract to call
* @param value value of the transaction
* @param callData data of the transaction
* @return returnData data returned from the call
*/
function _executeStaticReturnData(
address safe,
address target,
uint256 value,
bytes memory callData
)
internal
view
returns (bytes memory returnData)
{
bool success;
(success, returnData) = safe.staticcall(
abi.encodeCall(ISafe.execTransactionFromModuleReturnData, (target, value, callData, 0))
);
if (!success) revert ExecutionFailed();
}
}
71 changes: 12 additions & 59 deletions accounts/safe7579/src/core/ModuleManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper {
error InitializerError();
error ValidatorStorageHelperError();
error NoFallbackHandler(bytes4 msgSig);
error FallbackInstalled(bytes4 msgSig);

mapping(address smartAccount => ModuleManagerStorage moduleManagerStorage) internal
$moduleManager;
Expand Down Expand Up @@ -176,16 +177,12 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper {
function _installFallbackHandler(address handler, bytes calldata params) internal virtual {
(bytes4 functionSig, CallType calltype, bytes memory initData) =
abi.decode(params, (bytes4, CallType, bytes));
if (_isFallbackHandlerInstalled(functionSig)) revert();
if (_isFallbackHandlerInstalled(functionSig)) revert FallbackInstalled(functionSig);

FallbackHandler storage $fallbacks = $moduleManager[msg.sender]._fallbacks[functionSig];
$fallbacks.calltype = calltype;
$fallbacks.handler = handler;

//
// ModuleManagerStorage storage $mms = $moduleManager[msg.sender];
// $mms.fallbackHandler = handler;
// // Initialize Fallback Module via Safe
_execute({
safe: msg.sender,
target: handler,
Expand Down Expand Up @@ -230,71 +227,27 @@ abstract contract ModuleManager is AccessControl, Receiver, ExecutionHelper {

// FALLBACK
// solhint-disable-next-line no-complex-fallback
fallback() external payable override(Receiver) receiverFallback {
fallback(bytes calldata callData)
external
payable
override(Receiver)
receiverFallback
returns (bytes memory fallbackRet)
{
FallbackHandler storage $fallbackHandler = $moduleManager[msg.sender]._fallbacks[msg.sig];
address handler = $fallbackHandler.handler;
CallType calltype = $fallbackHandler.calltype;
if (handler == address(0)) revert NoFallbackHandler(msg.sig);

if (calltype == CALLTYPE_STATIC) {
assembly {
function allocate(length) -> pos {
pos := mload(0x40)
mstore(0x40, add(pos, length))
}

let calldataPtr := allocate(calldatasize())
calldatacopy(calldataPtr, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
let senderPtr := allocate(20)
mstore(senderPtr, shl(96, caller()))

// Add 20 bytes for the address appended add the end
let success :=
staticcall(gas(), handler, calldataPtr, add(calldatasize(), 20), 0, 0)

let returnDataPtr := allocate(returndatasize())
returndatacopy(returnDataPtr, 0, returndatasize())
if iszero(success) { revert(returnDataPtr, returndatasize()) }
return(returnDataPtr, returndatasize())
}
return _executeStaticReturnData(msg.sender, handler, 0, callData);
}
if (calltype == CALLTYPE_SINGLE) {
assembly {
function allocate(length) -> pos {
pos := mload(0x40)
mstore(0x40, add(pos, length))
}

let calldataPtr := allocate(calldatasize())
calldatacopy(calldataPtr, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
let senderPtr := allocate(20)
mstore(senderPtr, shl(96, caller()))

// Add 20 bytes for the address appended add the end
let success := call(gas(), handler, 0, calldataPtr, add(calldatasize(), 20), 0, 0)

let returnDataPtr := allocate(returndatasize())
returndatacopy(returnDataPtr, 0, returndatasize())
if iszero(success) { revert(returnDataPtr, returndatasize()) }
return(returnDataPtr, returndatasize())
}
return _executeReturnData(msg.sender, handler, 0, callData);
}

if (calltype == CALLTYPE_DELEGATECALL) {
assembly {
calldatacopy(0, 0, calldatasize())
let result := delegatecall(gas(), handler, 0, calldatasize(), 0, 0)
returndatacopy(0, 0, returndatasize())
switch result
case 0 { revert(0, returndatasize()) }
default { return(0, returndatasize()) }
}
return _executeDelegateCallReturnData(msg.sender, handler, callData);
}
}
}
20 changes: 16 additions & 4 deletions accounts/safe7579/test/SafeERC7579.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,24 @@ contract Safe7579Test is TestBaseUtil {
test_initializeAccount();
MockFallback _fallback = new MockFallback();
vm.prank(address(safe));
IERC7579Account(address(safe)).installModule(3, address(_fallback), "");
(uint256 ret, address erc2771Sender, address msgSender) =
MockFallback(address(safe)).target(1337);
IERC7579Account(address(safe)).installModule(
3, address(_fallback), abi.encode(MockFallback.target.selector, CALLTYPE_SINGLE, "")
);
(uint256 ret, address msgSender) = MockFallback(address(safe)).target(1337);

assertEq(ret, 1337);
assertEq(erc2771Sender, address(this));
assertEq(msgSender, address(safe));

vm.prank(address(safe));
IERC7579Account(address(safe)).installModule(
3,
address(_fallback),
abi.encode(MockFallback.target2.selector, CALLTYPE_DELEGATECALL, "")
);
(uint256 _ret, address _this, address _msgSender) =
MockFallback(address(safe)).target2(1337);

assertEq(_ret, 1337);
assertEq(_this, address(safe));
}
}
11 changes: 8 additions & 3 deletions accounts/safe7579/test/mocks/MockFallback.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ import { HandlerContext } from "@safe-global/safe-contracts/contracts/handler/Ha
import { MockFallback as MockFallbackBase } from "@rhinestone/modulekit/src/mocks/MockFallback.sol";

contract MockFallback is MockFallbackBase, HandlerContext {
function target(uint256 value)
function target(uint256 value) external returns (uint256 _value, address msgSender) {
_value = value;
msgSender = msg.sender;
}

function target2(uint256 value)
external
returns (uint256 _value, address erc2771Sender, address msgSender)
returns (uint256 _value, address _this, address msgSender)
{
_value = value;
erc2771Sender = _msgSender();
msgSender = msg.sender;
_this = address(this);
}
}

0 comments on commit 0774d2b

Please sign in to comment.