diff --git a/contracts/hooks/examples/TakingFee.sol b/contracts/hooks/examples/TakingFee.sol index dbbe551a..d3294e6c 100644 --- a/contracts/hooks/examples/TakingFee.sol +++ b/contracts/hooks/examples/TakingFee.sol @@ -5,54 +5,46 @@ import {BaseHook} from "../../BaseHook.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; -import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; -import {Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {Currency, CurrencyLibrary} from "@uniswap/v4-core/src/types/Currency.sol"; import {SafeCast} from "@uniswap/v4-core/src/libraries/SafeCast.sol"; import {Owned} from "solmate/auth/Owned.sol"; +import {IUnlockCallback} from "@uniswap/v4-core/src/interfaces/callback/IUnlockCallback.sol"; -contract TakingFee is BaseHook, Owned { - using PoolIdLibrary for PoolKey; +contract TakingFee is BaseHook, IUnlockCallback, Owned { using SafeCast for uint256; uint128 private constant TOTAL_BIPS = 10000; uint128 private constant MAX_BIPS = 100; uint128 public swapFeeBips; - address public treasury = msg.sender; - constructor( - IPoolManager _poolManager, - uint128 _swapFeeBips, - address _treasury - ) BaseHook(_poolManager) Owned(msg.sender) { + struct CallbackData { + address to; + Currency[] currencies; + } + + constructor(IPoolManager _poolManager, uint128 _swapFeeBips) BaseHook(_poolManager) Owned(msg.sender) { swapFeeBips = _swapFeeBips; - treasury = _treasury; } - function getHookPermissions() - public - pure - override - returns (Hooks.Permissions memory) - { - return - Hooks.Permissions({ - beforeInitialize: false, - afterInitialize: false, - beforeAddLiquidity: false, - afterAddLiquidity: false, - beforeRemoveLiquidity: false, - afterRemoveLiquidity: false, - beforeSwap: false, - afterSwap: true, - beforeDonate: false, - afterDonate: false, - beforeSwapReturnDelta: false, - afterSwapReturnDelta: true, - afterAddLiquidityReturnDelta: false, - afterRemoveLiquidityReturnDelta: false - }); + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: false, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: true, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); } function afterSwap( @@ -63,16 +55,15 @@ contract TakingFee is BaseHook, Owned { bytes calldata ) external override returns (bytes4, int128) { // fee will be in the unspecified token of the swap - bool specifiedTokenIs0 = (params.amountSpecified < 0 == - params.zeroForOne); - (Currency feeCurrency, int128 swapAmount) = (specifiedTokenIs0) - ? (key.currency1, delta.amount1()) - : (key.currency0, delta.amount0()); + bool specifiedTokenIs0 = (params.amountSpecified < 0 == params.zeroForOne); + (Currency feeCurrency, int128 swapAmount) = + (specifiedTokenIs0) ? (key.currency1, delta.amount1()) : (key.currency0, delta.amount0()); // if fee is on output, get the absolute output amount if (swapAmount < 0) swapAmount = -swapAmount; uint256 feeAmount = (uint128(swapAmount) * swapFeeBips) / TOTAL_BIPS; - poolManager.take(feeCurrency, treasury, feeAmount); + // mint ERC6909 instead of take to avoid edge case where PM doesn't have enough balance + poolManager.mint(address(this), CurrencyLibrary.toId(feeCurrency), feeAmount); return (BaseHook.afterSwap.selector, feeAmount.toInt128()); } @@ -82,7 +73,26 @@ contract TakingFee is BaseHook, Owned { swapFeeBips = _swapFeeBips; } - function setTreasury(address _treasury) external onlyOwner { - treasury = _treasury; + function withdraw(address to, Currency[] calldata currencies) external onlyOwner { + poolManager.unlock(abi.encode(CallbackData(to, currencies))); + } + + function unlockCallback(bytes calldata rawData) + external + override(IUnlockCallback, BaseHook) + poolManagerOnly + returns (bytes memory) + { + CallbackData memory data = abi.decode(rawData, (CallbackData)); + uint256 length = data.currencies.length; + for (uint256 i = 0; i < length;) { + uint256 amount = poolManager.balanceOf(address(this), CurrencyLibrary.toId(data.currencies[i])); + poolManager.burn(address(this), CurrencyLibrary.toId(data.currencies[i]), amount); + poolManager.take(data.currencies[i], data.to, amount); + unchecked { + i++; + } + } + return ""; } } diff --git a/test/TakingFee.t.sol b/test/TakingFee.t.sol index 18f48869..7a1edd50 100644 --- a/test/TakingFee.t.sol +++ b/test/TakingFee.t.sol @@ -2,7 +2,6 @@ pragma solidity ^0.8.19; import {Test} from "forge-std/Test.sol"; -import {GetSender} from "./shared/GetSender.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; import {TakingFee} from "../contracts/hooks/examples/TakingFee.sol"; import {TakingFeeImplementation} from "./shared/implementation/TakingFeeImplementation.sol"; @@ -12,7 +11,6 @@ import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; -import {TickMath} from "@uniswap/v4-core/src/libraries/TickMath.sol"; import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; import {StateLibrary} from "@uniswap/v4-core/src/libraries/StateLibrary.sol"; @@ -27,6 +25,9 @@ contract TakingFeeTest is Test, Deployers { address constant TREASURY = address(0x1234567890123456789012345678901234567890); uint128 private constant TOTAL_BIPS = 10000; + // rounding for tests to avoid floating point errors + uint128 R = 10; + HookEnabledSwapRouter router; TestERC20 token0; TestERC20 token1; @@ -42,7 +43,7 @@ contract TakingFeeTest is Test, Deployers { token1 = TestERC20(Currency.unwrap(currency1)); vm.record(); - TakingFeeImplementation impl = new TakingFeeImplementation(manager, 25, TREASURY, takingFee); + TakingFeeImplementation impl = new TakingFeeImplementation(manager, 25, takingFee); (, bytes32[] memory writes) = vm.accesses(address(impl)); vm.etch(address(takingFee), address(impl).code); // for each storage key that was written during the hook implementation, copy the value over @@ -63,39 +64,83 @@ contract TakingFeeTest is Test, Deployers { } function testSwapHooks() public { - // rounding for tests - uint128 ROUND_FACTOR = 8; - - // positions were created in setup() assertEq(currency0.balanceOf(TREASURY), 0); assertEq(currency1.balanceOf(TREASURY), 0); - // Perform a test swap // + // Swap exact token0 for token1 // bool zeroForOne = true; - int256 amountSpecified = -1e12; // negative number indicates exact input swap + int256 amountSpecified = -1e12; BalanceDelta swapDelta = swap(key, zeroForOne, amountSpecified, ZERO_BYTES); - // ------------------- // + // ---------------------------- // uint128 output = uint128(swapDelta.amount1()); - assertFalse(output == 0); + assertTrue(output > 0); - uint256 expectedFee = output * TOTAL_BIPS/(TOTAL_BIPS - takingFee.swapFeeBips()) - output; + uint256 expectedFee = output * TOTAL_BIPS / (TOTAL_BIPS - takingFee.swapFeeBips()) - output; - assertEq(currency0.balanceOf(TREASURY), 0); - assertEq(currency1.balanceOf(TREASURY) / ROUND_FACTOR, expectedFee / ROUND_FACTOR); - - // Perform a test swap // + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency1)) / R, expectedFee / R); + + // Swap token0 for exact token1 // bool zeroForOne2 = true; int256 amountSpecified2 = 1e12; // positive number indicates exact output swap BalanceDelta swapDelta2 = swap(key, zeroForOne2, amountSpecified2, ZERO_BYTES); - // ------------------- // - + // ---------------------------- // + uint128 input = uint128(-swapDelta2.amount0()); - assertFalse(input == 0); + assertTrue(output > 0); + + uint128 expectedFee2 = (input * takingFee.swapFeeBips()) / (TOTAL_BIPS + takingFee.swapFeeBips()); + + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency0)) / R, expectedFee2 / R); + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency1)) / R, expectedFee / R); + + // test withdrawing tokens // + Currency[] memory currencies = new Currency[](2); + currencies[0] = key.currency0; + currencies[1] = key.currency1; + takingFee.withdraw(TREASURY, currencies); + assertEq(manager.balanceOf(address(this), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(this), CurrencyLibrary.toId(key.currency1)), 0); + assertEq(currency0.balanceOf(TREASURY) / R, expectedFee2 / R); + assertEq(currency1.balanceOf(TREASURY) / R, expectedFee / R); + } + + function testEdgeCase() public { + // Swap exact token0 for token1 // + bool zeroForOne = true; + int256 amountSpecified = -1e18; + BalanceDelta swapDelta = swap(key, zeroForOne, amountSpecified, ZERO_BYTES); + // ---------------------------- // + + uint128 output = uint128(swapDelta.amount1()); + assertTrue(output > 0); + + uint256 expectedFee = output * TOTAL_BIPS / (TOTAL_BIPS - takingFee.swapFeeBips()) - output; + + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency1)) / R, expectedFee / R); + + // Swap token1 for exact token0 // + bool zeroForOne2 = false; + int256 amountSpecified2 = 1e18; // positive number indicates exact output swap + BalanceDelta swapDelta2 = swap(key, zeroForOne2, amountSpecified2, ZERO_BYTES); + // ---------------------------- // + + uint128 input = uint128(-swapDelta2.amount1()); + assertTrue(output > 0); uint128 expectedFee2 = (input * takingFee.swapFeeBips()) / (TOTAL_BIPS + takingFee.swapFeeBips()); - assertEq(currency0.balanceOf(TREASURY) / ROUND_FACTOR, expectedFee2 / ROUND_FACTOR); - assertEq(currency1.balanceOf(TREASURY) / ROUND_FACTOR, expectedFee / ROUND_FACTOR); + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency0)), 0); + assertEq(manager.balanceOf(address(takingFee), CurrencyLibrary.toId(key.currency1)) / R, (expectedFee + expectedFee2) / R); + + // test withdrawing tokens // + Currency[] memory currencies = new Currency[](2); + currencies[0] = key.currency0; + currencies[1] = key.currency1; + takingFee.withdraw(TREASURY, currencies); + assertEq(currency0.balanceOf(TREASURY) / R, 0); + assertEq(currency1.balanceOf(TREASURY) / R, (expectedFee + expectedFee2) / R); } } diff --git a/test/shared/implementation/TakingFeeImplementation.sol b/test/shared/implementation/TakingFeeImplementation.sol index e5e07237..8c1a0c11 100644 --- a/test/shared/implementation/TakingFeeImplementation.sol +++ b/test/shared/implementation/TakingFeeImplementation.sol @@ -7,7 +7,9 @@ import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; contract TakingFeeImplementation is TakingFee { - constructor(IPoolManager _poolManager, uint128 _swapFeeBips, address _treasury, TakingFee addressToEtch) TakingFee(_poolManager, _swapFeeBips, _treasury) { + constructor(IPoolManager _poolManager, uint128 _swapFeeBips, TakingFee addressToEtch) + TakingFee(_poolManager, _swapFeeBips) + { Hooks.validateHookPermissions(addressToEtch, getHookPermissions()); }