From 352d23e15a5982e745daeac1742daf8d51a3ac22 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 01:29:19 +0100 Subject: [PATCH 1/8] feat(ModuleKitHelpers): add verifyModuleStorageWasCleared --- src/test/ModuleKitHelpers.sol | 94 ++++++++++++++++++++- src/test/RhinestoneModuleKit.sol | 12 +++ test/Diff.t.sol | 63 ++++++++++++++ test/SwapTest.t.sol | 122 --------------------------- test/integrations/SmartSession.t.sol | 4 +- test/mocks/MockK1Validator.sol | 3 +- 6 files changed, 172 insertions(+), 126 deletions(-) delete mode 100644 test/SwapTest.t.sol diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 07a6083e..17f7a0c8 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -15,7 +15,13 @@ import { PackedUserOperation } from "../external/ERC4337.sol"; import { ERC4337Helpers } from "./utils/ERC4337Helpers.sol"; import { HelperBase } from "./helpers/HelperBase.sol"; import { Execution, MODULE_TYPE_HOOK } from "../external/ERC7579.sol"; -import { prank } from "src/test/utils/Vm.sol"; +import { + prank, + VmSafe, + startStateDiffRecording as vmStartStateDiffRecording, + stopAndReturnStateDiff as vmStopAndReturnStateDiff, + getMappingKeyAndParentOf +} from "src/test/utils/Vm.sol"; import { getAccountType as getAccountTypeFromStorage, writeAccountType, @@ -323,6 +329,92 @@ library ModuleKitHelpers { } } } + + /// Start recording the state diff + function startStateDiffRecording(AccountInstance memory) internal { + vmStartStateDiffRecording(); + } + + /// Stop recording the state diff and return the account accesses + function stopAndReturnStateDiff(AccountInstance memory) + internal + returns (VmSafe.AccountAccess[] memory) + { + return vmStopAndReturnStateDiff(); + } + + /// Verifies from an accountAccesses array that storage was correctly cleared after uninstalling + /// a module + function verifyModuleStorageWasCleared( + AccountInstance memory, + VmSafe.AccountAccess[] memory accountAccesses, + address module + ) + internal + view + { + // Track all writes and clears across all accesses + bytes32[] memory allWrittenSlots = new bytes32[](1000); + bytes32[] memory allClearedSlots = new bytes32[](1000); + uint256 totalWritten = 0; + uint256 totalCleared = 0; + + // Loop through account accesses + for (uint256 i; i < accountAccesses.length; i++) { + // Skip tests + if (accountAccesses[i].accessor == address(this)) { + continue; + } + + // If we are accessing the storage of the module check writes and clears + if (accountAccesses[i].account == module) { + // Process all storage accesses for this module + for (uint256 j; j < accountAccesses[i].storageAccesses.length; j++) { + VmSafe.StorageAccess memory access = accountAccesses[i].storageAccesses[j]; + + // Skip reads + if (!access.isWrite) { + continue; + } + + if (access.newValue != bytes32(0)) { + // Record write + allWrittenSlots[totalWritten] = access.slot; + totalWritten++; + } else { + // Record clear + allClearedSlots[totalCleared] = access.slot; + totalCleared++; + } + } + } + } + + // Verify all writes were cleared + for (uint256 i; i < totalWritten; i++) { + bool wasCleared = false; + + for (uint256 j; j < totalCleared; j++) { + if (allWrittenSlots[i] == allClearedSlots[j]) { + wasCleared = true; + break; + } + } + + if (!wasCleared) { + revert("Storage not cleared after uninstalling module"); + } + } + } + + /// Verifies that storage was correctly cleared after uninstalling a module + modifier withUninstallStorageValidation(AccountInstance memory instance, address module) { + instance.startStateDiffRecording(); + _; + VmSafe.AccountAccess[] memory accountAccess = instance.stopAndReturnStateDiff(); + verifyModuleStorageWasCleared(instance, accountAccess, module); + } + /*////////////////////////////////////////////////////////////////////////// CONTROL FLOW //////////////////////////////////////////////////////////////////////////*/ diff --git a/src/test/RhinestoneModuleKit.sol b/src/test/RhinestoneModuleKit.sol index 25eb4857..177b73b5 100644 --- a/src/test/RhinestoneModuleKit.sol +++ b/src/test/RhinestoneModuleKit.sol @@ -29,6 +29,7 @@ import { writeHelper } from "./utils/Storage.sol"; import { ModuleKitHelpers } from "./ModuleKitHelpers.sol"; +import { VmSafe } from "./utils/Vm.sol"; enum AccountType { DEFAULT, @@ -327,4 +328,15 @@ contract RhinestoneModuleKit is AuxiliaryFactory { defaultSessionValidator: ISessionValidator(sessionValidator) }); } + + /*////////////////////////////////////////////////////////////// + STORAGE CLEARING + //////////////////////////////////////////////////////////////*/ + + modifier withModuleStorageClearValidation(AccountInstance memory instance, address module) { + instance.startStateDiffRecording(); + _; + VmSafe.AccountAccess[] memory accountAccess = instance.stopAndReturnStateDiff(); + instance.verifyModuleStorageWasCleared(accountAccess, module); + } } diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 3047a41b..26556ae0 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -14,8 +14,13 @@ import { import { getAccountType, InstalledModule } from "src/test/utils/Storage.sol"; import { toString } from "src/test/utils/Vm.sol"; import { MockValidatorFalse } from "test/mocks/MockValidatorFalse.sol"; +import { MockK1Validator, VALIDATION_SUCCESS } from "test/mocks/MockK1Validator.sol"; +import { VALIDATION_SUCCESS, VALIDATION_FAILED } from "erc7579/interfaces/IERC7579Module.sol"; +import { VmSafe } from "src/test/utils/Vm.sol"; contract ERC7579DifferentialModuleKitLibTest is BaseTest { + event LogAddress(address); + using ModuleKitHelpers for *; using ModuleKitUserOp for *; @@ -626,6 +631,64 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { } } + function test_verifyModuleStorageWasCleared() public { + // Set simulate mode to false + instance.simulateUserOp(false); + // Install a module + address module = address(new MockK1Validator()); + // Start state diff recording + instance.startStateDiffRecording(); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(instance.account) + }); + // Uninstall the module + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + // Stop state diff recording + VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff(); + // Assert that the module storage was cleared + instance.verifyModuleStorageWasCleared(accountAccesses, module); + } + + function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared() public { + // Set simulate mode to false + instance.simulateUserOp(false); + // Install a module + address module = address(new MockK1Validator()); + // Start state diff recording + instance.startStateDiffRecording(); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(0xffffffffffffffffffff) + }); + // Assert module storage + assertEq( + address(0xffffffffffffffffffff), + MockK1Validator(module).smartAccountOwners(address(instance.account)) + ); + // Stop state diff recording + VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff(); + // Expect revert + vm.expectRevert(); + // Assert that the module storage was cleared + instance.verifyModuleStorageWasCleared(accountAccesses, module); + } + + function test_withModuleStorageClearValidation() public { + // Install a module + address module = address(new MockK1Validator()); + // Install the module + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(VALIDATION_FAILED) + }); + // Uninstall the module + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + } + /*////////////////////////////////////////////////////////////// EXPECT REVERT //////////////////////////////////////////////////////////////*/ diff --git a/test/SwapTest.t.sol b/test/SwapTest.t.sol deleted file mode 100644 index ae745347..00000000 --- a/test/SwapTest.t.sol +++ /dev/null @@ -1,122 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.25; - -import "test/BaseTest.t.sol"; -import "src/ModuleKit.sol"; -import { ERC7579ExecutorBase } from "src/Modules.sol"; -import { IERC20 } from "forge-std/interfaces/IERC20.sol"; -import { UniswapV3Integration } from "../src/integrations/uniswap/v3/Uniswap.sol"; - -contract TestUniswap is BaseTest { - using ModuleKitHelpers for AccountInstance; - using UniswapV3Integration for *; - - IERC20 tokenA; - IERC20 tokenB; - MockERC20 mockTokenA; - MockERC20 mockTokenB; - - uint256 amountIn = 100_000_000; // Example: 100 tokens of tokenA - uint32 slippage = 1; // 0.1% slippage - - address internal constant USDC_HOLDER = 0x4B16c5dE96EB2117bBE5fd171E4d203624B014aa; // account - // with USDC holdings - address internal constant WETH_HOLDER = 0x57757E3D981446D585Af0D9Ae4d7DF6D64647806; // account - // with WETH holdings - - address constant USDC_ADDRESS = 0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48; - address constant WETH_ADDRESS = 0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2; - - function setUp() public override { - string memory MAINNET_RPC_URL = vm.envString("MAINNET_RPC_URL"); - vm.createSelectFork(MAINNET_RPC_URL); - vm.rollFork(20_426_591); - instance = makeAccountInstance("account1"); - assertTrue(instance.account != address(0)); - - tokenA = IERC20(USDC_ADDRESS); - tokenB = IERC20(WETH_ADDRESS); - - _fundAccountWithTokenA(amountIn); - vm.deal(instance.account, 1 ether); - assertTrue(instance.account.balance == 1 ether); - } - - function _fundAccountWithTokenA(uint256 amount) internal { - vm.startPrank(USDC_HOLDER); - bool success = tokenA.transfer(instance.account, amount); - require(success, "Failed to transfer tokenA to account"); - vm.stopPrank(); - } - - function testApproveAndSwap() public { - address poolAddress = UniswapV3Integration.getPoolAddress(address(tokenA), address(tokenB)); - uint160 sqrtPriceX96 = UniswapV3Integration.getSqrtPriceX96(poolAddress); - emit log_named_uint("Square Root Price X96", sqrtPriceX96); - - uint256 priceRatio = UniswapV3Integration.sqrtPriceX96toPriceRatio(sqrtPriceX96); - - emit log_named_uint("Price Ratio", priceRatio); - - uint256 price = - UniswapV3Integration.priceRatioToPrice(priceRatio, poolAddress, address(tokenA)); - - emit log_named_uint("Price", price); - - bool swapToken0to1 = UniswapV3Integration.checkTokenOrder(address(tokenA), poolAddress); - - uint256 priceRatioLimit; - if (swapToken0to1) { - priceRatioLimit = (priceRatio * (1000 - slippage)) / 1000; - } else { - priceRatioLimit = (priceRatio * (1000 + slippage)) / 1000; - } - - emit log_named_uint("Price Ratio Limit", priceRatioLimit); - - uint256 priceLimit = - UniswapV3Integration.priceRatioToPrice(priceRatioLimit, poolAddress, address(tokenA)); - - emit log_named_uint("Price Limit", priceLimit); - - uint160 sqrtPriceLimitX96 = UniswapV3Integration.priceRatioToSqrtPriceX96(priceRatioLimit); - - emit log_named_uint("sqrtPriceLimitX96", sqrtPriceLimitX96); - - uint256 initialAccountBalanceA = tokenA.balanceOf(instance.account); - uint256 initialAccountBalanceB = tokenB.balanceOf(instance.account); - - emit log_named_uint("Initial Balance of Token A (account)", initialAccountBalanceA); - emit log_named_uint("Initial Balance of Token B (account)", initialAccountBalanceB); - - Execution[] memory swap = UniswapV3Integration.approveAndSwap( - instance.account, tokenA, tokenB, amountIn, sqrtPriceLimitX96 - ); - - for (uint256 i = 0; i < swap.length; i++) { - instance.exec({ - target: swap[i].target, - value: swap[i].value, - callData: swap[i].callData - }); - } - - uint256 finalAccountBalanceA = tokenA.balanceOf(instance.account); - uint256 finalAccountBalanceB = tokenB.balanceOf(instance.account); - - emit log_named_uint("Final Balance of Token A (account)", finalAccountBalanceA); - emit log_named_uint("Final Balance of Token B (account)", finalAccountBalanceB); - - sqrtPriceX96 = UniswapV3Integration.getSqrtPriceX96(poolAddress); - emit log_named_uint("Post Swap Square Root Price X96", sqrtPriceX96); - - require( - finalAccountBalanceA < initialAccountBalanceA, - "Token A balance in account did not decrease" - ); - require( - finalAccountBalanceB > initialAccountBalanceB, - "Token B balance in account did not increase" - ); - } -} diff --git a/test/integrations/SmartSession.t.sol b/test/integrations/SmartSession.t.sol index 6d627879..e481e5c5 100644 --- a/test/integrations/SmartSession.t.sol +++ b/test/integrations/SmartSession.t.sol @@ -232,7 +232,7 @@ contract SmartSessionTest is BaseTest { instance.installModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: address(mockK1Validator), - data: abi.encodePacked(owner.addr) + data: abi.encode(owner.addr) }); // Install smart session @@ -288,7 +288,7 @@ contract SmartSessionTest is BaseTest { instance.installModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: address(mockK1Validator), - data: abi.encodePacked(owner.addr) + data: abi.encode(owner.addr) }); // Install smart session diff --git a/test/mocks/MockK1Validator.sol b/test/mocks/MockK1Validator.sol index e60224f2..4c8504ff 100644 --- a/test/mocks/MockK1Validator.sol +++ b/test/mocks/MockK1Validator.sol @@ -43,7 +43,8 @@ contract MockK1Validator is IValidator { } function onInstall(bytes calldata data) external { - smartAccountOwners[msg.sender] = address(bytes20(data)); + address owner = abi.decode(data, (address)); + smartAccountOwners[msg.sender] = owner; } function onUninstall(bytes calldata data) external { From 4792f3d3e97820e8234bdb9a339edc158a692206 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 01:30:17 +0100 Subject: [PATCH 2/8] chore: remove unused event --- test/Diff.t.sol | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 26556ae0..6de6b3ef 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -19,8 +19,6 @@ import { VALIDATION_SUCCESS, VALIDATION_FAILED } from "erc7579/interfaces/IERC75 import { VmSafe } from "src/test/utils/Vm.sol"; contract ERC7579DifferentialModuleKitLibTest is BaseTest { - event LogAddress(address); - using ModuleKitHelpers for *; using ModuleKitUserOp for *; From 35742f59ce889a694710beced26d94e1d966b516 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 01:41:44 +0100 Subject: [PATCH 3/8] fix: fix algo --- src/test/ModuleKitHelpers.sol | 46 ++++++++++++++++------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 17f7a0c8..7918fba2 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -353,11 +353,9 @@ library ModuleKitHelpers { internal view { - // Track all writes and clears across all accesses - bytes32[] memory allWrittenSlots = new bytes32[](1000); - bytes32[] memory allClearedSlots = new bytes32[](1000); - uint256 totalWritten = 0; - uint256 totalCleared = 0; + bytes32[] memory seenSlots = new bytes32[](1000); + bytes32[] memory finalValues = new bytes32[](1000); + uint256 numSlots; // Loop through account accesses for (uint256 i; i < accountAccesses.length; i++) { @@ -377,31 +375,29 @@ library ModuleKitHelpers { continue; } - if (access.newValue != bytes32(0)) { - // Record write - allWrittenSlots[totalWritten] = access.slot; - totalWritten++; - } else { - // Record clear - allClearedSlots[totalCleared] = access.slot; - totalCleared++; + // Find if we've seen this slot + bool found; + for (uint256 k; k < numSlots; k++) { + if (seenSlots[k] == access.slot) { + finalValues[k] = access.newValue; + found = true; + break; + } } - } - } - } - // Verify all writes were cleared - for (uint256 i; i < totalWritten; i++) { - bool wasCleared = false; - - for (uint256 j; j < totalCleared; j++) { - if (allWrittenSlots[i] == allClearedSlots[j]) { - wasCleared = true; - break; + // If not seen, add it + if (!found) { + seenSlots[numSlots] = access.slot; + finalValues[numSlots] = access.newValue; + numSlots++; + } } } + } - if (!wasCleared) { + // Check if any slot's final value is non-zero + for (uint256 i; i < numSlots; i++) { + if (finalValues[i] != bytes32(0)) { revert("Storage not cleared after uninstalling module"); } } From c8099704d7a9ff6bf02267780f139e0def6ec3f3 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 03:46:05 +0100 Subject: [PATCH 4/8] fix: add mising modifier, revert swaptest delete --- test/Diff.t.sol | 8 +- test/integrations/SwapTest.t.sol | 122 +++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 test/integrations/SwapTest.t.sol diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 6de6b3ef..44f69033 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -30,6 +30,7 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { MockTarget internal mockTarget; MockERC20 internal token; + address module; function setUp() public override { super.setUp(); @@ -674,9 +675,12 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { instance.verifyModuleStorageWasCleared(accountAccesses, module); } - function test_withModuleStorageClearValidation() public { + function test_withModuleStorageClearValidation() + public + withModuleStorageClearValidation(instance, module) + { // Install a module - address module = address(new MockK1Validator()); + module = address(new MockK1Validator()); // Install the module instance.installModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, diff --git a/test/integrations/SwapTest.t.sol b/test/integrations/SwapTest.t.sol new file mode 100644 index 00000000..a4458df8 --- /dev/null +++ b/test/integrations/SwapTest.t.sol @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import "test/BaseTest.t.sol"; +import "src/ModuleKit.sol"; +import { ERC7579ExecutorBase } from "src/Modules.sol"; +import { IERC20 } from "forge-std/interfaces/IERC20.sol"; +import { UniswapV3Integration } from "src/integrations/uniswap/v3/Uniswap.sol"; + +contract TestUniswap is BaseTest { + using ModuleKitHelpers for AccountInstance; + using UniswapV3Integration for *; + + IERC20 tokenA; + IERC20 tokenB; + MockERC20 mockTokenA; + MockERC20 mockTokenB; + + uint256 amountIn = 100_000_000; // Example: 100 tokens of tokenA + uint32 slippage = 1; // 0.1% slippage + + address internal constant USDC_HOLDER = 0x4B16c5dE96EB2117bBE5fd171E4d203624B014aa; // account + // with USDC holdings + address internal constant WETH_HOLDER = 0x57757E3D981446D585Af0D9Ae4d7DF6D64647806; // account + // with WETH holdings + + address constant USDC_ADDRESS = 0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48; + address constant WETH_ADDRESS = 0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2; + + function setUp() public override { + string memory MAINNET_RPC_URL = vm.envString("MAINNET_RPC_URL"); + vm.createSelectFork(MAINNET_RPC_URL); + vm.rollFork(20_426_591); + instance = makeAccountInstance("account1"); + assertTrue(instance.account != address(0)); + + tokenA = IERC20(USDC_ADDRESS); + tokenB = IERC20(WETH_ADDRESS); + + _fundAccountWithTokenA(amountIn); + vm.deal(instance.account, 1 ether); + assertTrue(instance.account.balance == 1 ether); + } + + function _fundAccountWithTokenA(uint256 amount) internal { + vm.startPrank(USDC_HOLDER); + bool success = tokenA.transfer(instance.account, amount); + require(success, "Failed to transfer tokenA to account"); + vm.stopPrank(); + } + + function testApproveAndSwap() public { + address poolAddress = UniswapV3Integration.getPoolAddress(address(tokenA), address(tokenB)); + uint160 sqrtPriceX96 = UniswapV3Integration.getSqrtPriceX96(poolAddress); + emit log_named_uint("Square Root Price X96", sqrtPriceX96); + + uint256 priceRatio = UniswapV3Integration.sqrtPriceX96toPriceRatio(sqrtPriceX96); + + emit log_named_uint("Price Ratio", priceRatio); + + uint256 price = + UniswapV3Integration.priceRatioToPrice(priceRatio, poolAddress, address(tokenA)); + + emit log_named_uint("Price", price); + + bool swapToken0to1 = UniswapV3Integration.checkTokenOrder(address(tokenA), poolAddress); + + uint256 priceRatioLimit; + if (swapToken0to1) { + priceRatioLimit = (priceRatio * (1000 - slippage)) / 1000; + } else { + priceRatioLimit = (priceRatio * (1000 + slippage)) / 1000; + } + + emit log_named_uint("Price Ratio Limit", priceRatioLimit); + + uint256 priceLimit = + UniswapV3Integration.priceRatioToPrice(priceRatioLimit, poolAddress, address(tokenA)); + + emit log_named_uint("Price Limit", priceLimit); + + uint160 sqrtPriceLimitX96 = UniswapV3Integration.priceRatioToSqrtPriceX96(priceRatioLimit); + + emit log_named_uint("sqrtPriceLimitX96", sqrtPriceLimitX96); + + uint256 initialAccountBalanceA = tokenA.balanceOf(instance.account); + uint256 initialAccountBalanceB = tokenB.balanceOf(instance.account); + + emit log_named_uint("Initial Balance of Token A (account)", initialAccountBalanceA); + emit log_named_uint("Initial Balance of Token B (account)", initialAccountBalanceB); + + Execution[] memory swap = UniswapV3Integration.approveAndSwap( + instance.account, tokenA, tokenB, amountIn, sqrtPriceLimitX96 + ); + + for (uint256 i = 0; i < swap.length; i++) { + instance.exec({ + target: swap[i].target, + value: swap[i].value, + callData: swap[i].callData + }); + } + + uint256 finalAccountBalanceA = tokenA.balanceOf(instance.account); + uint256 finalAccountBalanceB = tokenB.balanceOf(instance.account); + + emit log_named_uint("Final Balance of Token A (account)", finalAccountBalanceA); + emit log_named_uint("Final Balance of Token B (account)", finalAccountBalanceB); + + sqrtPriceX96 = UniswapV3Integration.getSqrtPriceX96(poolAddress); + emit log_named_uint("Post Swap Square Root Price X96", sqrtPriceX96); + + require( + finalAccountBalanceA < initialAccountBalanceA, + "Token A balance in account did not decrease" + ); + require( + finalAccountBalanceB > initialAccountBalanceB, + "Token B balance in account did not increase" + ); + } +} From 4c94f60076dd3f3af9affeafdb84a26889a17125 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 04:03:44 +0100 Subject: [PATCH 5/8] fix: force simulate to flase in tests --- test/Diff.t.sol | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 44f69033..8857e09d 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -679,6 +679,8 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { public withModuleStorageClearValidation(instance, module) { + // Set simulate mode to false + instance.simulateUserOp(false); // Install a module module = address(new MockK1Validator()); // Install the module From 6e22992227163f6a6863a8eb3e94d2a1eb6b3f40 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 04:43:26 +0100 Subject: [PATCH 6/8] feat: add COMPLIANCE env variable and hook to install/uninstall --- src/test/ModuleKitHelpers.sol | 40 ++++++++--- test/Diff.t.sol | 16 +++-- .../MockK1ValidatorUncompliantUninstall.sol | 69 +++++++++++++++++++ 3 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 test/mocks/MockK1ValidatorUncompliantUninstall.sol diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 7918fba2..be986c04 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -20,7 +20,8 @@ import { VmSafe, startStateDiffRecording as vmStartStateDiffRecording, stopAndReturnStateDiff as vmStopAndReturnStateDiff, - getMappingKeyAndParentOf + getMappingKeyAndParentOf, + envOr } from "src/test/utils/Vm.sol"; import { getAccountType as getAccountTypeFromStorage, @@ -146,6 +147,31 @@ library ModuleKitHelpers { return exec(instance, target, 0, callData); } + /*////////////////////////////////////////////////////////////// + HOOKS + //////////////////////////////////////////////////////////////*/ + + function preEnvHook() internal { + if (envOr("COMPLIANCE", false)) { + if (envOr("SIMULATE", false)) { + revert("Compliance and simulate cannot be used together"); + } else { + // Start state diff recording + vmStartStateDiffRecording(); + } + } + } + + function postEnvHook(AccountInstance memory instance, bytes memory data) internal { + if (envOr("COMPLIANCE", false)) { + address module = abi.decode(data, (address)); + // Stop state diff recording and return account accesses + VmSafe.AccountAccess[] memory accountAccesses = vmStopAndReturnStateDiff(); + // Check if storage was cleared + verifyModuleStorageWasCleared(instance, accountAccesses, module); + } + } + /*////////////////////////////////////////////////////////////////////////// MODULE CONFIG //////////////////////////////////////////////////////////////////////////*/ @@ -159,6 +185,8 @@ library ModuleKitHelpers { internal returns (UserOpData memory userOpData) { + // Run preEnvHook + preEnvHook(); userOpData = instance.getInstallModuleOps( moduleTypeId, module, data, address(instance.defaultValidator) ); @@ -187,6 +215,8 @@ library ModuleKitHelpers { // send userOp to entrypoint userOpData.execUserOps(); + // Run postEnvHook + postEnvHook(instance, abi.encode(module)); } function isModuleInstalled( @@ -403,14 +433,6 @@ library ModuleKitHelpers { } } - /// Verifies that storage was correctly cleared after uninstalling a module - modifier withUninstallStorageValidation(AccountInstance memory instance, address module) { - instance.startStateDiffRecording(); - _; - VmSafe.AccountAccess[] memory accountAccess = instance.stopAndReturnStateDiff(); - verifyModuleStorageWasCleared(instance, accountAccess, module); - } - /*////////////////////////////////////////////////////////////////////////// CONTROL FLOW //////////////////////////////////////////////////////////////////////////*/ diff --git a/test/Diff.t.sol b/test/Diff.t.sol index 8857e09d..e6696b69 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -15,6 +15,8 @@ import { getAccountType, InstalledModule } from "src/test/utils/Storage.sol"; import { toString } from "src/test/utils/Vm.sol"; import { MockValidatorFalse } from "test/mocks/MockValidatorFalse.sol"; import { MockK1Validator, VALIDATION_SUCCESS } from "test/mocks/MockK1Validator.sol"; +import { MockK1ValidatorUncompliantUninstall } from + "test/mocks/MockK1ValidatorUncompliantUninstall.sol"; import { VALIDATION_SUCCESS, VALIDATION_FAILED } from "erc7579/interfaces/IERC7579Module.sol"; import { VmSafe } from "src/test/utils/Vm.sol"; @@ -654,9 +656,7 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { // Set simulate mode to false instance.simulateUserOp(false); // Install a module - address module = address(new MockK1Validator()); - // Start state diff recording - instance.startStateDiffRecording(); + module = address(new MockK1ValidatorUncompliantUninstall()); instance.installModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, @@ -667,12 +667,14 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { address(0xffffffffffffffffffff), MockK1Validator(module).smartAccountOwners(address(instance.account)) ); - // Stop state diff recording - VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff(); // Expect revert vm.expectRevert(); - // Assert that the module storage was cleared - instance.verifyModuleStorageWasCleared(accountAccesses, module); + this.__revertWhen_verifyModuleStorageWasCleared_NotCleared(); + } + + function __revertWhen_verifyModuleStorageWasCleared_NotCleared() public { + // Uninstall + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); } function test_withModuleStorageClearValidation() diff --git a/test/mocks/MockK1ValidatorUncompliantUninstall.sol b/test/mocks/MockK1ValidatorUncompliantUninstall.sol new file mode 100644 index 00000000..e0e4ed89 --- /dev/null +++ b/test/mocks/MockK1ValidatorUncompliantUninstall.sol @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { + IValidator, + VALIDATION_SUCCESS, + VALIDATION_FAILED, + MODULE_TYPE_VALIDATOR +} from "erc7579/interfaces/IERC7579Module.sol"; +import { PackedUserOperation } from "src/external/ERC4337.sol"; +import { ECDSA } from "solady/utils/ECDSA.sol"; +import { SignatureCheckerLib } from "solady/utils/SignatureCheckerLib.sol"; +import { MessageHashUtils } from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import { EIP1271_MAGIC_VALUE, IERC1271 } from "module-bases/interfaces/IERC1271.sol"; + +contract MockK1ValidatorUncompliantUninstall is IValidator { + bytes4 constant ERC1271_INVALID = 0xffffffff; + mapping(address => address) public smartAccountOwners; + + function validateUserOp( + PackedUserOperation calldata userOp, + bytes32 userOpHash + ) + external + view + returns (uint256 validation) + { + return ECDSA.recover(MessageHashUtils.toEthSignedMessageHash(userOpHash), userOp.signature) + == smartAccountOwners[msg.sender] ? VALIDATION_SUCCESS : VALIDATION_FAILED; + } + + function isValidSignatureWithSender( + address, + bytes32 hash, + bytes calldata signature + ) + external + view + returns (bytes4) + { + return ECDSA.recover(MessageHashUtils.toEthSignedMessageHash(hash), signature) + == smartAccountOwners[msg.sender] ? EIP1271_MAGIC_VALUE : ERC1271_INVALID; + } + + function onInstall(bytes calldata data) external { + address owner = abi.decode(data, (address)); + smartAccountOwners[msg.sender] = owner; + } + + function onUninstall(bytes calldata data) external pure { + data; + } + + function isModuleType(uint256 moduleTypeId) external pure returns (bool) { + return moduleTypeId == MODULE_TYPE_VALIDATOR; + } + + function isOwner(address account, address owner) external view returns (bool) { + return smartAccountOwners[account] == owner; + } + + function isInitialized(address) external pure returns (bool) { + return false; + } + + function getOwner(address account) external view returns (address) { + return smartAccountOwners[account]; + } +} From 8bff79902d91953091b92744efaaef818f8363f1 Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 04:52:45 +0100 Subject: [PATCH 7/8] feat: add get/write compliance --- src/test/ModuleKitHelpers.sol | 13 ++++++++++--- src/test/utils/Storage.sol | 18 ++++++++++++++++++ test/Diff.t.sol | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index be986c04..3f75e282 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -29,6 +29,9 @@ import { writeExpectRevert, writeGasIdentifier, writeSimulateUserOp, + writeStorageCompliance, + getStorageCompliance, + getSimulateUserOp, writeAccountEnv, getFactory, getHelper as getHelperFromStorage, @@ -152,8 +155,8 @@ library ModuleKitHelpers { //////////////////////////////////////////////////////////////*/ function preEnvHook() internal { - if (envOr("COMPLIANCE", false)) { - if (envOr("SIMULATE", false)) { + if (envOr("COMPLIANCE", false) || getStorageCompliance()) { + if (envOr("SIMULATE", false) || getSimulateUserOp()) { revert("Compliance and simulate cannot be used together"); } else { // Start state diff recording @@ -163,7 +166,7 @@ library ModuleKitHelpers { } function postEnvHook(AccountInstance memory instance, bytes memory data) internal { - if (envOr("COMPLIANCE", false)) { + if (envOr("COMPLIANCE", false) || getStorageCompliance()) { address module = abi.decode(data, (address)); // Stop state diff recording and return account accesses VmSafe.AccountAccess[] memory accountAccesses = vmStopAndReturnStateDiff(); @@ -465,6 +468,10 @@ library ModuleKitHelpers { writeSimulateUserOp(value); } + function storageCompliance(AccountInstance memory, bool value) internal { + writeStorageCompliance(value); + } + /*////////////////////////////////////////////////////////////////////////// ACCOUNT UTILS //////////////////////////////////////////////////////////////////////////*/ diff --git a/src/test/utils/Storage.sol b/src/test/utils/Storage.sol index 3fd6cf02..f46742a8 100644 --- a/src/test/utils/Storage.sol +++ b/src/test/utils/Storage.sol @@ -80,6 +80,24 @@ function getSimulateUserOp() view returns (bool value) { } } +/*////////////////////////////////////////////////////////////// + STORAGE COMPLIANCE +//////////////////////////////////////////////////////////////*/ + +function writeStorageCompliance(bool value) { + bytes32 slot = keccak256("ModuleKit.StorageCompliance"); + assembly { + sstore(slot, value) + } +} + +function getStorageCompliance() view returns (bool value) { + bytes32 slot = keccak256("ModuleKit.StorageCompliance"); + assembly { + value := sload(slot) + } +} + /*////////////////////////////////////////////////////////////// ACCOUNT ENV //////////////////////////////////////////////////////////////*/ diff --git a/test/Diff.t.sol b/test/Diff.t.sol index e6696b69..44155c7b 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -652,9 +652,14 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { instance.verifyModuleStorageWasCleared(accountAccesses, module); } - function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared() public { + function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared_UsingComplianceFlag() + public + { // Set simulate mode to false instance.simulateUserOp(false); + // Set compliance flag + instance.storageCompliance(true); + // Install a module module = address(new MockK1ValidatorUncompliantUninstall()); instance.installModule({ @@ -672,6 +677,33 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { this.__revertWhen_verifyModuleStorageWasCleared_NotCleared(); } + function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared() public { + // Set simulate mode to false + instance.simulateUserOp(false); + // Install a module + module = address(new MockK1ValidatorUncompliantUninstall()); + // Start state diff recording + instance.startStateDiffRecording(); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: module, + data: abi.encode(0xffffffffffffffffffff) + }); + // Assert module storage + assertEq( + address(0xffffffffffffffffffff), + MockK1Validator(module).smartAccountOwners(address(instance.account)) + ); + // Uninstall the module + instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); + // Stop state diff recording + VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff(); + // Expect revert + vm.expectRevert(); + // Assert that the module storage was cleared + instance.verifyModuleStorageWasCleared(accountAccesses, module); + } + function __revertWhen_verifyModuleStorageWasCleared_NotCleared() public { // Uninstall instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" }); From a43ebe1d406130bbb81b0e397a09e9d685d3ef3a Mon Sep 17 00:00:00 2001 From: highskore Date: Mon, 4 Nov 2024 04:58:03 +0100 Subject: [PATCH 8/8] fix: remove revert on SIM+COMP --- src/test/ModuleKitHelpers.sol | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 3f75e282..3d940cbe 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -156,12 +156,8 @@ library ModuleKitHelpers { function preEnvHook() internal { if (envOr("COMPLIANCE", false) || getStorageCompliance()) { - if (envOr("SIMULATE", false) || getSimulateUserOp()) { - revert("Compliance and simulate cannot be used together"); - } else { - // Start state diff recording - vmStartStateDiffRecording(); - } + // Start state diff recording + vmStartStateDiffRecording(); } }