Skip to content

Commit

Permalink
override deltas edge case with internal delegatecall
Browse files Browse the repository at this point in the history
  • Loading branch information
Jun1on committed Jul 11, 2024
1 parent fcd3b34 commit bb2641c
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 31 deletions.
6 changes: 3 additions & 3 deletions contracts/middleware/BaseMiddleware.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol";

contract BaseMiddleware is Proxy {
/// @notice The address of the pool manager
IPoolManager public immutable poolManager;
IPoolManager public immutable manager;
address public immutable implementation;

constructor(IPoolManager _poolManager, address _impl) {
poolManager = _poolManager;
constructor(IPoolManager _manager, address _impl) {
manager = _manager;
implementation = _impl;
}

Expand Down
8 changes: 4 additions & 4 deletions contracts/middleware/BaseMiddlewareFactory.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import {IBaseHook} from "../interfaces/IBaseHook.sol";
contract BaseMiddlewareFactory is IMiddlewareFactory {
mapping(address => address) private _implementations;

IPoolManager public immutable poolManager;
IPoolManager public immutable manager;

constructor(IPoolManager _poolManager) {
poolManager = _poolManager;
constructor(IPoolManager _manager) {
manager = _manager;
}

function getImplementation(address middleware) external view override returns (address implementation) {
Expand All @@ -29,6 +29,6 @@ contract BaseMiddlewareFactory is IMiddlewareFactory {
}

function _deployMiddleware(address implementation, bytes32 salt) internal virtual returns (address middleware) {
return address(new BaseMiddleware{salt: salt}(poolManager, implementation));
return address(new BaseMiddleware{salt: salt}(manager, implementation));
}
}
35 changes: 27 additions & 8 deletions contracts/middleware/MiddlewareRemove.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@ import {BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol";
import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol";
import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol";
import {CustomRevert} from "@uniswap/v4-core/src/libraries/CustomRevert.sol";
import {NonZeroDeltaCount} from "@uniswap/v4-core/src/libraries/NonZeroDeltaCount.sol";
import {IExttload} from "@uniswap/v4-core/src/interfaces/IExttload.sol";

contract MiddlewareRemove is BaseMiddleware {
using CustomRevert for bytes4;
using Hooks for IHooks;

error HookPermissionForbidden(address hooks);
error HookModifiedDeltas();

struct afterRemoveLiquidityParams {
address sender;
PoolKey key;
IPoolManager.ModifyLiquidityParams params;
}

bytes internal constant ZERO_BYTES = bytes("");
uint256 public constant gasLimit = 10000000;

constructor(IPoolManager _poolManager, address _impl) BaseMiddleware(_poolManager, _impl) {
constructor(IPoolManager _manager, address _impl) BaseMiddleware(_manager, _impl) {
if (IHooks(address(this)).hasPermission(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG)) {
HookPermissionForbidden.selector.revertWith(address(this));
}
Expand All @@ -38,14 +47,24 @@ contract MiddlewareRemove is BaseMiddleware {
}

function afterRemoveLiquidity(
address,
PoolKey calldata,
IPoolManager.ModifyLiquidityParams calldata,
BalanceDelta,
bytes calldata
address sender,
PoolKey calldata key,
IPoolManager.ModifyLiquidityParams calldata params,
BalanceDelta delta,
bytes calldata hookData
) external returns (bytes4, BalanceDelta) {
implementation.delegatecall{gas: gasLimit}(msg.data);
// hook cannot return delta
address(this).delegatecall(abi.encodeWithSelector(this._callAndEnsureZeroDeltas.selector, msg.data));
return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA);
}

function _callAndEnsureZeroDeltas(bytes calldata data) external {
bytes32 slot = bytes32(NonZeroDeltaCount.NONZERO_DELTA_COUNT_SLOT);
uint256 countBefore = uint256(IExttload(address(manager)).exttload(slot));
address(implementation).delegatecall(msg.data);
uint256 countAfter = uint256(IExttload(address(manager)).exttload(slot));
if (countAfter > countBefore) {
// purpousely revert to cause the whole hook to reset
revert HookModifiedDeltas();
}
}
}
4 changes: 2 additions & 2 deletions contracts/middleware/MiddlewareRemoveFactory.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import {MiddlewareRemove} from "./MiddlewareRemove.sol";
import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol";

contract MiddlewareRemoveFactory is BaseMiddlewareFactory {
constructor(IPoolManager _poolManager) BaseMiddlewareFactory(_poolManager) {}
constructor(IPoolManager _manager) BaseMiddlewareFactory(_manager) {}

function _deployMiddleware(address implementation, bytes32 salt) internal override returns (address middleware) {
return address(new MiddlewareRemove{salt: salt}(poolManager, implementation));
return address(new MiddlewareRemove{salt: salt}(manager, implementation));
}
}
34 changes: 25 additions & 9 deletions test/MiddlewareRemoveFactory.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,39 @@ contract MiddlewareRemoveFactoryTest is Test, Deployers {

function testFeeOnRemove() public {
uint160 flags = uint160(Hooks.AFTER_REMOVE_LIQUIDITY_FLAG);
FeeOnRemove feeOnRemove = FeeOnRemove(address(uint160(Hooks.AFTER_REMOVE_LIQUIDITY_FLAG)));
FeeOnRemove feeOnRemove = FeeOnRemove(address(flags));
FeeOnRemove impl = new FeeOnRemove(manager);
vm.etch(address(feeOnRemove), address(impl).code);
(address hookAddress, bytes32 salt) = HookMiner.find(
(, bytes32 salt) = HookMiner.find(
address(factory),
flags,
type(MiddlewareRemove).creationCode,
abi.encode(address(manager), address(feeOnRemove))
);
factory.createMiddleware(address(feeOnRemove), salt);
middleware = factory.createMiddleware(address(feeOnRemove), salt);

initPoolAndAddLiquidity(currency0, currency1, IHooks(feeOnRemove), 3000, SQRT_PRICE_1_1, ZERO_BYTES);
vm.expectRevert(IPoolManager.CurrencyNotSettled.selector);
removeLiquidity(currency0, currency1, IHooks(feeOnRemove), 3000, SQRT_PRICE_1_1, ZERO_BYTES);

IHooks noHooks = IHooks(address(0));
initPoolAndAddLiquidity(currency0, currency1, noHooks, 3000, SQRT_PRICE_1_1, ZERO_BYTES);
uint256 initialBalance0 = token0.balanceOf(address(this));
uint256 initialBalance1 = token1.balanceOf(address(this));
removeLiquidity(currency0, currency1, IHooks(feeOnRemove), 3000, SQRT_PRICE_1_1, ZERO_BYTES);
uint256 outWithFees0 = token0.balanceOf(address(this)) - initialBalance0;
uint256 outWithFees1 = token1.balanceOf(address(this)) - initialBalance1;
console.log(outWithFees0, outWithFees1);
removeLiquidity(currency0, currency1, noHooks, 3000, SQRT_PRICE_1_1, ZERO_BYTES);
uint256 outNormal0 = token0.balanceOf(address(this)) - initialBalance0;
uint256 outNormal1 = token1.balanceOf(address(this)) - initialBalance1;

initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES);
initialBalance0 = token0.balanceOf(address(this));
initialBalance1 = token1.balanceOf(address(this));
removeLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES);
uint256 out0 = token0.balanceOf(address(this)) - initialBalance0;
uint256 out1 = token1.balanceOf(address(this)) - initialBalance1;
console.log(out0, out1);

// no fees taken
assertEq(outNormal0, out0);
assertEq(outNormal1, out1);
}

function testVariousFactory() public {
Expand Down Expand Up @@ -131,14 +138,23 @@ contract MiddlewareRemoveFactoryTest is Test, Deployers {
assertEq(factory.getImplementation(hookAddress), implementation);
}

function testRevertOnDeltaFlags() public {
uint160 flags = uint160(Hooks.AFTER_REMOVE_LIQUIDITY_RETURNS_DELTA_FLAG);
(address hookAddress, bytes32 salt) = HookMiner.find(
address(factory), flags, type(MiddlewareRemove).creationCode, abi.encode(address(manager), address(counter))
);
vm.expectRevert(MiddlewareRemove.HookPermissionForbidden.selector);
factory.createMiddleware(address(counter), salt);
}

// from BaseMiddlewareFactory.t.sol
function testRevertOnSameDeployment() public {
uint160 flags = uint160(
Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG
| Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG
| Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG
);
(address hookAddress, bytes32 salt) = HookMiner.find(
(, bytes32 salt) = HookMiner.find(
address(factory), flags, type(MiddlewareRemove).creationCode, abi.encode(address(manager), address(counter))
);
factory.createMiddleware(address(counter), salt);
Expand Down
8 changes: 3 additions & 5 deletions test/middleware/FeeOnRemove.sol
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ contract FeeOnRemove is BaseHook {
}

function afterRemoveLiquidity(
address, /* sender **/
address,
PoolKey calldata key,
IPoolManager.ModifyLiquidityParams calldata, /* params **/
IPoolManager.ModifyLiquidityParams calldata,
BalanceDelta delta,
bytes calldata /* hookData **/
bytes calldata
) external override onlyByManager returns (bytes4, BalanceDelta) {
assert(delta.amount0() >= 0 && delta.amount1() >= 0);

uint128 feeAmount0 = uint128(delta.amount0()) * LIQUIDITY_FEE / TOTAL_BIPS;
uint128 feeAmount1 = uint128(delta.amount1()) * LIQUIDITY_FEE / TOTAL_BIPS;

Expand Down

0 comments on commit bb2641c

Please sign in to comment.