diff --git a/package.json b/package.json index 310a8abb..9a424d89 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@rhinestone/modulekit", - "version": "0.4.10", + "version": "0.4.11", "description": "A development kit for building and testing smart account modules.", "license": "GPL-3.0", "author": { diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 886dd200..af188a29 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -170,7 +170,6 @@ library ModuleKitHelpers { address module ) internal - view returns (bool) { return HelperBase(instance.accountHelper).isModuleInstalled(instance, moduleTypeId, module); @@ -183,7 +182,6 @@ library ModuleKitHelpers { bytes memory data ) internal - view returns (bool) { return HelperBase(instance.accountHelper).isModuleInstalled( @@ -312,7 +310,15 @@ library ModuleKitHelpers { //////////////////////////////////////////////////////////////////////////*/ function expect4337Revert(AccountInstance memory) internal { - writeExpectRevert(1); + writeExpectRevert(""); + } + + function expect4337Revert(AccountInstance memory, bytes4 selector) internal { + writeExpectRevert(abi.encodePacked(selector)); + } + + function expect4337Revert(AccountInstance memory, bytes memory message) internal { + writeExpectRevert(message); } /** diff --git a/src/test/helpers/HelperBase.sol b/src/test/helpers/HelperBase.sol index 07990391..282c4f0b 100644 --- a/src/test/helpers/HelperBase.sol +++ b/src/test/helpers/HelperBase.sol @@ -270,8 +270,8 @@ abstract contract HelperBase { address module ) public - view virtual + deployAccountForAction(instance) returns (bool) { return isModuleInstalled(instance, moduleTypeId, module, ""); @@ -284,8 +284,8 @@ abstract contract HelperBase { bytes memory additionalContext ) public - view virtual + deployAccountForAction(instance) returns (bool) { return IERC7579Account(instance.account).isModuleInstalled( diff --git a/src/test/helpers/KernelHelpers.sol b/src/test/helpers/KernelHelpers.sol index 84c3cdb0..42531e2b 100644 --- a/src/test/helpers/KernelHelpers.sol +++ b/src/test/helpers/KernelHelpers.sol @@ -394,9 +394,9 @@ contract KernelHelpers is HelperBase { bytes memory data ) public - view virtual override + deployAccountForAction(instance) returns (bool) { if (moduleTypeId == MODULE_TYPE_HOOK) { diff --git a/src/test/helpers/SafeHelpers.sol b/src/test/helpers/SafeHelpers.sol index 6ece9e0c..8e485f0e 100644 --- a/src/test/helpers/SafeHelpers.sol +++ b/src/test/helpers/SafeHelpers.sol @@ -245,9 +245,9 @@ contract SafeHelpers is HelperBase { bytes memory data ) public - view virtual override + deployAccountForAction(instance) returns (bool) { if (moduleTypeId == MODULE_TYPE_HOOK) { diff --git a/src/test/utils/ERC4337Helpers.sol b/src/test/utils/ERC4337Helpers.sol index 140680ff..2dcd0746 100644 --- a/src/test/utils/ERC4337Helpers.sol +++ b/src/test/utils/ERC4337Helpers.sol @@ -16,7 +16,8 @@ import { GasParser } from "./gas/GasParser.sol"; import { getSimulateUserOp, getExpectRevert, - writeExpectRevert, + getExpectRevertMessage, + clearExpectRevert, getGasIdentifier, writeGasIdentifier, writeInstalledModule, @@ -31,6 +32,8 @@ library ERC4337Helpers { error UserOperationReverted( bytes32 userOpHash, address sender, uint256 nonce, bytes revertReason ); + error InvalidRevertMessage(bytes4 expected, bytes4 reason); + error InvalidRevertMessageBytes(bytes expected, bytes reason); function exec4337(PackedUserOperation[] memory userOps, IEntryPoint onEntryPoint) internal { uint256 isExpectRevert = getExpectRevert(); @@ -49,10 +52,12 @@ library ERC4337Helpers { // Execute userOps address payable beneficiary = payable(address(0x69)); bytes memory userOpCalldata = abi.encodeCall(IEntryPoint.handleOps, (userOps, beneficiary)); - (bool success,) = address(onEntryPoint).call(userOpCalldata); + (bool success, bytes memory returnData) = address(onEntryPoint).call(userOpCalldata); if (isExpectRevert == 0) { require(success, "UserOperation execution failed"); + } else if (isExpectRevert == 2 && !success) { + checkRevertMessage(returnData); } // Parse logs and determine if a revert happened @@ -68,14 +73,17 @@ library ERC4337Helpers { abi.decode(logs[i].data, (uint256, bool, uint256, uint256)); totalUserOpGas = actualGasUsed; if (!userOpSuccess) { + bytes32 userOpHash = logs[i].topics[1]; if (isExpectRevert == 0) { - bytes32 userOpHash = logs[i].topics[1]; bytes memory revertReason = getUserOpRevertReason(logs, userOpHash); revert UserOperationReverted( userOpHash, address(bytes20(logs[i].topics[2])), nonce, revertReason ); } else { - writeExpectRevert(0); + if (isExpectRevert == 2) { + checkRevertMessage(getUserOpRevertReason(logs, userOpHash)); + } + clearExpectRevert(); } } } @@ -115,7 +123,7 @@ library ERC4337Helpers { require(!success, "UserOperation execution did not fail as expected"); } } - writeExpectRevert(0); + clearExpectRevert(); // Calculate gas for userOp string memory gasIdentifier = getGasIdentifier(); @@ -139,7 +147,7 @@ library ERC4337Helpers { function getUserOpRevertReason( VmSafe.Log[] memory logs, - bytes32 /* userOpHash */ + bytes32 userOpHash ) internal pure @@ -150,12 +158,29 @@ library ERC4337Helpers { if ( logs[i].topics[0] == 0x1c4fada7374c0a9ee8841fc38afe82932dc0f8e69012e927f061a8bae611a201 + && logs[i].topics[1] == userOpHash ) { (, revertReason) = abi.decode(logs[i].data, (uint256, bytes)); } } } + function checkRevertMessage(bytes memory actualReason) internal view { + bytes memory revertMessage = getExpectRevertMessage(); + + if (revertMessage.length == 4) { + bytes4 expected = bytes4(revertMessage); + bytes4 actual = bytes4(actualReason); + if (expected != actual) { + revert InvalidRevertMessage(expected, actual); + } + } else { + if (revertMessage.length != actualReason.length) { + revert InvalidRevertMessageBytes(revertMessage, actualReason); + } + } + } + function calculateGas( PackedUserOperation[] memory userOps, IEntryPoint onEntryPoint, diff --git a/src/test/utils/Storage.sol b/src/test/utils/Storage.sol index 474677d8..3fd6cf02 100644 --- a/src/test/utils/Storage.sol +++ b/src/test/utils/Storage.sol @@ -5,8 +5,18 @@ pragma solidity ^0.8.23; EXPECT REVERT //////////////////////////////////////////////////////////////*/ -function writeExpectRevert(uint256 value) { - bytes32 slot = keccak256("ModuleKit.ExpectSlot"); +function writeExpectRevert(bytes memory message) { + uint256 value = 1; + bytes32 slot = keccak256("ModuleKit.ExpectMessageSlot"); + + if (message.length > 0) { + value = 2; + assembly { + sstore(slot, message) + } + } + + slot = keccak256("ModuleKit.ExpectSlot"); assembly { sstore(slot, value) } @@ -19,6 +29,25 @@ function getExpectRevert() view returns (uint256 value) { } } +function getExpectRevertMessage() view returns (bytes memory data) { + bytes32 slot = keccak256("ModuleKit.ExpectMessageSlot"); + assembly { + data := sload(slot) + } +} + +function clearExpectRevert() { + bytes32 slot = keccak256("ModuleKit.ExpectSlot"); + assembly { + sstore(slot, 0) + } + + slot = keccak256("ModuleKit.ExpectMessageSlot"); + assembly { + sstore(slot, 0) + } +} + /*////////////////////////////////////////////////////////////// GAS IDENTIFIER //////////////////////////////////////////////////////////////*/ diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 9dff6393..3047a41b 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -13,12 +13,14 @@ import { } from "src/external/ERC7579.sol"; import { getAccountType, InstalledModule } from "src/test/utils/Storage.sol"; import { toString } from "src/test/utils/Vm.sol"; +import { MockValidatorFalse } from "test/mocks/MockValidatorFalse.sol"; contract ERC7579DifferentialModuleKitLibTest is BaseTest { using ModuleKitHelpers for *; using ModuleKitUserOp for *; MockValidator internal validator; + MockValidatorFalse internal validatorFalse; MockExecutor internal executor; MockFallback internal fallbackHandler; MockHook internal hook; @@ -33,6 +35,7 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { // Setup modules validator = new MockValidator(); + validatorFalse = new MockValidatorFalse(); hook = new MockHook(); executor = new MockExecutor(); fallbackHandler = new MockFallback(); @@ -103,57 +106,72 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { } function testexec__RevertWhen__ValidationFails() public { - // Create userOperation fields - address receiver = makeAddr("receiver"); - uint256 value = 10 gwei; - bytes memory callData = ""; + // No revert reason + _revertWhen__ValidationFails(""); - // Create userOperation - instance.expect4337Revert(); - // Create userOperation - instance.getExecOps({ - target: receiver, - value: value, - callData: callData, - txValidator: makeAddr("invalidValidator") - }).execUserOps(); + // Revert selector + _revertWhen__ValidationFails(abi.encodePacked(bytes4(0x220266b6))); + + // Revert message + _revertWhen__ValidationFails( + abi.encodeWithSignature("FailedOp(uint256,string)", 0, "AA24 signature error") + ); } function testexec__RevertWhen__ValidationReverts() public { - address revertingValidator = makeAddr("revertingValidator"); - vm.etch(revertingValidator, address(validator).code); + // No revert reason + _revertWhen__ValidationReverts(""); - instance.installModule({ - moduleTypeId: MODULE_TYPE_VALIDATOR, - module: revertingValidator, - data: "" - }); + // Revert selector + _revertWhen__ValidationReverts(abi.encodePacked(bytes4(0x65c8fd4d))); - vm.etch(revertingValidator, hex"fd"); + // Revert message + bytes memory revertMessage; - // Create userOperation fields - address receiver = makeAddr("receiver"); - uint256 value = 10 gwei; - bytes memory callData = ""; + AccountType env = ModuleKitHelpers.getAccountType(); + if (env == AccountType.SAFE) { + revertMessage = abi.encodeWithSignature( + "FailedOpWithRevert(uint256,string,bytes)", + 0, + "AA23 reverted", + abi.encode(bytes4(0xacfdb444)) + ); + } else { + revertMessage = abi.encodeWithSignature( + "FailedOpWithRevert(uint256,string,bytes)", 0, "AA23 reverted", "" + ); + } - // Create userOperation - instance.expect4337Revert(); - // Create userOperation - instance.getExecOps({ - target: receiver, - value: value, - callData: callData, - txValidator: revertingValidator - }).execUserOps(); + _revertWhen__ValidationReverts(revertMessage); } function testexec__RevertWhen__UserOperationFails() public { - // Create userOperation fields - bytes memory callData = abi.encodeWithSelector(MockTarget.setAccessControl.selector, 2); + // Deploy the account first + testexec__Given__TwoInputs(); - // Create userOperation - instance.expect4337Revert(); - instance.exec({ target: address(mockTarget), callData: callData, value: 0 }); + // No revert reason + _revertWhen__UserOperationFails(""); + + bytes memory revertSelector; + bytes memory revertMessage; + + AccountType env = ModuleKitHelpers.getAccountType(); + if (env == AccountType.SAFE) { + revertSelector = abi.encodePacked(bytes4(0xacfdb444)); + revertMessage = abi.encodePacked(bytes4(0xacfdb444)); + } else if (env == AccountType.KERNEL) { + revertSelector = abi.encodePacked(bytes4(0xf21e646b)); + revertMessage = abi.encodePacked(bytes4(0xf21e646b)); + } else { + revertSelector = abi.encodePacked(bytes4(0x08c379a0)); + revertMessage = abi.encodeWithSignature("Error(string)", "MockTarget: not authorized"); + } + + // Revert selector + _revertWhen__UserOperationFails(revertSelector); + + // Revert message + _revertWhen__UserOperationFails(revertMessage); } /*////////////////////////////////////////////////////////////////////////// @@ -608,6 +626,97 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { } } + /*////////////////////////////////////////////////////////////// + EXPECT REVERT + //////////////////////////////////////////////////////////////*/ + + function _revertWhen__ValidationFails(bytes memory revertReason) public { + // Create userOperation fields + address receiver = makeAddr("receiver"); + uint256 value = 10 gwei; + bytes memory callData = ""; + + if (!instance.isModuleInstalled(MODULE_TYPE_VALIDATOR, address(validatorFalse))) { + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: address(validatorFalse), + data: "" + }); + } + + // Expect the revert + if (revertReason.length == 0) { + instance.expect4337Revert(); + } else if (revertReason.length == 4) { + instance.expect4337Revert(bytes4(revertReason)); + } else { + instance.expect4337Revert(revertReason); + } + + // Create userOperation + instance.getExecOps({ + target: receiver, + value: value, + callData: callData, + txValidator: address(validatorFalse) + }).execUserOps(); + } + + function _revertWhen__ValidationReverts(bytes memory revertReason) public { + address revertingValidator = makeAddr("revertingValidator"); + + if (!instance.isModuleInstalled(MODULE_TYPE_VALIDATOR, revertingValidator)) { + vm.etch(revertingValidator, address(validator).code); + + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: revertingValidator, + data: "" + }); + + vm.etch(revertingValidator, hex"fd"); + } + + // Create userOperation fields + address receiver = makeAddr("receiver"); + uint256 value = 10 gwei; + bytes memory callData = ""; + + // Expect the revert + if (revertReason.length == 0) { + instance.expect4337Revert(); + } else if (revertReason.length == 4) { + instance.expect4337Revert(bytes4(revertReason)); + } else { + instance.expect4337Revert(revertReason); + } + + // Create userOperation + instance.getExecOps({ + target: receiver, + value: value, + callData: callData, + txValidator: revertingValidator + }).execUserOps(); + } + + function _revertWhen__UserOperationFails(bytes memory revertReason) public { + // Create userOperation fields + bytes memory callData = abi.encodeWithSelector(MockTarget.setAccessControl.selector, 2); + + // Expect the revert + if (revertReason.length == 0) { + instance.expect4337Revert(); + } else if (revertReason.length == 4) { + instance.expect4337Revert(bytes4(revertReason)); + } else { + instance.expect4337Revert(revertReason); + } + + // Create userOperation + instance.exec({ target: address(mockTarget), callData: callData, value: 0 }); + } + /*////////////////////////////////////////////////////////////// MODIFIERS //////////////////////////////////////////////////////////////*/ diff --git a/test/mocks/MockValidatorFalse.sol b/test/mocks/MockValidatorFalse.sol new file mode 100644 index 00000000..0b032c64 --- /dev/null +++ b/test/mocks/MockValidatorFalse.sol @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +/* solhint-disable no-unused-vars */ +import { ERC7579ValidatorBase } from "src/Modules.sol"; +import { PackedUserOperation } from "src/external/ERC4337.sol"; + +contract MockValidatorFalse is ERC7579ValidatorBase { + function onInstall(bytes calldata data) external virtual override { } + + function onUninstall(bytes calldata data) external virtual override { } + + function validateUserOp( + PackedUserOperation calldata userOp, + bytes32 userOpHash + ) + external + virtual + override + returns (ValidationData) + { + return _packValidationData({ sigFailed: true, validUntil: type(uint48).max, validAfter: 0 }); + } + + function isValidSignatureWithSender( + address sender, + bytes32 hash, + bytes calldata data + ) + external + view + virtual + override + returns (bytes4) + { + return EIP1271_FAILED; + } + + function isModuleType(uint256 typeID) external pure override returns (bool) { + return typeID == TYPE_VALIDATOR; + } + + function isInitialized(address smartAccount) external pure returns (bool) { + return false; + } +}