Skip to content

Commit

Permalink
last things
Browse files Browse the repository at this point in the history
  • Loading branch information
Jun1on committed Jul 10, 2024
1 parent 5bc319e commit 1c250a1
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 42 deletions.
5 changes: 5 additions & 0 deletions contracts/middleware/BaseImplementation.sol
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ abstract contract BaseImplementation is IHooks, SafeCallback {
middleware = _middleware;
}

function updateDynamicFee(PoolKey calldata key, uint24 fee) external {
if (msg.sender != middlewareFactory) revert NotMiddlewareFactory();
manager.updateDynamicLPFee(key, fee);
}

function getHookPermissions() public pure virtual returns (Hooks.Permissions memory);

function _unlockCallback(bytes calldata data) internal virtual override returns (bytes memory) {
Expand Down
11 changes: 9 additions & 2 deletions contracts/middleware/BaseMiddleware.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol";
import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol";

contract BaseMiddleware is IHooks {
using Hooks for BaseMiddleware;
using BeforeSwapDeltaLibrary for BeforeSwapDelta;

error NotManager();

IPoolManager public immutable manager;
Expand Down Expand Up @@ -102,11 +105,15 @@ contract BaseMiddleware is IHooks {
PoolKey calldata key,
IPoolManager.SwapParams calldata params,
bytes calldata hookData
) external virtual onlyByManager returns (bytes4, BeforeSwapDelta, uint24) {
) external virtual onlyByManager returns (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) {
if (msg.sender == address(implementation)) {
return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0);
}
return implementation.beforeSwap(sender, key, params, hookData);
(selector, beforeSwapDelta, lpFeeOverride) = implementation.beforeSwap(sender, key, params, hookData);
if (this.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG)) {
manager.take(key.currency0, sender, uint256(uint128(beforeSwapDelta.getSpecifiedDelta())));
manager.take(key.currency1, sender, uint256(uint128(beforeSwapDelta.getUnspecifiedDelta())));
}
}

function afterSwap(
Expand Down
70 changes: 66 additions & 4 deletions test/BaseMiddlewareFactory.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol";
import {BaseMiddlewareFactory} from "./../contracts/middleware/BaseMiddlewareFactory.sol";
import {HookMiner} from "./utils/HookMiner.sol";
import {HooksCounter} from "./middleware-implementations/HooksCounter.sol";
import {BaseImplementation} from "./../contracts/middleware/BaseImplementation.sol";

contract BaseMiddlewareFactoryTest is Test, Deployers {
HookEnabledSwapRouter router;
TestERC20 token0;
TestERC20 token1;
PoolId id;

BaseMiddlewareFactory factory;
HooksCounter hooksCounter;

address middleware;

function setUp() public {
deployFreshManagerAndRouters();
(currency0, currency1) = deployMintAndApprove2Currencies();
Expand All @@ -39,25 +41,85 @@ contract BaseMiddlewareFactoryTest is Test, Deployers {

token0.approve(address(router), type(uint256).max);
token1.approve(address(router), type(uint256).max);

uint160 flags = uint160(
Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG
| Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG
| Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG
);

(address hookAddress, bytes32 salt) = HookMiner.find(
address(factory),
flags,
type(BaseMiddleware).creationCode,
abi.encode(address(manager), address(hooksCounter))
);
middleware = factory.createMiddleware(address(hooksCounter), salt);
assertEq(hookAddress, middleware);
}

function testRevertOnAlreadyInitialized() public {
uint160 flags = uint160(
Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG
| Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG
| Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG
);
(address hookAddress, bytes32 salt) = HookMiner.find(
address(factory),
flags,
type(BaseMiddleware).creationCode,
abi.encode(address(manager), address(hooksCounter))
);
vm.expectRevert(BaseMiddlewareFactory.AlreadyInitialized.selector);
factory.createMiddleware(address(hooksCounter), salt);
}

function testRevertOnIncorrectFlags() public {
HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory));
uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG);

(address hookAddress, bytes32 salt) = HookMiner.find(
address(factory),
flags,
type(BaseMiddleware).creationCode,
abi.encode(address(manager), address(hooksCounter))
abi.encode(address(manager), address(hooksCounter2))
);
address implementation = address(hooksCounter);
address implementation = address(hooksCounter2);
vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress));
factory.createMiddleware(implementation, salt);
}

function testRevertOnIncorrectFlagsMined() public {
address implementation = address(hooksCounter);
HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory));
address implementation = address(hooksCounter2);
vm.expectRevert(); // HookAddressNotValid
factory.createMiddleware(implementation, bytes32("who needs to mine a salt?"));
}

function testRevertOnIncorrectCaller() public {
vm.expectRevert(BaseImplementation.NotMiddleware.selector);
hooksCounter.afterDonate(address(this), key, 0, 0, ZERO_BYTES);
}

function testCounters() public {
(PoolKey memory key, PoolId id) =
initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES);

assertEq(hooksCounter.beforeInitializeCount(id), 1);
assertEq(hooksCounter.afterInitializeCount(id), 1);
assertEq(hooksCounter.beforeSwapCount(id), 0);
assertEq(hooksCounter.afterSwapCount(id), 0);
assertEq(hooksCounter.beforeAddLiquidityCount(id), 1);
assertEq(hooksCounter.afterAddLiquidityCount(id), 1);
assertEq(hooksCounter.beforeRemoveLiquidityCount(id), 0);
assertEq(hooksCounter.afterRemoveLiquidityCount(id), 0);
assertEq(hooksCounter.beforeDonateCount(id), 0);
assertEq(hooksCounter.afterDonateCount(id), 0);

assertEq(hooksCounter.lastHookData(), ZERO_BYTES);
swap(key, true, 1, bytes("hi"));
assertEq(hooksCounter.lastHookData(), bytes("hi"));
assertEq(hooksCounter.beforeSwapCount(id), 1);
assertEq(hooksCounter.afterSwapCount(id), 1);
}
}
112 changes: 76 additions & 36 deletions test/middleware-implementations/HooksCounter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,27 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol";
import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol";
import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol";
import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol";
import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol";

contract HooksCounter is BaseImplementation {
using PoolIdLibrary for PoolKey;

mapping(PoolId => uint256) public beforeInitializeCount;
mapping(PoolId => uint256) public afterInitializeCount;

mapping(PoolId => uint256) public beforeSwapCount;
mapping(PoolId => uint256) public afterSwapCount;

mapping(PoolId => uint256) public beforeAddLiquidityCount;
mapping(PoolId => uint256) public afterAddLiquidityCount;
mapping(PoolId => uint256) public beforeRemoveLiquidityCount;
mapping(PoolId => uint256) public afterRemoveLiquidityCount;

mapping(PoolId => uint256) public beforeDonateCount;
mapping(PoolId => uint256) public afterDonateCount;

bytes public lastHookData;

constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {}

function getHookPermissions() public pure override returns (Hooks.Permissions memory) {
Expand All @@ -31,95 +50,116 @@ contract HooksCounter is BaseImplementation {
});
}

function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata)
function beforeInitialize(address, PoolKey calldata key, uint160, bytes calldata hookData)
external
pure
override
onlyByMiddleware
returns (bytes4)
{
beforeInitializeCount[key.toId()]++;
lastHookData = hookData;
return BaseHook.beforeInitialize.selector;
}

function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata)
function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata hookData)
external
pure
override
onlyByMiddleware
returns (bytes4)
{
afterInitializeCount[key.toId()]++;
lastHookData = hookData;
return BaseHook.afterInitialize.selector;
}

function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata)
external
pure
override
returns (bytes4)
{
return BaseHook.beforeAddLiquidity.selector;
}

function beforeRemoveLiquidity(
function beforeAddLiquidity(
address,
PoolKey calldata,
PoolKey calldata key,
IPoolManager.ModifyLiquidityParams calldata,
bytes calldata
) external pure override returns (bytes4) {
return BaseHook.beforeRemoveLiquidity.selector;
bytes calldata hookData
) external override onlyByMiddleware returns (bytes4) {
beforeAddLiquidityCount[key.toId()]++;
lastHookData = hookData;
return BaseHook.beforeAddLiquidity.selector;
}

function afterAddLiquidity(
address,
PoolKey calldata,
PoolKey calldata key,
IPoolManager.ModifyLiquidityParams calldata,
BalanceDelta,
bytes calldata
) external pure override returns (bytes4, BalanceDelta) {
bytes calldata hookData
) external override onlyByMiddleware returns (bytes4, BalanceDelta) {
afterAddLiquidityCount[key.toId()]++;
lastHookData = hookData;
return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA);
}

function beforeRemoveLiquidity(
address,
PoolKey calldata key,
IPoolManager.ModifyLiquidityParams calldata,
bytes calldata hookData
) external override onlyByMiddleware returns (bytes4) {
beforeRemoveLiquidityCount[key.toId()]++;
lastHookData = hookData;
return BaseHook.beforeRemoveLiquidity.selector;
}

function afterRemoveLiquidity(
address,
PoolKey calldata,
PoolKey calldata key,
IPoolManager.ModifyLiquidityParams calldata,
BalanceDelta,
bytes calldata
) external pure override returns (bytes4, BalanceDelta) {
bytes calldata hookData
) external override onlyByMiddleware returns (bytes4, BalanceDelta) {
afterRemoveLiquidityCount[key.toId()]++;
lastHookData = hookData;
return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA);
}

function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata)
function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata hookData)
external
pure
override
onlyByMiddleware
returns (bytes4, BeforeSwapDelta, uint24)
{
beforeSwapCount[key.toId()]++;
lastHookData = hookData;
return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0);
}

function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata)
external
pure
override
returns (bytes4, int128)
{
function afterSwap(
address,
PoolKey calldata key,
IPoolManager.SwapParams calldata,
BalanceDelta,
bytes calldata hookData
) external override onlyByMiddleware returns (bytes4, int128) {
afterSwapCount[key.toId()]++;
lastHookData = hookData;
return (BaseHook.afterSwap.selector, 0);
}

function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata)
function beforeDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData)
external
pure
override
onlyByMiddleware
returns (bytes4)
{
beforeDonateCount[key.toId()]++;
lastHookData = hookData;
return BaseHook.beforeDonate.selector;
}

function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata)
function afterDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData)
external
pure
override
onlyByMiddleware
returns (bytes4)
{
afterDonateCount[key.toId()]++;
lastHookData = hookData;
return BaseHook.afterDonate.selector;
}
}

0 comments on commit 1c250a1

Please sign in to comment.