From 109fcdfcd1efaa4173089801daa7817a523c3143 Mon Sep 17 00:00:00 2001 From: Spablob Date: Fri, 15 Nov 2024 14:34:58 +0000 Subject: [PATCH 1/3] prevent royalty policy to propagate up the graph --- contracts/modules/royalty/RoyaltyModule.sol | 42 ++++++++++++--- .../modules/royalty/RoyaltyModule.t.sol | 51 +++---------------- 2 files changed, 43 insertions(+), 50 deletions(-) diff --git a/contracts/modules/royalty/RoyaltyModule.sol b/contracts/modules/royalty/RoyaltyModule.sol index a0d08255..89177f50 100644 --- a/contracts/modules/royalty/RoyaltyModule.sol +++ b/contracts/modules/royalty/RoyaltyModule.sol @@ -497,10 +497,22 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad for (uint256 i = 0; i < parentIpIds.length; i++) { if (parentIpIds[i] == address(0)) revert Errors.RoyaltyModule__ZeroParentIpId(); if (licenseRoyaltyPolicies[i] == address(0)) revert Errors.RoyaltyModule__ZeroRoyaltyPolicy(); - _addToAccumulatedRoyaltyPolicies(parentIpIds[i], licenseRoyaltyPolicies[i]); - address[] memory accParentRoyaltyPolicies = $.accumulatedRoyaltyPolicies[parentIpIds[i]].values(); + // transfer the royalty tokens and add royalty policy to child if it's not contained in the parent + // since if it's already contained then it will be transferred in the next loop already + if (!$.accumulatedRoyaltyPolicies[parentIpIds[i]].contains(licenseRoyaltyPolicies[i])) { + _addToAccumulatedRoyaltyPolicies(ipId, licenseRoyaltyPolicies[i]); + totalRtsRequiredToLink += _transferRoyaltyTokensToPolicy( + parentIpIds[i], + licenseRoyaltyPolicies[i], + licensesPercent[i], + ipRoyaltyVault + ); + } + + // transfer the royalty tokens and add all the parent royalty policies to the child // this loop is limited to accumulatedRoyaltyPoliciesLimit + address[] memory accParentRoyaltyPolicies = $.accumulatedRoyaltyPolicies[parentIpIds[i]].values(); for (uint256 j = 0; j < accParentRoyaltyPolicies.length; j++) { // add the parent ancestor royalty policies to the child _addToAccumulatedRoyaltyPolicies(ipId, accParentRoyaltyPolicies[j]); @@ -508,13 +520,12 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad uint32 licensePercent = accParentRoyaltyPolicies[j] == licenseRoyaltyPolicies[i] ? licensesPercent[i] : 0; - uint32 rtsRequiredToLink = IRoyaltyPolicy(accParentRoyaltyPolicies[j]).getPolicyRtsRequiredToLink( + totalRtsRequiredToLink += _transferRoyaltyTokensToPolicy( parentIpIds[i], - licensePercent + accParentRoyaltyPolicies[j], + licensePercent, + ipRoyaltyVault ); - totalRtsRequiredToLink += rtsRequiredToLink; - if (totalRtsRequiredToLink > MAX_PERCENT) revert Errors.RoyaltyModule__AboveMaxPercent(); - IERC20(ipRoyaltyVault).safeTransfer(accParentRoyaltyPolicies[j], rtsRequiredToLink); } } @@ -537,6 +548,23 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad _getRoyaltyModuleStorage().accumulatedRoyaltyPolicies[ipId].add(royaltyPolicy); } + /// @notice Transfers the royalty tokens to the royalty policy + /// @param parentIpId The parent IP asset + /// @param royaltyPolicy The address of the royalty policy + /// @param licensePercent The license percentage + /// @param ipRoyaltyVault The address of the ipRoyaltyVault + /// @return rtsRequiredToLink The required royalty tokens to link + function _transferRoyaltyTokensToPolicy( + address parentIpId, + address royaltyPolicy, + uint32 licensePercent, + address ipRoyaltyVault + ) internal returns (uint32) { + uint32 rtsRequiredToLink = IRoyaltyPolicy(royaltyPolicy).getPolicyRtsRequiredToLink(parentIpId, licensePercent); + IERC20(ipRoyaltyVault).safeTransfer(royaltyPolicy, rtsRequiredToLink); + return rtsRequiredToLink; + } + /// @notice Handles the payment of royalties /// @param receiverIpId The ipId that receives the royalties /// @param payerAddress The address that pays the royalties diff --git a/test/foundry/modules/royalty/RoyaltyModule.t.sol b/test/foundry/modules/royalty/RoyaltyModule.t.sol index 36e68c92..4e37658d 100644 --- a/test/foundry/modules/royalty/RoyaltyModule.t.sol +++ b/test/foundry/modules/royalty/RoyaltyModule.t.sol @@ -629,41 +629,6 @@ contract TestRoyaltyModule is BaseTest { royaltyModule.onLinkToParents(address(80), parents, licenseRoyaltyPolicies, parentRoyalties, ""); } - function test_RoyaltyModule_onLinkToParents_revert_AboveMaxPercent() public { - address[] memory parents = new address[](3); - address[] memory licenseRoyaltyPolicies = new address[](3); - uint32[] memory parentRoyalties = new uint32[](3); - - // link 80 to 10 + 60 + 70 - parents = new address[](3); - licenseRoyaltyPolicies = new address[](3); - parentRoyalties = new uint32[](3); - parents[0] = address(10); - parents[1] = address(60); - parents[2] = address(70); - licenseRoyaltyPolicies[0] = address(royaltyPolicyLAP); - licenseRoyaltyPolicies[1] = address(royaltyPolicyLRP); - licenseRoyaltyPolicies[2] = address(mockExternalRoyaltyPolicy1); - parentRoyalties[0] = uint32(500 * 10 ** 6); - parentRoyalties[1] = uint32(17 * 10 ** 6); - parentRoyalties[2] = uint32(24 * 10 ** 6); - - vm.startPrank(address(licensingModule)); - ipGraph.addParentIp(address(80), parents); - - // tests royalty stack above 100% - vm.expectRevert(Errors.RoyaltyModule__AboveMaxPercent.selector); - royaltyModule.onLinkToParents(address(80), parents, licenseRoyaltyPolicies, parentRoyalties, ""); - - parentRoyalties[0] = uint32(50 * 10 ** 6); - parentRoyalties[1] = uint32(17 * 10 ** 6); - parentRoyalties[2] = uint32(240 * 10 ** 6); - - // tests royalty token supply above 100% - vm.expectRevert(Errors.RoyaltyModule__AboveMaxPercent.selector); - royaltyModule.onLinkToParents(address(80), parents, licenseRoyaltyPolicies, parentRoyalties, ""); - } - function test_RoyaltyModule_onLinkToParents_revert_RoyaltyModule_NotWhitelistedOrRegisteredRoyaltyPolicy() public { address[] memory parents = new address[](3); address[] memory licenseRoyaltyPolicies = new address[](3); @@ -773,22 +738,22 @@ contract TestRoyaltyModule is BaseTest { address[] memory accRoyaltyPolicies80After = royaltyModule.accumulatedRoyaltyPolicies(address(80)); assertEq(accRoyaltyPolicies80After[0], address(royaltyPolicyLAP)); assertEq(accRoyaltyPolicies80After[1], address(royaltyPolicyLRP)); - assertEq(accRoyaltyPolicies80After[2], address(mockExternalRoyaltyPolicy1)); - assertEq(accRoyaltyPolicies80After[3], address(mockExternalRoyaltyPolicy2)); + assertEq(accRoyaltyPolicies80After[2], address(mockExternalRoyaltyPolicy2)); + assertEq(accRoyaltyPolicies80After[3], address(mockExternalRoyaltyPolicy1)); address[] memory accRoyaltyPolicies10After = royaltyModule.accumulatedRoyaltyPolicies(address(10)); - assertEq(accRoyaltyPolicies10After[0], address(royaltyPolicyLAP)); + assertEq(accRoyaltyPolicies10After.length, 0); address[] memory accRoyaltyPolicies60After = royaltyModule.accumulatedRoyaltyPolicies(address(60)); assertEq(accRoyaltyPolicies60After[0], address(royaltyPolicyLAP)); assertEq(accRoyaltyPolicies60After[1], address(royaltyPolicyLRP)); - assertEq(accRoyaltyPolicies60After[2], address(mockExternalRoyaltyPolicy1)); + assertEq(accRoyaltyPolicies60After.length, 2); address[] memory accRoyaltyPolicies70After = royaltyModule.accumulatedRoyaltyPolicies(address(70)); - assertEq(accRoyaltyPolicies70After[0], address(royaltyPolicyLAP)); - assertEq(accRoyaltyPolicies70After[1], address(royaltyPolicyLRP)); - assertEq(accRoyaltyPolicies70After[2], address(mockExternalRoyaltyPolicy1)); - assertEq(accRoyaltyPolicies70After[3], address(mockExternalRoyaltyPolicy2)); + assertEq(accRoyaltyPolicies70After[0], address(mockExternalRoyaltyPolicy1)); + assertEq(accRoyaltyPolicies70After[1], address(royaltyPolicyLAP)); + assertEq(accRoyaltyPolicies70After[2], address(royaltyPolicyLRP)); + assertEq(accRoyaltyPolicies70After.length, 3); } function test_RoyaltyModule_onLinkToParents_group() public { From 4b0172655ac5f908e41977950a048e618bced6d1 Mon Sep 17 00:00:00 2001 From: Spablob Date: Fri, 15 Nov 2024 14:35:54 +0000 Subject: [PATCH 2/3] format fix --- contracts/modules/royalty/RoyaltyModule.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contracts/modules/royalty/RoyaltyModule.sol b/contracts/modules/royalty/RoyaltyModule.sol index 89177f50..cde3023c 100644 --- a/contracts/modules/royalty/RoyaltyModule.sol +++ b/contracts/modules/royalty/RoyaltyModule.sol @@ -554,7 +554,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad /// @param licensePercent The license percentage /// @param ipRoyaltyVault The address of the ipRoyaltyVault /// @return rtsRequiredToLink The required royalty tokens to link - function _transferRoyaltyTokensToPolicy( + function _transferRoyaltyTokensToPolicy( address parentIpId, address royaltyPolicy, uint32 licensePercent, From e050cb254c00caaa7978c2491ba72296d2368c67 Mon Sep 17 00:00:00 2001 From: Spablob Date: Tue, 19 Nov 2024 08:45:53 +0000 Subject: [PATCH 3/3] gas fix --- contracts/modules/royalty/RoyaltyModule.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contracts/modules/royalty/RoyaltyModule.sol b/contracts/modules/royalty/RoyaltyModule.sol index 3994290e..e5c161e1 100644 --- a/contracts/modules/royalty/RoyaltyModule.sol +++ b/contracts/modules/royalty/RoyaltyModule.sol @@ -561,7 +561,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad address ipRoyaltyVault ) internal returns (uint32) { uint32 rtsRequiredToLink = IRoyaltyPolicy(royaltyPolicy).getPolicyRtsRequiredToLink(parentIpId, licensePercent); - IERC20(ipRoyaltyVault).safeTransfer(royaltyPolicy, rtsRequiredToLink); + if (rtsRequiredToLink > 0) IERC20(ipRoyaltyVault).safeTransfer(royaltyPolicy, rtsRequiredToLink); return rtsRequiredToLink; }