diff --git a/contracts/BaseHook.sol b/contracts/BaseHook.sol index 01fc4954..8becad1a 100644 --- a/contracts/BaseHook.sol +++ b/contracts/BaseHook.sol @@ -72,22 +72,22 @@ abstract contract BaseHook is IHooks, SafeCallback { revert HookNotImplemented(); } - function beforeRemoveLiquidity( + function afterAddLiquidity( address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, bytes calldata - ) external virtual returns (bytes4) { + ) external virtual returns (bytes4, BalanceDelta) { revert HookNotImplemented(); } - function afterAddLiquidity( + function beforeRemoveLiquidity( address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, bytes calldata - ) external virtual returns (bytes4, BalanceDelta) { + ) external virtual returns (bytes4) { revert HookNotImplemented(); } diff --git a/contracts/interfaces/IBaseHook.sol b/contracts/interfaces/IBaseHook.sol new file mode 100644 index 00000000..7f404cf2 --- /dev/null +++ b/contracts/interfaces/IBaseHook.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; + +interface IBaseHook is IHooks { + function getHookPermissions() external pure returns (Hooks.Permissions memory); +} diff --git a/contracts/interfaces/IBaseImplementation.sol b/contracts/interfaces/IBaseImplementation.sol new file mode 100644 index 00000000..bbf73419 --- /dev/null +++ b/contracts/interfaces/IBaseImplementation.sol @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IBaseHook} from "./IBaseHook.sol"; + +interface IBaseImplementation is IBaseHook { + function initializeMiddleware(address _middleware) external; +} diff --git a/contracts/interfaces/IMiddlewareFactory.sol b/contracts/interfaces/IMiddlewareFactory.sol new file mode 100644 index 00000000..4af554fe --- /dev/null +++ b/contracts/interfaces/IMiddlewareFactory.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +interface IMiddlewareFactory { + event MiddlewareCreated(address implementation, address middleware); + + /// @notice Returns the implementation address for a given middleware + /// @param middleware The middleware address + /// @return implementation The implementation address + function getImplementation(address middleware) external view returns (address implementation); + + /// @notice Creates a middleware for the given implementation + /// @param implementation The implementation address + /// @param salt The salt to use to deploy the middleware + /// @return middleware The address of the newly created middleware + function createMiddleware(address implementation, bytes32 salt) external returns (address middleware); +} diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol new file mode 100644 index 00000000..3f5b034f --- /dev/null +++ b/contracts/middleware/BaseImplementation.sol @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {SafeCallback} from "./../base/SafeCallback.sol"; +import {ImmutableState} from "./../base/ImmutableState.sol"; + +abstract contract BaseImplementation is IHooks, SafeCallback { + error NotSelf(); + error InvalidPool(); + error LockFailure(); + error HookNotImplemented(); + error NotMiddleware(); + error NotMiddlewareFactory(); + + address public immutable middlewareFactory; + address public middleware; + + constructor(IPoolManager _manager, address _middlewareFactory) ImmutableState(_manager) { + middlewareFactory = _middlewareFactory; + } + + /// @dev Only this address may call this function + modifier selfOnly() { + if (msg.sender != address(this)) revert NotSelf(); + _; + } + + /// @dev Only pools with hooks set to this contract may call this function + modifier onlyValidPools(IHooks hooks) { + if (hooks != this) revert InvalidPool(); + _; + } + + modifier onlyByMiddleware() { + if (msg.sender != middleware) revert NotMiddleware(); + _; + } + + function initializeMiddleware(address _middleware) external { + if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); + middleware = _middleware; + } + + function updateDynamicFee(PoolKey calldata key, uint24 fee) external { + if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); + manager.updateDynamicLPFee(key, fee); + } + + function getHookPermissions() public pure virtual returns (Hooks.Permissions memory); + + function _unlockCallback(bytes calldata data) internal virtual override returns (bytes memory) { + (bool success, bytes memory returnData) = address(this).call(data); + if (success) return returnData; + if (returnData.length == 0) revert LockFailure(); + // if the call failed, bubble up the reason + /// @solidity memory-safe-assembly + assembly { + revert(add(returnData, 32), mload(returnData)) + } + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { + revert HookNotImplemented(); + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external virtual returns (bytes4) { + revert HookNotImplemented(); + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + revert HookNotImplemented(); + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + revert HookNotImplemented(); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + virtual + returns (bytes4, BeforeSwapDelta, uint24) + { + revert HookNotImplemented(); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + virtual + returns (bytes4, int128) + { + revert HookNotImplemented(); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } +} diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol new file mode 100644 index 00000000..1dc064f4 --- /dev/null +++ b/contracts/middleware/BaseMiddleware.sol @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {IBaseHook} from "./../interfaces/IBaseHook.sol"; +import {BaseHook} from "./../BaseHook.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; + +contract BaseMiddleware is IHooks { + using Hooks for BaseMiddleware; + using BeforeSwapDeltaLibrary for BeforeSwapDelta; + + error NotManager(); + + IPoolManager public immutable manager; + IBaseHook public immutable implementation; + + constructor(IPoolManager _manager, IBaseHook _implementation) { + manager = _manager; + implementation = _implementation; + } + + modifier onlyByManager() { + if (msg.sender != address(manager)) revert NotManager(); + _; + } + + function getHookPermissions() public view returns (Hooks.Permissions memory) { + return implementation.getHookPermissions(); + } + + function beforeInitialize(address sender, PoolKey calldata key, uint160 sqrtPriceX96, bytes calldata hookData) + external + virtual + onlyByManager + returns (bytes4) + { + if (msg.sender == address(implementation)) return BaseHook.beforeInitialize.selector; + return implementation.beforeInitialize(sender, key, sqrtPriceX96, hookData); + } + + function afterInitialize( + address sender, + PoolKey calldata key, + uint160 sqrtPriceX96, + int24 tick, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.afterInitialize.selector; + return implementation.afterInitialize(sender, key, sqrtPriceX96, tick, hookData); + } + + function beforeAddLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.beforeAddLiquidity.selector; + return implementation.beforeAddLiquidity(sender, key, params, hookData); + } + + function afterAddLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, BalanceDelta) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + return implementation.afterAddLiquidity(sender, key, params, delta, hookData); + } + + function beforeRemoveLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.beforeRemoveLiquidity.selector; + return implementation.beforeRemoveLiquidity(sender, key, params, hookData); + } + + function afterRemoveLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, BalanceDelta) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + return implementation.afterRemoveLiquidity(sender, key, params, delta, hookData); + } + + function beforeSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) { + if (msg.sender == address(implementation)) { + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + (selector, beforeSwapDelta, lpFeeOverride) = implementation.beforeSwap(sender, key, params, hookData); + if (this.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG)) { + manager.take(key.currency0, sender, uint256(uint128(beforeSwapDelta.getSpecifiedDelta()))); + manager.take(key.currency1, sender, uint256(uint128(beforeSwapDelta.getUnspecifiedDelta()))); + } + } + + function afterSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, int128) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterSwap.selector, 0); + } + return implementation.afterSwap(sender, key, params, delta, hookData); + } + + function beforeDonate( + address sender, + PoolKey calldata key, + uint256 amount0, + uint256 amount1, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.beforeDonate.selector; + return implementation.beforeDonate(sender, key, amount0, amount1, hookData); + } + + function afterDonate( + address sender, + PoolKey calldata key, + uint256 amount0, + uint256 amount1, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.afterDonate.selector; + return implementation.afterDonate(sender, key, amount0, amount1, hookData); + } +} diff --git a/contracts/middleware/BaseMiddlewareFactory.sol b/contracts/middleware/BaseMiddlewareFactory.sol new file mode 100644 index 00000000..2b549a6a --- /dev/null +++ b/contracts/middleware/BaseMiddlewareFactory.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IBaseHook} from "../interfaces/IBaseHook.sol"; +import {IBaseImplementation} from "../interfaces/IBaseImplementation.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract BaseMiddlewareFactory is IMiddlewareFactory { + error AlreadyInitialized(); + + mapping(address => address) private _implementations; + mapping(address => bool) private _initialized; + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getImplementation(address middleware) external view override returns (address implementation) { + return _implementations[middleware]; + } + + function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { + if (_initialized[implementation]) revert AlreadyInitialized(); + _initialized[implementation] = true; + middleware = _deployMiddleware(implementation, salt); + Hooks.validateHookPermissions(IHooks(middleware), IBaseImplementation(implementation).getHookPermissions()); + IBaseImplementation(implementation).initializeMiddleware(middleware); + _implementations[middleware] = implementation; + emit MiddlewareCreated(implementation, middleware); + } + + function _deployMiddleware(address implementation, bytes32 salt) internal virtual returns (address middleware) { + return address(new BaseMiddleware{salt: salt}(poolManager, IBaseHook(implementation))); + } +} diff --git a/lib/forge-gas-snapshot b/lib/forge-gas-snapshot index 2f884282..9161f7c0 160000 --- a/lib/forge-gas-snapshot +++ b/lib/forge-gas-snapshot @@ -1 +1 @@ -Subproject commit 2f884282b4cd067298e798974f5b534288b13bc2 +Subproject commit 9161f7c0b6c6788a89081e2b3b9c67592b71e689 diff --git a/lib/forge-std b/lib/forge-std index 2b58ecbc..75b3fcf0 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 2b58ecbcf3dfde7a75959dc7b4eb3d0670278de6 +Subproject commit 75b3fcf052cc7886327e4c2eac3d1a1f36942b41 diff --git a/test/BaseMiddlewareFactory.t.sol b/test/BaseMiddlewareFactory.t.sol new file mode 100644 index 00000000..3652f84c --- /dev/null +++ b/test/BaseMiddlewareFactory.t.sol @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol"; +import {BaseMiddlewareFactory} from "./../contracts/middleware/BaseMiddlewareFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; +import {HooksCounter} from "./middleware-implementations/HooksCounter.sol"; +import {BaseImplementation} from "./../contracts/middleware/BaseImplementation.sol"; + +contract BaseMiddlewareFactoryTest is Test, Deployers { + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + + BaseMiddlewareFactory factory; + HooksCounter hooksCounter; + + address middleware; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + factory = new BaseMiddlewareFactory(manager); + hooksCounter = new HooksCounter(manager, address(factory)); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + middleware = factory.createMiddleware(address(hooksCounter), salt); + assertEq(hookAddress, middleware); + } + + function testRevertOnAlreadyInitialized() public { + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + vm.expectRevert(BaseMiddlewareFactory.AlreadyInitialized.selector); + factory.createMiddleware(address(hooksCounter), salt); + } + + function testRevertOnIncorrectFlags() public { + HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory)); + uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter2)) + ); + address implementation = address(hooksCounter2); + vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); + factory.createMiddleware(implementation, salt); + } + + function testRevertOnIncorrectFlagsMined() public { + HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory)); + address implementation = address(hooksCounter2); + vm.expectRevert(); // HookAddressNotValid + factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); + } + + function testRevertOnIncorrectCaller() public { + vm.expectRevert(BaseImplementation.NotMiddleware.selector); + hooksCounter.afterDonate(address(this), key, 0, 0, ZERO_BYTES); + } + + function testCounters() public { + (PoolKey memory key, PoolId id) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + + assertEq(hooksCounter.beforeInitializeCount(id), 1); + assertEq(hooksCounter.afterInitializeCount(id), 1); + assertEq(hooksCounter.beforeSwapCount(id), 0); + assertEq(hooksCounter.afterSwapCount(id), 0); + assertEq(hooksCounter.beforeAddLiquidityCount(id), 1); + assertEq(hooksCounter.afterAddLiquidityCount(id), 1); + assertEq(hooksCounter.beforeRemoveLiquidityCount(id), 0); + assertEq(hooksCounter.afterRemoveLiquidityCount(id), 0); + assertEq(hooksCounter.beforeDonateCount(id), 0); + assertEq(hooksCounter.afterDonateCount(id), 0); + + assertEq(hooksCounter.lastHookData(), ZERO_BYTES); + swap(key, true, 1, bytes("hi")); + assertEq(hooksCounter.lastHookData(), bytes("hi")); + assertEq(hooksCounter.beforeSwapCount(id), 1); + assertEq(hooksCounter.afterSwapCount(id), 1); + } +} diff --git a/test/middleware-implementations/HooksCounter.sol b/test/middleware-implementations/HooksCounter.sol new file mode 100644 index 00000000..49aeb50a --- /dev/null +++ b/test/middleware-implementations/HooksCounter.sol @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; + +contract HooksCounter is BaseImplementation { + using PoolIdLibrary for PoolKey; + + mapping(PoolId => uint256) public beforeInitializeCount; + mapping(PoolId => uint256) public afterInitializeCount; + + mapping(PoolId => uint256) public beforeSwapCount; + mapping(PoolId => uint256) public afterSwapCount; + + mapping(PoolId => uint256) public beforeAddLiquidityCount; + mapping(PoolId => uint256) public afterAddLiquidityCount; + mapping(PoolId => uint256) public beforeRemoveLiquidityCount; + mapping(PoolId => uint256) public afterRemoveLiquidityCount; + + mapping(PoolId => uint256) public beforeDonateCount; + mapping(PoolId => uint256) public afterDonateCount; + + bytes public lastHookData; + + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} + + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata key, uint160, bytes calldata hookData) + external + override + onlyByMiddleware + returns (bytes4) + { + beforeInitializeCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeInitialize.selector; + } + + function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata hookData) + external + override + onlyByMiddleware + returns (bytes4) + { + afterInitializeCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.afterInitialize.selector; + } + + function beforeAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4) { + beforeAddLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeAddLiquidity.selector; + } + + function afterAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, BalanceDelta) { + afterAddLiquidityCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4) { + beforeRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeRemoveLiquidity.selector; + } + + function afterRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, BalanceDelta) { + afterRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata hookData) + external + override + onlyByMiddleware + returns (bytes4, BeforeSwapDelta, uint24) + { + beforeSwapCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap( + address, + PoolKey calldata key, + IPoolManager.SwapParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, int128) { + afterSwapCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.afterSwap.selector, 0); + } + + function beforeDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) + external + override + onlyByMiddleware + returns (bytes4) + { + beforeDonateCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeDonate.selector; + } + + function afterDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) + external + override + onlyByMiddleware + returns (bytes4) + { + afterDonateCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.afterDonate.selector; + } +} diff --git a/test/utils/HookMiner.sol b/test/utils/HookMiner.sol new file mode 100644 index 00000000..d6b30c40 --- /dev/null +++ b/test/utils/HookMiner.sol @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.21; + +/// @title HookMiner - a library for mining hook addresses +/// @dev This library is intended for `forge test` environments. There may be gotchas when using salts in `forge script` or `forge create` +library HookMiner { + // mask to slice out the bottom 14 bit of the address + uint160 constant FLAG_MASK = 0x3FFF; + + // Maximum number of iterations to find a salt, avoid infinite loops + uint256 constant MAX_LOOP = 100_000; + + /// @notice Find a salt that produces a hook address with the desired `flags` + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param flags The desired flags for the hook address + /// @param creationCode The creation code of a hook contract. Example: `type(Counter).creationCode` + /// @param constructorArgs The encoded constructor arguments of a hook contract. Example: `abi.encode(address(manager))` + /// @return hookAddress salt and corresponding address that was found. The salt can be used in `new Hook{salt: salt}()` + function find(address deployer, uint160 flags, bytes memory creationCode, bytes memory constructorArgs) + internal + view + returns (address, bytes32) + { + address hookAddress; + bytes memory creationCodeWithArgs = abi.encodePacked(creationCode, constructorArgs); + + uint256 salt; + for (salt; salt < MAX_LOOP; salt++) { + hookAddress = computeAddress(deployer, salt, creationCodeWithArgs); + if (uint160(hookAddress) & FLAG_MASK == flags && hookAddress.code.length == 0) { + return (hookAddress, bytes32(salt)); + } + } + revert("HookMiner: could not find salt"); + } + + /// @notice Precompute a contract address deployed via CREATE2 + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param salt The salt used to deploy the hook + /// @param creationCode The creation code of a hook contract + function computeAddress(address deployer, uint256 salt, bytes memory creationCode) + internal + pure + returns (address hookAddress) + { + return address( + uint160(uint256(keccak256(abi.encodePacked(bytes1(0xFF), deployer, salt, keccak256(creationCode))))) + ); + } +}