Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: track installed modules #138

Merged
merged 16 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/ModuleKit.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
pragma solidity ^0.8.23;

/* solhint-disable no-unused-import */
import { UserOpData, AccountInstance, RhinestoneModuleKit } from "./test/RhinestoneModuleKit.sol";
import {
UserOpData,
AccountInstance,
RhinestoneModuleKit,
AccountType
} from "./test/RhinestoneModuleKit.sol";
import { ModuleKitHelpers } from "./test/ModuleKitHelpers.sol";
import { ModuleKitUserOp } from "./test/ModuleKitUserOp.sol";
import { PackedUserOperation } from "./external/ERC4337.sol";
Expand Down
62 changes: 61 additions & 1 deletion src/test/ModuleKitHelpers.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ import {
writeAccountEnv,
getFactory,
getHelper as getHelperFromStorage,
getAccountEnv as getAccountEnvFromStorage
getAccountEnv as getAccountEnvFromStorage,
getInstalledModules as getInstalledModulesFromStorage,
writeInstalledModule as writeInstalledModuleToStorage,
removeInstalledModule as removeInstalledModuleFromStorage,
InstalledModule
} from "./utils/Storage.sol";
import { recordLogs, VmSafe, getRecordedLogs } from "./utils/Vm.sol";

library ModuleKitHelpers {
/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -262,6 +267,46 @@ library ModuleKitHelpers {
userOpData.entrypoint = instance.aux.entrypoint;
}

function getInstalledModules(
AccountInstance memory instance
)
internal
view
returns (InstalledModule[] memory)
{
return getInstalledModulesFromStorage(instance.account);
}

function writeInstalledModule(
AccountInstance memory instance,
InstalledModule memory module
)
internal
{
writeInstalledModuleToStorage(module, instance.account);
}

function removeInstalledModule(
AccountInstance memory instance,
uint256 moduleType,
address moduleAddress
)
internal
{
// Get installed modules for account
InstalledModule[] memory installedModules = getInstalledModules(instance);
// Find module to remove (not super scalable at high module counts)
for (uint256 i; i < installedModules.length; i++) {
if (
installedModules[i].moduleType == moduleType
&& installedModules[i].moduleAddress == moduleAddress
) {
// Remove module from storage
removeInstalledModuleFromStorage(i, instance.account);
return;
}
}
}
/*//////////////////////////////////////////////////////////////////////////
CONTROL FLOW
//////////////////////////////////////////////////////////////////////////*/
Expand Down Expand Up @@ -332,7 +377,22 @@ library ModuleKitHelpers {
}

function deployAccount(AccountInstance memory instance) internal {
// Record logs to track installed modules
recordLogs();
// Deploy account
HelperBase(instance.accountHelper).deployAccount(instance);
// Parse logs and determine if a module was installed
VmSafe.Log[] memory logs = getRecordedLogs();
for (uint256 i; i < logs.length; i++) {
// ModuleInstalled(uint256, address)
if (
logs[i].topics[0]
== 0xd21d0b289f126c4b473ea641963e766833c2f13866e4ff480abd787c100ef123
) {
(uint256 moduleType, address module) = abi.decode(logs[i].data, (uint256, address));
writeInstalledModuleToStorage(InstalledModule(moduleType, module), logs[i].emitter);
}
}
}

function setAccountType(AccountInstance memory, AccountType env) internal {
Expand Down
33 changes: 32 additions & 1 deletion src/test/utils/ERC4337Helpers.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ import {
getExpectRevert,
writeExpectRevert,
getGasIdentifier,
writeGasIdentifier
writeGasIdentifier,
writeInstalledModule,
getInstalledModules,
removeInstalledModule,
InstalledModule
} from "./Storage.sol";

library ERC4337Helpers {
Expand Down Expand Up @@ -75,6 +79,33 @@ library ERC4337Helpers {
}
}
}
// ModuleInstalled(uint256, address)
else if (
logs[i].topics[0]
== 0xd21d0b289f126c4b473ea641963e766833c2f13866e4ff480abd787c100ef123
) {
(uint256 moduleType, address module) = abi.decode(logs[i].data, (uint256, address));
writeInstalledModule(InstalledModule(moduleType, module), logs[i].emitter);
}
// ModuleUninstalled(uint256, address)
else if (
logs[i].topics[0]
== 0x341347516a9de374859dfda710fa4828b2d48cb57d4fbe4c1149612b8e02276e
) {
(uint256 moduleType, address module) = abi.decode(logs[i].data, (uint256, address));
// Get all installed modules
InstalledModule[] memory installedModules = getInstalledModules(logs[i].emitter);
// Remove the uninstalled module from the list of installed modules
for (uint256 j; j < installedModules.length; j++) {
if (
installedModules[j].moduleAddress == module
&& installedModules[j].moduleType == moduleType
) {
removeInstalledModule(j, logs[i].emitter);
break;
}
}
highskore marked this conversation as resolved.
Show resolved Hide resolved
}
}
isExpectRevert = getExpectRevert();
if (isExpectRevert != 0) {
Expand Down
180 changes: 180 additions & 0 deletions src/test/utils/Storage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,186 @@ function getHelper(string memory helperType) view returns (address helper) {
}
}

/*//////////////////////////////////////////////////////////////
INSTALLED MODULE
//////////////////////////////////////////////////////////////*/

struct InstalledModule {
uint256 moduleType;
address moduleAddress;
}

// Adds new address to the installed module linked list for the given account
// The list is stored in storage as a linked list in the following format:
highskore marked this conversation as resolved.
Show resolved Hide resolved
// ---------------------------------------------------------------------
// | Slot | Value |
// |----------------------------------------------------------|--------|
// | keccak256(abi.encode("ModuleKit.InstalledModuleSlot.")); | length |
// | keccak256(abi.encode("ModuleKit.InstalledModuleHead.")); | head |
// | keccak256(abi.encode("ModuleKit.InstalledModuleTail.")); | tail |
// | keccak256(abi.encode(lengthSlot)) - initially X | element|
// ---------------------------------------------------------------------
//
// The elements are stored in the following way:
// --------------------------
// | Slot | Value |
// |------------------------|
// | X | moduleType |
// | X + 0x20 | moduleAddr |
// | X + 0x40 | prev |
// | X + 0x60 | next |
// --------------------------
function writeInstalledModule(InstalledModule memory module, address account) {
bytes32 lengthSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleSlot.", keccak256(abi.encodePacked(account)))
);
bytes32 headSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleHead.", keccak256(abi.encodePacked(account)))
);
bytes32 tailSlot =
keccak256(abi.encode("ModuleKit.InstalledModuleTail", keccak256(abi.encodePacked(account))));
bytes32 elementSlot = keccak256(abi.encode(lengthSlot));
uint256 moduleType = module.moduleType;
address moduleAddress = module.moduleAddress;
assembly {
// Get the length of the array
let length := sload(lengthSlot)
let nextSlot
let oldTail
switch iszero(length)
case 1 {
// If length is zero, set element slot to head and tail
sstore(headSlot, elementSlot)
sstore(tailSlot, elementSlot)
oldTail := elementSlot
nextSlot := elementSlot
}
default {
oldTail := sload(tailSlot)
// Set the new elemeont slot to the old tail + 0x80
elementSlot := add(oldTail, 0x80)
// Set the old tail next slot to the new element slot
sstore(add(oldTail, 0x60), elementSlot)
// Update tailSlot to point to the new element slot
sstore(tailSlot, elementSlot)
// Set nextSlot to the head slot
nextSlot := sload(headSlot)
}
// Update the length of the list
sstore(lengthSlot, add(length, 1))
// Store the module type and address in the new slot
sstore(elementSlot, moduleType)
sstore(add(elementSlot, 0x20), moduleAddress)
// Store the old tail as the prev slot
sstore(add(elementSlot, 0x40), oldTail)
// Store the head as the next slot
sstore(add(elementSlot, 0x60), nextSlot)
}
}

// Removes a specific installed module
function removeInstalledModule(uint256 index, address account) {
bytes32 lengthSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleSlot.", keccak256(abi.encodePacked(account)))
);
bytes32 headSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleHead.", keccak256(abi.encodePacked(account)))
);
bytes32 tailSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleTail.", keccak256(abi.encodePacked(account)))
);
assembly {
// Get the length of the list
let length := sload(lengthSlot)
// Get the initial element slot
let elementSlot := sload(headSlot)
// Ensure the index is within bounds
if lt(index, length) {
// Traverse to the node to remove
for { let i := 0 } lt(i, index) { i := add(i, 1) } {
elementSlot := sload(add(elementSlot, 0x60))
}

// Get the previous and next slots
let prevSlot := sload(add(elementSlot, 0x40))
let nextSlot := sload(add(elementSlot, 0x60))

// Update the previous slot's next pointer
sstore(add(prevSlot, 0x60), nextSlot)
// Update the next slot's previous pointer
sstore(add(nextSlot, 0x40), prevSlot)

// Handle removing the head
if eq(elementSlot, sload(headSlot)) { sstore(headSlot, nextSlot) }

// Handle removing the tail
if eq(elementSlot, sload(tailSlot)) { sstore(tailSlot, prevSlot) }

// Clear the removed node
sstore(elementSlot, 0)
sstore(add(elementSlot, 0x20), 0)
sstore(add(elementSlot, 0x40), 0)
sstore(add(elementSlot, 0x60), 0)

// Update the length of the list
sstore(lengthSlot, sub(length, 1))
}
}
}

// Returns all installed modules for the given account
function getInstalledModules(address account) view returns (InstalledModule[] memory modules) {
bytes32 lengthSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleSlot.", keccak256(abi.encodePacked(account)))
);
bytes32 headSlot = keccak256(
abi.encode("ModuleKit.InstalledModuleHead.", keccak256(abi.encodePacked(account)))
);
assembly {
// Get the length of the array from storage
let length := sload(lengthSlot)

// Each struct is 64 bytes (32 bytes for moduleType and 32 bytes for moduleAddress)
let structSize := 0x40 // 64 bytes
let size := mul(length, structSize) // Total size for structs
let totalSize := add(add(size, 0x40), mul(0x20, length))

// Allocate memory for the array
let freeMemoryPtr := mload(0x40)
modules := freeMemoryPtr

// Store the length of the array in the first 32 bytes of memory
mstore(modules, length)

// Update the free memory pointer to the end of the allocated memory
mstore(0x40, add(freeMemoryPtr, totalSize))

// Get the head of the linked list
let storageLocation := sload(headSlot)

// Copy the structs from storage to memory
for { let i := 0 } lt(i, length) { i := add(i, 1) } {
// Calculate memory location for this struct
let structLocation :=
add(add(freeMemoryPtr, add(0x40, mul(i, structSize))), mul(0x20, length))

// Load the moduleType and moduleAddress from storage
let moduleType := sload(storageLocation)
let moduleAddress := sload(add(storageLocation, 0x20))

// Store the structLocation into memory
mstore(add(freeMemoryPtr, add(0x20, mul(i, 0x20))), structLocation)

// Store the moduleType and moduleAddress into memory
mstore(structLocation, moduleType)
mstore(add(structLocation, 0x20), moduleAddress)

// Move to the next element in the linked list
storageLocation := sload(add(storageLocation, 0x60))
}
}
}

/*//////////////////////////////////////////////////////////////
STRING STORAGE
//////////////////////////////////////////////////////////////*/
Expand Down
Loading
Loading