diff --git a/src/ModuleKit.sol b/src/ModuleKit.sol index d0b17dc8..35ebc6b5 100644 --- a/src/ModuleKit.sol +++ b/src/ModuleKit.sol @@ -2,7 +2,12 @@ pragma solidity ^0.8.23; /* solhint-disable no-unused-import */ -import { UserOpData, AccountInstance, RhinestoneModuleKit } from "./test/RhinestoneModuleKit.sol"; +import { + UserOpData, + AccountInstance, + RhinestoneModuleKit, + AccountType +} from "./test/RhinestoneModuleKit.sol"; import { ModuleKitHelpers } from "./test/ModuleKitHelpers.sol"; import { ModuleKitUserOp } from "./test/ModuleKitUserOp.sol"; import { PackedUserOperation } from "./external/ERC4337.sol"; diff --git a/src/test/ModuleKitHelpers.sol b/src/test/ModuleKitHelpers.sol index 34209afe..886dd200 100644 --- a/src/test/ModuleKitHelpers.sol +++ b/src/test/ModuleKitHelpers.sol @@ -23,8 +23,13 @@ import { writeAccountEnv, getFactory, getHelper as getHelperFromStorage, - getAccountEnv as getAccountEnvFromStorage + getAccountEnv as getAccountEnvFromStorage, + getInstalledModules as getInstalledModulesFromStorage, + writeInstalledModule as writeInstalledModuleToStorage, + removeInstalledModule as removeInstalledModuleFromStorage, + InstalledModule } from "./utils/Storage.sol"; +import { recordLogs, VmSafe, getRecordedLogs } from "./utils/Vm.sol"; library ModuleKitHelpers { /*////////////////////////////////////////////////////////////////////////// @@ -262,6 +267,46 @@ library ModuleKitHelpers { userOpData.entrypoint = instance.aux.entrypoint; } + function getInstalledModules( + AccountInstance memory instance + ) + internal + view + returns (InstalledModule[] memory) + { + return getInstalledModulesFromStorage(instance.account); + } + + function writeInstalledModule( + AccountInstance memory instance, + InstalledModule memory module + ) + internal + { + writeInstalledModuleToStorage(module, instance.account); + } + + function removeInstalledModule( + AccountInstance memory instance, + uint256 moduleType, + address moduleAddress + ) + internal + { + // Get installed modules for account + InstalledModule[] memory installedModules = getInstalledModules(instance); + // Find module to remove (not super scalable at high module counts) + for (uint256 i; i < installedModules.length; i++) { + if ( + installedModules[i].moduleType == moduleType + && installedModules[i].moduleAddress == moduleAddress + ) { + // Remove module from storage + removeInstalledModuleFromStorage(i, instance.account); + return; + } + } + } /*////////////////////////////////////////////////////////////////////////// CONTROL FLOW //////////////////////////////////////////////////////////////////////////*/ @@ -332,7 +377,22 @@ library ModuleKitHelpers { } function deployAccount(AccountInstance memory instance) internal { + // Record logs to track installed modules + recordLogs(); + // Deploy account HelperBase(instance.accountHelper).deployAccount(instance); + // Parse logs and determine if a module was installed + VmSafe.Log[] memory logs = getRecordedLogs(); + for (uint256 i; i < logs.length; i++) { + // ModuleInstalled(uint256, address) + if ( + logs[i].topics[0] + == 0xd21d0b289f126c4b473ea641963e766833c2f13866e4ff480abd787c100ef123 + ) { + (uint256 moduleType, address module) = abi.decode(logs[i].data, (uint256, address)); + writeInstalledModuleToStorage(InstalledModule(moduleType, module), logs[i].emitter); + } + } } function setAccountType(AccountInstance memory, AccountType env) internal { diff --git a/src/test/utils/ERC4337Helpers.sol b/src/test/utils/ERC4337Helpers.sol index ec7d9819..140680ff 100644 --- a/src/test/utils/ERC4337Helpers.sol +++ b/src/test/utils/ERC4337Helpers.sol @@ -18,7 +18,11 @@ import { getExpectRevert, writeExpectRevert, getGasIdentifier, - writeGasIdentifier + writeGasIdentifier, + writeInstalledModule, + getInstalledModules, + removeInstalledModule, + InstalledModule } from "./Storage.sol"; library ERC4337Helpers { @@ -75,6 +79,33 @@ library ERC4337Helpers { } } } + // ModuleInstalled(uint256, address) + else if ( + logs[i].topics[0] + == 0xd21d0b289f126c4b473ea641963e766833c2f13866e4ff480abd787c100ef123 + ) { + (uint256 moduleType, address module) = abi.decode(logs[i].data, (uint256, address)); + writeInstalledModule(InstalledModule(moduleType, module), logs[i].emitter); + } + // ModuleUninstalled(uint256, address) + else if ( + logs[i].topics[0] + == 0x341347516a9de374859dfda710fa4828b2d48cb57d4fbe4c1149612b8e02276e + ) { + (uint256 moduleType, address module) = abi.decode(logs[i].data, (uint256, address)); + // Get all installed modules + InstalledModule[] memory installedModules = getInstalledModules(logs[i].emitter); + // Remove the uninstalled module from the list of installed modules + for (uint256 j; j < installedModules.length; j++) { + if ( + installedModules[j].moduleAddress == module + && installedModules[j].moduleType == moduleType + ) { + removeInstalledModule(j, logs[i].emitter); + break; + } + } + } } isExpectRevert = getExpectRevert(); if (isExpectRevert != 0) { diff --git a/src/test/utils/Storage.sol b/src/test/utils/Storage.sol index d16800a2..474677d8 100644 --- a/src/test/utils/Storage.sol +++ b/src/test/utils/Storage.sol @@ -133,6 +133,186 @@ function getHelper(string memory helperType) view returns (address helper) { } } +/*////////////////////////////////////////////////////////////// + INSTALLED MODULE +//////////////////////////////////////////////////////////////*/ + +struct InstalledModule { + uint256 moduleType; + address moduleAddress; +} + +// Adds new address to the installed module linked list for the given account +// The list is stored in storage as a linked list in the following format: +// --------------------------------------------------------------------- +// | Slot | Value | +// |----------------------------------------------------------|--------| +// | keccak256(abi.encode("ModuleKit.InstalledModuleSlot.")); | length | +// | keccak256(abi.encode("ModuleKit.InstalledModuleHead.")); | head | +// | keccak256(abi.encode("ModuleKit.InstalledModuleTail.")); | tail | +// | keccak256(abi.encode(lengthSlot)) - initially X | element| +// --------------------------------------------------------------------- +// +// The elements are stored in the following way: +// -------------------------- +// | Slot | Value | +// |------------------------| +// | X | moduleType | +// | X + 0x20 | moduleAddr | +// | X + 0x40 | prev | +// | X + 0x60 | next | +// -------------------------- +function writeInstalledModule(InstalledModule memory module, address account) { + bytes32 lengthSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleSlot.", keccak256(abi.encodePacked(account))) + ); + bytes32 headSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleHead.", keccak256(abi.encodePacked(account))) + ); + bytes32 tailSlot = + keccak256(abi.encode("ModuleKit.InstalledModuleTail", keccak256(abi.encodePacked(account)))); + bytes32 elementSlot = keccak256(abi.encode(lengthSlot)); + uint256 moduleType = module.moduleType; + address moduleAddress = module.moduleAddress; + assembly { + // Get the length of the array + let length := sload(lengthSlot) + let nextSlot + let oldTail + switch iszero(length) + case 1 { + // If length is zero, set element slot to head and tail + sstore(headSlot, elementSlot) + sstore(tailSlot, elementSlot) + oldTail := elementSlot + nextSlot := elementSlot + } + default { + oldTail := sload(tailSlot) + // Set the new elemeont slot to the old tail + 0x80 + elementSlot := add(oldTail, 0x80) + // Set the old tail next slot to the new element slot + sstore(add(oldTail, 0x60), elementSlot) + // Update tailSlot to point to the new element slot + sstore(tailSlot, elementSlot) + // Set nextSlot to the head slot + nextSlot := sload(headSlot) + } + // Update the length of the list + sstore(lengthSlot, add(length, 1)) + // Store the module type and address in the new slot + sstore(elementSlot, moduleType) + sstore(add(elementSlot, 0x20), moduleAddress) + // Store the old tail as the prev slot + sstore(add(elementSlot, 0x40), oldTail) + // Store the head as the next slot + sstore(add(elementSlot, 0x60), nextSlot) + } +} + +// Removes a specific installed module +function removeInstalledModule(uint256 index, address account) { + bytes32 lengthSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleSlot.", keccak256(abi.encodePacked(account))) + ); + bytes32 headSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleHead.", keccak256(abi.encodePacked(account))) + ); + bytes32 tailSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleTail.", keccak256(abi.encodePacked(account))) + ); + assembly { + // Get the length of the list + let length := sload(lengthSlot) + // Get the initial element slot + let elementSlot := sload(headSlot) + // Ensure the index is within bounds + if lt(index, length) { + // Traverse to the node to remove + for { let i := 0 } lt(i, index) { i := add(i, 1) } { + elementSlot := sload(add(elementSlot, 0x60)) + } + + // Get the previous and next slots + let prevSlot := sload(add(elementSlot, 0x40)) + let nextSlot := sload(add(elementSlot, 0x60)) + + // Update the previous slot's next pointer + sstore(add(prevSlot, 0x60), nextSlot) + // Update the next slot's previous pointer + sstore(add(nextSlot, 0x40), prevSlot) + + // Handle removing the head + if eq(elementSlot, sload(headSlot)) { sstore(headSlot, nextSlot) } + + // Handle removing the tail + if eq(elementSlot, sload(tailSlot)) { sstore(tailSlot, prevSlot) } + + // Clear the removed node + sstore(elementSlot, 0) + sstore(add(elementSlot, 0x20), 0) + sstore(add(elementSlot, 0x40), 0) + sstore(add(elementSlot, 0x60), 0) + + // Update the length of the list + sstore(lengthSlot, sub(length, 1)) + } + } +} + +// Returns all installed modules for the given account +function getInstalledModules(address account) view returns (InstalledModule[] memory modules) { + bytes32 lengthSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleSlot.", keccak256(abi.encodePacked(account))) + ); + bytes32 headSlot = keccak256( + abi.encode("ModuleKit.InstalledModuleHead.", keccak256(abi.encodePacked(account))) + ); + assembly { + // Get the length of the array from storage + let length := sload(lengthSlot) + + // Each struct is 64 bytes (32 bytes for moduleType and 32 bytes for moduleAddress) + let structSize := 0x40 // 64 bytes + let size := mul(length, structSize) // Total size for structs + let totalSize := add(add(size, 0x40), mul(0x20, length)) + + // Allocate memory for the array + let freeMemoryPtr := mload(0x40) + modules := freeMemoryPtr + + // Store the length of the array in the first 32 bytes of memory + mstore(modules, length) + + // Update the free memory pointer to the end of the allocated memory + mstore(0x40, add(freeMemoryPtr, totalSize)) + + // Get the head of the linked list + let storageLocation := sload(headSlot) + + // Copy the structs from storage to memory + for { let i := 0 } lt(i, length) { i := add(i, 1) } { + // Calculate memory location for this struct + let structLocation := + add(add(freeMemoryPtr, add(0x40, mul(i, structSize))), mul(0x20, length)) + + // Load the moduleType and moduleAddress from storage + let moduleType := sload(storageLocation) + let moduleAddress := sload(add(storageLocation, 0x20)) + + // Store the structLocation into memory + mstore(add(freeMemoryPtr, add(0x20, mul(i, 0x20))), structLocation) + + // Store the moduleType and moduleAddress into memory + mstore(structLocation, moduleType) + mstore(add(structLocation, 0x20), moduleAddress) + + // Move to the next element in the linked list + storageLocation := sload(add(storageLocation, 0x60)) + } + } +} + /*////////////////////////////////////////////////////////////// STRING STORAGE //////////////////////////////////////////////////////////////*/ diff --git a/test/Diff.t.sol b/test/Diff.t.sol index dbbd7b8f..9dff6393 100644 --- a/test/Diff.t.sol +++ b/test/Diff.t.sol @@ -11,7 +11,8 @@ import { MODULE_TYPE_FALLBACK, CALLTYPE_SINGLE } from "src/external/ERC7579.sol"; -import { getAccountType } from "src/test/utils/Storage.sol"; +import { getAccountType, InstalledModule } from "src/test/utils/Storage.sol"; +import { toString } from "src/test/utils/Vm.sol"; contract ERC7579DifferentialModuleKitLibTest is BaseTest { using ModuleKitHelpers for *; @@ -181,6 +182,171 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { assertTrue(validator1Enabled); } + function test_getInstalledModules() public whenEnvIsNotKernel { + address newValidator = address(new MockValidator()); + address newValidator1 = address(new MockValidator()); + vm.label(newValidator, "2nd validator"); + + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator, + data: "" + }); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator1, + data: "" + }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 2, [newValidator, newValidator1], [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR] + ), + instance + ); + + address newExecutor = address(new MockExecutor()); + instance.installModule({ moduleTypeId: MODULE_TYPE_EXECUTOR, module: newExecutor, data: "" }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 3, + [newValidator, newValidator1, newExecutor], + [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR] + ), + instance + ); + } + + function test_getInstalledModules_DifferentInstances() public whenEnvIsNotKernel { + address newValidator = address(new MockValidator()); + address newValidator1 = address(new MockValidator()); + vm.label(newValidator, "2nd validator"); + + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator, + data: "" + }); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator1, + data: "" + }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 2, [newValidator, newValidator1], [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR] + ), + instance + ); + + address newExecutor = address(new MockExecutor()); + instance.installModule({ moduleTypeId: MODULE_TYPE_EXECUTOR, module: newExecutor, data: "" }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 3, + [newValidator, newValidator1, newExecutor], + [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR] + ), + instance + ); + + // Deploy new instance using current env + AccountInstance memory newInstance = makeAccountInstance("newSalt"); + assertTrue(newInstance.account.code.length == 0); + newInstance.deployAccount(); + assertTrue(newInstance.account.code.length > 0); + + // Install modules on new instance + newInstance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator, + data: "" + }); + newInstance.installModule({ + moduleTypeId: MODULE_TYPE_EXECUTOR, + module: newExecutor, + data: "" + }); + + // Assert installed modules on new instance + this._getModulesAndAssert( + abi.encode( + 2, [newValidator, newExecutor], [MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR] + ), + newInstance + ); + + // Old instance modules should still be the same + this._getModulesAndAssert( + abi.encode( + 3, + [newValidator, newValidator1, newExecutor], + [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR] + ), + instance + ); + } + + function test_getInstalledModules_AfterUninstall() public whenEnvIsNotKernel { + address newValidator = address(new MockValidator()); + address newValidator1 = address(new MockValidator()); + vm.label(newValidator, "2nd validator"); + + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator, + data: "" + }); + instance.installModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator1, + data: "" + }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 2, [newValidator, newValidator1], [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR] + ), + instance + ); + + address newExecutor = address(new MockExecutor()); + instance.installModule({ moduleTypeId: MODULE_TYPE_EXECUTOR, module: newExecutor, data: "" }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 3, // length + [newValidator, newValidator1, newExecutor], // expectedAddresses + [MODULE_TYPE_VALIDATOR, MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR] // expectedTypes + ), + instance + ); + + // Uninstall module + instance.uninstallModule({ + moduleTypeId: MODULE_TYPE_VALIDATOR, + module: newValidator, + data: "" + }); + + // Assert installed modules + this._getModulesAndAssert( + abi.encode( + 2, [newValidator1, newExecutor], [MODULE_TYPE_VALIDATOR, MODULE_TYPE_EXECUTOR] + ), + instance + ); + } + function testRemoveValidator() public { address newValidator = address(new MockValidator()); instance.installModule({ @@ -398,7 +564,7 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { } /*////////////////////////////////////////////////////////////// - INTERNAL + HELPERS //////////////////////////////////////////////////////////////*/ function _usingAccountEnv(string memory env) internal usingAccountEnv(env.toAccountType()) { @@ -409,4 +575,49 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest { assertTrue(newInstance.account.code.length > 0); } + + function _getModulesAndAssert( + bytes calldata expectedResultBytes, + AccountInstance memory _instance + ) + public + view + { + InstalledModule[] memory modules = _instance.getInstalledModules(); + // Parse length + uint256 length = abi.decode(expectedResultBytes[0:32], (uint256)); + // Parse addresses and types + address[] memory expectedAddresses = new address[](length); + uint256[] memory expectedTypes = new uint256[](length); + for (uint256 i = 0; i < length; i++) { + expectedAddresses[i] = + abi.decode(expectedResultBytes[32 + i * 32:64 + i * 32], (address)); + expectedTypes[i] = abi.decode( + expectedResultBytes[32 + length * 32 + i * 32:64 + length * 32 + i * 32], (uint256) + ); + } + // Assert expected modules length + assertTrue( + modules.length == length + (instance.getAccountType() == AccountType.SAFE ? 1 : 0) + ); + // AccountType.SAFE has 1 extra module added during setup, skip it + uint256 index = instance.getAccountType() == AccountType.SAFE ? 1 : 0; + for (uint256 i = 0; i < length; i++) { + assertTrue(modules[index + i].moduleAddress == expectedAddresses[i]); + assertTrue(modules[index + i].moduleType == expectedTypes[i]); + } + } + + /*////////////////////////////////////////////////////////////// + MODIFIERS + //////////////////////////////////////////////////////////////*/ + + // Used to skip tests when env is kernel as they don't emit events on module installation + modifier whenEnvIsNotKernel() { + AccountType env = ModuleKitHelpers.getAccountType(); + if (env == AccountType.KERNEL) { + return; + } + _; + } }