Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure royalty policy propagation is made only to children #307

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions contracts/modules/royalty/RoyaltyModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -497,24 +497,35 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad
for (uint256 i = 0; i < parentIpIds.length; i++) {
if (parentIpIds[i] == address(0)) revert Errors.RoyaltyModule__ZeroParentIpId();
if (licenseRoyaltyPolicies[i] == address(0)) revert Errors.RoyaltyModule__ZeroRoyaltyPolicy();
_addToAccumulatedRoyaltyPolicies(parentIpIds[i], licenseRoyaltyPolicies[i]);
address[] memory accParentRoyaltyPolicies = $.accumulatedRoyaltyPolicies[parentIpIds[i]].values();

// transfer the royalty tokens and add royalty policy to child if it's not contained in the parent
// since if it's already contained then it will be transferred in the next loop already
if (!$.accumulatedRoyaltyPolicies[parentIpIds[i]].contains(licenseRoyaltyPolicies[i])) {
_addToAccumulatedRoyaltyPolicies(ipId, licenseRoyaltyPolicies[i]);
totalRtsRequiredToLink += _transferRoyaltyTokensToPolicy(
parentIpIds[i],
licenseRoyaltyPolicies[i],
licensesPercent[i],
ipRoyaltyVault
);
}

// transfer the royalty tokens and add all the parent royalty policies to the child
// this loop is limited to accumulatedRoyaltyPoliciesLimit
address[] memory accParentRoyaltyPolicies = $.accumulatedRoyaltyPolicies[parentIpIds[i]].values();
for (uint256 j = 0; j < accParentRoyaltyPolicies.length; j++) {
// add the parent ancestor royalty policies to the child
_addToAccumulatedRoyaltyPolicies(ipId, accParentRoyaltyPolicies[j]);
// transfer the required royalty tokens to each policy
uint32 licensePercent = accParentRoyaltyPolicies[j] == licenseRoyaltyPolicies[i]
? licensesPercent[i]
: 0;
uint32 rtsRequiredToLink = IRoyaltyPolicy(accParentRoyaltyPolicies[j]).getPolicyRtsRequiredToLink(
totalRtsRequiredToLink += _transferRoyaltyTokensToPolicy(
parentIpIds[i],
licensePercent
accParentRoyaltyPolicies[j],
licensePercent,
ipRoyaltyVault
);
totalRtsRequiredToLink += rtsRequiredToLink;
if (totalRtsRequiredToLink > MAX_PERCENT) revert Errors.RoyaltyModule__AboveMaxPercent();
IERC20(ipRoyaltyVault).safeTransfer(accParentRoyaltyPolicies[j], rtsRequiredToLink);
}
}

Expand All @@ -537,6 +548,23 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad
_getRoyaltyModuleStorage().accumulatedRoyaltyPolicies[ipId].add(royaltyPolicy);
}

/// @notice Transfers the royalty tokens to the royalty policy
/// @param parentIpId The parent IP asset
/// @param royaltyPolicy The address of the royalty policy
/// @param licensePercent The license percentage
/// @param ipRoyaltyVault The address of the ipRoyaltyVault
/// @return rtsRequiredToLink The required royalty tokens to link
function _transferRoyaltyTokensToPolicy(
address parentIpId,
address royaltyPolicy,
uint32 licensePercent,
address ipRoyaltyVault
) internal returns (uint32) {
uint32 rtsRequiredToLink = IRoyaltyPolicy(royaltyPolicy).getPolicyRtsRequiredToLink(parentIpId, licensePercent);
if (rtsRequiredToLink > 0) IERC20(ipRoyaltyVault).safeTransfer(royaltyPolicy, rtsRequiredToLink);
return rtsRequiredToLink;
}

/// @notice Handles the payment of royalties
/// @param receiverIpId The ipId that receives the royalties
/// @param payerAddress The address that pays the royalties
Expand Down
51 changes: 8 additions & 43 deletions test/foundry/modules/royalty/RoyaltyModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -630,41 +630,6 @@ contract TestRoyaltyModule is BaseTest {
royaltyModule.onLinkToParents(address(80), parents, licenseRoyaltyPolicies, parentRoyalties, "");
}

function test_RoyaltyModule_onLinkToParents_revert_AboveMaxPercent() public {
address[] memory parents = new address[](3);
address[] memory licenseRoyaltyPolicies = new address[](3);
uint32[] memory parentRoyalties = new uint32[](3);

// link 80 to 10 + 60 + 70
parents = new address[](3);
licenseRoyaltyPolicies = new address[](3);
parentRoyalties = new uint32[](3);
parents[0] = address(10);
parents[1] = address(60);
parents[2] = address(70);
licenseRoyaltyPolicies[0] = address(royaltyPolicyLAP);
licenseRoyaltyPolicies[1] = address(royaltyPolicyLRP);
licenseRoyaltyPolicies[2] = address(mockExternalRoyaltyPolicy1);
parentRoyalties[0] = uint32(500 * 10 ** 6);
parentRoyalties[1] = uint32(17 * 10 ** 6);
parentRoyalties[2] = uint32(24 * 10 ** 6);

vm.startPrank(address(licensingModule));
ipGraph.addParentIp(address(80), parents);

// tests royalty stack above 100%
vm.expectRevert(Errors.RoyaltyModule__AboveMaxPercent.selector);
royaltyModule.onLinkToParents(address(80), parents, licenseRoyaltyPolicies, parentRoyalties, "");

parentRoyalties[0] = uint32(50 * 10 ** 6);
parentRoyalties[1] = uint32(17 * 10 ** 6);
parentRoyalties[2] = uint32(240 * 10 ** 6);

// tests royalty token supply above 100%
vm.expectRevert(Errors.RoyaltyModule__AboveMaxPercent.selector);
royaltyModule.onLinkToParents(address(80), parents, licenseRoyaltyPolicies, parentRoyalties, "");
}

function test_RoyaltyModule_onLinkToParents_revert_RoyaltyModule_NotWhitelistedOrRegisteredRoyaltyPolicy() public {
address[] memory parents = new address[](3);
address[] memory licenseRoyaltyPolicies = new address[](3);
Expand Down Expand Up @@ -774,22 +739,22 @@ contract TestRoyaltyModule is BaseTest {
address[] memory accRoyaltyPolicies80After = royaltyModule.accumulatedRoyaltyPolicies(address(80));
assertEq(accRoyaltyPolicies80After[0], address(royaltyPolicyLAP));
assertEq(accRoyaltyPolicies80After[1], address(royaltyPolicyLRP));
assertEq(accRoyaltyPolicies80After[2], address(mockExternalRoyaltyPolicy1));
assertEq(accRoyaltyPolicies80After[3], address(mockExternalRoyaltyPolicy2));
assertEq(accRoyaltyPolicies80After[2], address(mockExternalRoyaltyPolicy2));
assertEq(accRoyaltyPolicies80After[3], address(mockExternalRoyaltyPolicy1));

address[] memory accRoyaltyPolicies10After = royaltyModule.accumulatedRoyaltyPolicies(address(10));
assertEq(accRoyaltyPolicies10After[0], address(royaltyPolicyLAP));
assertEq(accRoyaltyPolicies10After.length, 0);

address[] memory accRoyaltyPolicies60After = royaltyModule.accumulatedRoyaltyPolicies(address(60));
assertEq(accRoyaltyPolicies60After[0], address(royaltyPolicyLAP));
assertEq(accRoyaltyPolicies60After[1], address(royaltyPolicyLRP));
assertEq(accRoyaltyPolicies60After[2], address(mockExternalRoyaltyPolicy1));
assertEq(accRoyaltyPolicies60After.length, 2);

address[] memory accRoyaltyPolicies70After = royaltyModule.accumulatedRoyaltyPolicies(address(70));
assertEq(accRoyaltyPolicies70After[0], address(royaltyPolicyLAP));
assertEq(accRoyaltyPolicies70After[1], address(royaltyPolicyLRP));
assertEq(accRoyaltyPolicies70After[2], address(mockExternalRoyaltyPolicy1));
assertEq(accRoyaltyPolicies70After[3], address(mockExternalRoyaltyPolicy2));
assertEq(accRoyaltyPolicies70After[0], address(mockExternalRoyaltyPolicy1));
assertEq(accRoyaltyPolicies70After[1], address(royaltyPolicyLAP));
assertEq(accRoyaltyPolicies70After[2], address(royaltyPolicyLRP));
assertEq(accRoyaltyPolicies70After.length, 3);
}

function test_RoyaltyModule_onLinkToParents_group() public {
Expand Down