From 3f69c05767ebe34964f4579814d3392a13fd02a7 Mon Sep 17 00:00:00 2001 From: Junion <69495294+Jun1on@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:37:09 -0400 Subject: [PATCH] optimize quoter --- .../MIDDLEWARE_PROTECT-multi-protected.snap | 2 +- .../MIDDLEWARE_PROTECT-multi-vanilla.snap | 2 +- .../MIDDLEWARE_PROTECT-protected.snap | 2 +- .../MIDDLEWARE_PROTECT-vanilla.snap | 2 +- src/middleware/CheapQuoter.sol | 190 ++++++++++++++++++ src/middleware/MiddlewareProtect.sol | 18 +- src/middleware/MiddlewareProtectFactory.sol | 10 +- test/MiddlewareProtectFactory.t.sol | 55 +++-- 8 files changed, 242 insertions(+), 39 deletions(-) create mode 100644 src/middleware/CheapQuoter.sol diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap index 4bf2964e..8be44c61 100644 --- a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap @@ -1 +1 @@ -914551 \ No newline at end of file +184020 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap index 98ddcf3c..9d7324da 100644 --- a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap @@ -1 +1 @@ -475485 \ No newline at end of file +143854 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap index 3af65269..afb522c2 100644 --- a/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap @@ -1 +1 @@ -201607 \ No newline at end of file +151513 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap index e1b225a9..bec120fa 100644 --- a/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap @@ -1 +1 @@ -147725 \ No newline at end of file +124869 \ No newline at end of file diff --git a/src/middleware/CheapQuoter.sol b/src/middleware/CheapQuoter.sol new file mode 100644 index 00000000..caf6bd68 --- /dev/null +++ b/src/middleware/CheapQuoter.sol @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.26; + +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {SwapMath} from "@uniswap/v4-core/src/libraries/SwapMath.sol"; +import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; +import "@uniswap/v4-core/src/libraries/SafeCast.sol"; +import {LiquidityMath} from "@uniswap/v4-core/src/libraries/LiquidityMath.sol"; +import {PoolTickBitmap} from "../libraries/PoolTickBitmap.sol"; +import {Slot0, Slot0Library} from "@uniswap/v4-core/src/types/Slot0.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; + +contract CheapQuoter { + IPoolManager public immutable poolManager; + + using SafeCast for uint256; + using SafeCast for int256; + + using Slot0Library for Slot0; + using StateLibrary for IPoolManager; + using PoolIdLibrary for PoolKey; + + struct Slot0Struct { + // the current price + uint160 sqrtPriceX96; + // the current tick + int24 tick; + // tick spacing + int24 tickSpacing; + } + + // used for packing under the stack limit + struct QuoteParams { + bool zeroForOne; + bool exactInput; + uint24 fee; + uint160 sqrtPriceLimitX96; + } + + struct SwapCache { + // the protocol fee for the input token + uint8 feeProtocol; + // liquidity at the beginning of the swap + uint128 liquidityStart; + // the timestamp of the current block + uint32 blockTimestamp; + // the current value of the tick accumulator, computed only if we cross an initialized tick + int56 tickCumulative; + // the current value of seconds per liquidity accumulator, computed only if we cross an initialized tick + uint160 secondsPerLiquidityCumulativeX128; + // whether we've computed and cached the above two accumulators + bool computedLatestObservation; + } + + // the top level state of the swap, the results of which are recorded in storage at the end + struct SwapState { + // the amount remaining to be swapped in/out of the input/output asset + int256 amountSpecifiedRemaining; + // the amount already swapped out/in of the output/input asset + int256 amountCalculated; + // current sqrt(price) + uint160 sqrtPriceX96; + // the tick associated with the current price + int24 tick; + // the global fee growth of the input token + uint256 feeGrowthGlobalX128; + // amount of input token paid as protocol fee + uint128 protocolFee; + // the current liquidity in range + uint128 liquidity; + } + + struct StepComputations { + // the price at the beginning of the step + uint160 sqrtPriceStartX96; + // the next tick to swap to from the current tick in the swap direction + int24 tickNext; + // whether tickNext is initialized or not + bool initialized; + // sqrt(price) for the next tick (1/0) + uint160 sqrtPriceNextX96; + // how much is being swapped in in this step + uint256 amountIn; + // how much is being swapped out + uint256 amountOut; + // how much fee is being paid in + uint256 feeAmount; + } + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function fillSlot0(PoolKey calldata poolKey) private view returns (Slot0Struct memory slot0) { + (slot0.sqrtPriceX96, slot0.tick,,) = poolManager.getSlot0(poolKey.toId()); + slot0.tickSpacing = poolKey.tickSpacing; + return slot0; + } + + function quote(PoolKey calldata poolKey, IPoolManager.SwapParams calldata swapParams) + external + view + returns (int256 quote) + { + QuoteParams memory quoteParams = QuoteParams( + swapParams.zeroForOne, swapParams.amountSpecified < 0, poolKey.fee, swapParams.sqrtPriceLimitX96 + ); + + Slot0Struct memory slot0 = fillSlot0(poolKey); + + SwapState memory state = SwapState({ + amountSpecifiedRemaining: -swapParams.amountSpecified, + amountCalculated: 0, + sqrtPriceX96: slot0.sqrtPriceX96, + tick: slot0.tick, + feeGrowthGlobalX128: 0, + protocolFee: 0, + liquidity: poolManager.getLiquidity(poolKey.toId()) + }); + + while (state.amountSpecifiedRemaining != 0 && state.sqrtPriceX96 != quoteParams.sqrtPriceLimitX96) { + StepComputations memory step; + + step.sqrtPriceStartX96 = state.sqrtPriceX96; + + (step.tickNext, step.initialized) = PoolTickBitmap.nextInitializedTickWithinOneWord( + poolManager, poolKey.toId(), slot0.tickSpacing, state.tick, quoteParams.zeroForOne + ); + + // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds + if (step.tickNext < TickMath.MIN_TICK) { + step.tickNext = TickMath.MIN_TICK; + } else if (step.tickNext > TickMath.MAX_TICK) { + step.tickNext = TickMath.MAX_TICK; + } + + // get the price for the next tick + step.sqrtPriceNextX96 = TickMath.getSqrtPriceAtTick(step.tickNext); + + // compute values to swap to the target tick, price limit, or point where input/output amount is exhausted + (state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount) = SwapMath.computeSwapStep( + state.sqrtPriceX96, + ( + quoteParams.zeroForOne + ? step.sqrtPriceNextX96 < quoteParams.sqrtPriceLimitX96 + : step.sqrtPriceNextX96 > quoteParams.sqrtPriceLimitX96 + ) ? quoteParams.sqrtPriceLimitX96 : step.sqrtPriceNextX96, + state.liquidity, + -state.amountSpecifiedRemaining, + quoteParams.fee + ); + + if (quoteParams.exactInput) { + state.amountSpecifiedRemaining -= (step.amountIn + step.feeAmount).toInt256(); + state.amountCalculated = state.amountCalculated + step.amountOut.toInt256(); + } else { + state.amountSpecifiedRemaining += step.amountOut.toInt256(); + state.amountCalculated = state.amountCalculated - (step.amountIn + step.feeAmount).toInt256(); + } + + // shift tick if we reached the next price + if (state.sqrtPriceX96 == step.sqrtPriceNextX96) { + // if the tick is initialized, run the tick transition + if (step.initialized) { + (, int128 liquidityNet,,) = poolManager.getTickInfo(poolKey.toId(), step.tickNext); + + // if we're moving leftward, we interpret liquidityNet as the opposite sign + // safe because liquidityNet cannot be type(int128).min + if (quoteParams.zeroForOne) liquidityNet = -liquidityNet; + + state.liquidity = LiquidityMath.addDelta(state.liquidity, liquidityNet); + } + + state.tick = quoteParams.zeroForOne ? step.tickNext - 1 : step.tickNext; + } else if (state.sqrtPriceX96 != step.sqrtPriceStartX96) { + // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved + state.tick = TickMath.getTickAtSqrtPrice(state.sqrtPriceX96); + } + + quote = quoteParams.exactInput + ? state.amountCalculated + : state.amountSpecifiedRemaining + swapParams.amountSpecified; + } + } +} diff --git a/src/middleware/MiddlewareProtect.sol b/src/middleware/MiddlewareProtect.sol index a806e698..2d5f8b53 100644 --- a/src/middleware/MiddlewareProtect.sol +++ b/src/middleware/MiddlewareProtect.sol @@ -20,7 +20,7 @@ import {console} from "forge-std/console.sol"; import {LPFeeLibrary} from "@uniswap/v4-core/src/libraries/LPFeeLibrary.sol"; import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; -import {IViewQuoter} from "../interfaces/IViewQuoter.sol"; +import {CheapQuoter} from "./CheapQuoter.sol"; contract MiddlewareProtect is BaseMiddleware { using CustomRevert for bytes4; @@ -44,25 +44,23 @@ contract MiddlewareProtect is BaseMiddleware { bytes internal constant ZERO_BYTES = bytes(""); - IViewQuoter public immutable viewQuoter; + CheapQuoter public immutable cheapQuoter; // todo: use tstore int256 private quote; - constructor(IPoolManager _manager, IViewQuoter _viewQuoter, address _impl) BaseMiddleware(_manager, _impl) { - viewQuoter = _viewQuoter; + constructor(IPoolManager _poolManager, CheapQuoter _cheapQuoter, address _impl) + BaseMiddleware(_poolManager, _impl) + { + cheapQuoter = _cheapQuoter; _ensureValidFlags(); } - function beforeSwap(address sender, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) external returns (bytes4, BeforeSwapDelta, uint24) { - if (params.zeroForOne) { - (, quote,,) = viewQuoter.quoteSingle(key, params); - } else { - (quote,,,) = viewQuoter.quoteSingle(key, params); - } + quote = cheapQuoter.quote(key, params); (bool success, bytes memory returnData) = address(implementation).delegatecall(msg.data); if (!success) { _handleRevert(returnData); diff --git a/src/middleware/MiddlewareProtectFactory.sol b/src/middleware/MiddlewareProtectFactory.sol index aff46d48..0ae5ef6d 100644 --- a/src/middleware/MiddlewareProtectFactory.sol +++ b/src/middleware/MiddlewareProtectFactory.sol @@ -3,7 +3,7 @@ pragma solidity ^0.8.19; import {MiddlewareProtect} from "./MiddlewareProtect.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; -import {IViewQuoter} from "../interfaces/IViewQuoter.sol"; +import {CheapQuoter} from "./CheapQuoter.sol"; contract MiddlewareProtectFactory { event MiddlewareCreated(address implementation, address middleware); @@ -11,11 +11,11 @@ contract MiddlewareProtectFactory { mapping(address => address) private _implementations; IPoolManager public immutable poolManager; - IViewQuoter public immutable viewQuoter; + CheapQuoter public immutable cheapQuoter; - constructor(IPoolManager _poolManager, IViewQuoter _viewQuoter) { + constructor(IPoolManager _poolManager, CheapQuoter _cheapQuoter) { poolManager = _poolManager; - viewQuoter = _viewQuoter; + cheapQuoter = _cheapQuoter; } /** @@ -34,7 +34,7 @@ contract MiddlewareProtectFactory { * @return middleware The address of the newly created middlewareRemove contract. */ function createMiddleware(address implementation, bytes32 salt) external returns (address middleware) { - middleware = address(new MiddlewareProtect{salt: salt}(poolManager, viewQuoter, implementation)); + middleware = address(new MiddlewareProtect{salt: salt}(poolManager, cheapQuoter, implementation)); _implementations[middleware] = implementation; emit MiddlewareCreated(implementation, middleware); } diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol index 03ea73c3..488f8910 100644 --- a/test/MiddlewareProtectFactory.t.sol +++ b/test/MiddlewareProtectFactory.t.sol @@ -26,8 +26,9 @@ import {LPFeeLibrary} from "@uniswap/v4-core/src/libraries/LPFeeLibrary.sol"; import {GasSnapshot} from "forge-gas-snapshot/GasSnapshot.sol"; import {BaseMiddleware} from "./../src/middleware/BaseMiddleware.sol"; import {BlankSwapHooks} from "./middleware/BlankSwapHooks.sol"; -import {ViewQuoter} from "./../src/lens/ViewQuoter.sol"; -import {IViewQuoter} from "./../src/interfaces/IViewQuoter.sol"; +import {CheapQuoter} from "./../src/middleware/CheapQuoter.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { HookEnabledSwapRouter router; @@ -38,7 +39,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { HooksCounter counter; address middleware; HooksFrontrun hooksFrontrun; - IViewQuoter viewQuoter; + CheapQuoter cheapQuoter; uint160 COUNTER_FLAGS = uint160( Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG @@ -54,8 +55,8 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { token0 = TestERC20(Currency.unwrap(currency0)); token1 = TestERC20(Currency.unwrap(currency1)); - viewQuoter = new ViewQuoter(manager); - factory = new MiddlewareProtectFactory(manager, viewQuoter); + cheapQuoter = new CheapQuoter(manager); + factory = new MiddlewareProtectFactory(manager, cheapQuoter); counter = HooksCounter(address(COUNTER_FLAGS)); vm.etch(address(counter), address(new HooksCounter(manager)).code); @@ -66,7 +67,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), COUNTER_FLAGS, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(counter)) + abi.encode(address(manager), address(cheapQuoter), address(counter)) ); middleware = factory.createMiddleware(address(counter), salt); assertEq(hookAddress, middleware); @@ -85,7 +86,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), flags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(hooksReturnDeltas)) + abi.encode(address(manager), address(cheapQuoter), address(hooksReturnDeltas)) ); address implementation = address(hooksReturnDeltas); vm.expectRevert(abi.encodePacked(bytes16(MiddlewareProtect.HookPermissionForbidden.selector), hookAddress)); @@ -116,7 +117,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), flags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(hooksFrontrun)) + abi.encode(address(manager), address(cheapQuoter), address(hooksFrontrun)) ); address implementation = address(hooksFrontrun); address hookAddressCreated = factory.createMiddleware(implementation, salt); @@ -147,7 +148,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), flags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(hooksRevert)) + abi.encode(address(manager), address(cheapQuoter), address(hooksRevert)) ); middleware = factory.createMiddleware(address(hooksRevert), salt); (key,) = initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); @@ -166,7 +167,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), flags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(hooksOutOfGas)) + abi.encode(address(manager), address(cheapQuoter), address(hooksOutOfGas)) ); middleware = factory.createMiddleware(address(hooksOutOfGas), salt); (key,) = initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); @@ -182,7 +183,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { // address(factory), // flags, // type(MiddlewareProtect).creationCode, - // abi.encode(address(manager), address(viewQuoter), address(frontrunAdd)) + // abi.encode(address(manager), address(cheapQuoter), address(frontrunAdd)) // ); // middleware = factory.createMiddleware(address(frontrunAdd), salt); // currency0.transfer(address(frontrunAdd), 1 ether); @@ -303,7 +304,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), thisFlags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), implFlags) + abi.encode(address(manager), address(cheapQuoter), implFlags) ); factory.createMiddleware(address(implFlags), salt); } @@ -319,7 +320,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), flags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(counter)) + abi.encode(address(manager), address(cheapQuoter), address(counter)) ); factory.createMiddleware(address(counter), salt); // second deployment should revert @@ -336,7 +337,7 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { address(factory), incorrectFlags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(counter2)) + abi.encode(address(manager), address(cheapQuoter), address(counter2)) ); address implementation = address(counter2); vm.expectRevert(BaseMiddleware.FlagsMismatch.selector); @@ -392,27 +393,41 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { uint160 flags = Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG; BlankSwapHooks blankSwapHooks = BlankSwapHooks(address(flags)); vm.etch(address(blankSwapHooks), address(new BlankSwapHooks(manager)).code); - (key,) = initPoolAndAddLiquidity( - currency0, currency1, IHooks(address(blankSwapHooks)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + PoolId id; + (key, id) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(blankSwapHooks)), 500, SQRT_PRICE_1_1, ZERO_BYTES ); swap(key, true, 0.0001 ether, ZERO_BYTES); + swap(key, true, 0.0001 ether, ZERO_BYTES); snapLastCall("MIDDLEWARE_PROTECT-vanilla"); uint160 maxFeeBips = 0; (, bytes32 salt) = HookMiner.find( address(factory), flags, type(MiddlewareProtect).creationCode, - abi.encode(address(manager), address(viewQuoter), address(blankSwapHooks)) + abi.encode(address(manager), address(cheapQuoter), address(blankSwapHooks)) ); address hookAddress = factory.createMiddleware(address(blankSwapHooks), salt); (PoolKey memory protectedKey,) = - initPoolAndAddLiquidity(currency0, currency1, IHooks(hookAddress), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + initPoolAndAddLiquidity(currency0, currency1, IHooks(hookAddress), 500, SQRT_PRICE_1_1, ZERO_BYTES); + swap(protectedKey, true, 0.0001 ether, ZERO_BYTES); swap(protectedKey, true, 0.0001 ether, ZERO_BYTES); snapLastCall("MIDDLEWARE_PROTECT-protected"); - swap(key, true, 0.01 ether, ZERO_BYTES); + (, int24 tick,,) = StateLibrary.getSlot0(manager, id); + + IPoolManager.ModifyLiquidityParams memory params = IPoolManager.ModifyLiquidityParams({ + tickLower: tick * 100, + tickUpper: tick * 2, + liquidityDelta: 100e18, + salt: 0 + }); + modifyLiquidityRouter.modifyLiquidity(key, params, ZERO_BYTES); + modifyLiquidityRouter.modifyLiquidity(protectedKey, params, ZERO_BYTES); + + swap(key, true, 0.1 ether, ZERO_BYTES); snapLastCall("MIDDLEWARE_PROTECT-multi-vanilla"); - swap(protectedKey, true, 0.01 ether, ZERO_BYTES); + swap(protectedKey, true, 0.1 ether, ZERO_BYTES); snapLastCall("MIDDLEWARE_PROTECT-multi-protected"); } }