diff --git a/.forge-snapshots/RouterBytecode.snap b/.forge-snapshots/RouterBytecode.snap index 19b15e18..17ac8e65 100644 --- a/.forge-snapshots/RouterBytecode.snap +++ b/.forge-snapshots/RouterBytecode.snap @@ -1 +1 @@ -5925 \ No newline at end of file +6067 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactIn1Hop.snap b/.forge-snapshots/RouterExactIn1Hop.snap index 3bfc7e5a..017f6b15 100644 --- a/.forge-snapshots/RouterExactIn1Hop.snap +++ b/.forge-snapshots/RouterExactIn1Hop.snap @@ -1 +1 @@ -101468 \ No newline at end of file +101510 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactIn2Hops.snap b/.forge-snapshots/RouterExactIn2Hops.snap index 6f6b2383..517bb875 100644 --- a/.forge-snapshots/RouterExactIn2Hops.snap +++ b/.forge-snapshots/RouterExactIn2Hops.snap @@ -1 +1 @@ -160449 \ No newline at end of file +160475 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactIn3Hops.snap b/.forge-snapshots/RouterExactIn3Hops.snap index 3031f4e5..028f0b10 100644 --- a/.forge-snapshots/RouterExactIn3Hops.snap +++ b/.forge-snapshots/RouterExactIn3Hops.snap @@ -1 +1 @@ -212680 \ No newline at end of file +212691 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactInputSingle.snap b/.forge-snapshots/RouterExactInputSingle.snap index 495f97ed..558ce33b 100644 --- a/.forge-snapshots/RouterExactInputSingle.snap +++ b/.forge-snapshots/RouterExactInputSingle.snap @@ -1 +1 @@ -106771 \ No newline at end of file +106775 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactOut1Hop.snap b/.forge-snapshots/RouterExactOut1Hop.snap index b08c4f7e..727daa90 100644 --- a/.forge-snapshots/RouterExactOut1Hop.snap +++ b/.forge-snapshots/RouterExactOut1Hop.snap @@ -1 +1 @@ -102242 \ No newline at end of file +102284 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactOut2Hops.snap b/.forge-snapshots/RouterExactOut2Hops.snap index f9da51bb..2e896618 100644 --- a/.forge-snapshots/RouterExactOut2Hops.snap +++ b/.forge-snapshots/RouterExactOut2Hops.snap @@ -1 +1 @@ -159812 \ No newline at end of file +159838 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactOut3Hops.snap b/.forge-snapshots/RouterExactOut3Hops.snap index 7563d22f..344aedc5 100644 --- a/.forge-snapshots/RouterExactOut3Hops.snap +++ b/.forge-snapshots/RouterExactOut3Hops.snap @@ -1 +1 @@ -212732 \ No newline at end of file +212743 \ No newline at end of file diff --git a/.forge-snapshots/RouterExactOutputSingle.snap b/.forge-snapshots/RouterExactOutputSingle.snap index 0a49af3a..91ebf967 100644 --- a/.forge-snapshots/RouterExactOutputSingle.snap +++ b/.forge-snapshots/RouterExactOutputSingle.snap @@ -1 +1 @@ -105380 \ No newline at end of file +105384 \ No newline at end of file diff --git a/contracts/V4Router.sol b/contracts/V4Router.sol index 49642064..f4266565 100644 --- a/contracts/V4Router.sol +++ b/contracts/V4Router.sol @@ -23,8 +23,10 @@ abstract contract V4Router is IV4Router { poolManager = _poolManager; } - function _v4Swap(SwapType swapType, bytes memory params) internal { - poolManager.unlock(abi.encode(SwapInfo(swapType, msg.sender, params))); + // @dev The contract inheriting from this contract, and calling _v4Swap must set the payer and recipient securely. + // Allowing any payer or recipient to be passed in could allow users to steal each others' tokens. + function _v4Swap(SwapType swapType, PaymentAddresses memory paymentAddresses, bytes memory params) internal { + poolManager.unlock(abi.encode(SwapInfo(swapType, paymentAddresses, params))); } /// @inheritdoc IUnlockCallback @@ -34,13 +36,17 @@ abstract contract V4Router is IV4Router { SwapInfo memory swapInfo = abi.decode(encodedSwapInfo, (SwapInfo)); if (swapInfo.swapType == SwapType.ExactInput) { - _swapExactInput(abi.decode(swapInfo.params, (IV4Router.ExactInputParams)), swapInfo.msgSender); + _swapExactInput(abi.decode(swapInfo.params, (IV4Router.ExactInputParams)), swapInfo.paymentAddresses); } else if (swapInfo.swapType == SwapType.ExactInputSingle) { - _swapExactInputSingle(abi.decode(swapInfo.params, (IV4Router.ExactInputSingleParams)), swapInfo.msgSender); + _swapExactInputSingle( + abi.decode(swapInfo.params, (IV4Router.ExactInputSingleParams)), swapInfo.paymentAddresses + ); } else if (swapInfo.swapType == SwapType.ExactOutput) { - _swapExactOutput(abi.decode(swapInfo.params, (IV4Router.ExactOutputParams)), swapInfo.msgSender); + _swapExactOutput(abi.decode(swapInfo.params, (IV4Router.ExactOutputParams)), swapInfo.paymentAddresses); } else if (swapInfo.swapType == SwapType.ExactOutputSingle) { - _swapExactOutputSingle(abi.decode(swapInfo.params, (IV4Router.ExactOutputSingleParams)), swapInfo.msgSender); + _swapExactOutputSingle( + abi.decode(swapInfo.params, (IV4Router.ExactOutputSingleParams)), swapInfo.paymentAddresses + ); } else { revert InvalidSwapType(); } @@ -48,20 +54,25 @@ abstract contract V4Router is IV4Router { return bytes(""); } - function _swapExactInputSingle(IV4Router.ExactInputSingleParams memory params, address msgSender) private { + function _swapExactInputSingle( + IV4Router.ExactInputSingleParams memory params, + PaymentAddresses memory paymentAddresses + ) private { _swap( params.poolKey, params.zeroForOne, int256(-int128(params.amountIn)), params.sqrtPriceLimitX96, - msgSender, + paymentAddresses, true, true, params.hookData ); } - function _swapExactInput(IV4Router.ExactInputParams memory params, address msgSender) private { + function _swapExactInput(IV4Router.ExactInputParams memory params, PaymentAddresses memory paymentAddresses) + private + { unchecked { uint256 pathLength = params.path.length; uint128 amountOut; @@ -74,7 +85,7 @@ abstract contract V4Router is IV4Router { zeroForOne, int256(-int128(params.amountIn)), 0, - msgSender, + paymentAddresses, i == 0, i == pathLength - 1, params.path[i].hookData @@ -89,20 +100,25 @@ abstract contract V4Router is IV4Router { } } - function _swapExactOutputSingle(IV4Router.ExactOutputSingleParams memory params, address msgSender) private { + function _swapExactOutputSingle( + IV4Router.ExactOutputSingleParams memory params, + PaymentAddresses memory paymentAddresses + ) private { _swap( params.poolKey, params.zeroForOne, int256(int128(params.amountOut)), params.sqrtPriceLimitX96, - msgSender, + paymentAddresses, true, true, params.hookData ); } - function _swapExactOutput(IV4Router.ExactOutputParams memory params, address msgSender) private { + function _swapExactOutput(IV4Router.ExactOutputParams memory params, PaymentAddresses memory paymentAddresses) + private + { unchecked { uint256 pathLength = params.path.length; uint128 amountIn; @@ -116,7 +132,7 @@ abstract contract V4Router is IV4Router { !oneForZero, int256(int128(params.amountOut)), 0, - msgSender, + paymentAddresses, i == 1, i == pathLength, params.path[i - 1].hookData @@ -135,7 +151,7 @@ abstract contract V4Router is IV4Router { bool zeroForOne, int256 amountSpecified, uint160 sqrtPriceLimitX96, - address msgSender, + PaymentAddresses memory paymentAddresses, bool settle, bool take, bytes memory hookData @@ -154,12 +170,12 @@ abstract contract V4Router is IV4Router { if (zeroForOne) { reciprocalAmount = amountSpecified < 0 ? delta.amount1() : delta.amount0(); - if (settle) _payAndSettle(poolKey.currency0, msgSender, delta.amount0()); - if (take) poolManager.take(poolKey.currency1, msgSender, uint128(delta.amount1())); + if (settle) _payAndSettle(poolKey.currency0, paymentAddresses.payer, delta.amount0()); + if (take) poolManager.take(poolKey.currency1, paymentAddresses.recipient, uint128(delta.amount1())); } else { reciprocalAmount = amountSpecified < 0 ? delta.amount0() : delta.amount1(); - if (settle) _payAndSettle(poolKey.currency1, msgSender, delta.amount1()); - if (take) poolManager.take(poolKey.currency0, msgSender, uint128(delta.amount0())); + if (settle) _payAndSettle(poolKey.currency1, paymentAddresses.payer, delta.amount1()); + if (take) poolManager.take(poolKey.currency0, paymentAddresses.recipient, uint128(delta.amount0())); } } @@ -176,11 +192,11 @@ abstract contract V4Router is IV4Router { poolKey = PoolKey(currency0, currency1, params.fee, params.tickSpacing, params.hooks); } - function _payAndSettle(Currency currency, address msgSender, int128 settleAmount) private { + function _payAndSettle(Currency currency, address payer, int128 settleAmount) private { poolManager.sync(currency); - _pay(Currency.unwrap(currency), msgSender, address(poolManager), uint256(uint128(-settleAmount))); + _pay(Currency.unwrap(currency), payer, uint256(uint128(-settleAmount))); poolManager.settle(currency); } - function _pay(address token, address payer, address recipient, uint256 amount) internal virtual; + function _pay(address token, address payer, uint256 amount) internal virtual; } diff --git a/contracts/interfaces/IV4Router.sol b/contracts/interfaces/IV4Router.sol index 387e2ff0..d54000b6 100644 --- a/contracts/interfaces/IV4Router.sol +++ b/contracts/interfaces/IV4Router.sol @@ -21,14 +21,18 @@ interface IV4Router is IUnlockCallback { struct SwapInfo { SwapType swapType; - address msgSender; + PaymentAddresses paymentAddresses; bytes params; } + struct PaymentAddresses { + address payer; + address recipient; + } + struct ExactInputSingleParams { PoolKey poolKey; bool zeroForOne; - address recipient; uint128 amountIn; uint128 amountOutMinimum; uint160 sqrtPriceLimitX96; @@ -38,7 +42,6 @@ interface IV4Router is IUnlockCallback { struct ExactInputParams { Currency currencyIn; PathKey[] path; - address recipient; uint128 amountIn; uint128 amountOutMinimum; } @@ -46,7 +49,6 @@ interface IV4Router is IUnlockCallback { struct ExactOutputSingleParams { PoolKey poolKey; bool zeroForOne; - address recipient; uint128 amountOut; uint128 amountInMaximum; uint160 sqrtPriceLimitX96; @@ -56,7 +58,6 @@ interface IV4Router is IUnlockCallback { struct ExactOutputParams { Currency currencyOut; PathKey[] path; - address recipient; uint128 amountOut; uint128 amountInMaximum; } diff --git a/test/V4Router.t.sol b/test/V4Router.t.sol index 706d1712..97cf39d8 100644 --- a/test/V4Router.t.sol +++ b/test/V4Router.t.sol @@ -74,7 +74,7 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { uint256 expectedAmountOut = 992054607780215625; IV4Router.ExactInputSingleParams memory params = - IV4Router.ExactInputSingleParams(key0, true, address(this), uint128(amountIn), 0, 0, bytes("")); + IV4Router.ExactInputSingleParams(key0, true, uint128(amountIn), 0, 0, bytes("")); uint256 prevBalance0 = key0.currency0.balanceOf(address(this)); uint256 prevBalance1 = key0.currency1.balanceOf(address(this)); @@ -95,7 +95,7 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { uint256 expectedAmountOut = 992054607780215625; IV4Router.ExactInputSingleParams memory params = - IV4Router.ExactInputSingleParams(key0, false, address(this), uint128(amountIn), 0, 0, bytes("")); + IV4Router.ExactInputSingleParams(key0, false, uint128(amountIn), 0, 0, bytes("")); uint256 prevBalance0 = key0.currency0.balanceOf(address(this)); uint256 prevBalance1 = key0.currency1.balanceOf(address(this)); @@ -212,7 +212,7 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { uint256 expectedAmountIn = 1008049273448486163; IV4Router.ExactOutputSingleParams memory params = - IV4Router.ExactOutputSingleParams(key0, true, address(this), uint128(amountOut), 0, 0, bytes("")); + IV4Router.ExactOutputSingleParams(key0, true, uint128(amountOut), 0, 0, bytes("")); uint256 prevBalance0 = key0.currency0.balanceOf(address(this)); uint256 prevBalance1 = key0.currency1.balanceOf(address(this)); @@ -233,7 +233,7 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { uint256 expectedAmountIn = 1008049273448486163; IV4Router.ExactOutputSingleParams memory params = - IV4Router.ExactOutputSingleParams(key0, false, address(this), uint128(amountOut), 0, 0, bytes("")); + IV4Router.ExactOutputSingleParams(key0, false, uint128(amountOut), 0, 0, bytes("")); uint256 prevBalance0 = key0.currency0.balanceOf(address(this)); uint256 prevBalance1 = key0.currency1.balanceOf(address(this)); @@ -372,7 +372,7 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { function getExactInputParams(MockERC20[] memory _tokenPath, uint256 amountIn) internal - view + pure returns (IV4Router.ExactInputParams memory params) { PathKey[] memory path = new PathKey[](_tokenPath.length - 1); @@ -382,14 +382,13 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { params.currencyIn = Currency.wrap(address(_tokenPath[0])); params.path = path; - params.recipient = address(this); params.amountIn = uint128(amountIn); params.amountOutMinimum = 0; } function getExactOutputParams(MockERC20[] memory _tokenPath, uint256 amountOut) internal - view + pure returns (IV4Router.ExactOutputParams memory params) { PathKey[] memory path = new PathKey[](_tokenPath.length - 1); @@ -399,7 +398,6 @@ contract V4RouterTest is Test, Deployers, GasSnapshot { params.currencyOut = Currency.wrap(address(_tokenPath[_tokenPath.length - 1])); params.path = path; - params.recipient = address(this); params.amountOut = uint128(amountOut); params.amountInMaximum = type(uint128).max; } diff --git a/test/shared/implementation/V4RouterImplementation.sol b/test/shared/implementation/V4RouterImplementation.sol index aac7fde4..fada2418 100644 --- a/test/shared/implementation/V4RouterImplementation.sol +++ b/test/shared/implementation/V4RouterImplementation.sol @@ -10,10 +10,10 @@ contract V4RouterImplementation is V4Router { constructor(IPoolManager _poolManager) V4Router(_poolManager) {} function swap(IV4Router.SwapType swapType, bytes memory params) external { - _v4Swap(swapType, params); + _v4Swap(swapType, PaymentAddresses({payer: msg.sender, recipient: msg.sender}), params); } - function _pay(address token, address payer, address recipient, uint256 amount) internal override { - IERC20Minimal(token).transferFrom(payer, recipient, amount); + function _pay(address token, address payer, uint256 amount) internal override { + IERC20Minimal(token).transferFrom(payer, address(poolManager), amount); } }