diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap index 8be44c61..08051b19 100644 --- a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap @@ -1 +1 @@ -184020 \ No newline at end of file +178969 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap index afb522c2..d5707a7d 100644 --- a/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap @@ -1 +1 @@ -151513 \ No newline at end of file +149436 \ No newline at end of file diff --git a/src/libraries/Quote.sol b/src/libraries/Quote.sol new file mode 100644 index 00000000..2d35352b --- /dev/null +++ b/src/libraries/Quote.sol @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.24; + +/// @notice This is a temporary library that allows us to use transient storage (tstore/tload) +/// for the quote. +/// TODO: This library can be deleted when we have the transient keyword support in solidity. +library Quote { + // The slot holding the quote. bytes32(uint256(keccak256("Quote")) - 1) + bytes32 internal constant QUOTE_SLOT = 0xbbd426867243227198e50d68cdb6f9a2a3a1c5ef433a2b6e7fcf3f462364310a; + + function read() internal view returns (int256 quote) { + assembly ("memory-safe") { + quote := tload(QUOTE_SLOT) + } + } + + function set(int256 quote) internal { + assembly ("memory-safe") { + tstore(QUOTE_SLOT, quote) + } + } + + function reset() internal { + assembly ("memory-safe") { + tstore(QUOTE_SLOT, 0) + } + } +} diff --git a/src/middleware/CheapQuoter.sol b/src/middleware/CheapQuoter.sol index caf6bd68..8a9d016c 100644 --- a/src/middleware/CheapQuoter.sol +++ b/src/middleware/CheapQuoter.sol @@ -181,10 +181,7 @@ contract CheapQuoter { // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved state.tick = TickMath.getTickAtSqrtPrice(state.sqrtPriceX96); } - - quote = quoteParams.exactInput - ? state.amountCalculated - : state.amountSpecifiedRemaining + swapParams.amountSpecified; } + quote = state.amountCalculated; } } diff --git a/src/middleware/MiddlewareProtect.sol b/src/middleware/MiddlewareProtect.sol index 2d5f8b53..1e16f60f 100644 --- a/src/middleware/MiddlewareProtect.sol +++ b/src/middleware/MiddlewareProtect.sol @@ -10,17 +10,14 @@ import {BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; import {CustomRevert} from "@uniswap/v4-core/src/libraries/CustomRevert.sol"; -import {NonZeroDeltaCount} from "@uniswap/v4-core/src/libraries/NonZeroDeltaCount.sol"; -import {IExttload} from "@uniswap/v4-core/src/interfaces/IExttload.sol"; import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; import {PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; import {BaseMiddleware} from "./BaseMiddleware.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {console} from "forge-std/console.sol"; import {LPFeeLibrary} from "@uniswap/v4-core/src/libraries/LPFeeLibrary.sol"; -import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; import {CheapQuoter} from "./CheapQuoter.sol"; +import {Quote} from "../libraries/Quote.sol"; contract MiddlewareProtect is BaseMiddleware { using CustomRevert for bytes4; @@ -46,9 +43,6 @@ contract MiddlewareProtect is BaseMiddleware { CheapQuoter public immutable cheapQuoter; - // todo: use tstore - int256 private quote; - constructor(IPoolManager _poolManager, CheapQuoter _cheapQuoter, address _impl) BaseMiddleware(_poolManager, _impl) { @@ -60,7 +54,7 @@ contract MiddlewareProtect is BaseMiddleware { external returns (bytes4, BeforeSwapDelta, uint24) { - quote = cheapQuoter.quote(key, params); + Quote.set(cheapQuoter.quote(key, params)); (bool success, bytes memory returnData) = address(implementation).delegatecall(msg.data); if (!success) { _handleRevert(returnData); @@ -77,8 +71,9 @@ contract MiddlewareProtect is BaseMiddleware { ) external returns (bytes4, int128) { IHooks implementation = IHooks(address(implementation)); if (implementation.hasPermission(Hooks.BEFORE_SWAP_FLAG)) { - int256 amountOut = params.zeroForOne ? delta.amount1() : delta.amount0(); - if (amountOut != quote) revert HookModifiedOutput(); + int256 amountActual = params.zeroForOne == params.amountSpecified < 0 ? delta.amount1() : delta.amount0(); + if (amountActual != Quote.read()) revert HookModifiedOutput(); + Quote.reset(); if (!implementation.hasPermission(Hooks.AFTER_SWAP_FLAG)) { return (BaseHook.afterSwap.selector, 0); }