diff --git a/src/V3FactoryOwner.sol b/src/V3FactoryOwner.sol index 3150559..0d5f397 100644 --- a/src/V3FactoryOwner.sol +++ b/src/V3FactoryOwner.sol @@ -70,6 +70,9 @@ contract V3FactoryOwner { /// @notice Thrown if the proposed admin is the zero address. error V3FactoryOwner__InvalidAddress(); + /// @notice Thrown when the fees collected from a pool are less than the caller expects. + error V3FactoryOwner__InsufficientFeesCollected(); + /// @param _admin The initial admin address for this deployment. Cannot be zero address. /// @param _factory The v3 factory instance for which this deployment will serve as owner. /// @param _payoutToken The ERC-20 token in which payouts will be denominated. @@ -151,19 +154,26 @@ contract V3FactoryOwner { /// method. /// @param _amount1Requested The amount1Requested param to forward to the pool's collectProtocol /// method. + /// @return _amount0 The amount0 fees collected, returned by the pool's collectProtocol method. + /// @return _amount1 The amount1 fees collected, returned by the pool's collectProtocol method. /// @dev See docs on IUniswapV3PoolOwnerActions for more information on forwarded params. function claimFees( IUniswapV3PoolOwnerActions _pool, address _recipient, uint128 _amount0Requested, uint128 _amount1Requested - ) external { + ) external returns (uint128, uint128) { PAYOUT_TOKEN.safeTransferFrom(msg.sender, address(REWARD_RECEIVER), PAYOUT_AMOUNT); REWARD_RECEIVER.notifyRewardsAmount(PAYOUT_AMOUNT); (uint128 _amount0, uint128 _amount1) = _pool.collectProtocol(_recipient, _amount0Requested, _amount1Requested); + // Protect the caller from receiving less than requested. See `collectProtocol` for context. + if (_amount0 < _amount0Requested || _amount1 < _amount1Requested) { + revert V3FactoryOwner__InsufficientFeesCollected(); + } emit FeesClaimed(address(_pool), msg.sender, _recipient, _amount0, _amount1); + return (_amount0, _amount1); } /// @notice Ensures the msg.sender is the contract admin and reverts otherwise. diff --git a/test/V3FactoryOwner.t.sol b/test/V3FactoryOwner.t.sol index 3487da4..a0c4e48 100644 --- a/test/V3FactoryOwner.t.sol +++ b/test/V3FactoryOwner.t.sol @@ -243,7 +243,7 @@ contract ClaimFees is V3FactoryOwnerTest { assertEq(rewardReceiver.lastParam__notifyRewardsAmount_amount(), _payoutAmount); } - function testFuzz_CallsPoolCollectProtocolMethodWithRecipientAndAmountsRequested( + function testFuzz_CallsPoolCollectProtocolMethodWithRecipientAndAmountsRequestedAndReturnsForwardedFeeAmountsFromPool( uint256 _payoutAmount, address _caller, address _recipient, @@ -257,12 +257,15 @@ contract ClaimFees is V3FactoryOwnerTest { vm.startPrank(_caller); payoutToken.approve(address(factoryOwner), _payoutAmount); - factoryOwner.claimFees(pool, _recipient, _amount0, _amount1); + (uint256 _amount0Collected, uint256 _amount1Collected) = + factoryOwner.claimFees(pool, _recipient, _amount0, _amount1); vm.stopPrank(); assertEq(pool.lastParam__collectProtocol_recipient(), _recipient); assertEq(pool.lastParam__collectProtocol_amount0Requested(), _amount0); assertEq(pool.lastParam__collectProtocol_amount1Requested(), _amount1); + assertEq(_amount0Collected, _amount0); + assertEq(_amount1Collected, _amount1); } function testFuzz_EmitsAnEventWithFeeClaimParameters( @@ -341,4 +344,37 @@ contract ClaimFees is V3FactoryOwnerTest { factoryOwner.claimFees(pool, _recipient, _amount0, _amount1); vm.stopPrank(); } + + function testFuzz_RevertIf_CallerExpectsMoreFeesThanPoolPaysOut( + uint256 _payoutAmount, + address _caller, + address _recipient, + uint128 _amount0Requested, + uint128 _amount1Requested, + uint128 _amount0Collected, + uint128 _amount1Collected + ) public { + _deployFactoryOwnerWithPayoutAmount(_payoutAmount); + vm.assume(_caller != address(0) && _recipient != address(0)); + _amount0Requested = uint128(bound(_amount0Requested, 1, type(uint128).max)); + _amount1Requested = uint128(bound(_amount1Requested, 1, type(uint128).max)); + + // sometimes get less amount0, other times get less amount1 + // uses arbitrary randomness via fuzzed _payoutAmount + if (_payoutAmount % 2 == 0) { + _amount0Collected = uint128(bound(_amount0Collected, 0, _amount0Requested - 1)); + } else { + _amount1Collected = uint128(bound(_amount1Collected, 0, _amount1Requested - 1)); + } + pool.setNextReturn__collectProtocol(_amount0Collected, _amount1Collected); + + payoutToken.mint(_caller, _payoutAmount); + + vm.startPrank(_caller); + payoutToken.approve(address(factoryOwner), _payoutAmount); + + vm.expectRevert(V3FactoryOwner.V3FactoryOwner__InsufficientFeesCollected.selector); + factoryOwner.claimFees(pool, _recipient, _amount0Requested, _amount1Requested); + vm.stopPrank(); + } } diff --git a/test/mocks/MockUniswapV3Pool.sol b/test/mocks/MockUniswapV3Pool.sol index 1c4a75e..1977a8d 100644 --- a/test/mocks/MockUniswapV3Pool.sol +++ b/test/mocks/MockUniswapV3Pool.sol @@ -11,19 +11,29 @@ contract MockUniswapV3Pool is IUniswapV3PoolOwnerActions { uint128 public lastParam__collectProtocol_amount0Requested; uint128 public lastParam__collectProtocol_amount1Requested; + bool public useMockProtocolFeeAmounts = false; + uint128 public mockFeesAmount0; + uint128 public mockFeesAmount1; + function setFeeProtocol(uint8 feeProtocol0, uint8 feeProtocol1) external { lastParam__setFeeProtocol_feeProtocol0 = feeProtocol0; lastParam__setFeeProtocol_feeProtocol1 = feeProtocol1; } + function setNextReturn__collectProtocol(uint128 amount0, uint128 amount1) external { + useMockProtocolFeeAmounts = true; + mockFeesAmount0 = amount0; + mockFeesAmount1 = amount1; + } + function collectProtocol(address recipient, uint128 amount0Requested, uint128 amount1Requested) external - returns (uint128 amount0, uint128 amount1) + returns (uint128, uint128) { lastParam__collectProtocol_recipient = recipient; lastParam__collectProtocol_amount0Requested = amount0Requested; lastParam__collectProtocol_amount1Requested = amount1Requested; - + if (useMockProtocolFeeAmounts) return (mockFeesAmount0, mockFeesAmount1); return (amount0Requested, amount1Requested); } }