diff --git a/.forge-snapshots/FullRangeAddInitialLiquidity.snap b/.forge-snapshots/FullRangeAddInitialLiquidity.snap index cd1e3c37..404cf12a 100644 --- a/.forge-snapshots/FullRangeAddInitialLiquidity.snap +++ b/.forge-snapshots/FullRangeAddInitialLiquidity.snap @@ -1 +1 @@ -311073 \ No newline at end of file +311181 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeAddLiquidity.snap b/.forge-snapshots/FullRangeAddLiquidity.snap index b3de2b4e..a4a14676 100644 --- a/.forge-snapshots/FullRangeAddLiquidity.snap +++ b/.forge-snapshots/FullRangeAddLiquidity.snap @@ -1 +1 @@ -122882 \ No newline at end of file +122990 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeFirstSwap.snap b/.forge-snapshots/FullRangeFirstSwap.snap index 54d0d097..da120795 100644 --- a/.forge-snapshots/FullRangeFirstSwap.snap +++ b/.forge-snapshots/FullRangeFirstSwap.snap @@ -1 +1 @@ -80283 \ No newline at end of file +80220 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeInitialize.snap b/.forge-snapshots/FullRangeInitialize.snap index f81651f8..7a0170eb 100644 --- a/.forge-snapshots/FullRangeInitialize.snap +++ b/.forge-snapshots/FullRangeInitialize.snap @@ -1 +1 @@ -1015169 \ No newline at end of file +1015181 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeRemoveLiquidity.snap b/.forge-snapshots/FullRangeRemoveLiquidity.snap index 265c1dec..feea4936 100644 --- a/.forge-snapshots/FullRangeRemoveLiquidity.snap +++ b/.forge-snapshots/FullRangeRemoveLiquidity.snap @@ -1 +1 @@ -110476 \ No newline at end of file +110566 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeRemoveLiquidityAndRebalance.snap b/.forge-snapshots/FullRangeRemoveLiquidityAndRebalance.snap index dcb62527..e0df7eb7 100644 --- a/.forge-snapshots/FullRangeRemoveLiquidityAndRebalance.snap +++ b/.forge-snapshots/FullRangeRemoveLiquidityAndRebalance.snap @@ -1 +1 @@ -239954 \ No newline at end of file +240044 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeSecondSwap.snap b/.forge-snapshots/FullRangeSecondSwap.snap index 1bf183ef..e68df8d3 100644 --- a/.forge-snapshots/FullRangeSecondSwap.snap +++ b/.forge-snapshots/FullRangeSecondSwap.snap @@ -1 +1 @@ -45993 \ No newline at end of file +45930 \ No newline at end of file diff --git a/.forge-snapshots/FullRangeSwap.snap b/.forge-snapshots/FullRangeSwap.snap index 5630ac05..b50d0ea2 100644 --- a/.forge-snapshots/FullRangeSwap.snap +++ b/.forge-snapshots/FullRangeSwap.snap @@ -1 +1 @@ -79414 \ No newline at end of file +79351 \ No newline at end of file diff --git a/.forge-snapshots/TWAMMSubmitOrder.snap b/.forge-snapshots/TWAMMSubmitOrder.snap index b2759d7f..eb3b0f6b 100644 --- a/.forge-snapshots/TWAMMSubmitOrder.snap +++ b/.forge-snapshots/TWAMMSubmitOrder.snap @@ -1 +1 @@ -122355 \ No newline at end of file +122336 \ No newline at end of file diff --git a/contracts/BaseHook.sol b/contracts/BaseHook.sol index eb75502c..01fc4954 100644 --- a/contracts/BaseHook.sol +++ b/contracts/BaseHook.sol @@ -7,28 +7,19 @@ import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {SafeCallback} from "./base/SafeCallback.sol"; +import {ImmutableState} from "./base/ImmutableState.sol"; -abstract contract BaseHook is IHooks { - error NotPoolManager(); +abstract contract BaseHook is IHooks, SafeCallback { error NotSelf(); error InvalidPool(); error LockFailure(); error HookNotImplemented(); - /// @notice The address of the pool manager - IPoolManager public immutable poolManager; - - constructor(IPoolManager _poolManager) { - poolManager = _poolManager; + constructor(IPoolManager _manager) ImmutableState(_manager) { validateHookAddress(this); } - /// @dev Only the pool manager may call this function - modifier poolManagerOnly() { - if (msg.sender != address(poolManager)) revert NotPoolManager(); - _; - } - /// @dev Only this address may call this function modifier selfOnly() { if (msg.sender != address(this)) revert NotSelf(); @@ -50,7 +41,7 @@ abstract contract BaseHook is IHooks { Hooks.validateHookPermissions(_this, getHookPermissions()); } - function unlockCallback(bytes calldata data) external virtual poolManagerOnly returns (bytes memory) { + function _unlockCallback(bytes calldata data) internal virtual override returns (bytes memory) { (bool success, bytes memory returnData) = address(this).call(data); if (success) return returnData; if (returnData.length == 0) revert LockFailure(); diff --git a/contracts/base/ImmutableState.sol b/contracts/base/ImmutableState.sol new file mode 100644 index 00000000..cce37514 --- /dev/null +++ b/contracts/base/ImmutableState.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.19; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; + +contract ImmutableState { + IPoolManager public immutable manager; + + constructor(IPoolManager _manager) { + manager = _manager; + } +} diff --git a/contracts/base/SafeCallback.sol b/contracts/base/SafeCallback.sol new file mode 100644 index 00000000..f985e67c --- /dev/null +++ b/contracts/base/SafeCallback.sol @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {ImmutableState} from "./ImmutableState.sol"; + +abstract contract SafeCallback is ImmutableState, IUnlockCallback { + error NotManager(); + + modifier onlyByManager() { + if (msg.sender != address(manager)) revert NotManager(); + _; + } + + /// @dev We force the onlyByManager modifier by exposing a virtual function after the onlyByManager check. + function unlockCallback(bytes calldata data) external onlyByManager returns (bytes memory) { + return _unlockCallback(data); + } + + function _unlockCallback(bytes calldata data) internal virtual returns (bytes memory); +} diff --git a/contracts/hooks/examples/FeeTaker.sol b/contracts/hooks/examples/FeeTaker.sol new file mode 100644 index 00000000..91d1fb7d --- /dev/null +++ b/contracts/hooks/examples/FeeTaker.sol @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "../../BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; + +abstract contract FeeTaker is BaseHook { + using SafeCast for uint256; + + bytes internal constant ZERO_BYTES = bytes(""); + + constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + + /** + * @notice This hook takes a fee from the unspecified token after a swap. + * @dev This can be overridden if more permissions are needed. + */ + function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: false, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: true, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function afterSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, int128) { + //(Currency currencyUnspecified, amountUnspecified) = key.getUnspecified(params); + + // fee will be in the unspecified token of the swap + bool currency0Specified = (params.amountSpecified < 0 == params.zeroForOne); + (Currency currencyUnspecified, int128 amountUnspecified) = + (currency0Specified) ? (key.currency1, delta.amount1()) : (key.currency0, delta.amount0()); + // if exactOutput swap, get the absolute output amount + if (amountUnspecified < 0) amountUnspecified = -amountUnspecified; + + uint256 feeAmount = _feeAmount(amountUnspecified); + // mint ERC6909 instead of take to avoid edge case where PM doesn't have enough balance + manager.mint(address(this), CurrencyLibrary.toId(currencyUnspecified), feeAmount); + + (bytes4 selector, int128 amount) = _afterSwap(sender, key, params, delta, hookData); + return (selector, feeAmount.toInt128() + amount); + } + + function withdraw(Currency[] calldata currencies) external { + manager.unlock(abi.encode(currencies)); + } + + function _unlockCallback(bytes calldata rawData) internal override returns (bytes memory) { + Currency[] memory currencies = abi.decode(rawData, (Currency[])); + uint256 length = currencies.length; + for (uint256 i = 0; i < length;) { + uint256 amount = manager.balanceOf(address(this), CurrencyLibrary.toId(currencies[i])); + manager.burn(address(this), CurrencyLibrary.toId(currencies[i]), amount); + manager.take(currencies[i], _recipient(), amount); + unchecked { + ++i; + } + } + return ZERO_BYTES; + } + + /** + * @dev This is a virtual function that should be overridden so it returns the fee charged for a given amount. + */ + function _feeAmount(int128 amountUnspecified) internal view virtual returns (uint256); + + /** + * @dev This is a virtual function that should be overridden so it returns the address to receive the fee. + */ + function _recipient() internal view virtual returns (address); + + /** + * @dev This can be overridden to add logic after a swap. + */ + function _afterSwap(address, PoolKey memory, IPoolManager.SwapParams memory, BalanceDelta, bytes calldata) + internal + virtual + returns (bytes4, int128) + { + return (BaseHook.afterSwap.selector, 0); + } +} diff --git a/contracts/hooks/examples/FeeTaking.sol b/contracts/hooks/examples/FeeTaking.sol new file mode 100644 index 00000000..c0099149 --- /dev/null +++ b/contracts/hooks/examples/FeeTaking.sol @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {FeeTaker} from "./FeeTaker.sol"; + +contract FeeTaking is FeeTaker, Owned { + using SafeCast for uint256; + + uint128 private constant TOTAL_BIPS = 10000; + uint128 public immutable swapFeeBips; + address public treasury; + + constructor(IPoolManager _poolManager, uint128 _swapFeeBips, address _owner, address _treasury) + FeeTaker(_poolManager) + Owned(_owner) + { + swapFeeBips = _swapFeeBips; + treasury = _treasury; + } + + function setTreasury(address _treasury) external onlyOwner { + treasury = _treasury; + } + + function _feeAmount(int128 amountUnspecified) internal view override returns (uint256) { + return uint128(amountUnspecified) * swapFeeBips / TOTAL_BIPS; + } + + function _recipient() internal view override returns (address) { + return treasury; + } +} diff --git a/contracts/hooks/examples/FullRange.sol b/contracts/hooks/examples/FullRange.sol index 194be803..191593b8 100644 --- a/contracts/hooks/examples/FullRange.sol +++ b/contracts/hooks/examples/FullRange.sol @@ -26,7 +26,7 @@ import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/type import "../../libraries/LiquidityAmounts.sol"; -contract FullRange is BaseHook, IUnlockCallback { +contract FullRange is BaseHook { using CurrencyLibrary for Currency; using CurrencySettler for Currency; using PoolIdLibrary for PoolKey; @@ -85,7 +85,7 @@ contract FullRange is BaseHook, IUnlockCallback { mapping(PoolId => PoolInfo) public poolInfo; - constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + constructor(IPoolManager _manager) BaseHook(_manager) {} modifier ensure(uint256 deadline) { if (deadline < block.timestamp) revert ExpiredPastDeadline(); @@ -126,13 +126,13 @@ contract FullRange is BaseHook, IUnlockCallback { PoolId poolId = key.toId(); - (uint160 sqrtPriceX96,,,) = poolManager.getSlot0(poolId); + (uint160 sqrtPriceX96,,,) = manager.getSlot0(poolId); if (sqrtPriceX96 == 0) revert PoolNotInitialized(); PoolInfo storage pool = poolInfo[poolId]; - uint128 poolLiquidity = poolManager.getLiquidity(poolId); + uint128 poolLiquidity = manager.getLiquidity(poolId); liquidity = LiquidityAmounts.getLiquidityForAmounts( sqrtPriceX96, @@ -184,7 +184,7 @@ contract FullRange is BaseHook, IUnlockCallback { PoolId poolId = key.toId(); - (uint160 sqrtPriceX96,,,) = poolManager.getSlot0(poolId); + (uint160 sqrtPriceX96,,,) = manager.getSlot0(poolId); if (sqrtPriceX96 == 0) revert PoolNotInitialized(); @@ -260,17 +260,17 @@ contract FullRange is BaseHook, IUnlockCallback { internal returns (BalanceDelta delta) { - delta = abi.decode(poolManager.unlock(abi.encode(CallbackData(msg.sender, key, params))), (BalanceDelta)); + delta = abi.decode(manager.unlock(abi.encode(CallbackData(msg.sender, key, params))), (BalanceDelta)); } function _settleDeltas(address sender, PoolKey memory key, BalanceDelta delta) internal { - key.currency0.settle(poolManager, sender, uint256(int256(-delta.amount0())), false); - key.currency1.settle(poolManager, sender, uint256(int256(-delta.amount1())), false); + key.currency0.settle(manager, sender, uint256(int256(-delta.amount0())), false); + key.currency1.settle(manager, sender, uint256(int256(-delta.amount1())), false); } function _takeDeltas(address sender, PoolKey memory key, BalanceDelta delta) internal { - poolManager.take(key.currency0, sender, uint256(uint128(delta.amount0()))); - poolManager.take(key.currency1, sender, uint256(uint128(delta.amount1()))); + manager.take(key.currency0, sender, uint256(uint128(delta.amount0()))); + manager.take(key.currency1, sender, uint256(uint128(delta.amount1()))); } function _removeLiquidity(PoolKey memory key, IPoolManager.ModifyLiquidityParams memory params) @@ -286,21 +286,16 @@ contract FullRange is BaseHook, IUnlockCallback { uint256 liquidityToRemove = FullMath.mulDiv( uint256(-params.liquidityDelta), - poolManager.getLiquidity(poolId), + manager.getLiquidity(poolId), UniswapV4ERC20(pool.liquidityToken).totalSupply() ); params.liquidityDelta = -(liquidityToRemove.toInt256()); - (delta,) = poolManager.modifyLiquidity(key, params, ZERO_BYTES); + (delta,) = manager.modifyLiquidity(key, params, ZERO_BYTES); pool.hasAccruedFees = false; } - function unlockCallback(bytes calldata rawData) - external - override(IUnlockCallback, BaseHook) - poolManagerOnly - returns (bytes memory) - { + function _unlockCallback(bytes calldata rawData) internal override returns (bytes memory) { CallbackData memory data = abi.decode(rawData, (CallbackData)); BalanceDelta delta; @@ -308,7 +303,7 @@ contract FullRange is BaseHook, IUnlockCallback { delta = _removeLiquidity(data.key, data.params); _takeDeltas(data.sender, data.key, delta); } else { - (delta,) = poolManager.modifyLiquidity(data.key, data.params, ZERO_BYTES); + (delta,) = manager.modifyLiquidity(data.key, data.params, ZERO_BYTES); _settleDeltas(data.sender, data.key, delta); } return abi.encode(delta); @@ -316,12 +311,12 @@ contract FullRange is BaseHook, IUnlockCallback { function _rebalance(PoolKey memory key) public { PoolId poolId = key.toId(); - (BalanceDelta balanceDelta,) = poolManager.modifyLiquidity( + (BalanceDelta balanceDelta,) = manager.modifyLiquidity( key, IPoolManager.ModifyLiquidityParams({ tickLower: MIN_TICK, tickUpper: MAX_TICK, - liquidityDelta: -(poolManager.getLiquidity(poolId).toInt256()), + liquidityDelta: -(manager.getLiquidity(poolId).toInt256()), salt: 0 }), ZERO_BYTES @@ -333,9 +328,9 @@ contract FullRange is BaseHook, IUnlockCallback { ) * FixedPointMathLib.sqrt(FixedPoint96.Q96) ).toUint160(); - (uint160 sqrtPriceX96,,,) = poolManager.getSlot0(poolId); + (uint160 sqrtPriceX96,,,) = manager.getSlot0(poolId); - poolManager.swap( + manager.swap( key, IPoolManager.SwapParams({ zeroForOne: newSqrtPriceX96 < sqrtPriceX96, @@ -353,7 +348,7 @@ contract FullRange is BaseHook, IUnlockCallback { uint256(uint128(balanceDelta.amount1())) ); - (BalanceDelta balanceDeltaAfter,) = poolManager.modifyLiquidity( + (BalanceDelta balanceDeltaAfter,) = manager.modifyLiquidity( key, IPoolManager.ModifyLiquidityParams({ tickLower: MIN_TICK, @@ -368,6 +363,6 @@ contract FullRange is BaseHook, IUnlockCallback { uint128 donateAmount0 = uint128(balanceDelta.amount0() + balanceDeltaAfter.amount0()); uint128 donateAmount1 = uint128(balanceDelta.amount1() + balanceDeltaAfter.amount1()); - poolManager.donate(key, donateAmount0, donateAmount1, ZERO_BYTES); + manager.donate(key, donateAmount0, donateAmount1, ZERO_BYTES); } } diff --git a/contracts/hooks/examples/GeomeanOracle.sol b/contracts/hooks/examples/GeomeanOracle.sol index ec8301a5..df5a9ad1 100644 --- a/contracts/hooks/examples/GeomeanOracle.sol +++ b/contracts/hooks/examples/GeomeanOracle.sol @@ -61,7 +61,7 @@ contract GeomeanOracle is BaseHook { return uint32(block.timestamp); } - constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + constructor(IPoolManager _manager) BaseHook(_manager) {} function getHookPermissions() public pure override returns (Hooks.Permissions memory) { return Hooks.Permissions({ @@ -86,20 +86,20 @@ contract GeomeanOracle is BaseHook { external view override - poolManagerOnly + onlyByManager returns (bytes4) { // This is to limit the fragmentation of pools using this oracle hook. In other words, // there may only be one pool per pair of tokens that use this hook. The tick spacing is set to the maximum // because we only allow max range liquidity in this pool. - if (key.fee != 0 || key.tickSpacing != poolManager.MAX_TICK_SPACING()) revert OnlyOneOraclePoolAllowed(); + if (key.fee != 0 || key.tickSpacing != manager.MAX_TICK_SPACING()) revert OnlyOneOraclePoolAllowed(); return GeomeanOracle.beforeInitialize.selector; } function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata) external override - poolManagerOnly + onlyByManager returns (bytes4) { PoolId id = key.toId(); @@ -110,9 +110,9 @@ contract GeomeanOracle is BaseHook { /// @dev Called before any action that potentially modifies pool price or liquidity, such as swap or modify position function _updatePool(PoolKey calldata key) private { PoolId id = key.toId(); - (, int24 tick,,) = poolManager.getSlot0(id); + (, int24 tick,,) = manager.getSlot0(id); - uint128 liquidity = poolManager.getLiquidity(id); + uint128 liquidity = manager.getLiquidity(id); (states[id].index, states[id].cardinality) = observations[id].write( states[id].index, _blockTimestamp(), tick, liquidity, states[id].cardinality, states[id].cardinalityNext @@ -124,8 +124,8 @@ contract GeomeanOracle is BaseHook { PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata params, bytes calldata - ) external override poolManagerOnly returns (bytes4) { - int24 maxTickSpacing = poolManager.MAX_TICK_SPACING(); + ) external override onlyByManager returns (bytes4) { + int24 maxTickSpacing = manager.MAX_TICK_SPACING(); if ( params.tickLower != TickMath.minUsableTick(maxTickSpacing) || params.tickUpper != TickMath.maxUsableTick(maxTickSpacing) @@ -139,14 +139,14 @@ contract GeomeanOracle is BaseHook { PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata - ) external view override poolManagerOnly returns (bytes4) { + ) external view override onlyByManager returns (bytes4) { revert OraclePoolMustLockLiquidity(); } function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata) external override - poolManagerOnly + onlyByManager returns (bytes4, BeforeSwapDelta, uint24) { _updatePool(key); @@ -163,9 +163,9 @@ contract GeomeanOracle is BaseHook { ObservationState memory state = states[id]; - (, int24 tick,,) = poolManager.getSlot0(id); + (, int24 tick,,) = manager.getSlot0(id); - uint128 liquidity = poolManager.getLiquidity(id); + uint128 liquidity = manager.getLiquidity(id); return observations[id].observe(_blockTimestamp(), secondsAgos, tick, state.index, liquidity, state.cardinality); } diff --git a/contracts/hooks/examples/LimitOrder.sol b/contracts/hooks/examples/LimitOrder.sol index 9ee7a33c..2a8ca909 100644 --- a/contracts/hooks/examples/LimitOrder.sol +++ b/contracts/hooks/examples/LimitOrder.sol @@ -75,7 +75,7 @@ contract LimitOrder is BaseHook { mapping(bytes32 => Epoch) public epochs; mapping(Epoch => EpochInfo) public epochInfos; - constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + constructor(IPoolManager _manager) BaseHook(_manager) {} function getHookPermissions() public pure override returns (Hooks.Permissions memory) { return Hooks.Permissions({ @@ -117,7 +117,7 @@ contract LimitOrder is BaseHook { } function getTick(PoolId poolId) private view returns (int24 tick) { - (, tick,,) = poolManager.getSlot0(poolId); + (, tick,,) = manager.getSlot0(poolId); } function getTickLower(int24 tick, int24 tickSpacing) private pure returns (int24) { @@ -129,7 +129,7 @@ contract LimitOrder is BaseHook { function afterInitialize(address, PoolKey calldata key, uint160, int24 tick, bytes calldata) external override - poolManagerOnly + onlyByManager returns (bytes4) { setTickLowerLast(key.toId(), getTickLower(tick, key.tickSpacing)); @@ -142,7 +142,7 @@ contract LimitOrder is BaseHook { IPoolManager.SwapParams calldata params, BalanceDelta, bytes calldata - ) external override poolManagerOnly returns (bytes4, int128) { + ) external override onlyByManager returns (bytes4, int128) { (int24 tickLower, int24 lower, int24 upper) = _getCrossedTicks(key.toId(), key.tickSpacing); if (lower > upper) return (LimitOrder.afterSwap.selector, 0); @@ -197,10 +197,10 @@ contract LimitOrder is BaseHook { function _unlockCallbackFill(PoolKey calldata key, int24 tickLower, int256 liquidityDelta) private - poolManagerOnly + onlyByManager returns (uint128 amount0, uint128 amount1) { - (BalanceDelta delta,) = poolManager.modifyLiquidity( + (BalanceDelta delta,) = manager.modifyLiquidity( key, IPoolManager.ModifyLiquidityParams({ tickLower: tickLower, @@ -212,10 +212,10 @@ contract LimitOrder is BaseHook { ); if (delta.amount0() > 0) { - poolManager.mint(address(this), key.currency0.toId(), amount0 = uint128(delta.amount0())); + manager.mint(address(this), key.currency0.toId(), amount0 = uint128(delta.amount0())); } if (delta.amount1() > 0) { - poolManager.mint(address(this), key.currency1.toId(), amount1 = uint128(delta.amount1())); + manager.mint(address(this), key.currency1.toId(), amount1 = uint128(delta.amount1())); } } @@ -225,7 +225,7 @@ contract LimitOrder is BaseHook { { if (liquidity == 0) revert ZeroLiquidity(); - poolManager.unlock( + manager.unlock( abi.encodeCall( this.unlockCallbackPlace, (key, tickLower, zeroForOne, int256(uint256(liquidity)), msg.sender) ) @@ -263,7 +263,7 @@ contract LimitOrder is BaseHook { int256 liquidityDelta, address owner ) external selfOnly { - (BalanceDelta delta,) = poolManager.modifyLiquidity( + (BalanceDelta delta,) = manager.modifyLiquidity( key, IPoolManager.ModifyLiquidityParams({ tickLower: tickLower, @@ -277,11 +277,11 @@ contract LimitOrder is BaseHook { if (delta.amount0() < 0) { if (delta.amount1() != 0) revert InRange(); if (!zeroForOne) revert CrossedRange(); - key.currency0.settle(poolManager, owner, uint256(uint128(-delta.amount0())), false); + key.currency0.settle(manager, owner, uint256(uint128(-delta.amount0())), false); } else { if (delta.amount0() != 0) revert InRange(); if (zeroForOne) revert CrossedRange(); - key.currency1.settle(poolManager, owner, uint256(uint128(-delta.amount1())), false); + key.currency1.settle(manager, owner, uint256(uint128(-delta.amount1())), false); } } @@ -298,7 +298,7 @@ contract LimitOrder is BaseHook { uint256 amount0Fee; uint256 amount1Fee; (amount0Fee, amount1Fee) = abi.decode( - poolManager.unlock( + manager.unlock( abi.encodeCall( this.unlockCallbackKill, (key, tickLower, -int256(uint256(liquidity)), to, liquidity == epochInfo.liquidityTotal) @@ -329,7 +329,7 @@ contract LimitOrder is BaseHook { // could be unfairly diluted by a user sychronously placing then killing a limit order to skim off fees. // to prevent this, we allocate all fee revenue to remaining limit order placers, unless this is the last order. if (!removingAllLiquidity) { - (, BalanceDelta deltaFee) = poolManager.modifyLiquidity( + (, BalanceDelta deltaFee) = manager.modifyLiquidity( key, IPoolManager.ModifyLiquidityParams({ tickLower: tickLower, @@ -341,14 +341,14 @@ contract LimitOrder is BaseHook { ); if (deltaFee.amount0() > 0) { - poolManager.mint(address(this), key.currency0.toId(), amount0Fee = uint128(deltaFee.amount0())); + manager.mint(address(this), key.currency0.toId(), amount0Fee = uint128(deltaFee.amount0())); } if (deltaFee.amount1() > 0) { - poolManager.mint(address(this), key.currency1.toId(), amount1Fee = uint128(deltaFee.amount1())); + manager.mint(address(this), key.currency1.toId(), amount1Fee = uint128(deltaFee.amount1())); } } - (BalanceDelta delta,) = poolManager.modifyLiquidity( + (BalanceDelta delta,) = manager.modifyLiquidity( key, IPoolManager.ModifyLiquidityParams({ tickLower: tickLower, @@ -360,10 +360,10 @@ contract LimitOrder is BaseHook { ); if (delta.amount0() > 0) { - key.currency0.take(poolManager, to, uint256(uint128(delta.amount0())), false); + key.currency0.take(manager, to, uint256(uint128(delta.amount0())), false); } if (delta.amount1() > 0) { - key.currency1.take(poolManager, to, uint256(uint128(delta.amount1())), false); + key.currency1.take(manager, to, uint256(uint128(delta.amount1())), false); } } @@ -385,7 +385,7 @@ contract LimitOrder is BaseHook { epochInfo.token1Total -= amount1; epochInfo.liquidityTotal = liquidityTotal - liquidity; - poolManager.unlock( + manager.unlock( abi.encodeCall( this.unlockCallbackWithdraw, (epochInfo.currency0, epochInfo.currency1, amount0, amount1, to) ) @@ -402,17 +402,17 @@ contract LimitOrder is BaseHook { address to ) external selfOnly { if (token0Amount > 0) { - poolManager.burn(address(this), currency0.toId(), token0Amount); - poolManager.take(currency0, to, token0Amount); + manager.burn(address(this), currency0.toId(), token0Amount); + manager.take(currency0, to, token0Amount); } if (token1Amount > 0) { - poolManager.burn(address(this), currency1.toId(), token1Amount); - poolManager.take(currency1, to, token1Amount); + manager.burn(address(this), currency1.toId(), token1Amount); + manager.take(currency1, to, token1Amount); } } function onERC1155Received(address, address, uint256, uint256, bytes calldata) external view returns (bytes4) { - if (msg.sender != address(poolManager)) revert NotPoolManagerToken(); + if (msg.sender != address(manager)) revert NotPoolManagerToken(); return IERC1155Receiver.onERC1155Received.selector; } } diff --git a/contracts/hooks/examples/TWAMM.sol b/contracts/hooks/examples/TWAMM.sol index 5704d765..dc1f3b00 100644 --- a/contracts/hooks/examples/TWAMM.sol +++ b/contracts/hooks/examples/TWAMM.sol @@ -61,7 +61,7 @@ contract TWAMM is BaseHook, ITWAMM { // tokensOwed[token][owner] => amountOwed mapping(Currency => mapping(address => uint256)) public tokensOwed; - constructor(IPoolManager _poolManager, uint256 _expirationInterval) BaseHook(_poolManager) { + constructor(IPoolManager _manager, uint256 _expirationInterval) BaseHook(_manager) { expirationInterval = _expirationInterval; } @@ -88,7 +88,7 @@ contract TWAMM is BaseHook, ITWAMM { external virtual override - poolManagerOnly + onlyByManager returns (bytes4) { // one-time initialization enforced in PoolManager @@ -101,7 +101,7 @@ contract TWAMM is BaseHook, ITWAMM { PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, bytes calldata - ) external override poolManagerOnly returns (bytes4) { + ) external override onlyByManager returns (bytes4) { executeTWAMMOrders(key); return BaseHook.beforeAddLiquidity.selector; } @@ -109,7 +109,7 @@ contract TWAMM is BaseHook, ITWAMM { function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata) external override - poolManagerOnly + onlyByManager returns (bytes4, BeforeSwapDelta, uint24) { executeTWAMMOrders(key); @@ -143,17 +143,14 @@ contract TWAMM is BaseHook, ITWAMM { /// @inheritdoc ITWAMM function executeTWAMMOrders(PoolKey memory key) public { PoolId poolId = key.toId(); - (uint160 sqrtPriceX96,,,) = poolManager.getSlot0(poolId); + (uint160 sqrtPriceX96,,,) = manager.getSlot0(poolId); State storage twamm = twammStates[poolId]; - (bool zeroForOne, uint160 sqrtPriceLimitX96) = _executeTWAMMOrders( - twamm, poolManager, key, PoolParamsOnExecute(sqrtPriceX96, poolManager.getLiquidity(poolId)) - ); + (bool zeroForOne, uint160 sqrtPriceLimitX96) = + _executeTWAMMOrders(twamm, manager, key, PoolParamsOnExecute(sqrtPriceX96, manager.getLiquidity(poolId))); if (sqrtPriceLimitX96 != 0 && sqrtPriceLimitX96 != sqrtPriceX96) { - poolManager.unlock( - abi.encode(key, IPoolManager.SwapParams(zeroForOne, type(int256).max, sqrtPriceLimitX96)) - ); + manager.unlock(abi.encode(key, IPoolManager.SwapParams(zeroForOne, type(int256).max, sqrtPriceLimitX96))); } } @@ -309,25 +306,25 @@ contract TWAMM is BaseHook, ITWAMM { IERC20Minimal(Currency.unwrap(token)).safeTransfer(to, amountTransferred); } - function unlockCallback(bytes calldata rawData) external override poolManagerOnly returns (bytes memory) { + function _unlockCallback(bytes calldata rawData) internal override returns (bytes memory) { (PoolKey memory key, IPoolManager.SwapParams memory swapParams) = abi.decode(rawData, (PoolKey, IPoolManager.SwapParams)); - BalanceDelta delta = poolManager.swap(key, swapParams, ZERO_BYTES); + BalanceDelta delta = manager.swap(key, swapParams, ZERO_BYTES); if (swapParams.zeroForOne) { if (delta.amount0() < 0) { - key.currency0.settle(poolManager, address(this), uint256(uint128(-delta.amount0())), false); + key.currency0.settle(manager, address(this), uint256(uint128(-delta.amount0())), false); } if (delta.amount1() > 0) { - key.currency1.take(poolManager, address(this), uint256(uint128(delta.amount1())), false); + key.currency1.take(manager, address(this), uint256(uint128(delta.amount1())), false); } } else { if (delta.amount1() < 0) { - key.currency1.settle(poolManager, address(this), uint256(uint128(-delta.amount1())), false); + key.currency1.settle(manager, address(this), uint256(uint128(-delta.amount1())), false); } if (delta.amount0() > 0) { - key.currency0.take(poolManager, address(this), uint256(uint128(delta.amount0())), false); + key.currency0.take(manager, address(this), uint256(uint128(delta.amount0())), false); } } return bytes(""); @@ -346,7 +343,7 @@ contract TWAMM is BaseHook, ITWAMM { /// @param pool The relevant state of the pool function _executeTWAMMOrders( State storage self, - IPoolManager poolManager, + IPoolManager manager, PoolKey memory key, PoolParamsOnExecute memory pool ) internal returns (bool zeroForOne, uint160 newSqrtPriceX96) { @@ -371,7 +368,7 @@ contract TWAMM is BaseHook, ITWAMM { if (orderPool0For1.sellRateCurrent != 0 && orderPool1For0.sellRateCurrent != 0) { pool = _advanceToNewTimestamp( self, - poolManager, + manager, key, AdvanceParams( expirationInterval, @@ -383,7 +380,7 @@ contract TWAMM is BaseHook, ITWAMM { } else { pool = _advanceTimestampForSinglePoolSell( self, - poolManager, + manager, key, AdvanceSingleParams( expirationInterval, @@ -405,14 +402,14 @@ contract TWAMM is BaseHook, ITWAMM { if (orderPool0For1.sellRateCurrent != 0 && orderPool1For0.sellRateCurrent != 0) { pool = _advanceToNewTimestamp( self, - poolManager, + manager, key, AdvanceParams(expirationInterval, block.timestamp, block.timestamp - prevTimestamp, pool) ); } else { pool = _advanceTimestampForSinglePoolSell( self, - poolManager, + manager, key, AdvanceSingleParams( expirationInterval, @@ -440,7 +437,7 @@ contract TWAMM is BaseHook, ITWAMM { function _advanceToNewTimestamp( State storage self, - IPoolManager poolManager, + IPoolManager manager, PoolKey memory poolKey, AdvanceParams memory params ) private returns (PoolParamsOnExecute memory) { @@ -462,13 +459,13 @@ contract TWAMM is BaseHook, ITWAMM { finalSqrtPriceX96 = TwammMath.getNewSqrtPriceX96(executionParams); (bool crossingInitializedTick, int24 tick) = - _isCrossingInitializedTick(params.pool, poolManager, poolKey, finalSqrtPriceX96); + _isCrossingInitializedTick(params.pool, manager, poolKey, finalSqrtPriceX96); unchecked { if (crossingInitializedTick) { uint256 secondsUntilCrossingX96; (params.pool, secondsUntilCrossingX96) = _advanceTimeThroughTickCrossing( self, - poolManager, + manager, poolKey, TickCrossingParams(tick, params.nextTimestamp, secondsElapsedX96, params.pool) ); @@ -503,7 +500,7 @@ contract TWAMM is BaseHook, ITWAMM { function _advanceTimestampForSinglePoolSell( State storage self, - IPoolManager poolManager, + IPoolManager manager, PoolKey memory poolKey, AdvanceSingleParams memory params ) private returns (PoolParamsOnExecute memory) { @@ -518,10 +515,10 @@ contract TWAMM is BaseHook, ITWAMM { ); (bool crossingInitializedTick, int24 tick) = - _isCrossingInitializedTick(params.pool, poolManager, poolKey, finalSqrtPriceX96); + _isCrossingInitializedTick(params.pool, manager, poolKey, finalSqrtPriceX96); if (crossingInitializedTick) { - (, int128 liquidityNetAtTick) = poolManager.getTickLiquidity(poolKey.toId(), tick); + (, int128 liquidityNetAtTick) = manager.getTickLiquidity(poolKey.toId(), tick); uint160 initializedSqrtPrice = TickMath.getSqrtPriceAtTick(tick); uint256 swapDelta0 = SqrtPriceMath.getAmount0Delta( @@ -575,7 +572,7 @@ contract TWAMM is BaseHook, ITWAMM { function _advanceTimeThroughTickCrossing( State storage self, - IPoolManager poolManager, + IPoolManager manager, PoolKey memory poolKey, TickCrossingParams memory params ) private returns (PoolParamsOnExecute memory, uint256) { @@ -605,7 +602,7 @@ contract TWAMM is BaseHook, ITWAMM { unchecked { // update pool - (, int128 liquidityNet) = poolManager.getTickLiquidity(poolKey.toId(), params.initializedTick); + (, int128 liquidityNet) = manager.getTickLiquidity(poolKey.toId(), params.initializedTick); if (initializedSqrtPrice < params.pool.sqrtPriceX96) liquidityNet = -liquidityNet; params.pool.liquidity = liquidityNet < 0 ? params.pool.liquidity - uint128(-liquidityNet) @@ -618,7 +615,7 @@ contract TWAMM is BaseHook, ITWAMM { function _isCrossingInitializedTick( PoolParamsOnExecute memory pool, - IPoolManager poolManager, + IPoolManager manager, PoolKey memory poolKey, uint160 nextSqrtPriceX96 ) internal view returns (bool crossingInitializedTick, int24 nextTickInit) { @@ -634,7 +631,7 @@ contract TWAMM is BaseHook, ITWAMM { unchecked { if (searchingLeft) nextTickInit -= 1; } - (nextTickInit, crossingInitializedTick) = poolManager.getNextInitializedTickWithinOneWord( + (nextTickInit, crossingInitializedTick) = manager.getNextInitializedTickWithinOneWord( poolKey.toId(), nextTickInit, poolKey.tickSpacing, searchingLeft ); nextTickInitFurtherThanTarget = searchingLeft ? nextTickInit <= targetTick : nextTickInit > targetTick; diff --git a/contracts/hooks/examples/VolatilityOracle.sol b/contracts/hooks/examples/VolatilityOracle.sol index ede61bf5..2900632f 100644 --- a/contracts/hooks/examples/VolatilityOracle.sol +++ b/contracts/hooks/examples/VolatilityOracle.sol @@ -12,14 +12,14 @@ contract VolatilityOracle is BaseHook { error MustUseDynamicFee(); - uint32 deployTimestamp; + uint32 immutable deployTimestamp; /// @dev For mocking function _blockTimestamp() internal view virtual returns (uint32) { return uint32(block.timestamp); } - constructor(IPoolManager _poolManager) BaseHook(_poolManager) { + constructor(IPoolManager _manager) BaseHook(_manager) { deployTimestamp = _blockTimestamp(); } @@ -56,7 +56,7 @@ contract VolatilityOracle is BaseHook { uint24 startingFee = 3000; uint32 lapsed = _blockTimestamp() - deployTimestamp; uint24 fee = startingFee + (uint24(lapsed) * 100) / 60; // 100 bps a minute - poolManager.updateDynamicLPFee(key, fee); // initial fee 0.30% + manager.updateDynamicLPFee(key, fee); // initial fee 0.30% } function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata) diff --git a/contracts/interfaces/IMiddlewareFactory.sol b/contracts/interfaces/IMiddlewareFactory.sol new file mode 100644 index 00000000..4af554fe --- /dev/null +++ b/contracts/interfaces/IMiddlewareFactory.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +interface IMiddlewareFactory { + event MiddlewareCreated(address implementation, address middleware); + + /// @notice Returns the implementation address for a given middleware + /// @param middleware The middleware address + /// @return implementation The implementation address + function getImplementation(address middleware) external view returns (address implementation); + + /// @notice Creates a middleware for the given implementation + /// @param implementation The implementation address + /// @param salt The salt to use to deploy the middleware + /// @return middleware The address of the newly created middleware + function createMiddleware(address implementation, bytes32 salt) external returns (address middleware); +} diff --git a/contracts/libraries/ReentrancyState.sol b/contracts/libraries/ReentrancyState.sol new file mode 100644 index 00000000..966d1269 --- /dev/null +++ b/contracts/libraries/ReentrancyState.sol @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +library ReentrancyState { + // bytes32(uint256(keccak256("ReentrancyState")) - 1) + bytes32 constant REENTRANCY_STATE_SLOT = 0xbedc9a60a226d4ae7b727cbc828f66c94c4eead57777428ceab2f04b0efca3a5; + + function unlock() internal { + assembly { + tstore(REENTRANCY_STATE_SLOT, 0) + } + } + + function lockSwap() internal { + assembly { + tstore(REENTRANCY_STATE_SLOT, 1) + } + } + + function lockSwapRemove() internal { + assembly { + tstore(REENTRANCY_STATE_SLOT, 2) + } + } + + function read() internal view returns (uint256 state) { + assembly { + state := tload(REENTRANCY_STATE_SLOT) + } + } + + function swapLocked() internal view returns (bool) { + return read() == 1 || read() == 2; + } + + function removeLocked() internal view returns (bool) { + return read() == 2; + } +} diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol new file mode 100644 index 00000000..ebe5c609 --- /dev/null +++ b/contracts/middleware/BaseMiddleware.sol @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol"; + +contract BaseMiddleware is Proxy { + /// @notice The address of the pool manager + IPoolManager public immutable poolManager; + address public immutable implementation; + + constructor(IPoolManager _poolManager, address _impl) { + poolManager = _poolManager; + implementation = _impl; + } + + function _implementation() internal view override returns (address) { + return implementation; + } + + // yo i wanna delete this function but how do i remove this warning + receive() external payable { + _delegate(_implementation()); + } +} diff --git a/contracts/middleware/BaseMiddlewareDefault.txt b/contracts/middleware/BaseMiddlewareDefault.txt new file mode 100644 index 00000000..046c1079 --- /dev/null +++ b/contracts/middleware/BaseMiddlewareDefault.txt @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract Middleware is IHooks{ + IPoolManager public immutable poolManager; + Hooks.Permissions private permissions; + IHooks public immutable implementation; + + constructor(IPoolManager _poolManager, address _implementation, uint160 _flags) { + poolManager = _poolManager; + permissions = IHooks(_implementation).getHookPermissions(); + implementation = IHooks(_implementation); + _flags = _flags; + } + + function getHookPermissions() public view returns (Hooks.Permissions memory) { + return permissions; + } + + modifier nonReentrantBefore() private { + if (_status == ENTERED) { + revert ActionBetweenHook(); + } + _status = ENTERED; + _; + } + + modifier nonReentrantAfter() private { + _; + _status = NOT_ENTERED; + } + + function beforeSwap(address sender, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata hookData) + nonReentrantBefore + external + returns (bytes4, BeforeSwapDelta, uint24) + { + try implementation.beforeSwap(sender, key, params, hookData) returns ( + bytes4 selector, + BeforeSwapDelta memory beforeSwapDelta, + uint24 lpFeeOverride + ) { + + return (selector, beforeSwapDelta, lpFeeOverride); + } catch { + return (defaultSelector, defaultBeforeSwapDelta, defaultLpFeeOverride); + } + } + + function afterSwap(...) nonReentrantAfter { try catch... } + function beforeAddLiquidity(...) nonReentrantBefore { try catch... } + function afterAddLiquidity(...) nonReentrantAfter { try catch... } + function beforeRemoveLiquidity(...) nonReentrantBefore { try catch... } + function afterRemoveLiquidity(...) nonReentrantAfter { try catch... } + // who cares about donate lol +} diff --git a/contracts/middleware/MiddlewareProtect.sol b/contracts/middleware/MiddlewareProtect.sol new file mode 100644 index 00000000..204679b2 --- /dev/null +++ b/contracts/middleware/MiddlewareProtect.sol @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {console} from "../../lib/forge-std/src/console.sol"; +import {BaseHook} from "./../BaseHook.sol"; +import {ReentrancyState} from "./../libraries/ReentrancyState.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract MiddlewareProtect is BaseMiddleware { + using StateLibrary for IPoolManager; + using Hooks for IHooks; + + uint256 public constant gasLimit = 1000000; + + /// @notice Thrown if the address will lead to forbidden flags being set + /// @param hooks The address of the hooks contract + error HookPermissionForbidden(address hooks); + error ForbiddenReturn(); + + constructor(IPoolManager _poolManager, address _impl) BaseMiddleware(_poolManager, _impl) { + // deny any hooks that return deltas + if ( + this.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG) + || this.hasPermission(Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG) + || this.hasPermission(Hooks.AFTER_ADD_LIQUIDITY_RETURNS_DELTA_FLAG) + || this.hasPermission(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG) + ) { + HookPermissionForbidden.selector.revertWith(address(this)); + } + } + + // block swaps and removes + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + swapNotLocked + returns (bytes4, BeforeSwapDelta, uint24) + { + ReentrancyState.lockSwapRemove(); + console.log("beforeSwap middleware"); + (bool success, bytes memory returnData) = implementation.delegatecall{gas: gasLimit}(msg.data); + require(success); + ReentrancyState.unlock(); + (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) = + abi.decode(returnData, (bytes4, BeforeSwapDelta, uint24)); + if (lpFeeOverride != 0) { + revert ForbiddenReturn(); + } + return (selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + // afterSwap - no protections + + // block swaps + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + returns (bytes4) + { + ReentrancyState.lockSwap(); + console.log("beforeAddLiquidity middleware"); + (bool success, bytes memory returnData) = implementation.delegatecall{gas: gasLimit}(msg.data); + require(success); + ReentrancyState.unlock(); + return abi.decode(returnData, (bytes4)); + } + + // afterAddLiquidity - no protections + + // block swaps and reverts + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external removeNotLocked returns (bytes4) { + ReentrancyState.lockSwap(); + console.log("beforeRemoveLiquidity middleware"); + implementation.delegatecall{gas: gasLimit}(msg.data); + ReentrancyState.unlock(); + return BaseHook.beforeRemoveLiquidity.selector; + } + + // block reverts + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + console.log("afterRemoveLiquidity middleware"); + implementation.delegatecall{gas: gasLimit}(msg.data); + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } +} diff --git a/contracts/middleware/MiddlewareProtectFactory.sol b/contracts/middleware/MiddlewareProtectFactory.sol new file mode 100644 index 00000000..3287a629 --- /dev/null +++ b/contracts/middleware/MiddlewareProtectFactory.sol @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; +import {MiddlewareProtect} from "./MiddlewareProtect.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; + +contract MiddlewareProtectFactory is IMiddlewareFactory { + mapping(address => address) private _implementations; + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getImplementation(address middleware) external view override returns (address implementation) { + return _implementations[middleware]; + } + + function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { + middleware = address(new MiddlewareProtect{salt: salt}(poolManager, implementation)); + _implementations[middleware] = implementation; + emit MiddlewareCreated(implementation, middleware); + } +} diff --git a/contracts/middleware/MiddlewareRemove.sol b/contracts/middleware/MiddlewareRemove.sol new file mode 100644 index 00000000..5df5fc43 --- /dev/null +++ b/contracts/middleware/MiddlewareRemove.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {BaseHook} from "../BaseHook.sol"; +import {console} from "../../lib/forge-std/src/console.sol"; +import {BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; + +contract MiddlewareRemove is BaseMiddleware { + bytes internal constant ZERO_BYTES = bytes(""); + uint256 public constant gasLimit = 10000000; + + constructor(IPoolManager _poolManager, address _impl) BaseMiddleware(_poolManager, _impl) {} + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external returns (bytes4) { + console.log("beforeRemoveLiquidity middleware"); + (bool success, bytes memory returnData) = implementation.delegatecall{gas: gasLimit}(msg.data); + console.log(success); + return BaseHook.beforeRemoveLiquidity.selector; + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + console.log("afterRemoveLiquidity middleware"); + (bool success, bytes memory returnData) = implementation.delegatecall{gas: gasLimit}(msg.data); + console.log(success); + // hook cannot return delta + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } +} diff --git a/contracts/middleware/MiddlewareRemoveFactory.sol b/contracts/middleware/MiddlewareRemoveFactory.sol new file mode 100644 index 00000000..72f4368a --- /dev/null +++ b/contracts/middleware/MiddlewareRemoveFactory.sol @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; +import {MiddlewareRemove} from "./MiddlewareRemove.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; + +contract MiddlewareRemoveFactory is IMiddlewareFactory { + mapping(address => address) private _implementations; + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getImplementation(address middleware) external view override returns (address implementation) { + return _implementations[middleware]; + } + + function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { + middleware = address(new MiddlewareRemove{salt: salt}(poolManager, implementation)); + _implementations[middleware] = implementation; + emit MiddlewareCreated(implementation, middleware); + } +} diff --git a/lib/forge-gas-snapshot b/lib/forge-gas-snapshot index 2f884282..9161f7c0 160000 --- a/lib/forge-gas-snapshot +++ b/lib/forge-gas-snapshot @@ -1 +1 @@ -Subproject commit 2f884282b4cd067298e798974f5b534288b13bc2 +Subproject commit 9161f7c0b6c6788a89081e2b3b9c67592b71e689 diff --git a/lib/forge-std b/lib/forge-std index 2b58ecbc..75b3fcf0 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 2b58ecbcf3dfde7a75959dc7b4eb3d0670278de6 +Subproject commit 75b3fcf052cc7886327e4c2eac3d1a1f36942b41 diff --git a/test/BaseMiddleware.t.sol b/test/BaseMiddleware.t.sol new file mode 100644 index 00000000..e2c7bdf0 --- /dev/null +++ b/test/BaseMiddleware.t.sol @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {FeeTakingLite} from "./middleware/FeeTakingLite.sol"; +import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol"; +import {BaseMiddlewareImplementation} from "./shared/implementation/BaseMiddlewareImplementation.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract BaseMiddlewareTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + BaseMiddleware baseMiddleware = + BaseMiddleware(payable(address(uint160(Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG)))); + PoolId id; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + FeeTakingLite feeTakingLite = new FeeTakingLite(manager); + + vm.record(); + BaseMiddlewareImplementation impl = + new BaseMiddlewareImplementation(manager, address(feeTakingLite), baseMiddleware); + (, bytes32[] memory writes) = vm.accesses(address(impl)); + vm.etch(address(baseMiddleware), address(impl).code); + // for each storage key that was written during the hook implementation, copy the value over + unchecked { + for (uint256 i = 0; i < writes.length; i++) { + bytes32 slot = writes[i]; + vm.store(address(baseMiddleware), slot, vm.load(address(impl), slot)); + } + } + + // key = PoolKey(currency0, currency1, 3000, 60, baseMiddleware); + (key, id) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(baseMiddleware)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + ); + + token0.approve(address(baseMiddleware), type(uint256).max); + token1.approve(address(baseMiddleware), type(uint256).max); + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + } + + // test normal behavior, copied from FeeTaking.t.sol + function testNormal() public { + // Swap exact token0 for token1 // + bool zeroForOne = true; + int256 amountSpecified = -1e12; + BalanceDelta swapDelta = swap(key, zeroForOne, amountSpecified, ZERO_BYTES); + // ---------------------------- // + + uint128 output = uint128(swapDelta.amount1()); + assertTrue(output > 0); + + uint256 expectedFee = calculateFeeForExactInput(output, 25); + + assertEq(manager.balanceOf(address(baseMiddleware), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(baseMiddleware), CurrencyLibrary.toId(key.currency1)), expectedFee); + + // Swap token0 for exact token1 // + bool zeroForOne2 = true; + int256 amountSpecified2 = 1e12; // positive number indicates exact output swap + BalanceDelta swapDelta2 = swap(key, zeroForOne2, amountSpecified2, ZERO_BYTES); + // ---------------------------- // + + uint128 input = uint128(-swapDelta2.amount0()); + assertTrue(output > 0); + + uint256 expectedFee2 = calculateFeeForExactOutput(input, 25); + + assertEq(manager.balanceOf(address(baseMiddleware), CurrencyLibrary.toId(key.currency0)), expectedFee2); + assertEq(manager.balanceOf(address(baseMiddleware), CurrencyLibrary.toId(key.currency1)), expectedFee); + + // test withdrawing tokens // + Currency[] memory currencies = new Currency[](2); + currencies[0] = key.currency0; + currencies[1] = key.currency1; + FeeTakingLite impl = FeeTakingLite(address(baseMiddleware)); + impl.withdraw(TREASURY, currencies); + assertEq(manager.balanceOf(address(baseMiddleware), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(baseMiddleware), CurrencyLibrary.toId(key.currency1)), 0); + assertEq(currency0.balanceOf(TREASURY), expectedFee2); + assertEq(currency1.balanceOf(TREASURY), expectedFee); + } + + function calculateFeeForExactInput(uint256 outputAmount, uint128 feeBips) internal pure returns (uint256) { + return outputAmount * TOTAL_BIPS / (TOTAL_BIPS - feeBips) - outputAmount; + } + + function calculateFeeForExactOutput(uint256 inputAmount, uint128 feeBips) internal pure returns (uint256) { + return (inputAmount * feeBips) / (TOTAL_BIPS + feeBips); + } +} diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol new file mode 100644 index 00000000..52fed845 --- /dev/null +++ b/test/MiddlewareProtectFactory.t.sol @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {HooksFrontrun} from "./middleware/HooksFrontrun.sol"; +import {HooksFrontrunImplementation} from "./shared/implementation/HooksFrontrunImplementation.sol"; +import {MiddlewareProtect} from "../contracts/middleware/MiddlewareProtect.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {HooksRevert} from "./middleware/HooksRevert.sol"; +import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; +import {MiddlewareProtectFactory} from "./../contracts/middleware/MiddlewareProtectFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; + +contract MiddlewareProtectFactoryTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + PoolId id; + + MiddlewareProtectFactory factory; + HooksFrontrun hooksFrontrun; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + hooksFrontrun = HooksFrontrun(address(uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG))); + vm.record(); + HooksFrontrunImplementation impl = new HooksFrontrunImplementation(manager, hooksFrontrun); + (, bytes32[] memory writes) = vm.accesses(address(impl)); + vm.etch(address(hooksFrontrun), address(impl).code); + unchecked { + for (uint256 i = 0; i < writes.length; i++) { + bytes32 slot = writes[i]; + vm.store(address(hooksFrontrun), slot, vm.load(address(impl), slot)); + } + } + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + + factory = new MiddlewareProtectFactory(manager); + } + + function testFrontrun() public { + (PoolKey memory key,) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(address(0)), 100, SQRT_PRICE_1_1, ZERO_BYTES); + BalanceDelta swapDelta = swap(key, true, 0.001 ether, ZERO_BYTES); + + (key,) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(hooksFrontrun)), 100, SQRT_PRICE_1_1, ZERO_BYTES + ); + BalanceDelta swapDelta2 = swap(key, true, 0.001 ether, ZERO_BYTES); + + // while both swaps are in the same pool, the second swap is more expensive + assertEq(swapDelta.amount1(), swapDelta2.amount1()); + assertTrue(abs(swapDelta.amount0()) < abs(swapDelta2.amount0())); + assertTrue(manager.balanceOf(address(hooksFrontrun), CurrencyLibrary.toId(key.currency0)) > 0); + } + + function testVariousProtectFactory() public { + uint160 flags = uint160(Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareProtect).creationCode, + abi.encode(address(manager), address(hooksFrontrun)) + ); + testOn(address(hooksFrontrun), salt); + } + + // creates a middleware on an implementation + function testOn(address implementation, bytes32 salt) internal { + address hookAddress = factory.createMiddleware(implementation, salt); + MiddlewareProtect middlewareProtect = MiddlewareProtect(payable(hookAddress)); + + (key,) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(middlewareProtect)), 100, SQRT_PRICE_1_1, ZERO_BYTES + ); + swap(key, true, 0.001 ether, ZERO_BYTES); + //vm.expectRevert(); + } + + function abs(int256 x) internal pure returns (uint256) { + return x >= 0 ? uint256(x) : uint256(-x); + } +} diff --git a/test/MiddlewareRemove.t.sol b/test/MiddlewareRemove.t.sol new file mode 100644 index 00000000..0ecae3f5 --- /dev/null +++ b/test/MiddlewareRemove.t.sol @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {FeeTakingLite} from "./middleware/FeeTakingLite.sol"; +import {MiddlewareRemove} from "../contracts/middleware/MiddlewareRemove.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {HooksRevert} from "./middleware/HooksRevert.sol"; +import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; + +contract MiddlewareRemoveTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + PoolId id; + + uint160 nonce = 0; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + } + + function testVarious() public { + FeeTakingLite feeTakingLite = new FeeTakingLite(manager); + testOn( + address(feeTakingLite), + uint160(Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG) + ); + HooksRevert hooksRevert = new HooksRevert(manager); + testOn(address(hooksRevert), uint160(Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG)); + testOn(address(hooksRevert), uint160(Hooks.AFTER_REMOVE_LIQUIDITY_FLAG)); + testOn(address(hooksRevert), uint160(Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG)); + HooksOutOfGas hooksOutOfGas = new HooksOutOfGas(manager); + testOn(address(hooksOutOfGas), uint160(Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG)); + testOn(address(hooksOutOfGas), uint160(Hooks.AFTER_REMOVE_LIQUIDITY_FLAG)); + testOn(address(hooksOutOfGas), uint160(Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG)); + } + + // creates a middleware on an implementation + function testOn(address implementation, uint160 flags) internal { + MiddlewareRemove middlewareRemove = MiddlewareRemove(payable(address(nonce << 20 | flags))); + nonce++; + vm.record(); + MiddlewareRemove impl = new MiddlewareRemove(manager, implementation); + (, bytes32[] memory writes) = vm.accesses(address(impl)); + vm.etch(address(middlewareRemove), address(impl).code); + unchecked { + for (uint256 i = 0; i < writes.length; i++) { + bytes32 slot = writes[i]; + vm.store(address(middlewareRemove), slot, vm.load(address(impl), slot)); + } + } + (key, id) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + ); + + removeLiquidity(currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + } +} diff --git a/test/MiddlewareRemoveFactory.t.sol b/test/MiddlewareRemoveFactory.t.sol new file mode 100644 index 00000000..dea6b7ae --- /dev/null +++ b/test/MiddlewareRemoveFactory.t.sol @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {FeeTakingLite} from "./middleware/FeeTakingLite.sol"; +import {MiddlewareRemove} from "../contracts/middleware/MiddlewareRemove.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {HooksRevert} from "./middleware/HooksRevert.sol"; +import {HooksOutOfGas} from "./middleware/HooksOutOfGas.sol"; +import {MiddlewareRemoveFactory} from "./../contracts/middleware/MiddlewareRemoveFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; + +contract MiddlewareRemoveFactoryTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + PoolId id; + + MiddlewareRemoveFactory factory; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + + factory = new MiddlewareRemoveFactory(manager); + } + + function testVariousFactory() public { + FeeTakingLite feeTakingLite = new FeeTakingLite(manager); + uint160 flags = + uint160(Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(feeTakingLite)) + ); + testOn(address(feeTakingLite), salt); + + HooksRevert hooksRevert = new HooksRevert(manager); + flags = uint160(Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG); + (hookAddress, salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(hooksRevert)) + ); + testOn(address(hooksRevert), salt); + + HooksOutOfGas hooksOutOfGas = new HooksOutOfGas(manager); + flags = uint160(Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG); + (hookAddress, salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareRemove).creationCode, + abi.encode(address(manager), address(hooksOutOfGas)) + ); + testOn(address(hooksOutOfGas), salt); + } + + // creates a middleware on an implementation + function testOn(address implementation, bytes32 salt) internal { + address hookAddress = factory.createMiddleware(implementation, salt); + MiddlewareRemove middlewareRemove = MiddlewareRemove(payable(hookAddress)); + + (key, id) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + ); + + removeLiquidity(currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + assertEq(factory.getImplementation(hookAddress), implementation); + } +} diff --git a/test/RemoveMiddleware.txt b/test/RemoveMiddleware.txt new file mode 100644 index 00000000..eb5db684 --- /dev/null +++ b/test/RemoveMiddleware.txt @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {FeeTakingLite} from "../contracts/middleware/test/FeeTakingLite.sol"; +import {HooksRevert} from "../contracts/middleware/test/HooksRevert.sol"; +import {MiddlewareRemove} from "../contracts/middleware/MiddlewareRemove.sol"; +import {MiddlewareRemoveImplementation} from "./shared/implementation/MiddlewareRemoveImplementation.sol"; +import {PoolManager} from "@uniswap/v4-core/src/PoolManager.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract MiddlewareRemoveTest is Test, Deployers { + using PoolIdLibrary for PoolKey; + using StateLibrary for IPoolManager; + + uint160 constant SQRT_RATIO_10_1 = 250541448375047931186413801569; + + address constant TREASURY = address(0x1234567890123456789012345678901234567890); + uint128 private constant TOTAL_BIPS = 10000; + + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + MiddlewareRemove middlewareRemove = + MiddlewareRemove(payable(address(uint160(Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_SWAP_RETURNS_DELTA_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG)))); + PoolId id; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + FeeTakingLite feeTakingLite = new FeeTakingLite(manager); + HooksRevert hooksRevert = new HooksRevert(manager); + + vm.record(); + MiddlewareRemoveImplementation impl = + new MiddlewareRemoveImplementation(manager, address(feeTakingLite), middlewareRemove); + (, bytes32[] memory writes) = vm.accesses(address(impl)); + vm.etch(address(middlewareRemove), address(impl).code); + // for each storage key that was written during the hook implementation, copy the value over + unchecked { + for (uint256 i = 0; i < writes.length; i++) { + bytes32 slot = writes[i]; + vm.store(address(middlewareRemove), slot, vm.load(address(impl), slot)); + } + } + + // key = PoolKey(currency0, currency1, 3000, 60, middlewareRemove); + (key, id) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + ); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + } + + function testMiddlewareRemove() public { + console.log("testMiddlewareRemove"); + console.log("testMiddlewareRemove"); + console.log("testMiddlewareRemove"); + removeLiquidity(currency0, currency1, IHooks(address(middlewareRemove)), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + } + + // test normal behavior, copied from FeeTaking.t.sol + function testNormal() public { + // Swap exact token0 for token1 // + bool zeroForOne = true; + int256 amountSpecified = -1e12; + BalanceDelta swapDelta = swap(key, zeroForOne, amountSpecified, ZERO_BYTES); + // ---------------------------- // + + uint128 output = uint128(swapDelta.amount1()); + assertTrue(output > 0); + + uint256 expectedFee = calculateFeeForExactInput(output, 25); + + assertEq(manager.balanceOf(address(middlewareRemove), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(middlewareRemove), CurrencyLibrary.toId(key.currency1)), expectedFee); + + // Swap token0 for exact token1 // + bool zeroForOne2 = true; + int256 amountSpecified2 = 1e12; // positive number indicates exact output swap + BalanceDelta swapDelta2 = swap(key, zeroForOne2, amountSpecified2, ZERO_BYTES); + // ---------------------------- // + + uint128 input = uint128(-swapDelta2.amount0()); + assertTrue(output > 0); + + uint256 expectedFee2 = calculateFeeForExactOutput(input, 25); + + assertEq(manager.balanceOf(address(middlewareRemove), CurrencyLibrary.toId(key.currency0)), expectedFee2); + assertEq(manager.balanceOf(address(middlewareRemove), CurrencyLibrary.toId(key.currency1)), expectedFee); + + // test withdrawing tokens // + Currency[] memory currencies = new Currency[](2); + currencies[0] = key.currency0; + currencies[1] = key.currency1; + FeeTakingLite impl = FeeTakingLite(address(middlewareRemove)); + impl.withdraw(TREASURY, currencies); + assertEq(manager.balanceOf(address(middlewareRemove), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(middlewareRemove), CurrencyLibrary.toId(key.currency1)), 0); + assertEq(currency0.balanceOf(TREASURY), expectedFee2); + assertEq(currency1.balanceOf(TREASURY), expectedFee); + } + + function calculateFeeForExactInput(uint256 outputAmount, uint128 feeBips) internal pure returns (uint256) { + return outputAmount * TOTAL_BIPS / (TOTAL_BIPS - feeBips) - outputAmount; + } + + function calculateFeeForExactOutput(uint256 inputAmount, uint128 feeBips) internal pure returns (uint256) { + return (inputAmount * feeBips) / (TOTAL_BIPS + feeBips); + } + + function testVarious() public { + + } + + function testOn(address impl) public { + + } +} diff --git a/test/middleware/CouterPayable.sol b/test/middleware/CouterPayable.sol new file mode 100644 index 00000000..13bcc521 --- /dev/null +++ b/test/middleware/CouterPayable.sol @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; + +contract Counter is BaseHook { + using PoolIdLibrary for PoolKey; + + // NOTE: --------------------------------------------------------- + // state variables should typically be unique to a pool + // a single hook contract should be able to service multiple pools + // --------------------------------------------------------------- + + mapping(PoolId => uint256 count) public beforeSwapCount; + mapping(PoolId => uint256 count) public afterSwapCount; + + mapping(PoolId => uint256 count) public beforeAddLiquidityCount; + mapping(PoolId => uint256 count) public beforeRemoveLiquidityCount; + + constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: true, + afterAddLiquidity: false, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: false, + beforeSwap: true, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + // ----------------------------------------------- + // NOTE: see IHooks.sol for function documentation + // ----------------------------------------------- + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata) + external + override + returns (bytes4, BeforeSwapDelta, uint24) + { + beforeSwapCount[key.toId()]++; + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + override + returns (bytes4, int128) + { + afterSwapCount[key.toId()]++; + return (BaseHook.afterSwap.selector, 0); + } + + function beforeAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external override returns (bytes4) { + beforeAddLiquidityCount[key.toId()]++; + return BaseHook.beforeAddLiquidity.selector; + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external override returns (bytes4) { + beforeRemoveLiquidityCount[key.toId()]++; + return BaseHook.beforeRemoveLiquidity.selector; + } +} diff --git a/test/middleware/FeeTakingLite.sol b/test/middleware/FeeTakingLite.sol new file mode 100644 index 00000000..3781ece6 --- /dev/null +++ b/test/middleware/FeeTakingLite.sol @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract FeeTakingLite is IUnlockCallback { + using SafeCast for uint256; + + bytes internal constant ZERO_BYTES = bytes(""); + uint128 private constant TOTAL_BIPS = 10000; + uint128 private constant MAX_BIPS = 100; + uint128 public constant swapFeeBips = 25; + IPoolManager public immutable poolManager; + + struct CallbackData { + address to; + Currency[] currencies; + } + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getHookPermissions() public pure returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: false, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: true, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function afterSwap( + address, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata + ) external returns (bytes4, int128) { + // fee will be in the unspecified token of the swap + bool currency0Specified = (params.amountSpecified < 0 == params.zeroForOne); + (Currency feeCurrency, int128 swapAmount) = + (currency0Specified) ? (key.currency1, delta.amount1()) : (key.currency0, delta.amount0()); + // if fee is on output, get the absolute output amount + if (swapAmount < 0) swapAmount = -swapAmount; + + uint256 feeAmount = (uint128(swapAmount) * swapFeeBips) / TOTAL_BIPS; + // mint ERC6909 instead of take to avoid edge case where PM doesn't have enough balance + poolManager.mint(address(this), CurrencyLibrary.toId(feeCurrency), feeAmount); + + return (BaseHook.afterSwap.selector, feeAmount.toInt128()); + } + + function setSwapFeeBips(uint128 _swapFeeBips) external pure { + require(_swapFeeBips <= MAX_BIPS); + //swapFeeBips = _swapFeeBips; + } + + function withdraw(address to, Currency[] calldata currencies) external { + poolManager.unlock(abi.encode(CallbackData(to, currencies))); + } + + function unlockCallback(bytes calldata rawData) external override returns (bytes memory) { + CallbackData memory data = abi.decode(rawData, (CallbackData)); + uint256 length = data.currencies.length; + for (uint256 i = 0; i < length;) { + uint256 amount = poolManager.balanceOf(address(this), CurrencyLibrary.toId(data.currencies[i])); + poolManager.burn(address(this), CurrencyLibrary.toId(data.currencies[i]), amount); + poolManager.take(data.currencies[i], data.to, amount); + unchecked { + i++; + } + } + return ZERO_BYTES; + } +} diff --git a/test/middleware/HooksFrontrun.sol b/test/middleware/HooksFrontrun.sol new file mode 100644 index 00000000..c0d23a8f --- /dev/null +++ b/test/middleware/HooksFrontrun.sol @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {console} from "forge-std/console.sol"; +import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; + +contract HooksFrontrun is BaseHook { + using SafeCast for uint256; + + bytes internal constant ZERO_BYTES = bytes(""); + uint160 public constant MIN_PRICE_LIMIT = TickMath.MIN_SQRT_PRICE + 1; + uint160 public constant MAX_PRICE_LIMIT = TickMath.MAX_SQRT_PRICE - 1; + + BalanceDelta swapDelta; + IPoolManager.SwapParams swapParams; + + constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + + function getHookPermissions() public pure virtual override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: true, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) + external + override + onlyByManager + returns (bytes4, BeforeSwapDelta, uint24) + { + swapParams = params; + console.log(params.zeroForOne); + console.logInt(params.amountSpecified); + swapDelta = manager.swap(key, params, ZERO_BYTES); + console.log("beforeDelta"); + console.logInt(swapDelta.amount0()); + console.logInt(swapDelta.amount1()); + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap( + address sender, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, int128) { + BalanceDelta afterDelta = manager.swap( + key, + IPoolManager.SwapParams( + !swapParams.zeroForOne, + -swapParams.amountSpecified, + swapParams.zeroForOne ? MAX_PRICE_LIMIT : MIN_PRICE_LIMIT + ), + ZERO_BYTES + ); + if (swapParams.zeroForOne) { + int256 profit = afterDelta.amount0() + swapDelta.amount0(); + if (profit > 0) { + // else hook reverts + manager.mint(address(this), key.currency0.toId(), uint256(profit)); + } + } else { + int256 profit = afterDelta.amount1() + swapDelta.amount1(); + if (profit > 0) { + // else hook reverts + manager.mint(address(this), key.currency1.toId(), uint256(profit)); + } + } + console.log("afterDelta"); + console.logInt(afterDelta.amount0()); + console.logInt(afterDelta.amount1()); + return (BaseHook.afterSwap.selector, 0); + } +} diff --git a/test/middleware/HooksOutOfGas.sol b/test/middleware/HooksOutOfGas.sol new file mode 100644 index 00000000..d7b35a37 --- /dev/null +++ b/test/middleware/HooksOutOfGas.sol @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract HooksOutOfGas { + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getHookPermissions() public pure returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { + consumeAllGas(); + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external returns (bytes4) { + consumeAllGas(); + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + consumeAllGas(); + } + + function afterRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + consumeAllGas(); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + virtual + returns (bytes4, BeforeSwapDelta, uint24) + { + consumeAllGas(); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + virtual + returns (bytes4, int128) + { + consumeAllGas(); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + consumeAllGas(); + } + + function consumeAllGas() internal view { + while (true) { + //console.log(gasleft()); + // This loop will run indefinitely and consume all available gas. + } + } +} diff --git a/test/middleware/HooksRevert.sol b/test/middleware/HooksRevert.sol new file mode 100644 index 00000000..f3f31401 --- /dev/null +++ b/test/middleware/HooksRevert.sol @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta, toBalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; +import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; + +contract HooksRevert { + error HookNotImplemented(); + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getHookPermissions() public pure returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) external virtual returns (bytes4) { + revert HookNotImplemented(); + } + + function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata + ) external returns (bytes4) { + revert HookNotImplemented(); + } + + function afterAddLiquidity( + address, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external virtual returns (bytes4, BalanceDelta) { + revert HookNotImplemented(); + } + + function afterRemoveLiquidity( + address sender, + PoolKey calldata, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata + ) external returns (bytes4, BalanceDelta) { + require(sender == address(0), "nobody can remove"); + return (BaseHook.beforeRemoveLiquidity.selector, toBalanceDelta(0, 0)); + } + + function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + external + virtual + returns (bytes4, BeforeSwapDelta, uint24) + { + revert HookNotImplemented(); + } + + function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + virtual + returns (bytes4, int128) + { + revert HookNotImplemented(); + } + + function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } + + function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + external + virtual + returns (bytes4) + { + revert HookNotImplemented(); + } +} diff --git a/test/shared/implementation/BaseMiddlewareImplementation.sol b/test/shared/implementation/BaseMiddlewareImplementation.sol new file mode 100644 index 00000000..fcabbd48 --- /dev/null +++ b/test/shared/implementation/BaseMiddlewareImplementation.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {BaseHook} from "../../../contracts/BaseHook.sol"; +import {BaseMiddleware} from "../../../contracts/middleware/BaseMiddleware.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; + +contract BaseMiddlewareImplementation is BaseMiddleware { + constructor(IPoolManager _poolManager, address _implementation, BaseMiddleware addressToEtch) + BaseMiddleware(_poolManager, _implementation) + { + Hooks.validateHookPermissions(IHooks(address(addressToEtch)), BaseHook(_implementation).getHookPermissions()); + } +} diff --git a/test/shared/implementation/FeeTakingLiteImplementation.sol b/test/shared/implementation/FeeTakingLiteImplementation.sol new file mode 100644 index 00000000..a9145bc7 --- /dev/null +++ b/test/shared/implementation/FeeTakingLiteImplementation.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {BaseHook} from "../../../contracts/BaseHook.sol"; +import {FeeTakingLite} from "../../middleware/FeeTakingLite.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; + +contract FeeTakingLiteImplementation is FeeTakingLite { + constructor(IPoolManager _poolManager, FeeTakingLite addressToEtch) FeeTakingLite(_poolManager) {} +} diff --git a/test/shared/implementation/HooksFrontrunImplementation.sol b/test/shared/implementation/HooksFrontrunImplementation.sol new file mode 100644 index 00000000..104a07b7 --- /dev/null +++ b/test/shared/implementation/HooksFrontrunImplementation.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {BaseHook} from "../../../contracts/BaseHook.sol"; +import {HooksFrontrun} from "../../middleware/HooksFrontrun.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract HooksFrontrunImplementation is HooksFrontrun { + constructor(IPoolManager _poolManager, HooksFrontrun addressToEtch) HooksFrontrun(_poolManager) { + Hooks.validateHookPermissions(addressToEtch, getHookPermissions()); + } + + // make this a no-op in testing + function validateHookAddress(BaseHook _this) internal pure override {} +} diff --git a/test/utils/HookMiner.sol b/test/utils/HookMiner.sol new file mode 100644 index 00000000..d6b30c40 --- /dev/null +++ b/test/utils/HookMiner.sol @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.21; + +/// @title HookMiner - a library for mining hook addresses +/// @dev This library is intended for `forge test` environments. There may be gotchas when using salts in `forge script` or `forge create` +library HookMiner { + // mask to slice out the bottom 14 bit of the address + uint160 constant FLAG_MASK = 0x3FFF; + + // Maximum number of iterations to find a salt, avoid infinite loops + uint256 constant MAX_LOOP = 100_000; + + /// @notice Find a salt that produces a hook address with the desired `flags` + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param flags The desired flags for the hook address + /// @param creationCode The creation code of a hook contract. Example: `type(Counter).creationCode` + /// @param constructorArgs The encoded constructor arguments of a hook contract. Example: `abi.encode(address(manager))` + /// @return hookAddress salt and corresponding address that was found. The salt can be used in `new Hook{salt: salt}()` + function find(address deployer, uint160 flags, bytes memory creationCode, bytes memory constructorArgs) + internal + view + returns (address, bytes32) + { + address hookAddress; + bytes memory creationCodeWithArgs = abi.encodePacked(creationCode, constructorArgs); + + uint256 salt; + for (salt; salt < MAX_LOOP; salt++) { + hookAddress = computeAddress(deployer, salt, creationCodeWithArgs); + if (uint160(hookAddress) & FLAG_MASK == flags && hookAddress.code.length == 0) { + return (hookAddress, bytes32(salt)); + } + } + revert("HookMiner: could not find salt"); + } + + /// @notice Precompute a contract address deployed via CREATE2 + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param salt The salt used to deploy the hook + /// @param creationCode The creation code of a hook contract + function computeAddress(address deployer, uint256 salt, bytes memory creationCode) + internal + pure + returns (address hookAddress) + { + return address( + uint160(uint256(keccak256(abi.encodePacked(bytes1(0xFF), deployer, salt, keccak256(creationCode))))) + ); + } +}