diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 07a6083e..3d940cbe 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -15,13 +15,23 @@ import { PackedUserOperation } from "../external/ERC4337.sol"; import { ERC4337Helpers } from "./utils/ERC4337Helpers.sol"; import { HelperBase } from "./helpers/HelperBase.sol"; import { Execution, MODULE_TYPE_HOOK } from "../external/ERC7579.sol"; -import { prank } from "src/test/utils/Vm.sol"; +import { + prank, + VmSafe, + startStateDiffRecording as vmStartStateDiffRecording, + stopAndReturnStateDiff as vmStopAndReturnStateDiff, + getMappingKeyAndParentOf, + envOr +} from "src/test/utils/Vm.sol"; import { getAccountType as getAccountTypeFromStorage, writeAccountType, writeExpectRevert, writeGasIdentifier, writeSimulateUserOp, + writeStorageCompliance, + getStorageCompliance, + getSimulateUserOp, writeAccountEnv, getFactory, getHelper as getHelperFromStorage, @@ -140,6 +150,27 @@ library ModuleKitHelpers { return exec(instance, target, 0, callData); } + /*////////////////////////////////////////////////////////////// + HOOKS + //////////////////////////////////////////////////////////////*/ + + function preEnvHook() internal { + if (envOr("COMPLIANCE", false) || getStorageCompliance()) { + // Start state diff recording + vmStartStateDiffRecording(); + } + } + + function postEnvHook(AccountInstance memory instance, bytes memory data) internal { + if (envOr("COMPLIANCE", false) || getStorageCompliance()) { + address module = abi.decode(data, (address)); + // Stop state diff recording and return account accesses + VmSafe.AccountAccess[] memory accountAccesses = vmStopAndReturnStateDiff(); + // Check if storage was cleared + verifyModuleStorageWasCleared(instance, accountAccesses, module); + } + } + /*////////////////////////////////////////////////////////////////////////// MODULE CONFIG //////////////////////////////////////////////////////////////////////////*/ @@ -153,6 +184,8 @@ library ModuleKitHelpers { internal returns (UserOpData memory userOpData) { + // Run preEnvHook + preEnvHook(); userOpData = instance.getInstallModuleOps( moduleTypeId, module, data, address(instance.defaultValidator) ); @@ -181,6 +214,8 @@ library ModuleKitHelpers { // send userOp to entrypoint userOpData.execUserOps(); + // Run postEnvHook + postEnvHook(instance, abi.encode(module)); } function isModuleInstalled( @@ -323,6 +358,80 @@ library ModuleKitHelpers { } } } + + /// Start recording the state diff + function startStateDiffRecording(AccountInstance memory) internal { + vmStartStateDiffRecording(); + } + + /// Stop recording the state diff and return the account accesses + function stopAndReturnStateDiff(AccountInstance memory) + internal + returns (VmSafe.AccountAccess[] memory) + { + return vmStopAndReturnStateDiff(); + } + + /// Verifies from an accountAccesses array that storage was correctly cleared after uninstalling + /// a module + function verifyModuleStorageWasCleared( + AccountInstance memory, + VmSafe.AccountAccess[] memory accountAccesses, + address module + ) + internal + view + { + bytes32[] memory seenSlots = new bytes32[](1000); + bytes32[] memory finalValues = new bytes32[](1000); + uint256 numSlots; + + // Loop through account accesses + for (uint256 i; i < accountAccesses.length; i++) { + // Skip tests + if (accountAccesses[i].accessor == address(this)) { + continue; + } + + // If we are accessing the storage of the module check writes and clears + if (accountAccesses[i].account == module) { + // Process all storage accesses for this module + for (uint256 j; j < accountAccesses[i].storageAccesses.length; j++) { + VmSafe.StorageAccess memory access = accountAccesses[i].storageAccesses[j]; + + // Skip reads + if (!access.isWrite) { + continue; + } + + // Find if we've seen this slot + bool found; + for (uint256 k; k < numSlots; k++) { + if (seenSlots[k] == access.slot) { + finalValues[k] = access.newValue; + found = true; + break; + } + } + + // If not seen, add it + if (!found) { + seenSlots[numSlots] = access.slot; + finalValues[numSlots] = access.newValue; + numSlots++; + } + } + } + } + + // Check if any slot's final value is non-zero + for (uint256 i; i < numSlots; i++) { + if (finalValues[i] != bytes32(0)) { + revert("Storage not cleared after uninstalling module"); + } + } + } + /*////////////////////////////////////////////////////////////////////////// CONTROL FLOW //////////////////////////////////////////////////////////////////////////*/ @@ -355,6 +464,10 @@ library ModuleKitHelpers { writeSimulateUserOp(value); } + function storageCompliance(AccountInstance memory, bool value) internal { + writeStorageCompliance(value); + } + /*////////////////////////////////////////////////////////////////////////// ACCOUNT UTILS //////////////////////////////////////////////////////////////////////////*/ diff --git a/src/test/RhinestoneModuleKit.sol b/src/test/RhinestoneModuleKit.sol index 25eb4857..177b73b5 100644 --- a/src/test/RhinestoneModuleKit.sol +++ b/src/test/RhinestoneModuleKit.sol @@ -29,6 +29,7 @@ import { writeHelper } from "./utils/Storage.sol"; import { ModuleKitHelpers } from "./ModuleKitHelpers.sol"; +import { VmSafe } from "./utils/Vm.sol"; enum AccountType { DEFAULT, @@ -327,4 +328,15 @@ contract RhinestoneModuleKit is AuxiliaryFactory { defaultSessionValidator: ISessionValidator(sessionValidator) }); } + + /*////////////////////////////////////////////////////////////// + STORAGE CLEARING + //////////////////////////////////////////////////////////////*/ + + modifier withModuleStorageClearValidation(AccountInstance memory instance, address module) { + instance.startStateDiffRecording(); + _; + VmSafe.AccountAccess[] memory accountAccess = instance.stopAndReturnStateDiff(); + instance.verifyModuleStorageWasCleared(accountAccess, module); + } } diff --git a/src/test/utils/Storage.sol b/src/test/utils/Storage.sol index 3fd6cf02..f46742a8 100644 --- a/src/test/utils/Storage.sol +++ b/src/test/utils/Storage.sol @@ -80,6 +80,24 @@ function getSimulateUserOp() view returns (bool value) { } } +/*////////////////////////////////////////////////////////////// + STORAGE COMPLIANCE +//////////////////////////////////////////////////////////////*/ + +function writeStorageCompliance(bool value) { + bytes32 slot = keccak256("ModuleKit.StorageCompliance"); + assembly { + sstore(slot, value) + } +} + +function getStorageCompliance() view returns (bool value) { + bytes32 slot = keccak256("ModuleKit.StorageCompliance"); + assembly { + value := sload(slot) + } +} + /*////////////////////////////////////////////////////////////// ACCOUNT ENV //////////////////////////////////////////////////////////////*/ diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 3047a41b..44155c7b 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -14,6 +14,11 @@ import { import { getAccountType, InstalledModule } from "src/test/utils/Storage.sol"; import { toString } from "src/test/utils/Vm.sol"; import { MockValidatorFalse } from "test/mocks/MockValidatorFalse.sol"; +import { MockK1Validator, VALIDATION_SUCCESS } from "test/mocks/MockK1Validator.sol"; +import { MockK1ValidatorUncompliantUninstall } from + "test/mocks/MockK1ValidatorUncompliantUninstall.sol"; +import { VALIDATION_SUCCESS, VALIDATION_FAILED } from "erc7579/interfaces/IERC7579Module.sol"; +import { VmSafe } from "src/test/utils/Vm.sol"; contract ERC7579DifferentialModuleKitLibTest is BaseTest { using ModuleKitHelpers for *; @@ -27,6 +32,7 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { MockTarget internal mockTarget; MockERC20 internal token; + address module; function setUp() public override { super.setUp(); @@ -626,6 +632,101 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { } } + function test_verifyModuleStorageWasCleared() public { + // Set simulate mode to false + instance.simulateUserOp(false); + // Install a module + address module = address(new MockK1Validator()); + // Start state diff recording + instance.startStateDiffRecording(); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(instance.account) + }); + // Uninstall the module + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + // Stop state diff recording + VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff(); + // Assert that the module storage was cleared + instance.verifyModuleStorageWasCleared(accountAccesses, module); + } + + function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared_UsingComplianceFlag() + public + { + // Set simulate mode to false + instance.simulateUserOp(false); + // Set compliance flag + instance.storageCompliance(true); + + // Install a module + module = address(new MockK1ValidatorUncompliantUninstall()); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(0xffffffffffffffffffff) + }); + // Assert module storage + assertEq( + address(0xffffffffffffffffffff), + MockK1Validator(module).smartAccountOwners(address(instance.account)) + ); + // Expect revert + vm.expectRevert(); + this.__revertWhen_verifyModuleStorageWasCleared_NotCleared(); + } + + function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared() public { + // Set simulate mode to false + instance.simulateUserOp(false); + // Install a module + module = address(new MockK1ValidatorUncompliantUninstall()); + // Start state diff recording + instance.startStateDiffRecording(); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(0xffffffffffffffffffff) + }); + // Assert module storage + assertEq( + address(0xffffffffffffffffffff), + MockK1Validator(module).smartAccountOwners(address(instance.account)) + ); + // Uninstall the module + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + // Stop state diff recording + VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff(); + // Expect revert + vm.expectRevert(); + // Assert that the module storage was cleared + instance.verifyModuleStorageWasCleared(accountAccesses, module); + } + + function __revertWhen_verifyModuleStorageWasCleared_NotCleared() public { + // Uninstall + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + } + + function test_withModuleStorageClearValidation() + public + withModuleStorageClearValidation(instance, module) + { + // Set simulate mode to false + instance.simulateUserOp(false); + // Install a module + module = address(new MockK1Validator()); + // Install the module + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(VALIDATION_FAILED) + }); + // Uninstall the module + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + } + /*////////////////////////////////////////////////////////////// EXPECT REVERT //////////////////////////////////////////////////////////////*/ diff --git a/test/integrations/SmartSession.t.sol b/test/integrations/SmartSession.t.sol index 6d627879..e481e5c5 100644 --- a/test/integrations/SmartSession.t.sol +++ b/test/integrations/SmartSession.t.sol @@ -232,7 +232,7 @@ contract SmartSessionTest is BaseTest { instance.installModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: address(mockK1Validator), - data: abi.encodePacked(owner.addr) + data: abi.encode(owner.addr) }); // Install smart session @@ -288,7 +288,7 @@ contract SmartSessionTest is BaseTest { instance.installModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: address(mockK1Validator), - data: abi.encodePacked(owner.addr) + data: abi.encode(owner.addr) }); // Install smart session diff --git a/test/SwapTest.t.sol b/test/integrations/SwapTest.t.sol similarity index 98% rename from test/SwapTest.t.sol rename to test/integrations/SwapTest.t.sol index ae745347..a4458df8 100644 --- a/test/SwapTest.t.sol +++ b/test/integrations/SwapTest.t.sol @@ -5,7 +5,7 @@ import "test/BaseTest.t.sol"; import "src/ModuleKit.sol"; import { ERC7579ExecutorBase } from "src/Modules.sol"; import { IERC20 } from "forge-std/interfaces/IERC20.sol"; -import { UniswapV3Integration } from "../src/integrations/uniswap/v3/Uniswap.sol"; +import { UniswapV3Integration } from "src/integrations/uniswap/v3/Uniswap.sol"; contract TestUniswap is BaseTest { using ModuleKitHelpers for AccountInstance; diff --git a/test/mocks/MockK1Validator.sol b/test/mocks/MockK1Validator.sol index e60224f2..4c8504ff 100644 --- a/test/mocks/MockK1Validator.sol +++ b/test/mocks/MockK1Validator.sol @@ -43,7 +43,8 @@ contract MockK1Validator is IValidator { } function onInstall(bytes calldata data) external { - smartAccountOwners[msg.sender] = address(bytes20(data)); + address owner = abi.decode(data, (address)); + smartAccountOwners[msg.sender] = owner; } function onUninstall(bytes calldata data) external { diff --git a/test/mocks/MockK1ValidatorUncompliantUninstall.sol b/test/mocks/MockK1ValidatorUncompliantUninstall.sol new file mode 100644 index 00000000..e0e4ed89 --- /dev/null +++ b/test/mocks/MockK1ValidatorUncompliantUninstall.sol @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { + IValidator, + VALIDATION_SUCCESS, + VALIDATION_FAILED, + MODULE_TYPE_VALIDATOR +} from "erc7579/interfaces/IERC7579Module.sol"; +import { PackedUserOperation } from "src/external/ERC4337.sol"; +import { ECDSA } from "solady/utils/ECDSA.sol"; +import { SignatureCheckerLib } from "solady/utils/SignatureCheckerLib.sol"; +import { MessageHashUtils } from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import { EIP1271_MAGIC_VALUE, IERC1271 } from "module-bases/interfaces/IERC1271.sol"; + +contract MockK1ValidatorUncompliantUninstall is IValidator { + bytes4 constant ERC1271_INVALID = 0xffffffff; + mapping(address => address) public smartAccountOwners; + + function validateUserOp( + PackedUserOperation calldata userOp, + bytes32 userOpHash + ) + external + view + returns (uint256 validation) + { + return ECDSA.recover(MessageHashUtils.toEthSignedMessageHash(userOpHash), userOp.signature) + == smartAccountOwners[msg.sender] ? VALIDATION_SUCCESS : VALIDATION_FAILED; + } + + function isValidSignatureWithSender( + address, + bytes32 hash, + bytes calldata signature + ) + external + view + returns (bytes4) + { + return ECDSA.recover(MessageHashUtils.toEthSignedMessageHash(hash), signature) + == smartAccountOwners[msg.sender] ? EIP1271_MAGIC_VALUE : ERC1271_INVALID; + } + + function onInstall(bytes calldata data) external { + address owner = abi.decode(data, (address)); + smartAccountOwners[msg.sender] = owner; + } + + function onUninstall(bytes calldata data) external pure { + data; + } + + function isModuleType(uint256 moduleTypeId) external pure returns (bool) { + return moduleTypeId == MODULE_TYPE_VALIDATOR; + } + + function isOwner(address account, address owner) external view returns (bool) { + return smartAccountOwners[account] == owner; + } + + function isInitialized(address) external pure returns (bool) { + return false; + } + + function getOwner(address account) external view returns (address) { + return smartAccountOwners[account]; + } +}