diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol index ebe5c609..e6f38031 100644 --- a/contracts/middleware/BaseMiddleware.sol +++ b/contracts/middleware/BaseMiddleware.sol @@ -11,11 +11,11 @@ import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol"; contract BaseMiddleware is Proxy { /// @notice The address of the pool manager - IPoolManager public immutable poolManager; + IPoolManager public immutable manager; address public immutable implementation; - constructor(IPoolManager _poolManager, address _impl) { - poolManager = _poolManager; + constructor(IPoolManager _manager, address _impl) { + manager = _manager; implementation = _impl; } diff --git a/contracts/middleware/BaseMiddlewareFactory.sol b/contracts/middleware/BaseMiddlewareFactory.sol index 4b06fd6d..d9385034 100644 --- a/contracts/middleware/BaseMiddlewareFactory.sol +++ b/contracts/middleware/BaseMiddlewareFactory.sol @@ -11,10 +11,10 @@ import {IBaseHook} from "../interfaces/IBaseHook.sol"; contract BaseMiddlewareFactory is IMiddlewareFactory { mapping(address => address) private _implementations; - IPoolManager public immutable poolManager; + IPoolManager public immutable manager; - constructor(IPoolManager _poolManager) { - poolManager = _poolManager; + constructor(IPoolManager _manager) { + manager = _manager; } function getImplementation(address middleware) external view override returns (address implementation) { @@ -29,6 +29,6 @@ contract BaseMiddlewareFactory is IMiddlewareFactory { } function _deployMiddleware(address implementation, bytes32 salt) internal virtual returns (address middleware) { - return address(new BaseMiddleware{salt: salt}(poolManager, implementation)); + return address(new BaseMiddleware{salt: salt}(manager, implementation)); } } diff --git a/contracts/middleware/MiddlewareRemove.sol b/contracts/middleware/MiddlewareRemove.sol index d5e0f8e2..7333d2b4 100644 --- a/contracts/middleware/MiddlewareRemove.sol +++ b/contracts/middleware/MiddlewareRemove.sol @@ -11,17 +11,26 @@ import {BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; import {CustomRevert} from "@uniswap/v4-core/src/libraries/CustomRevert.sol"; +import {NonZeroDeltaCount} from "@uniswap/v4-core/src/libraries/NonZeroDeltaCount.sol"; +import {IExttload} from "@uniswap/v4-core/src/interfaces/IExttload.sol"; contract MiddlewareRemove is BaseMiddleware { using CustomRevert for bytes4; using Hooks for IHooks; error HookPermissionForbidden(address hooks); + error HookModifiedDeltas(); + + struct afterRemoveLiquidityParams { + address sender; + PoolKey key; + IPoolManager.ModifyLiquidityParams params; + } bytes internal constant ZERO_BYTES = bytes(""); uint256 public constant gasLimit = 10000000; - constructor(IPoolManager _poolManager, address _impl) BaseMiddleware(_poolManager, _impl) { + constructor(IPoolManager _manager, address _impl) BaseMiddleware(_manager, _impl) { if (IHooks(address(this)).hasPermission(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG)) { HookPermissionForbidden.selector.revertWith(address(this)); } @@ -38,14 +47,24 @@ contract MiddlewareRemove is BaseMiddleware { } function afterRemoveLiquidity( - address, - PoolKey calldata, - IPoolManager.ModifyLiquidityParams calldata, - BalanceDelta, - bytes calldata + address sender, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta delta, + bytes calldata hookData ) external returns (bytes4, BalanceDelta) { - implementation.delegatecall{gas: gasLimit}(msg.data); - // hook cannot return delta + address(this).delegatecall(abi.encodeWithSelector(this._callAndEnsureZeroDeltas.selector, msg.data)); return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } + + function _callAndEnsureZeroDeltas(bytes calldata data) external { + bytes32 slot = bytes32(NonZeroDeltaCount.NONZERO_DELTA_COUNT_SLOT); + uint256 countBefore = uint256(IExttload(address(manager)).exttload(slot)); + address(implementation).delegatecall(msg.data); + uint256 countAfter = uint256(IExttload(address(manager)).exttload(slot)); + if (countAfter > countBefore) { + // purpousely revert to cause the whole hook to reset + revert HookModifiedDeltas(); + } + } } diff --git a/contracts/middleware/MiddlewareRemoveFactory.sol b/contracts/middleware/MiddlewareRemoveFactory.sol index 20b507bb..5dbca5d0 100644 --- a/contracts/middleware/MiddlewareRemoveFactory.sol +++ b/contracts/middleware/MiddlewareRemoveFactory.sol @@ -6,9 +6,9 @@ import {MiddlewareRemove} from "./MiddlewareRemove.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; contract MiddlewareRemoveFactory is BaseMiddlewareFactory { - constructor(IPoolManager _poolManager) BaseMiddlewareFactory(_poolManager) {} + constructor(IPoolManager _manager) BaseMiddlewareFactory(_manager) {} function _deployMiddleware(address implementation, bytes32 salt) internal override returns (address middleware) { - return address(new MiddlewareRemove{salt: salt}(poolManager, implementation)); + return address(new MiddlewareRemove{salt: salt}(manager, implementation)); } } diff --git a/test/MiddlewareRemoveFactory.t.sol b/test/MiddlewareRemoveFactory.t.sol index 93ea4540..e27042bc 100644 --- a/test/MiddlewareRemoveFactory.t.sol +++ b/test/MiddlewareRemoveFactory.t.sol @@ -61,24 +61,28 @@ contract MiddlewareRemoveFactoryTest is Test, Deployers { function testFeeOnRemove() public { uint160 flags = uint160(Hooks.AFTER_REMOVE_LIQUIDITY_FLAG); - FeeOnRemove feeOnRemove = FeeOnRemove(address(uint160(Hooks.AFTER_REMOVE_LIQUIDITY_FLAG))); + FeeOnRemove feeOnRemove = FeeOnRemove(address(flags)); FeeOnRemove impl = new FeeOnRemove(manager); vm.etch(address(feeOnRemove), address(impl).code); - (address hookAddress, bytes32 salt) = HookMiner.find( + (, bytes32 salt) = HookMiner.find( address(factory), flags, type(MiddlewareRemove).creationCode, abi.encode(address(manager), address(feeOnRemove)) ); - factory.createMiddleware(address(feeOnRemove), salt); + middleware = factory.createMiddleware(address(feeOnRemove), salt); initPoolAndAddLiquidity(currency0, currency1, IHooks(feeOnRemove), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + vm.expectRevert(IPoolManager.CurrencyNotSettled.selector); + removeLiquidity(currency0, currency1, IHooks(feeOnRemove), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + + IHooks noHooks = IHooks(address(0)); + initPoolAndAddLiquidity(currency0, currency1, noHooks, 3000, SQRT_PRICE_1_1, ZERO_BYTES); uint256 initialBalance0 = token0.balanceOf(address(this)); uint256 initialBalance1 = token1.balanceOf(address(this)); - removeLiquidity(currency0, currency1, IHooks(feeOnRemove), 3000, SQRT_PRICE_1_1, ZERO_BYTES); - uint256 outWithFees0 = token0.balanceOf(address(this)) - initialBalance0; - uint256 outWithFees1 = token1.balanceOf(address(this)) - initialBalance1; - console.log(outWithFees0, outWithFees1); + removeLiquidity(currency0, currency1, noHooks, 3000, SQRT_PRICE_1_1, ZERO_BYTES); + uint256 outNormal0 = token0.balanceOf(address(this)) - initialBalance0; + uint256 outNormal1 = token1.balanceOf(address(this)) - initialBalance1; initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); initialBalance0 = token0.balanceOf(address(this)); @@ -86,7 +90,10 @@ contract MiddlewareRemoveFactoryTest is Test, Deployers { removeLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); uint256 out0 = token0.balanceOf(address(this)) - initialBalance0; uint256 out1 = token1.balanceOf(address(this)) - initialBalance1; - console.log(out0, out1); + + // no fees taken + assertEq(outNormal0, out0); + assertEq(outNormal1, out1); } function testVariousFactory() public { @@ -131,6 +138,15 @@ contract MiddlewareRemoveFactoryTest is Test, Deployers { assertEq(factory.getImplementation(hookAddress), implementation); } + function testRevertOnDeltaFlags() public { + uint160 flags = uint160(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), flags, type(MiddlewareRemove).creationCode, abi.encode(address(manager), address(counter)) + ); + vm.expectRevert(MiddlewareRemove.HookPermissionForbidden.selector); + factory.createMiddleware(address(counter), salt); + } + // from BaseMiddlewareFactory.t.sol function testRevertOnSameDeployment() public { uint160 flags = uint160( @@ -138,7 +154,7 @@ contract MiddlewareRemoveFactoryTest is Test, Deployers { | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG ); - (address hookAddress, bytes32 salt) = HookMiner.find( + (, bytes32 salt) = HookMiner.find( address(factory), flags, type(MiddlewareRemove).creationCode, abi.encode(address(manager), address(counter)) ); factory.createMiddleware(address(counter), salt); diff --git a/test/middleware/FeeOnRemove.sol b/test/middleware/FeeOnRemove.sol index bcb8fb1f..f5c9603b 100644 --- a/test/middleware/FeeOnRemove.sol +++ b/test/middleware/FeeOnRemove.sol @@ -38,14 +38,12 @@ contract FeeOnRemove is BaseHook { } function afterRemoveLiquidity( - address, /* sender **/ + address, PoolKey calldata key, - IPoolManager.ModifyLiquidityParams calldata, /* params **/ + IPoolManager.ModifyLiquidityParams calldata, BalanceDelta delta, - bytes calldata /* hookData **/ + bytes calldata ) external override onlyByManager returns (bytes4, BalanceDelta) { - assert(delta.amount0() >= 0 && delta.amount1() >= 0); - uint128 feeAmount0 = uint128(delta.amount0()) * LIQUIDITY_FEE / TOTAL_BIPS; uint128 feeAmount1 = uint128(delta.amount1()) * LIQUIDITY_FEE / TOTAL_BIPS;