diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap new file mode 100644 index 00000000..79ff6983 --- /dev/null +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-protected.snap @@ -0,0 +1 @@ +872991 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap new file mode 100644 index 00000000..98ddcf3c --- /dev/null +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-multi-vanilla.snap @@ -0,0 +1 @@ +475485 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap new file mode 100644 index 00000000..e70c7a11 --- /dev/null +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-protected.snap @@ -0,0 +1 @@ +229771 \ No newline at end of file diff --git a/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap b/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap new file mode 100644 index 00000000..e1b225a9 --- /dev/null +++ b/.forge-snapshots/MIDDLEWARE_PROTECT-vanilla.snap @@ -0,0 +1 @@ +147725 \ No newline at end of file diff --git a/test/MiddlewareProtectFactory.t.sol b/test/MiddlewareProtectFactory.t.sol index 6ae26484..cbe6a675 100644 --- a/test/MiddlewareProtectFactory.t.sol +++ b/test/MiddlewareProtectFactory.t.sol @@ -25,6 +25,7 @@ import {FrontrunAdd} from "./middleware/FrontrunAdd.sol"; import {LPFeeLibrary} from "@uniswap/v4-core/src/libraries/LPFeeLibrary.sol"; import {GasSnapshot} from "forge-gas-snapshot/GasSnapshot.sol"; import {BaseMiddleware} from "./../src/middleware/BaseMiddleware.sol"; +import {BlankSwapHooks} from "./middleware/BlankSwapHooks.sol"; contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { HookEnabledSwapRouter router; @@ -379,4 +380,32 @@ contract MiddlewareProtectFactoryTest is Test, Deployers, GasSnapshot { assertEq(counterProxy.beforeRemoveLiquidityCount(id), 1); assertEq(counterProxy.afterRemoveLiquidityCount(id), 1); } + + function testMiddlewareRemoveGas() public { + uint160 flags = Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG; + BlankSwapHooks blankSwapHooks = BlankSwapHooks(address(flags)); + vm.etch(address(blankSwapHooks), address(new BlankSwapHooks(manager)).code); + (key,) = initPoolAndAddLiquidity( + currency0, currency1, IHooks(address(blankSwapHooks)), 3000, SQRT_PRICE_1_1, ZERO_BYTES + ); + swap(key, true, 0.0001 ether, ZERO_BYTES); + snapLastCall("MIDDLEWARE_PROTECT-vanilla"); + uint160 maxFeeBips = 0; + (, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(MiddlewareProtect).creationCode, + abi.encode(address(manager), address(blankSwapHooks)) + ); + address hookAddress = factory.createMiddleware(address(blankSwapHooks), salt); + (PoolKey memory protectedKey,) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(hookAddress), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + swap(protectedKey, true, 0.0001 ether, ZERO_BYTES); + snapLastCall("MIDDLEWARE_PROTECT-protected"); + + swap(key, true, 0.01 ether, ZERO_BYTES); + snapLastCall("MIDDLEWARE_PROTECT-multi-vanilla"); + swap(protectedKey, true, 0.01 ether, ZERO_BYTES); + snapLastCall("MIDDLEWARE_PROTECT-multi-protected"); + } } diff --git a/test/middleware/BlankSwapHooks.sol b/test/middleware/BlankSwapHooks.sol new file mode 100644 index 00000000..6fcd8fdb --- /dev/null +++ b/test/middleware/BlankSwapHooks.sol @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../src/base/hooks/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; + +contract BlankSwapHooks is BaseHook { + constructor(IPoolManager _poolManager) BaseHook(_poolManager) {} + + // for testing + function validateHookAddress(BaseHook _this) internal pure override {} + + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: true, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) + external + override + onlyByPoolManager + returns (bytes4, BeforeSwapDelta, uint24) + { + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) + external + override + onlyByPoolManager + returns (bytes4, int128) + { + return (BaseHook.afterSwap.selector, 0); + } +}