From ca6e75385e63396856f4454cc4eb390ec2362bd4 Mon Sep 17 00:00:00 2001 From: Spablob <99089658+Spablob@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:34:14 +0100 Subject: [PATCH] fix token pause (#146) --- .../royalty/policies/IIpRoyaltyVault.sol | 10 +++ .../royalty/policies/IpRoyaltyVault.sol | 63 +++++++++----- .../modules/royalty/IpRoyaltyVault.t.sol | 87 +++++++++++++++---- 3 files changed, 120 insertions(+), 40 deletions(-) diff --git a/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol b/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol index eaf7cae4..64de5082 100644 --- a/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol +++ b/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol @@ -65,6 +65,11 @@ interface IIpRoyaltyVault { /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to function collectRoyaltyTokens(address ancestorIpId) external; + /// @notice Collect the accrued tokens (if any) + /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to + /// @param tokens The list of revenue tokens to claim + function collectAccruedTokens(address ancestorIpId, address[] calldata tokens) external; + /// @notice The ip id to whom this royalty vault belongs to /// @return The ip id address function ipId() external view returns (address); @@ -79,6 +84,11 @@ interface IIpRoyaltyVault { /// @param token The address of the revenue token function ancestorsVaultAmount(address token) external view returns (uint256); + /// @notice The amount of revenue tokens that can be collected by the ancestor + /// @param ancestorIpId The ancestor ipId address + /// @param token The address of the revenue token + function collectableAmount(address ancestorIpId, address token) external view returns (uint256); + /// @notice Indicates whether the ancestor has collected the royalty tokens /// @param ancestorIpId The ancestor ipId address function isCollectedByAncestor(address ancestorIpId) external view returns (bool); diff --git a/contracts/modules/royalty/policies/IpRoyaltyVault.sol b/contracts/modules/royalty/policies/IpRoyaltyVault.sol index 266d96a9..1d609acd 100644 --- a/contracts/modules/royalty/policies/IpRoyaltyVault.sol +++ b/contracts/modules/royalty/policies/IpRoyaltyVault.sol @@ -28,6 +28,7 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @param unclaimedRoyaltyTokens The amount of unclaimed royalty tokens /// @param lastSnapshotTimestamp The last snapshotted timestamp /// @param ancestorsVaultAmount The amount of revenue token in the ancestors vault + /// @param collectableAmount The amount of revenue tokens that can be collected by the ancestor /// @param isCollectedByAncestor Indicates whether the ancestor has collected the royalty tokens /// @param claimVaultAmount Amount of revenue token in the claim vault /// @param claimableAtSnapshot Amount of revenue token claimable at a given snapshot @@ -40,6 +41,7 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy uint32 unclaimedRoyaltyTokens; uint40 lastSnapshotTimestamp; mapping(address token => uint256 amount) ancestorsVaultAmount; + mapping(address ancestorIpId => mapping(address token => uint256 amount)) collectableAmount; mapping(address ancestorIpId => bool) isCollectedByAncestor; mapping(address token => uint256 amount) claimVaultAmount; mapping(uint256 snapshotId => mapping(address token => uint256 amount)) claimableAtSnapshot; @@ -234,8 +236,18 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy // transfer royalty tokens to the ancestor IERC20Upgradeable(address(this)).safeTransfer(ancestorIpId, ancestorsRoyalties[index]); - // collect accrued revenue tokens (if any) - _collectAccruedTokens(ancestorsRoyalties[index], ancestorIpId); + // save the amount of revenue tokens that are collectable by the ancestor + address[] memory tokenList = $.tokens.values(); + uint256 unclaimedTokens = $.unclaimedRoyaltyTokens; + for (uint256 i = 0; i < tokenList.length; ++i) { + // the only case in which unclaimedRoyaltyTokens can be 0 is when the vault is empty and everyone claimed + // in which case the call will revert upstream with IpRoyaltyVault__AlreadyClaimed error + uint256 collectAmount = ($.ancestorsVaultAmount[tokenList[i]] * ancestorsRoyalties[index]) / + unclaimedTokens; + if (collectAmount == 0) continue; + + $.collectableAmount[ancestorIpId][tokenList[i]] += collectAmount; + } $.isCollectedByAncestor[ancestorIpId] = true; $.unclaimedRoyaltyTokens -= ancestorsRoyalties[index]; @@ -243,6 +255,24 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy emit RoyaltyTokensCollected(ancestorIpId, ancestorsRoyalties[index]); } + /// @notice Collect the accrued tokens (if any) + /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to + /// @param tokens The list of revenue tokens to claim + function collectAccruedTokens(address ancestorIpId, address[] calldata tokens) external nonReentrant whenNotPaused { + IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + + if (DISPUTE_MODULE.isIpTagged($.ipId)) revert Errors.IpRoyaltyVault__IpTagged(); + + for (uint256 i = 0; i < tokens.length; ++i) { + uint256 collectAmount = $.collectableAmount[ancestorIpId][tokens[i]]; + $.ancestorsVaultAmount[tokens[i]] -= collectAmount; + $.collectableAmount[ancestorIpId][tokens[i]] -= collectAmount; + IERC20Upgradeable(tokens[i]).safeTransfer(ancestorIpId, collectAmount); + + emit RevenueTokenClaimed(ancestorIpId, tokens[i], collectAmount); + } + } + /// @notice A function to calculate the amount of revenue token claimable by a token holder at certain snapshot /// @param account The address of the token holder /// @param snapshotId The snapshot id @@ -260,28 +290,6 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy return $.isClaimedAtSnapshot[snapshotId][account][token] ? 0 : (balance * claimableToken) / totalSupply; } - /// @dev Collect the accrued tokens (if any) - /// @param royaltyTokensToClaim The amount of royalty tokens being claimed by the ancestor - /// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to - function _collectAccruedTokens(uint256 royaltyTokensToClaim, address ancestorIpId) internal { - IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); - - address[] memory tokenList = $.tokens.values(); - - for (uint256 i = 0; i < tokenList.length; ++i) { - // the only case in which unclaimedRoyaltyTokens can be 0 is when the vault is empty and everyone claimed - // in which case the call will revert upstream with IpRoyaltyVault__AlreadyClaimed error - uint256 collectAmount = ($.ancestorsVaultAmount[tokenList[i]] * royaltyTokensToClaim) / - $.unclaimedRoyaltyTokens; - if (collectAmount == 0) continue; - - $.ancestorsVaultAmount[tokenList[i]] -= collectAmount; - IERC20Upgradeable(tokenList[i]).safeTransfer(ancestorIpId, collectAmount); - - emit RevenueTokenClaimed(ancestorIpId, tokenList[i], collectAmount); - } - } - /// @notice The ip id to whom this royalty vault belongs to /// @return The ip id address function ipId() external view returns (address) { @@ -304,6 +312,13 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy return _getIpRoyaltyVaultStorage().ancestorsVaultAmount[token]; } + /// @notice The amount of revenue tokens that can be collected by the ancestor + /// @param ancestorIpId The ancestor ipId address + /// @param token The address of the revenue token + function collectableAmount(address ancestorIpId, address token) external view returns (uint256) { + return _getIpRoyaltyVaultStorage().collectableAmount[ancestorIpId][token]; + } + /// @notice Indicates whether the ancestor has collected the royalty tokens /// @param ancestorIpId The ancestor ipId address function isCollectedByAncestor(address ancestorIpId) external view returns (bool) { diff --git a/test/foundry/modules/royalty/IpRoyaltyVault.t.sol b/test/foundry/modules/royalty/IpRoyaltyVault.t.sol index 06e81420..7715ba1f 100644 --- a/test/foundry/modules/royalty/IpRoyaltyVault.t.sol +++ b/test/foundry/modules/royalty/IpRoyaltyVault.t.sol @@ -168,6 +168,15 @@ contract TestIpRoyaltyVault is BaseTest { ); } + function test_IpRoyaltyVault_ClaimRevenueByTokenBatch_revert_Paused() public { + vm.stopPrank(); + vm.prank(u.admin); + royaltyPolicyLAP.pause(); + + vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); + ipRoyaltyVault.claimRevenueByTokenBatch(1, new address[](0)); + } + function test_IpRoyaltyVault_ClaimRevenueByTokenBatch() public { // payment is made to vault uint256 royaltyAmount = 100000 * 10 ** 6; @@ -217,6 +226,15 @@ contract TestIpRoyaltyVault is BaseTest { assertEq(ipRoyaltyVault.isClaimedAtSnapshot(1, address(2), address(LINK)), true); } + function test_IpRoyaltyVault_ClaimRevenueBySnapshotBatch_revert_Paused() public { + vm.stopPrank(); + vm.prank(u.admin); + royaltyPolicyLAP.pause(); + + vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); + ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC)); + } + function test_IpRoyaltyVault_ClaimRevenueBySnapshotBatch() public { uint256 royaltyAmount = 100000 * 10 ** 6; USDC.mint(address(3), royaltyAmount); // 100k USDC @@ -268,7 +286,7 @@ contract TestIpRoyaltyVault is BaseTest { ipRoyaltyVault.snapshot(); } - function test_IpRoyaltyVault_Snapshot_revert_paused() public { + function test_IpRoyaltyVault_Snapshot_revert_Paused() public { // payment is made to vault vm.stopPrank(); vm.prank(u.admin); @@ -348,11 +366,17 @@ contract TestIpRoyaltyVault is BaseTest { ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens); ipRoyaltyVault.collectRoyaltyTokens(address(5)); + ipRoyaltyVault.collectAccruedTokens(address(5), tokens); ipRoyaltyVault.collectRoyaltyTokens(address(11)); + ipRoyaltyVault.collectAccruedTokens(address(11), tokens); ipRoyaltyVault.collectRoyaltyTokens(address(12)); + ipRoyaltyVault.collectAccruedTokens(address(12), tokens); ipRoyaltyVault.collectRoyaltyTokens(address(6)); + ipRoyaltyVault.collectAccruedTokens(address(6), tokens); ipRoyaltyVault.collectRoyaltyTokens(address(13)); + ipRoyaltyVault.collectAccruedTokens(address(13), tokens); ipRoyaltyVault.collectRoyaltyTokens(address(14)); + ipRoyaltyVault.collectAccruedTokens(address(14), tokens); // take snapshot vm.warp(block.timestamp + 7 days + 1); @@ -374,6 +398,15 @@ contract TestIpRoyaltyVault is BaseTest { ipRoyaltyVault.collectRoyaltyTokens(address(0)); } + function test_IpRoyaltyVault_CollectRoyaltyTokens_revert_Paused() public { + vm.stopPrank(); + vm.prank(u.admin); + royaltyPolicyLAP.pause(); + + vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); + ipRoyaltyVault.collectRoyaltyTokens(address(5)); + } + function test_IpRoyaltyVault_CollectRoyaltyTokens() public { uint256 parentRoyalty = 5 * 10 ** 5; uint256 royaltyAmount = 100000 * 10 ** 6; @@ -404,8 +437,6 @@ contract TestIpRoyaltyVault is BaseTest { ipRoyaltyVault.collectRoyaltyTokens(address(5)); - assertEq(USDC.balanceOf(address(5)) - userUsdcBalanceBefore, accruedCollectableRevenue); - assertEq(contractUsdcBalanceBefore - USDC.balanceOf(address(ipRoyaltyVault)), accruedCollectableRevenue); assertEq(ipRoyaltyVault.isCollectedByAncestor(address(5)), true); assertEq( contractRTBalBefore - IERC20(address(ipRoyaltyVault)).balanceOf(address(ipRoyaltyVault)), @@ -413,30 +444,54 @@ contract TestIpRoyaltyVault is BaseTest { ); assertEq(IERC20(address(ipRoyaltyVault)).balanceOf(address(5)) - userRTBalBefore, parentRoyalty); assertEq(unclaimedRoyaltyTokensBefore - ipRoyaltyVault.unclaimedRoyaltyTokens(), parentRoyalty); - assertEq( - ancestorsVaultAmountBefore - ipRoyaltyVault.ancestorsVaultAmount(address(USDC)), - accruedCollectableRevenue - ); } - function test_IpRoyaltyVault_claimRevenue_revert_paused() public { + function test_IpRoyaltyVault_CollectAccruedTokens_revert_Paused() public { vm.stopPrank(); vm.prank(u.admin); royaltyPolicyLAP.pause(); vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); - ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC)); - - vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); - ipRoyaltyVault.claimRevenueByTokenBatch(1, new address[](0)); + ipRoyaltyVault.collectAccruedTokens(address(5), new address[](0)); } - function test_IpRoyaltyVault_collectRoyaltyTokens_revert_paused() public { + function test_IpRoyaltyVault_CollectAccruedTokens() public { + uint256 parentRoyalty = 5 * 10 ** 5; + uint256 royaltyAmount = 100000 * 10 ** 6; + uint256 accruedCollectableRevenue = (royaltyAmount * 5 * 10 ** 5) / royaltyPolicyLAP.TOTAL_RT_SUPPLY(); + + // payment is made to vault + USDC.mint(address(3), royaltyAmount); // 100k USDC + vm.startPrank(address(3)); + USDC.approve(address(royaltyPolicyLAP), royaltyAmount); + royaltyModule.payRoyaltyOnBehalf(address(2), address(3), address(USDC), royaltyAmount); vm.stopPrank(); - vm.prank(u.admin); - royaltyPolicyLAP.pause(); - vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); + // take snapshot + vm.warp(block.timestamp + 7 days + 1); + ipRoyaltyVault.snapshot(); + + // collect royalty tokens ipRoyaltyVault.collectRoyaltyTokens(address(5)); + + address[] memory tokens = new address[](1); + tokens[0] = address(USDC); + + uint256 userUsdcBalanceBefore = USDC.balanceOf(address(5)); + uint256 contractUsdcBalanceBefore = USDC.balanceOf(address(ipRoyaltyVault)); + uint256 ancestorsVaultAmountBefore = ipRoyaltyVault.ancestorsVaultAmount(address(USDC)); + uint256 collectableAmountBefore = ipRoyaltyVault.collectableAmount(address(5), address(USDC)); + + ipRoyaltyVault.collectAccruedTokens(address(5), tokens); + + uint256 userUsdcBalanceAfter = USDC.balanceOf(address(5)); + uint256 contractUsdcBalanceAfter = USDC.balanceOf(address(ipRoyaltyVault)); + uint256 ancestorsVaultAmountAfter = ipRoyaltyVault.ancestorsVaultAmount(address(USDC)); + uint256 collectableAmountAfter = ipRoyaltyVault.collectableAmount(address(5), address(USDC)); + + assertEq(userUsdcBalanceAfter - userUsdcBalanceBefore, accruedCollectableRevenue); + assertEq(contractUsdcBalanceBefore - contractUsdcBalanceAfter, accruedCollectableRevenue); + assertEq(ancestorsVaultAmountBefore - ancestorsVaultAmountAfter, accruedCollectableRevenue); + assertEq(collectableAmountAfter, 0); } }