diff --git a/contracts/interfaces/modules/royalty/IRoyaltyModule.sol b/contracts/interfaces/modules/royalty/IRoyaltyModule.sol index 701149b1a..783ce3e40 100644 --- a/contracts/interfaces/modules/royalty/IRoyaltyModule.sol +++ b/contracts/interfaces/modules/royalty/IRoyaltyModule.sol @@ -189,6 +189,11 @@ interface IRoyaltyModule is IModule { /// @return isWhitelisted True if the royalty token is whitelisted function isWhitelistedRoyaltyToken(address token) external view returns (bool); + /// @notice Indicates if an address is a royalty vault + /// @param ipRoyaltyVault The address to check + /// @return isIpRoyaltyVault True if the address is a royalty vault + function isIpRoyaltyVault(address ipRoyaltyVault) external view returns (bool); + /// @notice Indicates the royalty vault for a given IP asset /// @param ipId The ID of IP asset function ipRoyaltyVaults(address ipId) external view returns (address); diff --git a/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol b/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol index e7bf7c5a3..cda0af8a7 100644 --- a/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol +++ b/contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol @@ -53,17 +53,24 @@ interface IIpRoyaltyVault { /// @notice Allows token holders to claim revenue token based on the token balance at certain snapshot /// @param snapshotId The snapshot id /// @param tokenList The list of revenue tokens to claim + /// @param claimer The address of the claimer /// @return The amount of revenue tokens claimed for each token - function claimRevenueByTokenBatch( + function claimRevenueOnBehalfByTokenBatch( uint256 snapshotId, - address[] calldata tokenList + address[] calldata tokenList, + address claimer ) external returns (uint256[] memory); /// @notice Allows token holders to claim by a list of snapshot ids based on the token balance at certain snapshot /// @param snapshotIds The list of snapshot ids /// @param token The revenue token to claim + /// @param claimer The address of the claimer /// @return The amount of revenue tokens claimed - function claimRevenueBySnapshotBatch(uint256[] memory snapshotIds, address token) external returns (uint256); + function claimRevenueOnBehalfBySnapshotBatch( + uint256[] memory snapshotIds, + address token, + address claimer + ) external returns (uint256); /// @notice Allows to claim revenue tokens on behalf of the ip royalty vault /// @param snapshotId The snapshot id diff --git a/contracts/lib/Errors.sol b/contracts/lib/Errors.sol index 399425691..8c74d2822 100644 --- a/contracts/lib/Errors.sol +++ b/contracts/lib/Errors.sol @@ -623,6 +623,9 @@ library Errors { /// @notice Zero amount provided. error IpRoyaltyVault__ZeroAmount(); + /// @notice Vaults must claim as self. + error IpRoyaltyVault__VaultsMustClaimAsSelf(); + //////////////////////////////////////////////////////////////////////////// // Vault Controller // //////////////////////////////////////////////////////////////////////////// diff --git a/contracts/modules/royalty/RoyaltyModule.sol b/contracts/modules/royalty/RoyaltyModule.sol index 8fa2f7aeb..6d21a6294 100644 --- a/contracts/modules/royalty/RoyaltyModule.sol +++ b/contracts/modules/royalty/RoyaltyModule.sol @@ -60,6 +60,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad /// @param isWhitelistedRoyaltyToken Indicates if a royalty token is whitelisted /// @param isRegisteredExternalRoyaltyPolicy Indicates if an external royalty policy is registered /// @param ipRoyaltyVaults The royalty vault address for a given IP asset (if any) + /// @param isIpRoyaltyVault Indicates if an address is a royalty vault /// @param globalRoyaltyStack Sum of royalty stack from each whitelisted royalty policy for a given IP asset /// @param accumulatedRoyaltyPolicies The accumulated royalty policies for a given IP asset /// @param totalRevenueTokensReceived The total lifetime revenue tokens received for a given IP asset @@ -74,6 +75,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad mapping(address token => bool) isWhitelistedRoyaltyToken; mapping(address royaltyPolicy => bool) isRegisteredExternalRoyaltyPolicy; mapping(address ipId => address ipRoyaltyVault) ipRoyaltyVaults; + mapping(address ipRoyaltyVault => bool) isIpRoyaltyVault; mapping(address ipId => uint32) globalRoyaltyStack; mapping(address ipId => EnumerableSet.AddressSet) accumulatedRoyaltyPolicies; mapping(address ipId => mapping(address token => uint256)) totalRevenueTokensReceived; @@ -421,6 +423,13 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad return _getRoyaltyModuleStorage().isWhitelistedRoyaltyToken[token]; } + /// @notice Indicates if an address is a royalty vault + /// @param ipRoyaltyVault The address to check + /// @return isIpRoyaltyVault True if the address is a royalty vault + function isIpRoyaltyVault(address ipRoyaltyVault) external view returns (bool) { + return _getRoyaltyModuleStorage().isIpRoyaltyVault[ipRoyaltyVault]; + } + /// @notice Indicates the royalty vault for a given IP asset /// @param ipId The ID of IP asset function ipRoyaltyVaults(address ipId) external view returns (address) { @@ -461,6 +470,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad address ipRoyaltyVault = address(new BeaconProxy(ipRoyaltyVaultBeacon(), "")); IIpRoyaltyVault(ipRoyaltyVault).initialize("Royalty Token", "RT", MAX_PERCENT, ipId, receiver); $.ipRoyaltyVaults[ipId] = ipRoyaltyVault; + $.isIpRoyaltyVault[ipRoyaltyVault] = true; return ipRoyaltyVault; } diff --git a/contracts/modules/royalty/policies/IpRoyaltyVault.sol b/contracts/modules/royalty/policies/IpRoyaltyVault.sol index 312204d6d..791c2799f 100644 --- a/contracts/modules/royalty/policies/IpRoyaltyVault.sol +++ b/contracts/modules/royalty/policies/IpRoyaltyVault.sol @@ -19,6 +19,10 @@ import { Errors } from "../../../lib/Errors.sol"; /// @title Ip Royalty Vault /// @notice Defines the logic for claiming revenue tokens for a given IP +/// @dev [CAUTION] +/// Do not transfer ERC20 tokens directly to the ip royalty vault as they can be lost if the pendingVaultAmount +/// is not updated along with an ERC20 transfer. +/// Use appropriate callpaths that can update the pendingVaultAmount when an ERC20 transfer to the vault is made. contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, ReentrancyGuardUpgradeable { using EnumerableSet for EnumerableSet.AddressSet; using SafeERC20Upgradeable for IERC20Upgradeable; @@ -167,23 +171,28 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @notice Allows token holders to claim revenue token based on the token balance at certain snapshot /// @param snapshotId The snapshot id /// @param tokenList The list of revenue tokens to claim + /// @param claimer The address of the claimer /// @return The amount of revenue tokens claimed for each token - function claimRevenueByTokenBatch( + function claimRevenueOnBehalfByTokenBatch( uint256 snapshotId, - address[] calldata tokenList + address[] calldata tokenList, + address claimer ) external nonReentrant whenNotPaused returns (uint256[] memory) { IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + if (ROYALTY_MODULE.isIpRoyaltyVault(claimer) && msg.sender != claimer) + revert Errors.IpRoyaltyVault__VaultsMustClaimAsSelf(); + uint256[] memory claimableAmounts = new uint256[](tokenList.length); for (uint256 i = 0; i < tokenList.length; i++) { - claimableAmounts[i] = _claimableRevenue(msg.sender, snapshotId, tokenList[i]); + claimableAmounts[i] = _claimableRevenue(claimer, snapshotId, tokenList[i]); if (claimableAmounts[i] == 0) revert Errors.IpRoyaltyVault__NoClaimableTokens(); - $.isClaimedAtSnapshot[snapshotId][msg.sender][tokenList[i]] = true; + $.isClaimedAtSnapshot[snapshotId][claimer][tokenList[i]] = true; $.claimVaultAmount[tokenList[i]] -= claimableAmounts[i]; - IERC20Upgradeable(tokenList[i]).safeTransfer(msg.sender, claimableAmounts[i]); + IERC20Upgradeable(tokenList[i]).safeTransfer(claimer, claimableAmounts[i]); - emit RevenueTokenClaimed(msg.sender, tokenList[i], claimableAmounts[i]); + emit RevenueTokenClaimed(claimer, tokenList[i], claimableAmounts[i]); } return claimableAmounts; @@ -192,25 +201,30 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy /// @notice Allows token holders to claim by a list of snapshot ids based on the token balance at certain snapshot /// @param snapshotIds The list of snapshot ids /// @param token The revenue token to claim + /// @param claimer The address of the claimer /// @return The amount of revenue tokens claimed - function claimRevenueBySnapshotBatch( + function claimRevenueOnBehalfBySnapshotBatch( uint256[] memory snapshotIds, - address token + address token, + address claimer ) external nonReentrant whenNotPaused returns (uint256) { IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage(); + if (ROYALTY_MODULE.isIpRoyaltyVault(claimer) && msg.sender != claimer) + revert Errors.IpRoyaltyVault__VaultsMustClaimAsSelf(); + uint256 claimableAmount; for (uint256 i = 0; i < snapshotIds.length; i++) { - claimableAmount += _claimableRevenue(msg.sender, snapshotIds[i], token); - $.isClaimedAtSnapshot[snapshotIds[i]][msg.sender][token] = true; + claimableAmount += _claimableRevenue(claimer, snapshotIds[i], token); + $.isClaimedAtSnapshot[snapshotIds[i]][claimer][token] = true; } if (claimableAmount == 0) revert Errors.IpRoyaltyVault__NoClaimableTokens(); $.claimVaultAmount[token] -= claimableAmount; - IERC20Upgradeable(token).safeTransfer(msg.sender, claimableAmount); + IERC20Upgradeable(token).safeTransfer(claimer, claimableAmount); - emit RevenueTokenClaimed(msg.sender, token, claimableAmount); + emit RevenueTokenClaimed(claimer, token, claimableAmount); return claimableAmount; } @@ -233,9 +247,10 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy if (!ROYALTY_MODULE.hasAncestorIp(targetIpId, _getIpRoyaltyVaultStorage().ipId)) revert Errors.IpRoyaltyVault__VaultDoesNotBelongToAnAncestor(); - uint256[] memory claimedAmounts = IIpRoyaltyVault(targetIpVault).claimRevenueByTokenBatch( + uint256[] memory claimedAmounts = IIpRoyaltyVault(targetIpVault).claimRevenueOnBehalfByTokenBatch( snapshotId, - tokenList + tokenList, + address(this) ); // only tokens that have claimable revenue higher than zero will be added to the vault @@ -262,7 +277,11 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy if (!ROYALTY_MODULE.hasAncestorIp(targetIpId, _getIpRoyaltyVaultStorage().ipId)) revert Errors.IpRoyaltyVault__VaultDoesNotBelongToAnAncestor(); - uint256 claimedAmount = IIpRoyaltyVault(targetIpVault).claimRevenueBySnapshotBatch(snapshotIds, token); + uint256 claimedAmount = IIpRoyaltyVault(targetIpVault).claimRevenueOnBehalfBySnapshotBatch( + snapshotIds, + token, + address(this) + ); // the token will be added to the vault only if claimable revenue is higher than zero _updateVaultBalance(token, claimedAmount); diff --git a/test/foundry/integration/flows/royalty/Royalty.t.sol b/test/foundry/integration/flows/royalty/Royalty.t.sol index 71d13c00b..88e622bc2 100644 --- a/test/foundry/integration/flows/royalty/Royalty.t.sol +++ b/test/foundry/integration/flows/royalty/Royalty.t.sol @@ -187,7 +187,7 @@ contract Flows_Integration_Disputes is BaseIntegration { uint256 aliceBalanceBefore = mockToken.balanceOf(ipAcct[1]); - IpRoyaltyVault(vault).claimRevenueBySnapshotBatch(snapshotIds, address(mockToken)); + IpRoyaltyVault(vault).claimRevenueOnBehalfBySnapshotBatch(snapshotIds, address(mockToken), ipAcct[1]); uint256 aliceBalanceAfter = mockToken.balanceOf(ipAcct[1]); diff --git a/test/foundry/invariants/IpRoyaltyVault.t.sol b/test/foundry/invariants/IpRoyaltyVault.t.sol index 64668932a..f99b7d6e5 100644 --- a/test/foundry/invariants/IpRoyaltyVault.t.sol +++ b/test/foundry/invariants/IpRoyaltyVault.t.sol @@ -25,11 +25,11 @@ contract IpRoyaltyVaultHarness is Test { } function claimRevenueByTokenBatch(uint256 snapshotId, address[] calldata tokenList) public { - vault.claimRevenueByTokenBatch(snapshotId, tokenList); + vault.claimRevenueOnBehalfByTokenBatch(snapshotId, tokenList, address(this)); } function claimRevenueBySnapshotBatch(uint256[] memory snapshotIds, address token) public { - vault.claimRevenueBySnapshotBatch(snapshotIds, token); + vault.claimRevenueOnBehalfBySnapshotBatch(snapshotIds, token, address(this)); } function claimByTokenBatchAsSelf(uint256 snapshotId, address[] calldata tokenList, address targetIpId) public { diff --git a/test/foundry/mocks/grouping/MockEvenSplitGroupPool.sol b/test/foundry/mocks/grouping/MockEvenSplitGroupPool.sol index a2113b356..7f2886d0e 100644 --- a/test/foundry/mocks/grouping/MockEvenSplitGroupPool.sol +++ b/test/foundry/mocks/grouping/MockEvenSplitGroupPool.sol @@ -136,7 +136,7 @@ contract MockEvenSplitGroupPool is IGroupRewardPool { if (address(vault) == address(0)) return; uint256[] memory snapshotsToClaim = new uint256[](1); snapshotsToClaim[0] = vault.snapshot(); - uint256 royalties = vault.claimRevenueBySnapshotBatch(snapshotsToClaim, token); + uint256 royalties = vault.claimRevenueOnBehalfBySnapshotBatch(snapshotsToClaim, token, address(this)); poolInfo[groupId][token].availableBalance += royalties; poolInfo[groupId][token].accBalance += royalties; groupTokens[groupId].add(token); diff --git a/test/foundry/modules/royalty/IpRoyaltyVault.t.sol b/test/foundry/modules/royalty/IpRoyaltyVault.t.sol index 3ad801be3..f4742489d 100644 --- a/test/foundry/modules/royalty/IpRoyaltyVault.t.sol +++ b/test/foundry/modules/royalty/IpRoyaltyVault.t.sol @@ -203,7 +203,7 @@ contract TestIpRoyaltyVault is BaseTest { address[] memory tokens = new address[](1); tokens[0] = address(USDC); vm.startPrank(address(2)); - ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens); + ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, tokens, address(2)); vm.stopPrank(); // take snapshot @@ -225,10 +225,10 @@ contract TestIpRoyaltyVault is BaseTest { royaltyModule.pause(); vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); - ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC)); + ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(new uint256[](0), address(USDC), u.admin); vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector); - ipRoyaltyVault.claimRevenueByTokenBatch(1, new address[](0)); + ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, new address[](0), u.admin); } function test_IpRoyaltyVault_claimableRevenue() public { @@ -262,7 +262,18 @@ contract TestIpRoyaltyVault is BaseTest { assertEq(claimableRevenueMinHolder, (royaltyAmount * 30e6) / 100e6); } - function test_IpRoyaltyVault_claimRevenueByTokenBatch_revert_claimRevenueByTokenBatch() public { + function test_IpRoyaltyVault_claimRevenueOnBehalfByTokenBatch_revert_VaultsMustClaimAsSelf() public { + // deploy vault + vm.startPrank(address(licensingModule)); + royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), ""); + IpRoyaltyVault ipRoyaltyVault = IpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(1))); + vm.stopPrank(); + + vm.expectRevert(Errors.IpRoyaltyVault__VaultsMustClaimAsSelf.selector); + IIpRoyaltyVault(ipRoyaltyVault).claimRevenueOnBehalfByTokenBatch(1, new address[](0), address(ipRoyaltyVault)); + } + + function test_IpRoyaltyVault_claimRevenueOnBehalfByTokenBatch_revert_NoClaimableTokens() public { // deploy vault vm.startPrank(address(licensingModule)); royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), ""); @@ -281,10 +292,10 @@ contract TestIpRoyaltyVault is BaseTest { tokens[0] = address(USDC); vm.expectRevert(Errors.IpRoyaltyVault__NoClaimableTokens.selector); - ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens); + ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, tokens, u.admin); } - function test_IpRoyaltyVault_claimRevenueByTokenBatch() public { + function test_IpRoyaltyVault_claimRevenueOnBehalfByTokenBatch() public { // payment is made to vault uint256 royaltyAmount = 100000 * 10 ** 6; USDC.mint(address(2), royaltyAmount); // 100k USDC @@ -323,7 +334,7 @@ contract TestIpRoyaltyVault is BaseTest { emit IIpRoyaltyVault.RevenueTokenClaimed(address(2), address(USDC), expectedAmount); emit IIpRoyaltyVault.RevenueTokenClaimed(address(2), address(LINK), expectedAmount); - ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens); + ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, tokens, address(2)); assertEq(USDC.balanceOf(address(2)) - userUsdcBalanceBefore, expectedAmount); assertEq(LINK.balanceOf(address(2)) - userLinkBalanceBefore, expectedAmount); @@ -335,7 +346,18 @@ contract TestIpRoyaltyVault is BaseTest { assertEq(ipRoyaltyVault.isClaimedAtSnapshot(1, address(2), address(LINK)), true); } - function test_IpRoyaltyVault_claimRevenueBySnapshotBatch_revert_NoClaimableTokens() public { + function test_IpRoyaltyVault_claimRevenueOnBehalfBySnapshotBatch_revert_VaultsMustClaimAsSelf() public { + // deploy vault + vm.startPrank(address(licensingModule)); + royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), ""); + IpRoyaltyVault ipRoyaltyVault = IpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(1))); + vm.stopPrank(); + + vm.expectRevert(Errors.IpRoyaltyVault__VaultsMustClaimAsSelf.selector); + ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(new uint256[](0), address(USDC), address(ipRoyaltyVault)); + } + + function test_IpRoyaltyVault_claimRevenueOnBehalfBySnapshotBatch_revert_NoClaimableTokens() public { // deploy vault vm.startPrank(address(licensingModule)); royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), ""); @@ -343,10 +365,10 @@ contract TestIpRoyaltyVault is BaseTest { vm.stopPrank(); vm.expectRevert(Errors.IpRoyaltyVault__NoClaimableTokens.selector); - ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC)); + ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(new uint256[](0), address(USDC), u.admin); } - function test_IpRoyaltyVault_claimRevenueBySnapshotBatch() public { + function test_IpRoyaltyVault_claimRevenueOnBehalfBySnapshotBatch() public { uint256 royaltyAmount = 100000 * 10 ** 6; USDC.mint(address(2), royaltyAmount); // 100k USDC @@ -386,7 +408,7 @@ contract TestIpRoyaltyVault is BaseTest { emit IIpRoyaltyVault.RevenueTokenClaimed(address(2), address(USDC), expectedAmount); vm.startPrank(address(2)); - ipRoyaltyVault.claimRevenueBySnapshotBatch(snapshots, address(USDC)); + ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(snapshots, address(USDC), address(2)); assertEq(USDC.balanceOf(address(2)) - userUsdcBalanceBefore, expectedAmount); assertEq(contractUsdcBalanceBefore - USDC.balanceOf(address(ipRoyaltyVault)), expectedAmount); diff --git a/test/foundry/modules/royalty/RoyaltyModule.t.sol b/test/foundry/modules/royalty/RoyaltyModule.t.sol index 46781bd87..4412ea09a 100644 --- a/test/foundry/modules/royalty/RoyaltyModule.t.sol +++ b/test/foundry/modules/royalty/RoyaltyModule.t.sol @@ -439,6 +439,7 @@ contract TestRoyaltyModule is BaseTest { assertEq(ipIdRtBalAfter, royaltyModule.maxPercent()); assertFalse(royaltyModule.ipRoyaltyVaults(licensor) == address(0)); + assertEq(royaltyModule.isIpRoyaltyVault(newVault), true); } function test_RoyaltyModule_onLicenseMinting_NewVaultGroup() public { @@ -459,6 +460,7 @@ contract TestRoyaltyModule is BaseTest { assertEq(groupPoolRtBalAfter, royaltyModule.maxPercent()); assertFalse(royaltyModule.ipRoyaltyVaults(groupId) == address(0)); + assertEq(royaltyModule.isIpRoyaltyVault(newVault), true); } function test_RoyaltyModule_onLicenseMinting_ExistingVault() public { @@ -748,6 +750,7 @@ contract TestRoyaltyModule is BaseTest { uint256 ipId80IpIdRtBalAfter = IERC20(ipRoyaltyVault80).balanceOf(address(80)); assertFalse(royaltyModule.ipRoyaltyVaults(address(80)) == address(0)); + assertEq(royaltyModule.isIpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(80))), true); assertEq(ipId80RtLAPBalAfter, 0); assertEq(ipId80RtLRPBalAfter, 0); assertEq(ipId80RtLRPParentVaultBalAfter, 0);