From 0bc44d3ea2c6883ec72c3834f38cacc2c0db843a Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:06:03 -0400 Subject: [PATCH 1/8] create basemiddleware --- contracts/BaseHook.sol | 10 +- contracts/interfaces/IBaseHook.sol | 9 ++ contracts/middleware/BaseImplementation.sol | 143 +++++++++++++++++++ contracts/middleware/BaseMiddleware.sol | 144 ++++++++++++++++++++ contracts/middleware/Forwarder.sol | 44 ++++++ contracts/middleware/MiddlewareProtect.sol | 96 +++++++++++++ lib/forge-gas-snapshot | 2 +- lib/forge-std | 2 +- 8 files changed, 443 insertions(+), 7 deletions(-) create mode 100644 contracts/interfaces/IBaseHook.sol create mode 100644 contracts/middleware/BaseImplementation.sol create mode 100644 contracts/middleware/BaseMiddleware.sol create mode 100644 contracts/middleware/Forwarder.sol create mode 100644 contracts/middleware/MiddlewareProtect.sol diff --git a/contracts/BaseHook.sol b/contracts/BaseHook.sol index 01fc4954..8becad1a 100644 --- a/contracts/BaseHook.sol +++ b/contracts/BaseHook.sol @@ -72,22 +72,22 @@ abstract contract BaseHook is IHooks, SafeCallback { revert HookNotImplemented(); } - function beforeRemoveLiquidity( + function afterAddLiquidity( address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, bytes calldata - ) external virtual returns (bytes4) { + ) external virtual returns (bytes4, BalanceDelta) { revert HookNotImplemented(); } - function afterAddLiquidity( + function beforeRemoveLiquidity( address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, bytes calldata - ) external virtual returns (bytes4, BalanceDelta) { + ) external virtual returns (bytes4) { revert HookNotImplemented(); } diff --git a/contracts/interfaces/IBaseHook.sol b/contracts/interfaces/IBaseHook.sol new file mode 100644 index 00000000..7f404cf2 --- /dev/null +++ b/contracts/interfaces/IBaseHook.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; + +interface IBaseHook is IHooks { + function getHookPermissions() external pure returns (Hooks.Permissions memory); +} diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol new file mode 100644 index 00000000..1dfee1a8 --- /dev/null +++ b/contracts/middleware/BaseImplementation.sol @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {SafeCallback} from "./../base/SafeCallback.sol"; +import {ImmutableState} from "./../base/ImmutableState.sol"; + +abstract contract BaseImplementation is IHooks, SafeCallback { + error NotSelf(); + error InvalidPool(); + error LockFailure(); + error HookNotImplemented(); + error NotMiddleware(); + error NotMiddlewareFactory(); + + address public immutable middlewareFactory; + address public middleware; + + constructor(IPoolManager _manager, address _middlewareFactory) ImmutableState(_manager) { + middlewareFactory = _middlewareFactory; + } + + /// @dev Only this address may call this function + modifier selfOnly() { + if (msg.sender != address(this)) revert NotSelf(); + _; + } + + /// @dev Only pools with hooks set to this contract may call this function + modifier onlyValidPools(IHooks hooks) { + if (hooks != this) revert InvalidPool(); + _; + } + + modifier onlyByMiddleware() { + if (msg.sender != middleware) revert NotMiddleware(); + _; + } + + function initializeMiddleware(address _middleware) external { + if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); + middleware = _middleware; + } + + function getHookPermissions() public pure virtual returns (Hooks.Permissions memory); + + function _unlockCallback(bytes calldata data) internal virtual override returns (bytes memory) { + (bool success, bytes memory returnData) = address(this).call(data); + if (success) return returnData; + if (returnData.length == 0) revert LockFailure(); + // if the call failed, bubble up the reason + /// @solidity memory-safe-assembly + assembly { + revert(add(returnData, 32), mload(returnData)) + } + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { + revert HookNotImplemented(); + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external virtual returns (bytes4) { + revert HookNotImplemented(); + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + revert HookNotImplemented(); + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + revert HookNotImplemented(); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + virtual + returns (bytes4, BeforeSwapDelta, uint24) + { + revert HookNotImplemented(); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + virtual + returns (bytes4, int128) + { + revert HookNotImplemented(); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } +} diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol new file mode 100644 index 00000000..9a68ddf5 --- /dev/null +++ b/contracts/middleware/BaseMiddleware.sol @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {IBaseHook} from "./../interfaces/IBaseHook.sol"; +import {BaseHook} from "./../BaseHook.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; + +abstract contract BaseMiddleware is IHooks { + error NotManager(); + + IPoolManager public immutable manager; + IBaseHook public immutable implementation; + + constructor(IPoolManager _manager, IBaseHook _implementation) { + manager = _manager; + implementation = _implementation; + } + + modifier onlyByManager() { + if (msg.sender != address(manager)) revert NotManager(); + _; + } + + function getHookPermissions() public view returns (Hooks.Permissions memory) { + return implementation.getHookPermissions(); + } + + function beforeInitialize(address sender, PoolKey calldata key, uint160 sqrtPriceX96, bytes calldata hookData) + external + virtual + onlyByManager + returns (bytes4) + { + if (msg.sender == address(implementation)) return BaseHook.beforeInitialize.selector; + return implementation.beforeInitialize(sender, key, sqrtPriceX96, hookData); + } + + function afterInitialize( + address sender, + PoolKey calldata key, + uint160 sqrtPriceX96, + int24 tick, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.afterInitialize.selector; + return implementation.afterInitialize(sender, key, sqrtPriceX96, tick, hookData); + } + + function beforeAddLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.beforeAddLiquidity.selector; + return implementation.beforeAddLiquidity(sender, key, params, hookData); + } + + function afterAddLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, BalanceDelta) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterAddLiquidity.selector, BalanceDelta.ZERO_DELTA); + } + return implementation.afterAddLiquidity(sender, key, params, delta, hookData); + } + + function beforeRemoveLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.beforeRemoveLiquidity.selector; + return implementation.beforeRemoveLiquidity(sender, key, params, hookData); + } + + function afterRemoveLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, BalanceDelta) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterRemoveLiquidity.selector, BalanceDelta.ZERO_DELTA); + } + return implementation.afterRemoveLiquidity(sender, key, params, delta, hookData); + } + + function beforeSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, BeforeSwapDelta, uint24) { + if (msg.sender == address(implementation)) return (BaseHook.beforeSwap.selector, BeforeSwapDelta.ZERO_DELTA, 0); + return implementation.beforeSwap(sender, key, params, hookData); + } + + function afterSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4, int128) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterSwap.selector, 0); + } + return implementation.afterSwap(sender, key, params, delta, hookData); + } + + function beforeDonate( + address sender, + PoolKey calldata key, + uint256 amount0, + uint256 amount1, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.beforeDonate.selector; + return implementation.beforeDonate(sender, key, amount0, amount1, hookData); + } + + function afterDonate( + address sender, + PoolKey calldata key, + uint256 amount0, + uint256 amount1, + bytes calldata hookData + ) external virtual onlyByManager returns (bytes4) { + if (msg.sender == address(implementation)) return BaseHook.afterDonate.selector; + return implementation.afterDonate(sender, key, amount0, amount1, hookData); + } +} diff --git a/contracts/middleware/Forwarder.sol b/contracts/middleware/Forwarder.sol new file mode 100644 index 00000000..31afb012 --- /dev/null +++ b/contracts/middleware/Forwarder.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: UNLICENSED +// openzeppelin proxy but call instead of delegatecall + +pragma solidity ^0.8.24; + +/** + * @dev This abstract contract provides a fallback function that forwards all calls to another contract + */ +abstract contract Forwarder { + function _forward(address implementation) internal virtual { + assembly { + // Copy msg.data. We take full control of memory in this inline assembly + // block because it will not return to Solidity code. We overwrite the + // Solidity scratch pad at memory position 0. + calldatacopy(0, 0, calldatasize()) + + // Call the implementation. + // out and outsize are 0 because we don't know the size yet. + let result := call(gas(), implementation, 0, 0, calldatasize(), 0, 0) + + // Copy the returned data. + returndatacopy(0, 0, returndatasize()) + + switch result + // call returns 0 on error. + case 0 { revert(0, returndatasize()) } + default { return(0, returndatasize()) } + } + } + + /** + * @dev This is a virtual function that should be overridden so it returns the address to which the fallback function + * and {_fallback} should delegate. + */ + function _implementation() internal view virtual returns (address); + + /** + * @dev Fallback function that delegates calls to the address returned by `_implementation()`. Will run if no other + * function in the contract matches the call data. + */ + fallback() external payable virtual { + _forward(_implementation()); + } +} diff --git a/contracts/middleware/MiddlewareProtect.sol b/contracts/middleware/MiddlewareProtect.sol new file mode 100644 index 00000000..8c828c5e --- /dev/null +++ b/contracts/middleware/MiddlewareProtect.sol @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {console} from "../../lib/forge-std/src/console.sol"; +import {BaseHook} from "./../BaseHook.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract MiddlewareProtect is BaseMiddleware { + using StateLibrary for IPoolManager; + using Hooks for IHooks; + + /// @notice Thrown if the address will lead to forbidden flags being set + /// @param hooks The address of the hooks contract + error HookPermissionForbidden(address hooks); + error ForbiddenReturn(); + + uint256 public constant gasLimit = 1000000; + + constructor(IPoolManager _poolManager, IHooks _implementation) BaseMiddleware(_poolManager, _implementation) { + // deny any hooks that return deltas + if ( + _implementation.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG) + || _implementation.hasPermission(Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG) + || _implementation.hasPermission(Hooks.AFTER_ADD_LIQUIDITY_RETURNS_DELTA_FLAG) + || _implementation.hasPermission(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG) + ) { + HookPermissionForbidden.selector.revertWith(address(this)); + } + } + + // block swaps and removes + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + swapNotLocked + returns (bytes4, BeforeSwapDelta, uint24) + { + ReentrancyState.lockSwapRemove(); + console.log("beforeSwap middleware"); + (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) = + implementation.beforeSwap(sender, key, params, hookData); + if (lpFeeOverride != 0) { + revert ForbiddenReturn(); + } + return (selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + // afterSwap - no protections + + // block swaps + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + returns (bytes4) + { + ReentrancyState.lockSwap(); + console.log("beforeAddLiquidity middleware"); + selector = implementation.beforeSwap(sender, key, params, hookData); + ReentrancyState.unlock(); + return selector; + } + + // afterAddLiquidity - no protections + + // block swaps and reverts + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external removeNotLocked returns (bytes4) { + ReentrancyState.lockSwap(); + console.log("beforeRemoveLiquidity middleware"); + implementation.call{gas: gasLimit}(msg.data); + ReentrancyState.unlock(); + return BaseHook.beforeRemoveLiquidity.selector; + } + + // block reverts + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + console.log("afterRemoveLiquidity middleware"); + implementation.delegatecall{gas: gasLimit}(msg.data); + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } +} diff --git a/lib/forge-gas-snapshot b/lib/forge-gas-snapshot index 2f884282..9161f7c0 160000 --- a/lib/forge-gas-snapshot +++ b/lib/forge-gas-snapshot @@ -1 +1 @@ -Subproject commit 2f884282b4cd067298e798974f5b534288b13bc2 +Subproject commit 9161f7c0b6c6788a89081e2b3b9c67592b71e689 diff --git a/lib/forge-std b/lib/forge-std index 2b58ecbc..75b3fcf0 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 2b58ecbcf3dfde7a75959dc7b4eb3d0670278de6 +Subproject commit 75b3fcf052cc7886327e4c2eac3d1a1f36942b41 From 738a751f36e8b660d361316da2727fadd0262b8d Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Mon, 8 Jul 2024 01:00:41 -0400 Subject: [PATCH 2/8] fix compilation error --- contracts/hooks/examples/FeeTaker.sol | 104 ++++++++++++++ contracts/hooks/examples/FeeTaking.sol | 35 +++++ contracts/interfaces/IMiddlewareFactory.sol | 17 +++ contracts/libraries/ReentrancyState.sol | 39 ++++++ contracts/middleware/BaseMiddleware.sol | 12 +- contracts/middleware/MiddlewareProtect.sol | 85 ++++++++---- .../middleware/MiddlewareProtectFactory.sol | 27 ++++ test/MiddlewareProtectFactory.t.sol | 112 +++++++++++++++ test/middleware/HooksFrontrun.sol | 96 +++++++++++++ test/middleware/HooksOutOfGas.sol | 129 ++++++++++++++++++ test/middleware/HooksRevert.sol | 125 +++++++++++++++++ .../HooksFrontrunImplementation.sol | 16 +++ test/utils/HookMiner.sol | 52 +++++++ 13 files changed, 819 insertions(+), 30 deletions(-) create mode 100644 contracts/hooks/examples/FeeTaker.sol create mode 100644 contracts/hooks/examples/FeeTaking.sol create mode 100644 contracts/interfaces/IMiddlewareFactory.sol create mode 100644 contracts/libraries/ReentrancyState.sol create mode 100644 contracts/middleware/MiddlewareProtectFactory.sol create mode 100644 test/MiddlewareProtectFactory.t.sol create mode 100644 test/middleware/HooksFrontrun.sol create mode 100644 test/middleware/HooksOutOfGas.sol create mode 100644 test/middleware/HooksRevert.sol create mode 100644 test/shared/implementation/HooksFrontrunImplementation.sol create mode 100644 test/utils/HookMiner.sol diff --git a/contracts/hooks/examples/FeeTaker.sol b/contracts/hooks/examples/FeeTaker.sol new file mode 100644 index 00000000..91d1fb7d --- /dev/null +++ b/contracts/hooks/examples/FeeTaker.sol @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "../../BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; + +abstract contract FeeTaker is BaseHook { + using SafeCast for uint256; + + bytes internal constant ZERO_BYTES = bytes(""); + + constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + + /** + * @notice This hook takes a fee from the unspecified token after a swap. + * @dev This can be overridden if more permissions are needed. + */ + function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: false, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: true, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function afterSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, int128) { + //(Currency currencyUnspecified, amountUnspecified) = key.getUnspecified(params); + + // fee will be in the unspecified token of the swap + bool currency0Specified = (params.amountSpecified < 0 == params.zeroForOne); + (Currency currencyUnspecified, int128 amountUnspecified) = + (currency0Specified) ? (key.currency1, delta.amount1()) : (key.currency0, delta.amount0()); + // if exactOutput swap, get the absolute output amount + if (amountUnspecified < 0) amountUnspecified = -amountUnspecified; + + uint256 feeAmount = _feeAmount(amountUnspecified); + // mint ERC6909 instead of take to avoid edge case where PM doesn't have enough balance + manager.mint(address(this), CurrencyLibrary.toId(currencyUnspecified), feeAmount); + + (bytes4 selector, int128 amount) = _afterSwap(sender, key, params, delta, hookData); + return (selector, feeAmount.toInt128() + amount); + } + + function withdraw(Currency[] calldata currencies) external { + manager.unlock(abi.encode(currencies)); + } + + function _unlockCallback(bytes calldata rawData) internal override returns (bytes memory) { + Currency[] memory currencies = abi.decode(rawData, (Currency[])); + uint256 length = currencies.length; + for (uint256 i = 0; i < length;) { + uint256 amount = manager.balanceOf(address(this), CurrencyLibrary.toId(currencies[i])); + manager.burn(address(this), CurrencyLibrary.toId(currencies[i]), amount); + manager.take(currencies[i], _recipient(), amount); + unchecked { + ++i; + } + } + return ZERO_BYTES; + } + + /** + * @dev This is a virtual function that should be overridden so it returns the fee charged for a given amount. + */ + function _feeAmount(int128 amountUnspecified) internal view virtual returns (uint256); + + /** + * @dev This is a virtual function that should be overridden so it returns the address to receive the fee. + */ + function _recipient() internal view virtual returns (address); + + /** + * @dev This can be overridden to add logic after a swap. + */ + function _afterSwap(address, PoolKey memory, IPoolManager.SwapParams memory, BalanceDelta, bytes calldata) + internal + virtual + returns (bytes4, int128) + { + return (BaseHook.afterSwap.selector, 0); + } +} diff --git a/contracts/hooks/examples/FeeTaking.sol b/contracts/hooks/examples/FeeTaking.sol new file mode 100644 index 00000000..c0099149 --- /dev/null +++ b/contracts/hooks/examples/FeeTaking.sol @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {FeeTaker} from "./FeeTaker.sol"; + +contract FeeTaking is FeeTaker, Owned { + using SafeCast for uint256; + + uint128 private constant TOTAL_BIPS = 10000; + uint128 public immutable swapFeeBips; + address public treasury; + + constructor(IPoolManager _poolManager, uint128 _swapFeeBips, address _owner, address _treasury) + FeeTaker(_poolManager) + Owned(_owner) + { + swapFeeBips = _swapFeeBips; + treasury = _treasury; + } + + function setTreasury(address _treasury) external onlyOwner { + treasury = _treasury; + } + + function _feeAmount(int128 amountUnspecified) internal view override returns (uint256) { + return uint128(amountUnspecified) * swapFeeBips / TOTAL_BIPS; + } + + function _recipient() internal view override returns (address) { + return treasury; + } +} diff --git a/contracts/interfaces/IMiddlewareFactory.sol b/contracts/interfaces/IMiddlewareFactory.sol new file mode 100644 index 00000000..4af554fe --- /dev/null +++ b/contracts/interfaces/IMiddlewareFactory.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +interface IMiddlewareFactory { + event MiddlewareCreated(address implementation, address middleware); + + /// @notice Returns the implementation address for a given middleware + /// @param middleware The middleware address + /// @return implementation The implementation address + function getImplementation(address middleware) external view returns (address implementation); + + /// @notice Creates a middleware for the given implementation + /// @param implementation The implementation address + /// @param salt The salt to use to deploy the middleware + /// @return middleware The address of the newly created middleware + function createMiddleware(address implementation, bytes32 salt) external returns (address middleware); +} diff --git a/contracts/libraries/ReentrancyState.sol b/contracts/libraries/ReentrancyState.sol new file mode 100644 index 00000000..966d1269 --- /dev/null +++ b/contracts/libraries/ReentrancyState.sol @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +library ReentrancyState { + // bytes32(uint256(keccak256("ReentrancyState")) - 1) + bytes32 constant REENTRANCY_STATE_SLOT = 0xbedc9a60a226d4ae7b727cbc828f66c94c4eead57777428ceab2f04b0efca3a5; + + function unlock() internal { + assembly { + tstore(REENTRANCY_STATE_SLOT, 0) + } + } + + function lockSwap() internal { + assembly { + tstore(REENTRANCY_STATE_SLOT, 1) + } + } + + function lockSwapRemove() internal { + assembly { + tstore(REENTRANCY_STATE_SLOT, 2) + } + } + + function read() internal view returns (uint256 state) { + assembly { + state := tload(REENTRANCY_STATE_SLOT) + } + } + + function swapLocked() internal view returns (bool) { + return read() == 1 || read() == 2; + } + + function removeLocked() internal view returns (bool) { + return read() == 2; + } +} diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol index 9a68ddf5..3fcd5ab1 100644 --- a/contracts/middleware/BaseMiddleware.sol +++ b/contracts/middleware/BaseMiddleware.sol @@ -6,9 +6,9 @@ import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; import {IBaseHook} from "./../interfaces/IBaseHook.sol"; import {BaseHook} from "./../BaseHook.sol"; -import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; abstract contract BaseMiddleware is IHooks { error NotManager(); @@ -69,7 +69,7 @@ abstract contract BaseMiddleware is IHooks { bytes calldata hookData ) external virtual onlyByManager returns (bytes4, BalanceDelta) { if (msg.sender == address(implementation)) { - return (BaseHook.afterAddLiquidity.selector, BalanceDelta.ZERO_DELTA); + return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } return implementation.afterAddLiquidity(sender, key, params, delta, hookData); } @@ -92,7 +92,7 @@ abstract contract BaseMiddleware is IHooks { bytes calldata hookData ) external virtual onlyByManager returns (bytes4, BalanceDelta) { if (msg.sender == address(implementation)) { - return (BaseHook.afterRemoveLiquidity.selector, BalanceDelta.ZERO_DELTA); + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } return implementation.afterRemoveLiquidity(sender, key, params, delta, hookData); } @@ -103,7 +103,9 @@ abstract contract BaseMiddleware is IHooks { IPoolManager.SwapParams calldata params, bytes calldata hookData ) external virtual onlyByManager returns (bytes4, BeforeSwapDelta, uint24) { - if (msg.sender == address(implementation)) return (BaseHook.beforeSwap.selector, BeforeSwapDelta.ZERO_DELTA, 0); + if (msg.sender == address(implementation)) { + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } return implementation.beforeSwap(sender, key, params, hookData); } diff --git a/contracts/middleware/MiddlewareProtect.sol b/contracts/middleware/MiddlewareProtect.sol index 8c828c5e..bc39b752 100644 --- a/contracts/middleware/MiddlewareProtect.sol +++ b/contracts/middleware/MiddlewareProtect.sol @@ -9,58 +9,93 @@ import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/type import {console} from "../../lib/forge-std/src/console.sol"; import {BaseHook} from "./../BaseHook.sol"; import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; -import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {IBaseHook} from "./../interfaces/IBaseHook.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {ReentrancyState} from "./../libraries/ReentrancyState.sol"; +import {LPFeeLibrary} from "@uniswap/v4-core/src/libraries/LPFeeLibrary.sol"; +import {CustomRevert} from "@uniswap/v4-core/src/libraries/CustomRevert.sol"; contract MiddlewareProtect is BaseMiddleware { + using CustomRevert for bytes4; using StateLibrary for IPoolManager; - using Hooks for IHooks; + using LPFeeLibrary for uint24; /// @notice Thrown if the address will lead to forbidden flags being set /// @param hooks The address of the hooks contract error HookPermissionForbidden(address hooks); error ForbiddenReturn(); + error InvalidFee(); + error ActionBetweenHook(); uint256 public constant gasLimit = 1000000; - constructor(IPoolManager _poolManager, IHooks _implementation) BaseMiddleware(_poolManager, _implementation) { + constructor(IPoolManager _poolManager, IBaseHook _implementation) BaseMiddleware(_poolManager, _implementation) { + Hooks.Permissions memory permissions = _implementation.getHookPermissions(); // deny any hooks that return deltas if ( - _implementation.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG) - || _implementation.hasPermission(Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG) - || _implementation.hasPermission(Hooks.AFTER_ADD_LIQUIDITY_RETURNS_DELTA_FLAG) - || _implementation.hasPermission(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG) + permissions.beforeSwapReturnDelta || permissions.afterSwapReturnDelta + || permissions.afterAddLiquidityReturnDelta || permissions.afterRemoveLiquidityReturnDelta ) { HookPermissionForbidden.selector.revertWith(address(this)); } } - // block swaps and removes - function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + modifier swapNotLocked() { + if (ReentrancyState.swapLocked()) { + revert ActionBetweenHook(); + } + _; + } + + modifier removeNotLocked() { + if (ReentrancyState.removeLocked()) { + revert ActionBetweenHook(); + } + _; + } + + function beforeInitialize(address sender, PoolKey calldata key, uint160 sqrtPriceX96, bytes calldata hookData) external - swapNotLocked - returns (bytes4, BeforeSwapDelta, uint24) + override + onlyByManager + returns (bytes4) { + if (key.fee.isDynamicFee()) revert InvalidFee(); + if (msg.sender == address(implementation)) { + return BaseHook.beforeInitialize.selector; + } + return implementation.beforeInitialize(sender, key, sqrtPriceX96, hookData); + } + + // block swaps and removes + function beforeSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + bytes calldata hookData + ) external override swapNotLocked returns (bytes4, BeforeSwapDelta, uint24) { ReentrancyState.lockSwapRemove(); console.log("beforeSwap middleware"); - (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) = - implementation.beforeSwap(sender, key, params, hookData); - if (lpFeeOverride != 0) { - revert ForbiddenReturn(); + if (msg.sender == address(implementation)) { + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } - return (selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + (bytes4 selector, BeforeSwapDelta delta, uint24 fee) = implementation.beforeSwap(sender, key, params, hookData); + ReentrancyState.unlock(); + return (selector, delta, fee); } // afterSwap - no protections // block swaps - function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) - external - returns (bytes4) - { + function beforeAddLiquidity( + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external override returns (bytes4) { ReentrancyState.lockSwap(); console.log("beforeAddLiquidity middleware"); - selector = implementation.beforeSwap(sender, key, params, hookData); + bytes4 selector = implementation.beforeAddLiquidity(sender, key, params, hookData); ReentrancyState.unlock(); return selector; } @@ -73,10 +108,10 @@ contract MiddlewareProtect is BaseMiddleware { PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata - ) external removeNotLocked returns (bytes4) { + ) external override removeNotLocked returns (bytes4) { ReentrancyState.lockSwap(); console.log("beforeRemoveLiquidity middleware"); - implementation.call{gas: gasLimit}(msg.data); + address(implementation).call{gas: gasLimit}(msg.data); ReentrancyState.unlock(); return BaseHook.beforeRemoveLiquidity.selector; } @@ -88,9 +123,9 @@ contract MiddlewareProtect is BaseMiddleware { IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, bytes calldata - ) external returns (bytes4, BalanceDelta) { + ) external override returns (bytes4, BalanceDelta) { console.log("afterRemoveLiquidity middleware"); - implementation.delegatecall{gas: gasLimit}(msg.data); + address(implementation).call{gas: gasLimit}(msg.data); return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } } diff --git a/contracts/middleware/MiddlewareProtectFactory.sol b/contracts/middleware/MiddlewareProtectFactory.sol new file mode 100644 index 00000000..7e7e225b --- /dev/null +++ b/contracts/middleware/MiddlewareProtectFactory.sol @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; +import {MiddlewareProtect} from "./MiddlewareProtect.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IBaseHook} from "../interfaces/IBaseHook.sol"; + +contract MiddlewareProtectFactory is IMiddlewareFactory { + mapping(address => address) private _implementations; + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getImplementation(address middleware) external view override returns (address implementation) { + return _implementations[middleware]; + } + + function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { + middleware = address(new MiddlewareProtect{salt: salt}(poolManager, IBaseHook(implementation))); + _implementations[middleware] = implementation; + emit MiddlewareCreated(implementation, middleware); + } +} diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol new file mode 100644 index 00000000..52fed845 --- /dev/null +++ b/test/MiddlewareProtectFactory.t.sol @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {HooksFrontrun} from "./middleware/HooksFrontrun.sol"; +import {HooksFrontrunImplementation} from "./shared/implementation/HooksFrontrunImplementation.sol"; +import {MiddlewareProtect} from "../contracts/middleware/MiddlewareProtect.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {HooksRevert} from "./middleware/HooksRevert.sol"; +import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; +import {MiddlewareProtectFactory} from "./../contracts/middleware/MiddlewareProtectFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; + +contract MiddlewareProtectFactoryTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + PoolId id; + + MiddlewareProtectFactory factory; + HooksFrontrun hooksFrontrun; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + hooksFrontrun = HooksFrontrun(address(uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG))); + vm.record(); + HooksFrontrunImplementation impl = new HooksFrontrunImplementation(manager, hooksFrontrun); + (, bytes32[] memory writes) = vm.accesses(address(impl)); + vm.etch(address(hooksFrontrun), address(impl).code); + unchecked { + for (uint256 i = 0; i < writes.length; i++) { + bytes32 slot = writes[i]; + vm.store(address(hooksFrontrun), slot, vm.load(address(impl), slot)); + } + } + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + + factory = new MiddlewareProtectFactory(manager); + } + + function testFrontrun() public { + (PoolKey memory key,) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(address(0)), 100, SQRT_PRICE_1_1, ZERO_BYTES); + BalanceDelta swapDelta = swap(key, true, 0.001 ether, ZERO_BYTES); + + (key,) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(hooksFrontrun)), 100, SQRT_PRICE_1_1, ZERO_BYTES + ); + BalanceDelta swapDelta2 = swap(key, true, 0.001 ether, ZERO_BYTES); + + // while both swaps are in the same pool, the second swap is more expensive + assertEq(swapDelta.amount1(), swapDelta2.amount1()); + assertTrue(abs(swapDelta.amount0()) < abs(swapDelta2.amount0())); + assertTrue(manager.balanceOf(address(hooksFrontrun), CurrencyLibrary.toId(key.currency0)) > 0); + } + + function testVariousProtectFactory() public { + uint160 flags = uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareProtect).creationCode, + abi.encode(address(manager), address(hooksFrontrun)) + ); + testOn(address(hooksFrontrun), salt); + } + + // creates a middleware on an implementation + function testOn(address implementation, bytes32 salt) internal { + address hookAddress = factory.createMiddleware(implementation, salt); + MiddlewareProtect middlewareProtect = MiddlewareProtect(payable(hookAddress)); + + (key,) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(middlewareProtect)), 100, SQRT_PRICE_1_1, ZERO_BYTES + ); + swap(key, true, 0.001 ether, ZERO_BYTES); + //vm.expectRevert(); + } + + function abs(int256 x) internal pure returns (uint256) { + return x >= 0 ? uint256(x) : uint256(-x); + } +} diff --git a/test/middleware/HooksFrontrun.sol b/test/middleware/HooksFrontrun.sol new file mode 100644 index 00000000..c0d23a8f --- /dev/null +++ b/test/middleware/HooksFrontrun.sol @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {console} from "forge-std/console.sol"; +import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; + +contract HooksFrontrun is BaseHook { + using SafeCast for uint256; + + bytes internal constant ZERO_BYTES = bytes(""); + uint160 public constant MIN_PRICE_LIMIT = TickMath.MIN_SQRT_PRICE + 1; + uint160 public constant MAX_PRICE_LIMIT = TickMath.MAX_SQRT_PRICE - 1; + + BalanceDelta swapDelta; + IPoolManager.SwapParams swapParams; + + constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + + function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: true, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) + external + override + onlyByManager + returns (bytes4, BeforeSwapDelta, uint24) + { + swapParams = params; + console.log(params.zeroForOne); + console.logInt(params.amountSpecified); + swapDelta = manager.swap(key, params, ZERO_BYTES); + console.log("beforeDelta"); + console.logInt(swapDelta.amount0()); + console.logInt(swapDelta.amount1()); + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, int128) { + BalanceDelta afterDelta = manager.swap( + key, + IPoolManager.SwapParams( + !swapParams.zeroForOne, + -swapParams.amountSpecified, + swapParams.zeroForOne ? MAX_PRICE_LIMIT : MIN_PRICE_LIMIT + ), + ZERO_BYTES + ); + if (swapParams.zeroForOne) { + int256 profit = afterDelta.amount0() + swapDelta.amount0(); + if (profit > 0) { + // else hook reverts + manager.mint(address(this), key.currency0.toId(), uint256(profit)); + } + } else { + int256 profit = afterDelta.amount1() + swapDelta.amount1(); + if (profit > 0) { + // else hook reverts + manager.mint(address(this), key.currency1.toId(), uint256(profit)); + } + } + console.log("afterDelta"); + console.logInt(afterDelta.amount0()); + console.logInt(afterDelta.amount1()); + return (BaseHook.afterSwap.selector, 0); + } +} diff --git a/test/middleware/HooksOutOfGas.sol b/test/middleware/HooksOutOfGas.sol new file mode 100644 index 00000000..d7b35a37 --- /dev/null +++ b/test/middleware/HooksOutOfGas.sol @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract HooksOutOfGas { + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getHookPermissions() public pure returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { + consumeAllGas(); + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external returns (bytes4) { + consumeAllGas(); + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + consumeAllGas(); + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + consumeAllGas(); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + virtual + returns (bytes4, BeforeSwapDelta, uint24) + { + consumeAllGas(); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + virtual + returns (bytes4, int128) + { + consumeAllGas(); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function consumeAllGas() internal view { + while (true) { + //console.log(gasleft()); + // This loop will run indefinitely and consume all available gas. + } + } +} diff --git a/test/middleware/HooksRevert.sol b/test/middleware/HooksRevert.sol new file mode 100644 index 00000000..f3f31401 --- /dev/null +++ b/test/middleware/HooksRevert.sol @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract HooksRevert { + error HookNotImplemented(); + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getHookPermissions() public pure returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { + revert HookNotImplemented(); + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external returns (bytes4) { + revert HookNotImplemented(); + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + revert HookNotImplemented(); + } + + function afterRemoveLiquidity( + address sender, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + require(sender == address(0), "nobody can remove"); + return (BaseHook.beforeRemoveLiquidity.selector, toBalanceDelta(0, 0)); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + virtual + returns (bytes4, BeforeSwapDelta, uint24) + { + revert HookNotImplemented(); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + virtual + returns (bytes4, int128) + { + revert HookNotImplemented(); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } +} diff --git a/test/shared/implementation/HooksFrontrunImplementation.sol b/test/shared/implementation/HooksFrontrunImplementation.sol new file mode 100644 index 00000000..104a07b7 --- /dev/null +++ b/test/shared/implementation/HooksFrontrunImplementation.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {BaseHook} from "../../../contracts/BaseHook.sol"; +import {HooksFrontrun} from "../../middleware/HooksFrontrun.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract HooksFrontrunImplementation is HooksFrontrun { + constructor(IPoolManager _poolManager, HooksFrontrun addressToEtch) HooksFrontrun(_poolManager) { + Hooks.validateHookPermissions(addressToEtch, getHookPermissions()); + } + + // make this a no-op in testing + function validateHookAddress(BaseHook _this) internal pure override {} +} diff --git a/test/utils/HookMiner.sol b/test/utils/HookMiner.sol new file mode 100644 index 00000000..d6b30c40 --- /dev/null +++ b/test/utils/HookMiner.sol @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.21; + +/// @title HookMiner - a library for mining hook addresses +/// @dev This library is intended for `forge test` environments. There may be gotchas when using salts in `forge script` or `forge create` +library HookMiner { + // mask to slice out the bottom 14 bit of the address + uint160 constant FLAG_MASK = 0x3FFF; + + // Maximum number of iterations to find a salt, avoid infinite loops + uint256 constant MAX_LOOP = 100_000; + + /// @notice Find a salt that produces a hook address with the desired `flags` + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param flags The desired flags for the hook address + /// @param creationCode The creation code of a hook contract. Example: `type(Counter).creationCode` + /// @param constructorArgs The encoded constructor arguments of a hook contract. Example: `abi.encode(address(manager))` + /// @return hookAddress salt and corresponding address that was found. The salt can be used in `new Hook{salt: salt}()` + function find(address deployer, uint160 flags, bytes memory creationCode, bytes memory constructorArgs) + internal + view + returns (address, bytes32) + { + address hookAddress; + bytes memory creationCodeWithArgs = abi.encodePacked(creationCode, constructorArgs); + + uint256 salt; + for (salt; salt < MAX_LOOP; salt++) { + hookAddress = computeAddress(deployer, salt, creationCodeWithArgs); + if (uint160(hookAddress) & FLAG_MASK == flags && hookAddress.code.length == 0) { + return (hookAddress, bytes32(salt)); + } + } + revert("HookMiner: could not find salt"); + } + + /// @notice Precompute a contract address deployed via CREATE2 + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param salt The salt used to deploy the hook + /// @param creationCode The creation code of a hook contract + function computeAddress(address deployer, uint256 salt, bytes memory creationCode) + internal + pure + returns (address hookAddress) + { + return address( + uint160(uint256(keccak256(abi.encodePacked(bytes1(0xFF), deployer, salt, keccak256(creationCode))))) + ); + } +} From 3406833d294710c6d62d3a1a299a7be550dcafe5 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Mon, 8 Jul 2024 01:10:10 -0400 Subject: [PATCH 3/8] hook correctly errors --- contracts/interfaces/IBaseImplementation.sol | 6 ++++++ contracts/middleware/BaseImplementation.sol | 2 ++ .../middleware/MiddlewareProtectFactory.sol | 2 ++ test/MiddlewareProtectFactory.t.sol | 7 +++---- test/middleware/HooksFrontrun.sol | 7 ++++--- .../HooksFrontrunImplementation.sol | 16 ---------------- 6 files changed, 17 insertions(+), 23 deletions(-) create mode 100644 contracts/interfaces/IBaseImplementation.sol delete mode 100644 test/shared/implementation/HooksFrontrunImplementation.sol diff --git a/contracts/interfaces/IBaseImplementation.sol b/contracts/interfaces/IBaseImplementation.sol new file mode 100644 index 00000000..f05f9238 --- /dev/null +++ b/contracts/interfaces/IBaseImplementation.sol @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +interface IBaseImplementation { + function initializeMiddleware(address _middleware) external; +} diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol index 1dfee1a8..5a38569f 100644 --- a/contracts/middleware/BaseImplementation.sol +++ b/contracts/middleware/BaseImplementation.sol @@ -9,6 +9,7 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; import {SafeCallback} from "./../base/SafeCallback.sol"; import {ImmutableState} from "./../base/ImmutableState.sol"; +import {console} from "forge-std/console.sol"; abstract contract BaseImplementation is IHooks, SafeCallback { error NotSelf(); @@ -44,6 +45,7 @@ abstract contract BaseImplementation is IHooks, SafeCallback { function initializeMiddleware(address _middleware) external { if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); + console.log("initializeMiddleware"); middleware = _middleware; } diff --git a/contracts/middleware/MiddlewareProtectFactory.sol b/contracts/middleware/MiddlewareProtectFactory.sol index 7e7e225b..db9f76dd 100644 --- a/contracts/middleware/MiddlewareProtectFactory.sol +++ b/contracts/middleware/MiddlewareProtectFactory.sol @@ -5,6 +5,7 @@ import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; import {MiddlewareProtect} from "./MiddlewareProtect.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {IBaseHook} from "../interfaces/IBaseHook.sol"; +import {IBaseImplementation} from "../interfaces/IBaseImplementation.sol"; contract MiddlewareProtectFactory is IMiddlewareFactory { mapping(address => address) private _implementations; @@ -21,6 +22,7 @@ contract MiddlewareProtectFactory is IMiddlewareFactory { function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { middleware = address(new MiddlewareProtect{salt: salt}(poolManager, IBaseHook(implementation))); + IBaseImplementation(implementation).initializeMiddleware(middleware); _implementations[middleware] = implementation; emit MiddlewareCreated(implementation, middleware); } diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol index 52fed845..7bbcecdd 100644 --- a/test/MiddlewareProtectFactory.t.sol +++ b/test/MiddlewareProtectFactory.t.sol @@ -4,7 +4,6 @@ pragma solidity ^0.8.19; import {Test} from "forge-std/Test.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {HooksFrontrun} from "./middleware/HooksFrontrun.sol"; -import {HooksFrontrunImplementation} from "./shared/implementation/HooksFrontrunImplementation.sol"; import {MiddlewareProtect} from "../contracts/middleware/MiddlewareProtect.sol"; import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; @@ -48,9 +47,11 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { token0 = TestERC20(Currency.unwrap(currency0)); token1 = TestERC20(Currency.unwrap(currency1)); + factory = new MiddlewareProtectFactory(manager); + hooksFrontrun = HooksFrontrun(address(uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG))); vm.record(); - HooksFrontrunImplementation impl = new HooksFrontrunImplementation(manager, hooksFrontrun); + HooksFrontrun impl = new HooksFrontrun(manager, address(factory)); (, bytes32[] memory writes) = vm.accesses(address(impl)); vm.etch(address(hooksFrontrun), address(impl).code); unchecked { @@ -62,8 +63,6 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { token0.approve(address(router), type(uint256).max); token1.approve(address(router), type(uint256).max); - - factory = new MiddlewareProtectFactory(manager); } function testFrontrun() public { diff --git a/test/middleware/HooksFrontrun.sol b/test/middleware/HooksFrontrun.sol index c0d23a8f..6ecf000b 100644 --- a/test/middleware/HooksFrontrun.sol +++ b/test/middleware/HooksFrontrun.sol @@ -11,8 +11,9 @@ import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; import {console} from "forge-std/console.sol"; import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; +import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; -contract HooksFrontrun is BaseHook { +contract HooksFrontrun is BaseImplementation { using SafeCast for uint256; bytes internal constant ZERO_BYTES = bytes(""); @@ -22,7 +23,7 @@ contract HooksFrontrun is BaseHook { BalanceDelta swapDelta; IPoolManager.SwapParams swapParams; - constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { return Hooks.Permissions({ @@ -46,7 +47,7 @@ contract HooksFrontrun is BaseHook { function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) external override - onlyByManager + onlyByMiddleware returns (bytes4, BeforeSwapDelta, uint24) { swapParams = params; diff --git a/test/shared/implementation/HooksFrontrunImplementation.sol b/test/shared/implementation/HooksFrontrunImplementation.sol deleted file mode 100644 index 104a07b7..00000000 --- a/test/shared/implementation/HooksFrontrunImplementation.sol +++ /dev/null @@ -1,16 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.19; - -import {BaseHook} from "../../../contracts/BaseHook.sol"; -import {HooksFrontrun} from "../../middleware/HooksFrontrun.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; - -contract HooksFrontrunImplementation is HooksFrontrun { - constructor(IPoolManager _poolManager, HooksFrontrun addressToEtch) HooksFrontrun(_poolManager) { - Hooks.validateHookPermissions(addressToEtch, getHookPermissions()); - } - - // make this a no-op in testing - function validateHookAddress(BaseHook _this) internal pure override {} -} From f101078296ee2a2bb50cd25cb9cc5f44ad5d3173 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Mon, 8 Jul 2024 14:35:21 -0400 Subject: [PATCH 4/8] prevent double reinitialize BaseImplementation --- contracts/middleware/BaseImplementation.sol | 4 ++-- test/MiddlewareProtectFactory.t.sol | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol index 5a38569f..57407a24 100644 --- a/contracts/middleware/BaseImplementation.sol +++ b/contracts/middleware/BaseImplementation.sol @@ -9,7 +9,6 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; import {SafeCallback} from "./../base/SafeCallback.sol"; import {ImmutableState} from "./../base/ImmutableState.sol"; -import {console} from "forge-std/console.sol"; abstract contract BaseImplementation is IHooks, SafeCallback { error NotSelf(); @@ -18,6 +17,7 @@ abstract contract BaseImplementation is IHooks, SafeCallback { error HookNotImplemented(); error NotMiddleware(); error NotMiddlewareFactory(); + error AlreadyInitialized(); address public immutable middlewareFactory; address public middleware; @@ -45,7 +45,7 @@ abstract contract BaseImplementation is IHooks, SafeCallback { function initializeMiddleware(address _middleware) external { if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); - console.log("initializeMiddleware"); + if (middleware != address(0)) revert AlreadyInitialized(); middleware = _middleware; } diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol index 7bbcecdd..aaf02230 100644 --- a/test/MiddlewareProtectFactory.t.sol +++ b/test/MiddlewareProtectFactory.t.sol @@ -81,7 +81,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { assertTrue(manager.balanceOf(address(hooksFrontrun), CurrencyLibrary.toId(key.currency0)) > 0); } - function testVariousProtectFactory() public { + function testRevertOnFrontrun() public { uint160 flags = uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG); (address hookAddress, bytes32 salt) = HookMiner.find( @@ -90,19 +90,16 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { type(MiddlewareProtect).creationCode, abi.encode(address(manager), address(hooksFrontrun)) ); - testOn(address(hooksFrontrun), salt); - } - - // creates a middleware on an implementation - function testOn(address implementation, bytes32 salt) internal { - address hookAddress = factory.createMiddleware(implementation, salt); + address implementation = address(hooksFrontrun); + address hookAddressCreated = factory.createMiddleware(implementation, salt); + assertEq(hookAddressCreated, hookAddress); MiddlewareProtect middlewareProtect = MiddlewareProtect(payable(hookAddress)); (key,) = initPoolAndAddLiquidity( currency0, currency1, IHooks(address(middlewareProtect)), 100, SQRT_PRICE_1_1, ZERO_BYTES ); + vm.expectRevert(MiddlewareProtect.ActionBetweenHook.selector); swap(key, true, 0.001 ether, ZERO_BYTES); - //vm.expectRevert(); } function abs(int256 x) internal pure returns (uint256) { From c55572409e722c6bbc863c5475baa2c255ddd277 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Mon, 8 Jul 2024 18:44:32 -0400 Subject: [PATCH 5/8] patch things up --- contracts/hooks/examples/FeeTaker.sol | 104 -------------- contracts/hooks/examples/FeeTaking.sol | 35 ----- contracts/interfaces/IBaseImplementation.sol | 4 +- contracts/middleware/BaseImplementation.sol | 2 - contracts/middleware/BaseMiddleware.sol | 2 +- .../middleware/BaseMiddlewareFactory.sol | 41 ++++++ contracts/middleware/Forwarder.sol | 44 ------ contracts/middleware/MiddlewareProtect.sol | 15 ++- .../middleware/MiddlewareProtectFactory.sol | 24 +--- contracts/middleware/MiddlewareRemove.sol | 42 ++++++ .../middleware/MiddlewareRemoveFactory.sol | 15 +++ test/MiddlewareProtectFactory.t.sol | 41 ++++++ test/MiddlewareRemoveFactory.t.sol | 127 ++++++++++++++++++ test/middleware/HooksDoNothing.sol | 125 +++++++++++++++++ test/middleware/HooksFrontrun.sol | 23 +--- test/middleware/HooksOutOfGas.sol | 78 ++++------- test/middleware/HooksReturnDeltas.sol | 51 +++++++ test/middleware/HooksRevert.sol | 83 ++++-------- 18 files changed, 517 insertions(+), 339 deletions(-) delete mode 100644 contracts/hooks/examples/FeeTaker.sol delete mode 100644 contracts/hooks/examples/FeeTaking.sol create mode 100644 contracts/middleware/BaseMiddlewareFactory.sol delete mode 100644 contracts/middleware/Forwarder.sol create mode 100644 contracts/middleware/MiddlewareRemove.sol create mode 100644 contracts/middleware/MiddlewareRemoveFactory.sol create mode 100644 test/MiddlewareRemoveFactory.t.sol create mode 100644 test/middleware/HooksDoNothing.sol create mode 100644 test/middleware/HooksReturnDeltas.sol diff --git a/contracts/hooks/examples/FeeTaker.sol b/contracts/hooks/examples/FeeTaker.sol deleted file mode 100644 index 91d1fb7d..00000000 --- a/contracts/hooks/examples/FeeTaker.sol +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {BaseHook} from "../../BaseHook.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; -import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; - -abstract contract FeeTaker is BaseHook { - using SafeCast for uint256; - - bytes internal constant ZERO_BYTES = bytes(""); - - constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} - - /** - * @notice This hook takes a fee from the unspecified token after a swap. - * @dev This can be overridden if more permissions are needed. - */ - function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { - return Hooks.Permissions({ - beforeInitialize: false, - afterInitialize: false, - beforeAddLiquidity: false, - afterAddLiquidity: false, - beforeRemoveLiquidity: false, - afterRemoveLiquidity: false, - beforeSwap: false, - afterSwap: true, - beforeDonate: false, - afterDonate: false, - beforeSwapReturnDelta: false, - afterSwapReturnDelta: true, - afterAddLiquidityReturnDelta: false, - afterRemoveLiquidityReturnDelta: false - }); - } - - function afterSwap( - address sender, - PoolKey calldata key, - IPoolManager.SwapParams calldata params, - BalanceDelta delta, - bytes calldata hookData - ) external override onlyByManager returns (bytes4, int128) { - //(Currency currencyUnspecified, amountUnspecified) = key.getUnspecified(params); - - // fee will be in the unspecified token of the swap - bool currency0Specified = (params.amountSpecified < 0 == params.zeroForOne); - (Currency currencyUnspecified, int128 amountUnspecified) = - (currency0Specified) ? (key.currency1, delta.amount1()) : (key.currency0, delta.amount0()); - // if exactOutput swap, get the absolute output amount - if (amountUnspecified < 0) amountUnspecified = -amountUnspecified; - - uint256 feeAmount = _feeAmount(amountUnspecified); - // mint ERC6909 instead of take to avoid edge case where PM doesn't have enough balance - manager.mint(address(this), CurrencyLibrary.toId(currencyUnspecified), feeAmount); - - (bytes4 selector, int128 amount) = _afterSwap(sender, key, params, delta, hookData); - return (selector, feeAmount.toInt128() + amount); - } - - function withdraw(Currency[] calldata currencies) external { - manager.unlock(abi.encode(currencies)); - } - - function _unlockCallback(bytes calldata rawData) internal override returns (bytes memory) { - Currency[] memory currencies = abi.decode(rawData, (Currency[])); - uint256 length = currencies.length; - for (uint256 i = 0; i < length;) { - uint256 amount = manager.balanceOf(address(this), CurrencyLibrary.toId(currencies[i])); - manager.burn(address(this), CurrencyLibrary.toId(currencies[i]), amount); - manager.take(currencies[i], _recipient(), amount); - unchecked { - ++i; - } - } - return ZERO_BYTES; - } - - /** - * @dev This is a virtual function that should be overridden so it returns the fee charged for a given amount. - */ - function _feeAmount(int128 amountUnspecified) internal view virtual returns (uint256); - - /** - * @dev This is a virtual function that should be overridden so it returns the address to receive the fee. - */ - function _recipient() internal view virtual returns (address); - - /** - * @dev This can be overridden to add logic after a swap. - */ - function _afterSwap(address, PoolKey memory, IPoolManager.SwapParams memory, BalanceDelta, bytes calldata) - internal - virtual - returns (bytes4, int128) - { - return (BaseHook.afterSwap.selector, 0); - } -} diff --git a/contracts/hooks/examples/FeeTaking.sol b/contracts/hooks/examples/FeeTaking.sol deleted file mode 100644 index c0099149..00000000 --- a/contracts/hooks/examples/FeeTaking.sol +++ /dev/null @@ -1,35 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; -import {Owned} from "solmate/auth/Owned.sol"; -import {FeeTaker} from "./FeeTaker.sol"; - -contract FeeTaking is FeeTaker, Owned { - using SafeCast for uint256; - - uint128 private constant TOTAL_BIPS = 10000; - uint128 public immutable swapFeeBips; - address public treasury; - - constructor(IPoolManager _poolManager, uint128 _swapFeeBips, address _owner, address _treasury) - FeeTaker(_poolManager) - Owned(_owner) - { - swapFeeBips = _swapFeeBips; - treasury = _treasury; - } - - function setTreasury(address _treasury) external onlyOwner { - treasury = _treasury; - } - - function _feeAmount(int128 amountUnspecified) internal view override returns (uint256) { - return uint128(amountUnspecified) * swapFeeBips / TOTAL_BIPS; - } - - function _recipient() internal view override returns (address) { - return treasury; - } -} diff --git a/contracts/interfaces/IBaseImplementation.sol b/contracts/interfaces/IBaseImplementation.sol index f05f9238..bbf73419 100644 --- a/contracts/interfaces/IBaseImplementation.sol +++ b/contracts/interfaces/IBaseImplementation.sol @@ -1,6 +1,8 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; -interface IBaseImplementation { +import {IBaseHook} from "./IBaseHook.sol"; + +interface IBaseImplementation is IBaseHook { function initializeMiddleware(address _middleware) external; } diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol index 57407a24..1dfee1a8 100644 --- a/contracts/middleware/BaseImplementation.sol +++ b/contracts/middleware/BaseImplementation.sol @@ -17,7 +17,6 @@ abstract contract BaseImplementation is IHooks, SafeCallback { error HookNotImplemented(); error NotMiddleware(); error NotMiddlewareFactory(); - error AlreadyInitialized(); address public immutable middlewareFactory; address public middleware; @@ -45,7 +44,6 @@ abstract contract BaseImplementation is IHooks, SafeCallback { function initializeMiddleware(address _middleware) external { if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); - if (middleware != address(0)) revert AlreadyInitialized(); middleware = _middleware; } diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol index 3fcd5ab1..8074aace 100644 --- a/contracts/middleware/BaseMiddleware.sol +++ b/contracts/middleware/BaseMiddleware.sol @@ -10,7 +10,7 @@ import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/Bala import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -abstract contract BaseMiddleware is IHooks { +contract BaseMiddleware is IHooks { error NotManager(); IPoolManager public immutable manager; diff --git a/contracts/middleware/BaseMiddlewareFactory.sol b/contracts/middleware/BaseMiddlewareFactory.sol new file mode 100644 index 00000000..2b549a6a --- /dev/null +++ b/contracts/middleware/BaseMiddlewareFactory.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IBaseHook} from "../interfaces/IBaseHook.sol"; +import {IBaseImplementation} from "../interfaces/IBaseImplementation.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract BaseMiddlewareFactory is IMiddlewareFactory { + error AlreadyInitialized(); + + mapping(address => address) private _implementations; + mapping(address => bool) private _initialized; + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getImplementation(address middleware) external view override returns (address implementation) { + return _implementations[middleware]; + } + + function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { + if (_initialized[implementation]) revert AlreadyInitialized(); + _initialized[implementation] = true; + middleware = _deployMiddleware(implementation, salt); + Hooks.validateHookPermissions(IHooks(middleware), IBaseImplementation(implementation).getHookPermissions()); + IBaseImplementation(implementation).initializeMiddleware(middleware); + _implementations[middleware] = implementation; + emit MiddlewareCreated(implementation, middleware); + } + + function _deployMiddleware(address implementation, bytes32 salt) internal virtual returns (address middleware) { + return address(new BaseMiddleware{salt: salt}(poolManager, IBaseHook(implementation))); + } +} diff --git a/contracts/middleware/Forwarder.sol b/contracts/middleware/Forwarder.sol deleted file mode 100644 index 31afb012..00000000 --- a/contracts/middleware/Forwarder.sol +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -// openzeppelin proxy but call instead of delegatecall - -pragma solidity ^0.8.24; - -/** - * @dev This abstract contract provides a fallback function that forwards all calls to another contract - */ -abstract contract Forwarder { - function _forward(address implementation) internal virtual { - assembly { - // Copy msg.data. We take full control of memory in this inline assembly - // block because it will not return to Solidity code. We overwrite the - // Solidity scratch pad at memory position 0. - calldatacopy(0, 0, calldatasize()) - - // Call the implementation. - // out and outsize are 0 because we don't know the size yet. - let result := call(gas(), implementation, 0, 0, calldatasize(), 0, 0) - - // Copy the returned data. - returndatacopy(0, 0, returndatasize()) - - switch result - // call returns 0 on error. - case 0 { revert(0, returndatasize()) } - default { return(0, returndatasize()) } - } - } - - /** - * @dev This is a virtual function that should be overridden so it returns the address to which the fallback function - * and {_fallback} should delegate. - */ - function _implementation() internal view virtual returns (address); - - /** - * @dev Fallback function that delegates calls to the address returned by `_implementation()`. Will run if no other - * function in the contract matches the call data. - */ - fallback() external payable virtual { - _forward(_implementation()); - } -} diff --git a/contracts/middleware/MiddlewareProtect.sol b/contracts/middleware/MiddlewareProtect.sol index bc39b752..be081b5b 100644 --- a/contracts/middleware/MiddlewareProtect.sol +++ b/contracts/middleware/MiddlewareProtect.sol @@ -8,7 +8,6 @@ import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/Bala import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; import {console} from "../../lib/forge-std/src/console.sol"; import {BaseHook} from "./../BaseHook.sol"; -import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; import {IBaseHook} from "./../interfaces/IBaseHook.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {ReentrancyState} from "./../libraries/ReentrancyState.sol"; @@ -17,7 +16,6 @@ import {CustomRevert} from "@uniswap/v4-core/src/libraries/CustomRevert.sol"; contract MiddlewareProtect is BaseMiddleware { using CustomRevert for bytes4; - using StateLibrary for IPoolManager; using LPFeeLibrary for uint24; /// @notice Thrown if the address will lead to forbidden flags being set @@ -74,11 +72,11 @@ contract MiddlewareProtect is BaseMiddleware { IPoolManager.SwapParams calldata params, bytes calldata hookData ) external override swapNotLocked returns (bytes4, BeforeSwapDelta, uint24) { - ReentrancyState.lockSwapRemove(); - console.log("beforeSwap middleware"); if (msg.sender == address(implementation)) { return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } + ReentrancyState.lockSwapRemove(); + console.log("beforeSwap middleware"); (bytes4 selector, BeforeSwapDelta delta, uint24 fee) = implementation.beforeSwap(sender, key, params, hookData); ReentrancyState.unlock(); return (selector, delta, fee); @@ -93,6 +91,9 @@ contract MiddlewareProtect is BaseMiddleware { IPoolManager.ModifyLiquidityParams calldata params, bytes calldata hookData ) external override returns (bytes4) { + if (msg.sender == address(implementation)) { + return BaseHook.beforeAddLiquidity.selector; + } ReentrancyState.lockSwap(); console.log("beforeAddLiquidity middleware"); bytes4 selector = implementation.beforeAddLiquidity(sender, key, params, hookData); @@ -109,6 +110,9 @@ contract MiddlewareProtect is BaseMiddleware { IPoolManager.ModifyLiquidityParams calldata, bytes calldata ) external override removeNotLocked returns (bytes4) { + if (msg.sender == address(implementation)) { + return BaseHook.beforeRemoveLiquidity.selector; + } ReentrancyState.lockSwap(); console.log("beforeRemoveLiquidity middleware"); address(implementation).call{gas: gasLimit}(msg.data); @@ -124,6 +128,9 @@ contract MiddlewareProtect is BaseMiddleware { BalanceDelta, bytes calldata ) external override returns (bytes4, BalanceDelta) { + if (msg.sender == address(implementation)) { + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } console.log("afterRemoveLiquidity middleware"); address(implementation).call{gas: gasLimit}(msg.data); return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); diff --git a/contracts/middleware/MiddlewareProtectFactory.sol b/contracts/middleware/MiddlewareProtectFactory.sol index db9f76dd..ad86dfaa 100644 --- a/contracts/middleware/MiddlewareProtectFactory.sol +++ b/contracts/middleware/MiddlewareProtectFactory.sol @@ -1,29 +1,15 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; -import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; import {MiddlewareProtect} from "./MiddlewareProtect.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {IBaseHook} from "../interfaces/IBaseHook.sol"; -import {IBaseImplementation} from "../interfaces/IBaseImplementation.sol"; +import {BaseMiddlewareFactory} from "./BaseMiddlewareFactory.sol"; -contract MiddlewareProtectFactory is IMiddlewareFactory { - mapping(address => address) private _implementations; +contract MiddlewareProtectFactory is BaseMiddlewareFactory { + constructor(IPoolManager _poolManager) BaseMiddlewareFactory(_poolManager) {} - IPoolManager public immutable poolManager; - - constructor(IPoolManager _poolManager) { - poolManager = _poolManager; - } - - function getImplementation(address middleware) external view override returns (address implementation) { - return _implementations[middleware]; - } - - function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { - middleware = address(new MiddlewareProtect{salt: salt}(poolManager, IBaseHook(implementation))); - IBaseImplementation(implementation).initializeMiddleware(middleware); - _implementations[middleware] = implementation; - emit MiddlewareCreated(implementation, middleware); + function _deployMiddleware(address implementation, bytes32 salt) internal override returns (address middleware) { + return address(new MiddlewareProtect{salt: salt}(poolManager, IBaseHook(implementation))); } } diff --git a/contracts/middleware/MiddlewareRemove.sol b/contracts/middleware/MiddlewareRemove.sol new file mode 100644 index 00000000..23c63784 --- /dev/null +++ b/contracts/middleware/MiddlewareRemove.sol @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BaseHook} from "./../BaseHook.sol"; +import {IBaseHook} from "./../interfaces/IBaseHook.sol"; +import {console} from "../../lib/forge-std/src/console.sol"; + +contract MiddlewareRemove is BaseMiddleware { + uint256 public constant gasLimit = 1000000; + + constructor(IPoolManager _poolManager, IBaseHook _implementation) BaseMiddleware(_poolManager, _implementation) {} + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external override returns (bytes4) { + console.log("beforeRemoveLiquidity middleware"); + (bool success, bytes memory returnData) = address(implementation).call{gas: gasLimit}(msg.data); + console.log(success); + return BaseHook.beforeRemoveLiquidity.selector; + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external override returns (bytes4, BalanceDelta) { + console.log("afterRemoveLiquidity middleware"); + (bool success, bytes memory returnData) = address(implementation).call{gas: gasLimit}(msg.data); + console.log(success); + // hook cannot return delta + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } +} diff --git a/contracts/middleware/MiddlewareRemoveFactory.sol b/contracts/middleware/MiddlewareRemoveFactory.sol new file mode 100644 index 00000000..5b45fb70 --- /dev/null +++ b/contracts/middleware/MiddlewareRemoveFactory.sol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {MiddlewareRemove} from "./MiddlewareRemove.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IBaseHook} from "../interfaces/IBaseHook.sol"; +import {BaseMiddlewareFactory} from "./BaseMiddlewareFactory.sol"; + +contract MiddlewareRemoveFactory is BaseMiddlewareFactory { + constructor(IPoolManager _poolManager) BaseMiddlewareFactory(_poolManager) {} + + function _deployMiddleware(address implementation, bytes32 salt) internal override returns (address middleware) { + return address(new MiddlewareRemove{salt: salt}(poolManager, IBaseHook(implementation))); + } +} diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol index aaf02230..fabf2600 100644 --- a/test/MiddlewareProtectFactory.t.sol +++ b/test/MiddlewareProtectFactory.t.sol @@ -21,6 +21,8 @@ import {HooksRevert} from "./middleware/HooksRevert.sol"; import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; import {MiddlewareProtectFactory} from "./../contracts/middleware/MiddlewareProtectFactory.sol"; import {HookMiner} from "./utils/HookMiner.sol"; +import {HooksReturnDeltas} from "./middleware/HooksReturnDeltas.sol"; +import {HooksDoNothing} from "./middleware/HooksDoNothing.sol"; contract MiddlewareProtectFactoryTest is Test, Deployers { using PoolIdLibrary for PoolKey; @@ -37,6 +39,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { PoolId id; MiddlewareProtectFactory factory; + HooksDoNothing hooksDoNothing; HooksFrontrun hooksFrontrun; function setUp() public { @@ -48,6 +51,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { token1 = TestERC20(Currency.unwrap(currency1)); factory = new MiddlewareProtectFactory(manager); + hooksDoNothing = new HooksDoNothing(manager, address(factory)); hooksFrontrun = HooksFrontrun(address(uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG))); vm.record(); @@ -65,7 +69,44 @@ contract MiddlewareProtectFactoryTest is Test, Deployers { token1.approve(address(router), type(uint256).max); } + function testRevertOnIncorrectFlags() public { + uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareProtect).creationCode, + abi.encode(address(manager), address(hooksDoNothing)) + ); + address implementation = address(hooksDoNothing); + vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); + factory.createMiddleware(implementation, salt); + } + + function testRevertOnIncorrectFlagsMined() public { + address implementation = address(hooksDoNothing); + vm.expectRevert(); // HookAddressNotValid + factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); + } + + function testRevertOnDeltas() public { + HooksReturnDeltas hooksReturnDeltas = new HooksReturnDeltas(manager, address(factory)); + uint160 flags = uint160(Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareProtect).creationCode, + abi.encode(address(manager), address(hooksReturnDeltas)) + ); + address implementation = address(hooksReturnDeltas); + console.log(hookAddress); + vm.expectRevert(abi.encodePacked(bytes16(MiddlewareProtect.HookPermissionForbidden.selector), hookAddress)); + factory.createMiddleware(implementation, salt); + } + function testFrontrun() public { + return; (PoolKey memory key,) = initPoolAndAddLiquidity(currency0, currency1, IHooks(address(0)), 100, SQRT_PRICE_1_1, ZERO_BYTES); BalanceDelta swapDelta = swap(key, true, 0.001 ether, ZERO_BYTES); diff --git a/test/MiddlewareRemoveFactory.t.sol b/test/MiddlewareRemoveFactory.t.sol new file mode 100644 index 00000000..6440ebf0 --- /dev/null +++ b/test/MiddlewareRemoveFactory.t.sol @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {MiddlewareRemove} from "../contracts/middleware/MiddlewareRemove.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {HooksRevert} from "./middleware/HooksRevert.sol"; +import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; +import {MiddlewareRemoveFactory} from "./../contracts/middleware/MiddlewareRemoveFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; +import {HooksDoNothing} from "./middleware/HooksDoNothing.sol"; + +contract MiddlewareRemoveFactoryTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + PoolId id; + + MiddlewareRemoveFactory factory; + HooksDoNothing hooksDoNothing; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + + factory = new MiddlewareRemoveFactory(manager); + hooksDoNothing = new HooksDoNothing(manager, address(factory)); + } + + function testRevertOnIncorrectFlags() public { + uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(hooksDoNothing)) + ); + address implementation = address(hooksDoNothing); + vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); + factory.createMiddleware(implementation, salt); + } + + function testRevertOnIncorrectFlagsMined() public { + address implementation = address(hooksDoNothing); + vm.expectRevert(); // HookAddressNotValid + factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); + } + + function testVariousFactory() public { + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_ADD_LIQUIDITY_FLAG + | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG + | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(hooksDoNothing)) + ); + _testOn(address(hooksDoNothing), salt); + + HooksRevert hooksRevert = new HooksRevert(manager, address(factory)); + flags = uint160( + Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_SWAP_FLAG + | Hooks.AFTER_SWAP_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + (hookAddress, salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(hooksRevert)) + ); + _testOn(address(hooksRevert), salt); + + HooksOutOfGas hooksOutOfGas = new HooksOutOfGas(manager, address(factory)); + (hookAddress, salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(hooksOutOfGas)) + ); + _testOn(address(hooksOutOfGas), salt); + } + + // creates a middleware on an implementation + function _testOn(address implementation, bytes32 salt) internal { + address hookAddress = factory.createMiddleware(implementation, salt); + MiddlewareRemove middlewareRemove = MiddlewareRemove(payable(hookAddress)); + + (key, id) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + ); + + // removeLiquidity does not fail + removeLiquidity(currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + assertEq(factory.getImplementation(hookAddress), implementation); + } +} diff --git a/test/middleware/HooksDoNothing.sol b/test/middleware/HooksDoNothing.sol new file mode 100644 index 00000000..63b84b62 --- /dev/null +++ b/test/middleware/HooksDoNothing.sol @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; + +contract HooksDoNothing is BaseImplementation { + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} + + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) + external + pure + override + returns (bytes4) + { + return BaseHook.beforeInitialize.selector; + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + pure + override + returns (bytes4) + { + return BaseHook.afterInitialize.selector; + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + pure + override + returns (bytes4) + { + return BaseHook.beforeAddLiquidity.selector; + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external pure override returns (bytes4) { + return BaseHook.beforeRemoveLiquidity.selector; + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external pure override returns (bytes4, BalanceDelta) { + return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external pure override returns (bytes4, BalanceDelta) { + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + pure + override + returns (bytes4, BeforeSwapDelta, uint24) + { + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + pure + override + returns (bytes4, int128) + { + return (BaseHook.afterSwap.selector, 0); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + pure + override + returns (bytes4) + { + return BaseHook.beforeDonate.selector; + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + pure + override + returns (bytes4) + { + return BaseHook.afterDonate.selector; + } +} diff --git a/test/middleware/HooksFrontrun.sol b/test/middleware/HooksFrontrun.sol index 6ecf000b..afee4223 100644 --- a/test/middleware/HooksFrontrun.sol +++ b/test/middleware/HooksFrontrun.sol @@ -6,10 +6,8 @@ import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {console} from "forge-std/console.sol"; import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; @@ -51,22 +49,16 @@ contract HooksFrontrun is BaseImplementation { returns (bytes4, BeforeSwapDelta, uint24) { swapParams = params; - console.log(params.zeroForOne); - console.logInt(params.amountSpecified); swapDelta = manager.swap(key, params, ZERO_BYTES); - console.log("beforeDelta"); - console.logInt(swapDelta.amount0()); - console.logInt(swapDelta.amount1()); return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } - function afterSwap( - address sender, - PoolKey calldata key, - IPoolManager.SwapParams calldata params, - BalanceDelta delta, - bytes calldata hookData - ) external override onlyByManager returns (bytes4, int128) { + function afterSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + override + onlyByManager + returns (bytes4, int128) + { BalanceDelta afterDelta = manager.swap( key, IPoolManager.SwapParams( @@ -89,9 +81,6 @@ contract HooksFrontrun is BaseImplementation { manager.mint(address(this), key.currency1.toId(), uint256(profit)); } } - console.log("afterDelta"); - console.logInt(afterDelta.amount0()); - console.logInt(afterDelta.amount1()); return (BaseHook.afterSwap.selector, 0); } } diff --git a/test/middleware/HooksOutOfGas.sol b/test/middleware/HooksOutOfGas.sol index d7b35a37..354b85bb 100644 --- a/test/middleware/HooksOutOfGas.sol +++ b/test/middleware/HooksOutOfGas.sol @@ -5,27 +5,21 @@ import {BaseHook} from "./../../contracts/BaseHook.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; -import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; -import {Owned} from "solmate/auth/Owned.sol"; -import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; -import {console} from "../../../lib/forge-std/src/console.sol"; +import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -contract HooksOutOfGas { - IPoolManager public immutable poolManager; +contract HooksOutOfGas is BaseImplementation { + uint256 public counter; - constructor(IPoolManager _poolManager) { - poolManager = _poolManager; - } + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} - function getHookPermissions() public pure returns (Hooks.Permissions memory) { + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { return Hooks.Permissions({ - beforeInitialize: true, - afterInitialize: true, - beforeAddLiquidity: true, - afterAddLiquidity: true, + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, beforeRemoveLiquidity: true, afterRemoveLiquidity: true, beforeSwap: true, @@ -39,43 +33,14 @@ contract HooksOutOfGas { }); } - function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { - consumeAllGas(); - } - - function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) - external - virtual - returns (bytes4) - { - consumeAllGas(); - } - - function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) - external - virtual - returns (bytes4) - { - consumeAllGas(); - } - function beforeRemoveLiquidity( address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata - ) external returns (bytes4) { - consumeAllGas(); - } - - function afterAddLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata - ) external virtual returns (bytes4, BalanceDelta) { + ) external override returns (bytes4) { consumeAllGas(); + return BaseHook.beforeRemoveLiquidity.selector; } function afterRemoveLiquidity( @@ -84,45 +49,50 @@ contract HooksOutOfGas { IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, bytes calldata - ) external returns (bytes4, BalanceDelta) { + ) external override returns (bytes4, BalanceDelta) { consumeAllGas(); + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) external - virtual + override returns (bytes4, BeforeSwapDelta, uint24) { consumeAllGas(); + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) external - virtual + override returns (bytes4, int128) { consumeAllGas(); + return (BaseHook.afterSwap.selector, 0); } function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) external - virtual + override returns (bytes4) { consumeAllGas(); + return BaseHook.beforeDonate.selector; } function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) external - virtual + override returns (bytes4) { consumeAllGas(); + return BaseHook.afterDonate.selector; } - function consumeAllGas() internal view { + function consumeAllGas() internal { while (true) { - //console.log(gasleft()); + counter++; // This loop will run indefinitely and consume all available gas. } } diff --git a/test/middleware/HooksReturnDeltas.sol b/test/middleware/HooksReturnDeltas.sol new file mode 100644 index 00000000..df12fa4c --- /dev/null +++ b/test/middleware/HooksReturnDeltas.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; + +contract HooksReturnDeltas is BaseImplementation { + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} + + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: true, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: true, + afterSwapReturnDelta: true, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + pure + override + returns (bytes4, BeforeSwapDelta, uint24) + { + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + pure + override + returns (bytes4, int128) + { + return (BaseHook.afterSwap.selector, 0); + } +} diff --git a/test/middleware/HooksRevert.sol b/test/middleware/HooksRevert.sol index f3f31401..7386a561 100644 --- a/test/middleware/HooksRevert.sol +++ b/test/middleware/HooksRevert.sol @@ -5,29 +5,21 @@ import {BaseHook} from "./../../contracts/BaseHook.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; -import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; -import {Owned} from "solmate/auth/Owned.sol"; -import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; -import {console} from "../../../lib/forge-std/src/console.sol"; -contract HooksRevert { - error HookNotImplemented(); +contract HooksRevert is BaseImplementation { + error AlwaysReverts(); - IPoolManager public immutable poolManager; + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} - constructor(IPoolManager _poolManager) { - poolManager = _poolManager; - } - - function getHookPermissions() public pure returns (Hooks.Permissions memory) { + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { return Hooks.Permissions({ - beforeInitialize: true, - afterInitialize: true, - beforeAddLiquidity: true, - afterAddLiquidity: true, + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, beforeRemoveLiquidity: true, afterRemoveLiquidity: true, beforeSwap: true, @@ -41,43 +33,14 @@ contract HooksRevert { }); } - function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { - revert HookNotImplemented(); - } - - function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) - external - virtual - returns (bytes4) - { - revert HookNotImplemented(); - } - - function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) - external - virtual - returns (bytes4) - { - revert HookNotImplemented(); - } - - function beforeRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - bytes calldata - ) external returns (bytes4) { - revert HookNotImplemented(); - } - function afterAddLiquidity( address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, bytes calldata - ) external virtual returns (bytes4, BalanceDelta) { - revert HookNotImplemented(); + ) external pure override returns (bytes4, BalanceDelta) { + revert AlwaysReverts(); } function afterRemoveLiquidity( @@ -86,40 +49,44 @@ contract HooksRevert { IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, bytes calldata - ) external returns (bytes4, BalanceDelta) { + ) external pure override returns (bytes4, BalanceDelta) { require(sender == address(0), "nobody can remove"); return (BaseHook.beforeRemoveLiquidity.selector, toBalanceDelta(0, 0)); } function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) external - virtual + pure + override returns (bytes4, BeforeSwapDelta, uint24) { - revert HookNotImplemented(); + revert AlwaysReverts(); } function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) external - virtual + pure + override returns (bytes4, int128) { - revert HookNotImplemented(); + revert AlwaysReverts(); } function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) external - virtual + pure + override returns (bytes4) { - revert HookNotImplemented(); + revert AlwaysReverts(); } function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) external - virtual + pure + override returns (bytes4) { - revert HookNotImplemented(); + revert AlwaysReverts(); } } From 27b8d9cadf536238a95c4b3f534fe6be5ee50d19 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Mon, 8 Jul 2024 19:02:59 -0400 Subject: [PATCH 6/8] trim files in branch to only base --- contracts/middleware/MiddlewareProtect.sol | 138 ---------------- .../middleware/MiddlewareProtectFactory.sol | 15 -- contracts/middleware/MiddlewareRemove.sol | 42 ----- .../middleware/MiddlewareRemoveFactory.sol | 15 -- test/BaseMiddlewareFactory.t.sol | 63 ++++++++ test/MiddlewareProtectFactory.t.sol | 149 ------------------ test/MiddlewareRemoveFactory.t.sol | 127 --------------- .../HooksCounter.sol} | 2 +- test/middleware/HooksFrontrun.sol | 86 ---------- test/middleware/HooksOutOfGas.sol | 99 ------------ test/middleware/HooksReturnDeltas.sol | 51 ------ test/middleware/HooksRevert.sol | 92 ----------- 12 files changed, 64 insertions(+), 815 deletions(-) delete mode 100644 contracts/middleware/MiddlewareProtect.sol delete mode 100644 contracts/middleware/MiddlewareProtectFactory.sol delete mode 100644 contracts/middleware/MiddlewareRemove.sol delete mode 100644 contracts/middleware/MiddlewareRemoveFactory.sol create mode 100644 test/BaseMiddlewareFactory.t.sol delete mode 100644 test/MiddlewareProtectFactory.t.sol delete mode 100644 test/MiddlewareRemoveFactory.t.sol rename test/{middleware/HooksDoNothing.sol => middleware-implementations/HooksCounter.sol} (98%) delete mode 100644 test/middleware/HooksFrontrun.sol delete mode 100644 test/middleware/HooksOutOfGas.sol delete mode 100644 test/middleware/HooksReturnDeltas.sol delete mode 100644 test/middleware/HooksRevert.sol diff --git a/contracts/middleware/MiddlewareProtect.sol b/contracts/middleware/MiddlewareProtect.sol deleted file mode 100644 index be081b5b..00000000 --- a/contracts/middleware/MiddlewareProtect.sol +++ /dev/null @@ -1,138 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BaseMiddleware} from "./BaseMiddleware.sol"; -import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {console} from "../../lib/forge-std/src/console.sol"; -import {BaseHook} from "./../BaseHook.sol"; -import {IBaseHook} from "./../interfaces/IBaseHook.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {ReentrancyState} from "./../libraries/ReentrancyState.sol"; -import {LPFeeLibrary} from "@uniswap/v4-core/src/libraries/LPFeeLibrary.sol"; -import {CustomRevert} from "@uniswap/v4-core/src/libraries/CustomRevert.sol"; - -contract MiddlewareProtect is BaseMiddleware { - using CustomRevert for bytes4; - using LPFeeLibrary for uint24; - - /// @notice Thrown if the address will lead to forbidden flags being set - /// @param hooks The address of the hooks contract - error HookPermissionForbidden(address hooks); - error ForbiddenReturn(); - error InvalidFee(); - error ActionBetweenHook(); - - uint256 public constant gasLimit = 1000000; - - constructor(IPoolManager _poolManager, IBaseHook _implementation) BaseMiddleware(_poolManager, _implementation) { - Hooks.Permissions memory permissions = _implementation.getHookPermissions(); - // deny any hooks that return deltas - if ( - permissions.beforeSwapReturnDelta || permissions.afterSwapReturnDelta - || permissions.afterAddLiquidityReturnDelta || permissions.afterRemoveLiquidityReturnDelta - ) { - HookPermissionForbidden.selector.revertWith(address(this)); - } - } - - modifier swapNotLocked() { - if (ReentrancyState.swapLocked()) { - revert ActionBetweenHook(); - } - _; - } - - modifier removeNotLocked() { - if (ReentrancyState.removeLocked()) { - revert ActionBetweenHook(); - } - _; - } - - function beforeInitialize(address sender, PoolKey calldata key, uint160 sqrtPriceX96, bytes calldata hookData) - external - override - onlyByManager - returns (bytes4) - { - if (key.fee.isDynamicFee()) revert InvalidFee(); - if (msg.sender == address(implementation)) { - return BaseHook.beforeInitialize.selector; - } - return implementation.beforeInitialize(sender, key, sqrtPriceX96, hookData); - } - - // block swaps and removes - function beforeSwap( - address sender, - PoolKey calldata key, - IPoolManager.SwapParams calldata params, - bytes calldata hookData - ) external override swapNotLocked returns (bytes4, BeforeSwapDelta, uint24) { - if (msg.sender == address(implementation)) { - return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); - } - ReentrancyState.lockSwapRemove(); - console.log("beforeSwap middleware"); - (bytes4 selector, BeforeSwapDelta delta, uint24 fee) = implementation.beforeSwap(sender, key, params, hookData); - ReentrancyState.unlock(); - return (selector, delta, fee); - } - - // afterSwap - no protections - - // block swaps - function beforeAddLiquidity( - address sender, - PoolKey calldata key, - IPoolManager.ModifyLiquidityParams calldata params, - bytes calldata hookData - ) external override returns (bytes4) { - if (msg.sender == address(implementation)) { - return BaseHook.beforeAddLiquidity.selector; - } - ReentrancyState.lockSwap(); - console.log("beforeAddLiquidity middleware"); - bytes4 selector = implementation.beforeAddLiquidity(sender, key, params, hookData); - ReentrancyState.unlock(); - return selector; - } - - // afterAddLiquidity - no protections - - // block swaps and reverts - function beforeRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - bytes calldata - ) external override removeNotLocked returns (bytes4) { - if (msg.sender == address(implementation)) { - return BaseHook.beforeRemoveLiquidity.selector; - } - ReentrancyState.lockSwap(); - console.log("beforeRemoveLiquidity middleware"); - address(implementation).call{gas: gasLimit}(msg.data); - ReentrancyState.unlock(); - return BaseHook.beforeRemoveLiquidity.selector; - } - - // block reverts - function afterRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata - ) external override returns (bytes4, BalanceDelta) { - if (msg.sender == address(implementation)) { - return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); - } - console.log("afterRemoveLiquidity middleware"); - address(implementation).call{gas: gasLimit}(msg.data); - return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); - } -} diff --git a/contracts/middleware/MiddlewareProtectFactory.sol b/contracts/middleware/MiddlewareProtectFactory.sol deleted file mode 100644 index ad86dfaa..00000000 --- a/contracts/middleware/MiddlewareProtectFactory.sol +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.19; - -import {MiddlewareProtect} from "./MiddlewareProtect.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {IBaseHook} from "../interfaces/IBaseHook.sol"; -import {BaseMiddlewareFactory} from "./BaseMiddlewareFactory.sol"; - -contract MiddlewareProtectFactory is BaseMiddlewareFactory { - constructor(IPoolManager _poolManager) BaseMiddlewareFactory(_poolManager) {} - - function _deployMiddleware(address implementation, bytes32 salt) internal override returns (address middleware) { - return address(new MiddlewareProtect{salt: salt}(poolManager, IBaseHook(implementation))); - } -} diff --git a/contracts/middleware/MiddlewareRemove.sol b/contracts/middleware/MiddlewareRemove.sol deleted file mode 100644 index 23c63784..00000000 --- a/contracts/middleware/MiddlewareRemove.sol +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BaseMiddleware} from "./BaseMiddleware.sol"; -import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {BaseHook} from "./../BaseHook.sol"; -import {IBaseHook} from "./../interfaces/IBaseHook.sol"; -import {console} from "../../lib/forge-std/src/console.sol"; - -contract MiddlewareRemove is BaseMiddleware { - uint256 public constant gasLimit = 1000000; - - constructor(IPoolManager _poolManager, IBaseHook _implementation) BaseMiddleware(_poolManager, _implementation) {} - - function beforeRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - bytes calldata - ) external override returns (bytes4) { - console.log("beforeRemoveLiquidity middleware"); - (bool success, bytes memory returnData) = address(implementation).call{gas: gasLimit}(msg.data); - console.log(success); - return BaseHook.beforeRemoveLiquidity.selector; - } - - function afterRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata - ) external override returns (bytes4, BalanceDelta) { - console.log("afterRemoveLiquidity middleware"); - (bool success, bytes memory returnData) = address(implementation).call{gas: gasLimit}(msg.data); - console.log(success); - // hook cannot return delta - return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); - } -} diff --git a/contracts/middleware/MiddlewareRemoveFactory.sol b/contracts/middleware/MiddlewareRemoveFactory.sol deleted file mode 100644 index 5b45fb70..00000000 --- a/contracts/middleware/MiddlewareRemoveFactory.sol +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.19; - -import {MiddlewareRemove} from "./MiddlewareRemove.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {IBaseHook} from "../interfaces/IBaseHook.sol"; -import {BaseMiddlewareFactory} from "./BaseMiddlewareFactory.sol"; - -contract MiddlewareRemoveFactory is BaseMiddlewareFactory { - constructor(IPoolManager _poolManager) BaseMiddlewareFactory(_poolManager) {} - - function _deployMiddleware(address implementation, bytes32 salt) internal override returns (address middleware) { - return address(new MiddlewareRemove{salt: salt}(poolManager, IBaseHook(implementation))); - } -} diff --git a/test/BaseMiddlewareFactory.t.sol b/test/BaseMiddlewareFactory.t.sol new file mode 100644 index 00000000..bef9110b --- /dev/null +++ b/test/BaseMiddlewareFactory.t.sol @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol"; +import {BaseMiddlewareFactory} from "./../contracts/middleware/BaseMiddlewareFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; +import {HooksCounter} from "./middleware-implementations/HooksCounter.sol"; + +contract BaseMiddlewareFactoryTest is Test, Deployers { + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + PoolId id; + + BaseMiddlewareFactory factory; + HooksCounter hooksCounter; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + factory = new BaseMiddlewareFactory(manager); + hooksCounter = new HooksCounter(manager, address(factory)); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + } + + function testRevertOnIncorrectFlags() public { + uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + address implementation = address(hooksCounter); + vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); + factory.createMiddleware(implementation, salt); + } + + function testRevertOnIncorrectFlagsMined() public { + address implementation = address(hooksCounter); + vm.expectRevert(); // HookAddressNotValid + factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); + } +} diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol deleted file mode 100644 index fabf2600..00000000 --- a/test/MiddlewareProtectFactory.t.sol +++ /dev/null @@ -1,149 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.19; - -import {Test} from "forge-std/Test.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {HooksFrontrun} from "./middleware/HooksFrontrun.sol"; -import {MiddlewareProtect} from "../contracts/middleware/MiddlewareProtect.sol"; -import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; -import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; -import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; -import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; -import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; -import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; -import {console} from "../../../lib/forge-std/src/console.sol"; -import {HooksRevert} from "./middleware/HooksRevert.sol"; -import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; -import {MiddlewareProtectFactory} from "./../contracts/middleware/MiddlewareProtectFactory.sol"; -import {HookMiner} from "./utils/HookMiner.sol"; -import {HooksReturnDeltas} from "./middleware/HooksReturnDeltas.sol"; -import {HooksDoNothing} from "./middleware/HooksDoNothing.sol"; - -contract MiddlewareProtectFactoryTest is Test, Deployers { - using PoolIdLibrary for PoolKey; - using StateLibrary for IPoolManager; - - uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; - - address constant TREASURY = address(0x1234567890123456789012345678901234567890); - uint128 private constant TOTAL_BIPS = 10000; - - HookEnabledSwapRouter router; - TestERC20 token0; - TestERC20 token1; - PoolId id; - - MiddlewareProtectFactory factory; - HooksDoNothing hooksDoNothing; - HooksFrontrun hooksFrontrun; - - function setUp() public { - deployFreshManagerAndRouters(); - (currency0, currency1) = deployMintAndApprove2Currencies(); - - router = new HookEnabledSwapRouter(manager); - token0 = TestERC20(Currency.unwrap(currency0)); - token1 = TestERC20(Currency.unwrap(currency1)); - - factory = new MiddlewareProtectFactory(manager); - hooksDoNothing = new HooksDoNothing(manager, address(factory)); - - hooksFrontrun = HooksFrontrun(address(uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG))); - vm.record(); - HooksFrontrun impl = new HooksFrontrun(manager, address(factory)); - (, bytes32[] memory writes) = vm.accesses(address(impl)); - vm.etch(address(hooksFrontrun), address(impl).code); - unchecked { - for (uint256 i = 0; i < writes.length; i++) { - bytes32 slot = writes[i]; - vm.store(address(hooksFrontrun), slot, vm.load(address(impl), slot)); - } - } - - token0.approve(address(router), type(uint256).max); - token1.approve(address(router), type(uint256).max); - } - - function testRevertOnIncorrectFlags() public { - uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); - - (address hookAddress, bytes32 salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(hooksDoNothing)) - ); - address implementation = address(hooksDoNothing); - vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); - factory.createMiddleware(implementation, salt); - } - - function testRevertOnIncorrectFlagsMined() public { - address implementation = address(hooksDoNothing); - vm.expectRevert(); // HookAddressNotValid - factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); - } - - function testRevertOnDeltas() public { - HooksReturnDeltas hooksReturnDeltas = new HooksReturnDeltas(manager, address(factory)); - uint160 flags = uint160(Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG); - - (address hookAddress, bytes32 salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(hooksReturnDeltas)) - ); - address implementation = address(hooksReturnDeltas); - console.log(hookAddress); - vm.expectRevert(abi.encodePacked(bytes16(MiddlewareProtect.HookPermissionForbidden.selector), hookAddress)); - factory.createMiddleware(implementation, salt); - } - - function testFrontrun() public { - return; - (PoolKey memory key,) = - initPoolAndAddLiquidity(currency0, currency1, IHooks(address(0)), 100, SQRT_PRICE_1_1, ZERO_BYTES); - BalanceDelta swapDelta = swap(key, true, 0.001 ether, ZERO_BYTES); - - (key,) = initPoolAndAddLiquidity( - currency0, currency1, IHooks(address(hooksFrontrun)), 100, SQRT_PRICE_1_1, ZERO_BYTES - ); - BalanceDelta swapDelta2 = swap(key, true, 0.001 ether, ZERO_BYTES); - - // while both swaps are in the same pool, the second swap is more expensive - assertEq(swapDelta.amount1(), swapDelta2.amount1()); - assertTrue(abs(swapDelta.amount0()) < abs(swapDelta2.amount0())); - assertTrue(manager.balanceOf(address(hooksFrontrun), CurrencyLibrary.toId(key.currency0)) > 0); - } - - function testRevertOnFrontrun() public { - uint160 flags = uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG); - - (address hookAddress, bytes32 salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(hooksFrontrun)) - ); - address implementation = address(hooksFrontrun); - address hookAddressCreated = factory.createMiddleware(implementation, salt); - assertEq(hookAddressCreated, hookAddress); - MiddlewareProtect middlewareProtect = MiddlewareProtect(payable(hookAddress)); - - (key,) = initPoolAndAddLiquidity( - currency0, currency1, IHooks(address(middlewareProtect)), 100, SQRT_PRICE_1_1, ZERO_BYTES - ); - vm.expectRevert(MiddlewareProtect.ActionBetweenHook.selector); - swap(key, true, 0.001 ether, ZERO_BYTES); - } - - function abs(int256 x) internal pure returns (uint256) { - return x >= 0 ? uint256(x) : uint256(-x); - } -} diff --git a/test/MiddlewareRemoveFactory.t.sol b/test/MiddlewareRemoveFactory.t.sol deleted file mode 100644 index 6440ebf0..00000000 --- a/test/MiddlewareRemoveFactory.t.sol +++ /dev/null @@ -1,127 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.19; - -import {Test} from "forge-std/Test.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {MiddlewareRemove} from "../contracts/middleware/MiddlewareRemove.sol"; -import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; -import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; -import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; -import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; -import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; -import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; -import {console} from "../../../lib/forge-std/src/console.sol"; -import {HooksRevert} from "./middleware/HooksRevert.sol"; -import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; -import {MiddlewareRemoveFactory} from "./../contracts/middleware/MiddlewareRemoveFactory.sol"; -import {HookMiner} from "./utils/HookMiner.sol"; -import {HooksDoNothing} from "./middleware/HooksDoNothing.sol"; - -contract MiddlewareRemoveFactoryTest is Test, Deployers { - using PoolIdLibrary for PoolKey; - using StateLibrary for IPoolManager; - - uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; - - address constant TREASURY = address(0x1234567890123456789012345678901234567890); - uint128 private constant TOTAL_BIPS = 10000; - - HookEnabledSwapRouter router; - TestERC20 token0; - TestERC20 token1; - PoolId id; - - MiddlewareRemoveFactory factory; - HooksDoNothing hooksDoNothing; - - function setUp() public { - deployFreshManagerAndRouters(); - (currency0, currency1) = deployMintAndApprove2Currencies(); - - router = new HookEnabledSwapRouter(manager); - token0 = TestERC20(Currency.unwrap(currency0)); - token1 = TestERC20(Currency.unwrap(currency1)); - - token0.approve(address(router), type(uint256).max); - token1.approve(address(router), type(uint256).max); - - factory = new MiddlewareRemoveFactory(manager); - hooksDoNothing = new HooksDoNothing(manager, address(factory)); - } - - function testRevertOnIncorrectFlags() public { - uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); - - (address hookAddress, bytes32 salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareRemove).creationCode, - abi.encode(address(manager), address(hooksDoNothing)) - ); - address implementation = address(hooksDoNothing); - vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); - factory.createMiddleware(implementation, salt); - } - - function testRevertOnIncorrectFlagsMined() public { - address implementation = address(hooksDoNothing); - vm.expectRevert(); // HookAddressNotValid - factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); - } - - function testVariousFactory() public { - uint160 flags = uint160( - Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_ADD_LIQUIDITY_FLAG - | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG - | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG - ); - (address hookAddress, bytes32 salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareRemove).creationCode, - abi.encode(address(manager), address(hooksDoNothing)) - ); - _testOn(address(hooksDoNothing), salt); - - HooksRevert hooksRevert = new HooksRevert(manager, address(factory)); - flags = uint160( - Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_SWAP_FLAG - | Hooks.AFTER_SWAP_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG - ); - (hookAddress, salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareRemove).creationCode, - abi.encode(address(manager), address(hooksRevert)) - ); - _testOn(address(hooksRevert), salt); - - HooksOutOfGas hooksOutOfGas = new HooksOutOfGas(manager, address(factory)); - (hookAddress, salt) = HookMiner.find( - address(factory), - flags, - type(MiddlewareRemove).creationCode, - abi.encode(address(manager), address(hooksOutOfGas)) - ); - _testOn(address(hooksOutOfGas), salt); - } - - // creates a middleware on an implementation - function _testOn(address implementation, bytes32 salt) internal { - address hookAddress = factory.createMiddleware(implementation, salt); - MiddlewareRemove middlewareRemove = MiddlewareRemove(payable(hookAddress)); - - (key, id) = initPoolAndAddLiquidity( - currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES - ); - - // removeLiquidity does not fail - removeLiquidity(currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES); - assertEq(factory.getImplementation(hookAddress), implementation); - } -} diff --git a/test/middleware/HooksDoNothing.sol b/test/middleware-implementations/HooksCounter.sol similarity index 98% rename from test/middleware/HooksDoNothing.sol rename to test/middleware-implementations/HooksCounter.sol index 63b84b62..fa6d5e5b 100644 --- a/test/middleware/HooksDoNothing.sol +++ b/test/middleware-implementations/HooksCounter.sol @@ -9,7 +9,7 @@ import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/type import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -contract HooksDoNothing is BaseImplementation { +contract HooksCounter is BaseImplementation { constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} function getHookPermissions() public pure override returns (Hooks.Permissions memory) { diff --git a/test/middleware/HooksFrontrun.sol b/test/middleware/HooksFrontrun.sol deleted file mode 100644 index afee4223..00000000 --- a/test/middleware/HooksFrontrun.sol +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {BaseHook} from "./../../contracts/BaseHook.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; -import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; -import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; -import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; - -contract HooksFrontrun is BaseImplementation { - using SafeCast for uint256; - - bytes internal constant ZERO_BYTES = bytes(""); - uint160 public constant MIN_PRICE_LIMIT = TickMath.MIN_SQRT_PRICE + 1; - uint160 public constant MAX_PRICE_LIMIT = TickMath.MAX_SQRT_PRICE - 1; - - BalanceDelta swapDelta; - IPoolManager.SwapParams swapParams; - - constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} - - function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { - return Hooks.Permissions({ - beforeInitialize: false, - afterInitialize: false, - beforeAddLiquidity: false, - afterAddLiquidity: false, - beforeRemoveLiquidity: false, - afterRemoveLiquidity: false, - beforeSwap: true, - afterSwap: true, - beforeDonate: false, - afterDonate: false, - beforeSwapReturnDelta: false, - afterSwapReturnDelta: false, - afterAddLiquidityReturnDelta: false, - afterRemoveLiquidityReturnDelta: false - }); - } - - function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) - external - override - onlyByMiddleware - returns (bytes4, BeforeSwapDelta, uint24) - { - swapParams = params; - swapDelta = manager.swap(key, params, ZERO_BYTES); - return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); - } - - function afterSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) - external - override - onlyByManager - returns (bytes4, int128) - { - BalanceDelta afterDelta = manager.swap( - key, - IPoolManager.SwapParams( - !swapParams.zeroForOne, - -swapParams.amountSpecified, - swapParams.zeroForOne ? MAX_PRICE_LIMIT : MIN_PRICE_LIMIT - ), - ZERO_BYTES - ); - if (swapParams.zeroForOne) { - int256 profit = afterDelta.amount0() + swapDelta.amount0(); - if (profit > 0) { - // else hook reverts - manager.mint(address(this), key.currency0.toId(), uint256(profit)); - } - } else { - int256 profit = afterDelta.amount1() + swapDelta.amount1(); - if (profit > 0) { - // else hook reverts - manager.mint(address(this), key.currency1.toId(), uint256(profit)); - } - } - return (BaseHook.afterSwap.selector, 0); - } -} diff --git a/test/middleware/HooksOutOfGas.sol b/test/middleware/HooksOutOfGas.sol deleted file mode 100644 index 354b85bb..00000000 --- a/test/middleware/HooksOutOfGas.sol +++ /dev/null @@ -1,99 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {BaseHook} from "./../../contracts/BaseHook.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; -import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; - -contract HooksOutOfGas is BaseImplementation { - uint256 public counter; - - constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} - - function getHookPermissions() public pure override returns (Hooks.Permissions memory) { - return Hooks.Permissions({ - beforeInitialize: false, - afterInitialize: false, - beforeAddLiquidity: false, - afterAddLiquidity: false, - beforeRemoveLiquidity: true, - afterRemoveLiquidity: true, - beforeSwap: true, - afterSwap: true, - beforeDonate: true, - afterDonate: true, - beforeSwapReturnDelta: false, - afterSwapReturnDelta: false, - afterAddLiquidityReturnDelta: false, - afterRemoveLiquidityReturnDelta: false - }); - } - - function beforeRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - bytes calldata - ) external override returns (bytes4) { - consumeAllGas(); - return BaseHook.beforeRemoveLiquidity.selector; - } - - function afterRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata - ) external override returns (bytes4, BalanceDelta) { - consumeAllGas(); - return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); - } - - function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) - external - override - returns (bytes4, BeforeSwapDelta, uint24) - { - consumeAllGas(); - return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); - } - - function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) - external - override - returns (bytes4, int128) - { - consumeAllGas(); - return (BaseHook.afterSwap.selector, 0); - } - - function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) - external - override - returns (bytes4) - { - consumeAllGas(); - return BaseHook.beforeDonate.selector; - } - - function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) - external - override - returns (bytes4) - { - consumeAllGas(); - return BaseHook.afterDonate.selector; - } - - function consumeAllGas() internal { - while (true) { - counter++; - // This loop will run indefinitely and consume all available gas. - } - } -} diff --git a/test/middleware/HooksReturnDeltas.sol b/test/middleware/HooksReturnDeltas.sol deleted file mode 100644 index df12fa4c..00000000 --- a/test/middleware/HooksReturnDeltas.sol +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {BaseHook} from "./../../contracts/BaseHook.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; -import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; - -contract HooksReturnDeltas is BaseImplementation { - constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} - - function getHookPermissions() public pure override returns (Hooks.Permissions memory) { - return Hooks.Permissions({ - beforeInitialize: false, - afterInitialize: false, - beforeAddLiquidity: false, - afterAddLiquidity: false, - beforeRemoveLiquidity: false, - afterRemoveLiquidity: false, - beforeSwap: true, - afterSwap: true, - beforeDonate: false, - afterDonate: false, - beforeSwapReturnDelta: true, - afterSwapReturnDelta: true, - afterAddLiquidityReturnDelta: false, - afterRemoveLiquidityReturnDelta: false - }); - } - - function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) - external - pure - override - returns (bytes4, BeforeSwapDelta, uint24) - { - return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); - } - - function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) - external - pure - override - returns (bytes4, int128) - { - return (BaseHook.afterSwap.selector, 0); - } -} diff --git a/test/middleware/HooksRevert.sol b/test/middleware/HooksRevert.sol deleted file mode 100644 index 7386a561..00000000 --- a/test/middleware/HooksRevert.sol +++ /dev/null @@ -1,92 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; - -import {BaseHook} from "./../../contracts/BaseHook.sol"; -import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; -import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; -import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; - -contract HooksRevert is BaseImplementation { - error AlwaysReverts(); - - constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} - - function getHookPermissions() public pure override returns (Hooks.Permissions memory) { - return Hooks.Permissions({ - beforeInitialize: false, - afterInitialize: false, - beforeAddLiquidity: false, - afterAddLiquidity: false, - beforeRemoveLiquidity: true, - afterRemoveLiquidity: true, - beforeSwap: true, - afterSwap: true, - beforeDonate: true, - afterDonate: true, - beforeSwapReturnDelta: false, - afterSwapReturnDelta: false, - afterAddLiquidityReturnDelta: false, - afterRemoveLiquidityReturnDelta: false - }); - } - - function afterAddLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata - ) external pure override returns (bytes4, BalanceDelta) { - revert AlwaysReverts(); - } - - function afterRemoveLiquidity( - address sender, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata - ) external pure override returns (bytes4, BalanceDelta) { - require(sender == address(0), "nobody can remove"); - return (BaseHook.beforeRemoveLiquidity.selector, toBalanceDelta(0, 0)); - } - - function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) - external - pure - override - returns (bytes4, BeforeSwapDelta, uint24) - { - revert AlwaysReverts(); - } - - function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) - external - pure - override - returns (bytes4, int128) - { - revert AlwaysReverts(); - } - - function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) - external - pure - override - returns (bytes4) - { - revert AlwaysReverts(); - } - - function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) - external - pure - override - returns (bytes4) - { - revert AlwaysReverts(); - } -} From 5bc319e7562490f5ae460735b50ed23f4300eb69 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Mon, 8 Jul 2024 19:05:29 -0400 Subject: [PATCH 7/8] remove ReentrancyState --- contracts/libraries/ReentrancyState.sol | 39 ------------------------- 1 file changed, 39 deletions(-) delete mode 100644 contracts/libraries/ReentrancyState.sol diff --git a/contracts/libraries/ReentrancyState.sol b/contracts/libraries/ReentrancyState.sol deleted file mode 100644 index 966d1269..00000000 --- a/contracts/libraries/ReentrancyState.sol +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.20; - -library ReentrancyState { - // bytes32(uint256(keccak256("ReentrancyState")) - 1) - bytes32 constant REENTRANCY_STATE_SLOT = 0xbedc9a60a226d4ae7b727cbc828f66c94c4eead57777428ceab2f04b0efca3a5; - - function unlock() internal { - assembly { - tstore(REENTRANCY_STATE_SLOT, 0) - } - } - - function lockSwap() internal { - assembly { - tstore(REENTRANCY_STATE_SLOT, 1) - } - } - - function lockSwapRemove() internal { - assembly { - tstore(REENTRANCY_STATE_SLOT, 2) - } - } - - function read() internal view returns (uint256 state) { - assembly { - state := tload(REENTRANCY_STATE_SLOT) - } - } - - function swapLocked() internal view returns (bool) { - return read() == 1 || read() == 2; - } - - function removeLocked() internal view returns (bool) { - return read() == 2; - } -} From 1c250a1d20a6682af549176441e988d91ceb6e58 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:37:58 -0400 Subject: [PATCH 8/8] last things --- contracts/middleware/BaseImplementation.sol | 5 + contracts/middleware/BaseMiddleware.sol | 11 +- test/BaseMiddlewareFactory.t.sol | 70 ++++++++++- .../HooksCounter.sol | 112 ++++++++++++------ 4 files changed, 156 insertions(+), 42 deletions(-) diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol index 1dfee1a8..3f5b034f 100644 --- a/contracts/middleware/BaseImplementation.sol +++ b/contracts/middleware/BaseImplementation.sol @@ -47,6 +47,11 @@ abstract contract BaseImplementation is IHooks, SafeCallback { middleware = _middleware; } + function updateDynamicFee(PoolKey calldata key, uint24 fee) external { + if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); + manager.updateDynamicLPFee(key, fee); + } + function getHookPermissions() public pure virtual returns (Hooks.Permissions memory); function _unlockCallback(bytes calldata data) internal virtual override returns (bytes memory) { diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol index 8074aace..1dc064f4 100644 --- a/contracts/middleware/BaseMiddleware.sol +++ b/contracts/middleware/BaseMiddleware.sol @@ -11,6 +11,9 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; contract BaseMiddleware is IHooks { + using Hooks for BaseMiddleware; + using BeforeSwapDeltaLibrary for BeforeSwapDelta; + error NotManager(); IPoolManager public immutable manager; @@ -102,11 +105,15 @@ contract BaseMiddleware is IHooks { PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata hookData - ) external virtual onlyByManager returns (bytes4, BeforeSwapDelta, uint24) { + ) external virtual onlyByManager returns (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) { if (msg.sender == address(implementation)) { return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } - return implementation.beforeSwap(sender, key, params, hookData); + (selector, beforeSwapDelta, lpFeeOverride) = implementation.beforeSwap(sender, key, params, hookData); + if (this.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG)) { + manager.take(key.currency0, sender, uint256(uint128(beforeSwapDelta.getSpecifiedDelta()))); + manager.take(key.currency1, sender, uint256(uint128(beforeSwapDelta.getUnspecifiedDelta()))); + } } function afterSwap( diff --git a/test/BaseMiddlewareFactory.t.sol b/test/BaseMiddlewareFactory.t.sol index bef9110b..3652f84c 100644 --- a/test/BaseMiddlewareFactory.t.sol +++ b/test/BaseMiddlewareFactory.t.sol @@ -16,16 +16,18 @@ import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol"; import {BaseMiddlewareFactory} from "./../contracts/middleware/BaseMiddlewareFactory.sol"; import {HookMiner} from "./utils/HookMiner.sol"; import {HooksCounter} from "./middleware-implementations/HooksCounter.sol"; +import {BaseImplementation} from "./../contracts/middleware/BaseImplementation.sol"; contract BaseMiddlewareFactoryTest is Test, Deployers { HookEnabledSwapRouter router; TestERC20 token0; TestERC20 token1; - PoolId id; BaseMiddlewareFactory factory; HooksCounter hooksCounter; + address middleware; + function setUp() public { deployFreshManagerAndRouters(); (currency0, currency1) = deployMintAndApprove2Currencies(); @@ -39,25 +41,85 @@ contract BaseMiddlewareFactoryTest is Test, Deployers { token0.approve(address(router), type(uint256).max); token1.approve(address(router), type(uint256).max); + + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + middleware = factory.createMiddleware(address(hooksCounter), salt); + assertEq(hookAddress, middleware); + } + + function testRevertOnAlreadyInitialized() public { + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + vm.expectRevert(BaseMiddlewareFactory.AlreadyInitialized.selector); + factory.createMiddleware(address(hooksCounter), salt); } function testRevertOnIncorrectFlags() public { + HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory)); uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); (address hookAddress, bytes32 salt) = HookMiner.find( address(factory), flags, type(BaseMiddleware).creationCode, - abi.encode(address(manager), address(hooksCounter)) + abi.encode(address(manager), address(hooksCounter2)) ); - address implementation = address(hooksCounter); + address implementation = address(hooksCounter2); vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); factory.createMiddleware(implementation, salt); } function testRevertOnIncorrectFlagsMined() public { - address implementation = address(hooksCounter); + HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory)); + address implementation = address(hooksCounter2); vm.expectRevert(); // HookAddressNotValid factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); } + + function testRevertOnIncorrectCaller() public { + vm.expectRevert(BaseImplementation.NotMiddleware.selector); + hooksCounter.afterDonate(address(this), key, 0, 0, ZERO_BYTES); + } + + function testCounters() public { + (PoolKey memory key, PoolId id) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + + assertEq(hooksCounter.beforeInitializeCount(id), 1); + assertEq(hooksCounter.afterInitializeCount(id), 1); + assertEq(hooksCounter.beforeSwapCount(id), 0); + assertEq(hooksCounter.afterSwapCount(id), 0); + assertEq(hooksCounter.beforeAddLiquidityCount(id), 1); + assertEq(hooksCounter.afterAddLiquidityCount(id), 1); + assertEq(hooksCounter.beforeRemoveLiquidityCount(id), 0); + assertEq(hooksCounter.afterRemoveLiquidityCount(id), 0); + assertEq(hooksCounter.beforeDonateCount(id), 0); + assertEq(hooksCounter.afterDonateCount(id), 0); + + assertEq(hooksCounter.lastHookData(), ZERO_BYTES); + swap(key, true, 1, bytes("hi")); + assertEq(hooksCounter.lastHookData(), bytes("hi")); + assertEq(hooksCounter.beforeSwapCount(id), 1); + assertEq(hooksCounter.afterSwapCount(id), 1); + } } diff --git a/test/middleware-implementations/HooksCounter.sol b/test/middleware-implementations/HooksCounter.sol index fa6d5e5b..49aeb50a 100644 --- a/test/middleware-implementations/HooksCounter.sol +++ b/test/middleware-implementations/HooksCounter.sol @@ -8,8 +8,27 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; contract HooksCounter is BaseImplementation { + using PoolIdLibrary for PoolKey; + + mapping(PoolId => uint256) public beforeInitializeCount; + mapping(PoolId => uint256) public afterInitializeCount; + + mapping(PoolId => uint256) public beforeSwapCount; + mapping(PoolId => uint256) public afterSwapCount; + + mapping(PoolId => uint256) public beforeAddLiquidityCount; + mapping(PoolId => uint256) public afterAddLiquidityCount; + mapping(PoolId => uint256) public beforeRemoveLiquidityCount; + mapping(PoolId => uint256) public afterRemoveLiquidityCount; + + mapping(PoolId => uint256) public beforeDonateCount; + mapping(PoolId => uint256) public afterDonateCount; + + bytes public lastHookData; + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} function getHookPermissions() public pure override returns (Hooks.Permissions memory) { @@ -31,95 +50,116 @@ contract HooksCounter is BaseImplementation { }); } - function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) + function beforeInitialize(address, PoolKey calldata key, uint160, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + beforeInitializeCount[key.toId()]++; + lastHookData = hookData; return BaseHook.beforeInitialize.selector; } - function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + afterInitializeCount[key.toId()]++; + lastHookData = hookData; return BaseHook.afterInitialize.selector; } - function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) - external - pure - override - returns (bytes4) - { - return BaseHook.beforeAddLiquidity.selector; - } - - function beforeRemoveLiquidity( + function beforeAddLiquidity( address, - PoolKey calldata, + PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, - bytes calldata - ) external pure override returns (bytes4) { - return BaseHook.beforeRemoveLiquidity.selector; + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4) { + beforeAddLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeAddLiquidity.selector; } function afterAddLiquidity( address, - PoolKey calldata, + PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, - bytes calldata - ) external pure override returns (bytes4, BalanceDelta) { + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, BalanceDelta) { + afterAddLiquidityCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } + function beforeRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4) { + beforeRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeRemoveLiquidity.selector; + } + function afterRemoveLiquidity( address, - PoolKey calldata, + PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, - bytes calldata - ) external pure override returns (bytes4, BalanceDelta) { + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, BalanceDelta) { + afterRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } - function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4, BeforeSwapDelta, uint24) { + beforeSwapCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } - function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) - external - pure - override - returns (bytes4, int128) - { + function afterSwap( + address, + PoolKey calldata key, + IPoolManager.SwapParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, int128) { + afterSwapCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.afterSwap.selector, 0); } - function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + function beforeDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + beforeDonateCount[key.toId()]++; + lastHookData = hookData; return BaseHook.beforeDonate.selector; } - function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + function afterDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + afterDonateCount[key.toId()]++; + lastHookData = hookData; return BaseHook.afterDonate.selector; } }