Skip to content

Commit

Permalink
Allow revenue token claiming to be done on behalf of the claimer (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
Spablob authored Sep 18, 2024
1 parent de6de66 commit e643f5f
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 33 deletions.
5 changes: 5 additions & 0 deletions contracts/interfaces/modules/royalty/IRoyaltyModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ interface IRoyaltyModule is IModule {
/// @return isWhitelisted True if the royalty token is whitelisted
function isWhitelistedRoyaltyToken(address token) external view returns (bool);

/// @notice Indicates if an address is a royalty vault
/// @param ipRoyaltyVault The address to check
/// @return isIpRoyaltyVault True if the address is a royalty vault
function isIpRoyaltyVault(address ipRoyaltyVault) external view returns (bool);

/// @notice Indicates the royalty vault for a given IP asset
/// @param ipId The ID of IP asset
function ipRoyaltyVaults(address ipId) external view returns (address);
Expand Down
13 changes: 10 additions & 3 deletions contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,24 @@ interface IIpRoyaltyVault {
/// @notice Allows token holders to claim revenue token based on the token balance at certain snapshot
/// @param snapshotId The snapshot id
/// @param tokenList The list of revenue tokens to claim
/// @param claimer The address of the claimer
/// @return The amount of revenue tokens claimed for each token
function claimRevenueByTokenBatch(
function claimRevenueOnBehalfByTokenBatch(
uint256 snapshotId,
address[] calldata tokenList
address[] calldata tokenList,
address claimer
) external returns (uint256[] memory);

/// @notice Allows token holders to claim by a list of snapshot ids based on the token balance at certain snapshot
/// @param snapshotIds The list of snapshot ids
/// @param token The revenue token to claim
/// @param claimer The address of the claimer
/// @return The amount of revenue tokens claimed
function claimRevenueBySnapshotBatch(uint256[] memory snapshotIds, address token) external returns (uint256);
function claimRevenueOnBehalfBySnapshotBatch(
uint256[] memory snapshotIds,
address token,
address claimer
) external returns (uint256);

/// @notice Allows to claim revenue tokens on behalf of the ip royalty vault
/// @param snapshotId The snapshot id
Expand Down
3 changes: 3 additions & 0 deletions contracts/lib/Errors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ library Errors {
/// @notice Zero amount provided.
error IpRoyaltyVault__ZeroAmount();

/// @notice Vaults must claim as self.
error IpRoyaltyVault__VaultsMustClaimAsSelf();

////////////////////////////////////////////////////////////////////////////
// Vault Controller //
////////////////////////////////////////////////////////////////////////////
Expand Down
10 changes: 10 additions & 0 deletions contracts/modules/royalty/RoyaltyModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad
/// @param isWhitelistedRoyaltyToken Indicates if a royalty token is whitelisted
/// @param isRegisteredExternalRoyaltyPolicy Indicates if an external royalty policy is registered
/// @param ipRoyaltyVaults The royalty vault address for a given IP asset (if any)
/// @param isIpRoyaltyVault Indicates if an address is a royalty vault
/// @param globalRoyaltyStack Sum of royalty stack from each whitelisted royalty policy for a given IP asset
/// @param accumulatedRoyaltyPolicies The accumulated royalty policies for a given IP asset
/// @param totalRevenueTokensReceived The total lifetime revenue tokens received for a given IP asset
Expand All @@ -74,6 +75,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad
mapping(address token => bool) isWhitelistedRoyaltyToken;
mapping(address royaltyPolicy => bool) isRegisteredExternalRoyaltyPolicy;
mapping(address ipId => address ipRoyaltyVault) ipRoyaltyVaults;
mapping(address ipRoyaltyVault => bool) isIpRoyaltyVault;
mapping(address ipId => uint32) globalRoyaltyStack;
mapping(address ipId => EnumerableSet.AddressSet) accumulatedRoyaltyPolicies;
mapping(address ipId => mapping(address token => uint256)) totalRevenueTokensReceived;
Expand Down Expand Up @@ -421,6 +423,13 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad
return _getRoyaltyModuleStorage().isWhitelistedRoyaltyToken[token];
}

/// @notice Indicates if an address is a royalty vault
/// @param ipRoyaltyVault The address to check
/// @return isIpRoyaltyVault True if the address is a royalty vault
function isIpRoyaltyVault(address ipRoyaltyVault) external view returns (bool) {
return _getRoyaltyModuleStorage().isIpRoyaltyVault[ipRoyaltyVault];
}

/// @notice Indicates the royalty vault for a given IP asset
/// @param ipId The ID of IP asset
function ipRoyaltyVaults(address ipId) external view returns (address) {
Expand Down Expand Up @@ -461,6 +470,7 @@ contract RoyaltyModule is IRoyaltyModule, VaultController, ReentrancyGuardUpgrad
address ipRoyaltyVault = address(new BeaconProxy(ipRoyaltyVaultBeacon(), ""));
IIpRoyaltyVault(ipRoyaltyVault).initialize("Royalty Token", "RT", MAX_PERCENT, ipId, receiver);
$.ipRoyaltyVaults[ipId] = ipRoyaltyVault;
$.isIpRoyaltyVault[ipRoyaltyVault] = true;

return ipRoyaltyVault;
}
Expand Down
49 changes: 34 additions & 15 deletions contracts/modules/royalty/policies/IpRoyaltyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ import { Errors } from "../../../lib/Errors.sol";

/// @title Ip Royalty Vault
/// @notice Defines the logic for claiming revenue tokens for a given IP
/// @dev [CAUTION]
/// Do not transfer ERC20 tokens directly to the ip royalty vault as they can be lost if the pendingVaultAmount
/// is not updated along with an ERC20 transfer.
/// Use appropriate callpaths that can update the pendingVaultAmount when an ERC20 transfer to the vault is made.
contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, ReentrancyGuardUpgradeable {
using EnumerableSet for EnumerableSet.AddressSet;
using SafeERC20Upgradeable for IERC20Upgradeable;
Expand Down Expand Up @@ -167,23 +171,28 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
/// @notice Allows token holders to claim revenue token based on the token balance at certain snapshot
/// @param snapshotId The snapshot id
/// @param tokenList The list of revenue tokens to claim
/// @param claimer The address of the claimer
/// @return The amount of revenue tokens claimed for each token
function claimRevenueByTokenBatch(
function claimRevenueOnBehalfByTokenBatch(
uint256 snapshotId,
address[] calldata tokenList
address[] calldata tokenList,
address claimer
) external nonReentrant whenNotPaused returns (uint256[] memory) {
IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage();

if (ROYALTY_MODULE.isIpRoyaltyVault(claimer) && msg.sender != claimer)
revert Errors.IpRoyaltyVault__VaultsMustClaimAsSelf();

uint256[] memory claimableAmounts = new uint256[](tokenList.length);
for (uint256 i = 0; i < tokenList.length; i++) {
claimableAmounts[i] = _claimableRevenue(msg.sender, snapshotId, tokenList[i]);
claimableAmounts[i] = _claimableRevenue(claimer, snapshotId, tokenList[i]);
if (claimableAmounts[i] == 0) revert Errors.IpRoyaltyVault__NoClaimableTokens();

$.isClaimedAtSnapshot[snapshotId][msg.sender][tokenList[i]] = true;
$.isClaimedAtSnapshot[snapshotId][claimer][tokenList[i]] = true;
$.claimVaultAmount[tokenList[i]] -= claimableAmounts[i];
IERC20Upgradeable(tokenList[i]).safeTransfer(msg.sender, claimableAmounts[i]);
IERC20Upgradeable(tokenList[i]).safeTransfer(claimer, claimableAmounts[i]);

emit RevenueTokenClaimed(msg.sender, tokenList[i], claimableAmounts[i]);
emit RevenueTokenClaimed(claimer, tokenList[i], claimableAmounts[i]);
}

return claimableAmounts;
Expand All @@ -192,25 +201,30 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
/// @notice Allows token holders to claim by a list of snapshot ids based on the token balance at certain snapshot
/// @param snapshotIds The list of snapshot ids
/// @param token The revenue token to claim
/// @param claimer The address of the claimer
/// @return The amount of revenue tokens claimed
function claimRevenueBySnapshotBatch(
function claimRevenueOnBehalfBySnapshotBatch(
uint256[] memory snapshotIds,
address token
address token,
address claimer
) external nonReentrant whenNotPaused returns (uint256) {
IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage();

if (ROYALTY_MODULE.isIpRoyaltyVault(claimer) && msg.sender != claimer)
revert Errors.IpRoyaltyVault__VaultsMustClaimAsSelf();

uint256 claimableAmount;
for (uint256 i = 0; i < snapshotIds.length; i++) {
claimableAmount += _claimableRevenue(msg.sender, snapshotIds[i], token);
$.isClaimedAtSnapshot[snapshotIds[i]][msg.sender][token] = true;
claimableAmount += _claimableRevenue(claimer, snapshotIds[i], token);
$.isClaimedAtSnapshot[snapshotIds[i]][claimer][token] = true;
}

if (claimableAmount == 0) revert Errors.IpRoyaltyVault__NoClaimableTokens();

$.claimVaultAmount[token] -= claimableAmount;
IERC20Upgradeable(token).safeTransfer(msg.sender, claimableAmount);
IERC20Upgradeable(token).safeTransfer(claimer, claimableAmount);

emit RevenueTokenClaimed(msg.sender, token, claimableAmount);
emit RevenueTokenClaimed(claimer, token, claimableAmount);

return claimableAmount;
}
Expand All @@ -233,9 +247,10 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
if (!ROYALTY_MODULE.hasAncestorIp(targetIpId, _getIpRoyaltyVaultStorage().ipId))
revert Errors.IpRoyaltyVault__VaultDoesNotBelongToAnAncestor();

uint256[] memory claimedAmounts = IIpRoyaltyVault(targetIpVault).claimRevenueByTokenBatch(
uint256[] memory claimedAmounts = IIpRoyaltyVault(targetIpVault).claimRevenueOnBehalfByTokenBatch(
snapshotId,
tokenList
tokenList,
address(this)
);

// only tokens that have claimable revenue higher than zero will be added to the vault
Expand All @@ -262,7 +277,11 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
if (!ROYALTY_MODULE.hasAncestorIp(targetIpId, _getIpRoyaltyVaultStorage().ipId))
revert Errors.IpRoyaltyVault__VaultDoesNotBelongToAnAncestor();

uint256 claimedAmount = IIpRoyaltyVault(targetIpVault).claimRevenueBySnapshotBatch(snapshotIds, token);
uint256 claimedAmount = IIpRoyaltyVault(targetIpVault).claimRevenueOnBehalfBySnapshotBatch(
snapshotIds,
token,
address(this)
);

// the token will be added to the vault only if claimable revenue is higher than zero
_updateVaultBalance(token, claimedAmount);
Expand Down
2 changes: 1 addition & 1 deletion test/foundry/integration/flows/royalty/Royalty.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ contract Flows_Integration_Disputes is BaseIntegration {

uint256 aliceBalanceBefore = mockToken.balanceOf(ipAcct[1]);

IpRoyaltyVault(vault).claimRevenueBySnapshotBatch(snapshotIds, address(mockToken));
IpRoyaltyVault(vault).claimRevenueOnBehalfBySnapshotBatch(snapshotIds, address(mockToken), ipAcct[1]);

uint256 aliceBalanceAfter = mockToken.balanceOf(ipAcct[1]);

Expand Down
4 changes: 2 additions & 2 deletions test/foundry/invariants/IpRoyaltyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ contract IpRoyaltyVaultHarness is Test {
}

function claimRevenueByTokenBatch(uint256 snapshotId, address[] calldata tokenList) public {
vault.claimRevenueByTokenBatch(snapshotId, tokenList);
vault.claimRevenueOnBehalfByTokenBatch(snapshotId, tokenList, address(this));
}

function claimRevenueBySnapshotBatch(uint256[] memory snapshotIds, address token) public {
vault.claimRevenueBySnapshotBatch(snapshotIds, token);
vault.claimRevenueOnBehalfBySnapshotBatch(snapshotIds, token, address(this));
}

function claimByTokenBatchAsSelf(uint256 snapshotId, address[] calldata tokenList, address targetIpId) public {
Expand Down
2 changes: 1 addition & 1 deletion test/foundry/mocks/grouping/MockEvenSplitGroupPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ contract MockEvenSplitGroupPool is IGroupRewardPool {
if (address(vault) == address(0)) return;
uint256[] memory snapshotsToClaim = new uint256[](1);
snapshotsToClaim[0] = vault.snapshot();
uint256 royalties = vault.claimRevenueBySnapshotBatch(snapshotsToClaim, token);
uint256 royalties = vault.claimRevenueOnBehalfBySnapshotBatch(snapshotsToClaim, token, address(this));
poolInfo[groupId][token].availableBalance += royalties;
poolInfo[groupId][token].accBalance += royalties;
groupTokens[groupId].add(token);
Expand Down
44 changes: 33 additions & 11 deletions test/foundry/modules/royalty/IpRoyaltyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ contract TestIpRoyaltyVault is BaseTest {
address[] memory tokens = new address[](1);
tokens[0] = address(USDC);
vm.startPrank(address(2));
ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens);
ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, tokens, address(2));
vm.stopPrank();

// take snapshot
Expand All @@ -225,10 +225,10 @@ contract TestIpRoyaltyVault is BaseTest {
royaltyModule.pause();

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC));
ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(new uint256[](0), address(USDC), u.admin);

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.claimRevenueByTokenBatch(1, new address[](0));
ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, new address[](0), u.admin);
}

function test_IpRoyaltyVault_claimableRevenue() public {
Expand Down Expand Up @@ -262,7 +262,18 @@ contract TestIpRoyaltyVault is BaseTest {
assertEq(claimableRevenueMinHolder, (royaltyAmount * 30e6) / 100e6);
}

function test_IpRoyaltyVault_claimRevenueByTokenBatch_revert_claimRevenueByTokenBatch() public {
function test_IpRoyaltyVault_claimRevenueOnBehalfByTokenBatch_revert_VaultsMustClaimAsSelf() public {
// deploy vault
vm.startPrank(address(licensingModule));
royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), "");
IpRoyaltyVault ipRoyaltyVault = IpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(1)));
vm.stopPrank();

vm.expectRevert(Errors.IpRoyaltyVault__VaultsMustClaimAsSelf.selector);
IIpRoyaltyVault(ipRoyaltyVault).claimRevenueOnBehalfByTokenBatch(1, new address[](0), address(ipRoyaltyVault));
}

function test_IpRoyaltyVault_claimRevenueOnBehalfByTokenBatch_revert_NoClaimableTokens() public {
// deploy vault
vm.startPrank(address(licensingModule));
royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), "");
Expand All @@ -281,10 +292,10 @@ contract TestIpRoyaltyVault is BaseTest {
tokens[0] = address(USDC);

vm.expectRevert(Errors.IpRoyaltyVault__NoClaimableTokens.selector);
ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens);
ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, tokens, u.admin);
}

function test_IpRoyaltyVault_claimRevenueByTokenBatch() public {
function test_IpRoyaltyVault_claimRevenueOnBehalfByTokenBatch() public {
// payment is made to vault
uint256 royaltyAmount = 100000 * 10 ** 6;
USDC.mint(address(2), royaltyAmount); // 100k USDC
Expand Down Expand Up @@ -323,7 +334,7 @@ contract TestIpRoyaltyVault is BaseTest {
emit IIpRoyaltyVault.RevenueTokenClaimed(address(2), address(USDC), expectedAmount);
emit IIpRoyaltyVault.RevenueTokenClaimed(address(2), address(LINK), expectedAmount);

ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens);
ipRoyaltyVault.claimRevenueOnBehalfByTokenBatch(1, tokens, address(2));

assertEq(USDC.balanceOf(address(2)) - userUsdcBalanceBefore, expectedAmount);
assertEq(LINK.balanceOf(address(2)) - userLinkBalanceBefore, expectedAmount);
Expand All @@ -335,18 +346,29 @@ contract TestIpRoyaltyVault is BaseTest {
assertEq(ipRoyaltyVault.isClaimedAtSnapshot(1, address(2), address(LINK)), true);
}

function test_IpRoyaltyVault_claimRevenueBySnapshotBatch_revert_NoClaimableTokens() public {
function test_IpRoyaltyVault_claimRevenueOnBehalfBySnapshotBatch_revert_VaultsMustClaimAsSelf() public {
// deploy vault
vm.startPrank(address(licensingModule));
royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), "");
IpRoyaltyVault ipRoyaltyVault = IpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(1)));
vm.stopPrank();

vm.expectRevert(Errors.IpRoyaltyVault__VaultsMustClaimAsSelf.selector);
ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(new uint256[](0), address(USDC), address(ipRoyaltyVault));
}

function test_IpRoyaltyVault_claimRevenueOnBehalfBySnapshotBatch_revert_NoClaimableTokens() public {
// deploy vault
vm.startPrank(address(licensingModule));
royaltyModule.onLicenseMinting(address(1), address(royaltyPolicyLAP), uint32(10 * 10 ** 6), "");
IpRoyaltyVault ipRoyaltyVault = IpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(1)));
vm.stopPrank();

vm.expectRevert(Errors.IpRoyaltyVault__NoClaimableTokens.selector);
ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC));
ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(new uint256[](0), address(USDC), u.admin);
}

function test_IpRoyaltyVault_claimRevenueBySnapshotBatch() public {
function test_IpRoyaltyVault_claimRevenueOnBehalfBySnapshotBatch() public {
uint256 royaltyAmount = 100000 * 10 ** 6;
USDC.mint(address(2), royaltyAmount); // 100k USDC

Expand Down Expand Up @@ -386,7 +408,7 @@ contract TestIpRoyaltyVault is BaseTest {
emit IIpRoyaltyVault.RevenueTokenClaimed(address(2), address(USDC), expectedAmount);

vm.startPrank(address(2));
ipRoyaltyVault.claimRevenueBySnapshotBatch(snapshots, address(USDC));
ipRoyaltyVault.claimRevenueOnBehalfBySnapshotBatch(snapshots, address(USDC), address(2));

assertEq(USDC.balanceOf(address(2)) - userUsdcBalanceBefore, expectedAmount);
assertEq(contractUsdcBalanceBefore - USDC.balanceOf(address(ipRoyaltyVault)), expectedAmount);
Expand Down
3 changes: 3 additions & 0 deletions test/foundry/modules/royalty/RoyaltyModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ contract TestRoyaltyModule is BaseTest {

assertEq(ipIdRtBalAfter, royaltyModule.maxPercent());
assertFalse(royaltyModule.ipRoyaltyVaults(licensor) == address(0));
assertEq(royaltyModule.isIpRoyaltyVault(newVault), true);
}

function test_RoyaltyModule_onLicenseMinting_NewVaultGroup() public {
Expand All @@ -459,6 +460,7 @@ contract TestRoyaltyModule is BaseTest {

assertEq(groupPoolRtBalAfter, royaltyModule.maxPercent());
assertFalse(royaltyModule.ipRoyaltyVaults(groupId) == address(0));
assertEq(royaltyModule.isIpRoyaltyVault(newVault), true);
}

function test_RoyaltyModule_onLicenseMinting_ExistingVault() public {
Expand Down Expand Up @@ -748,6 +750,7 @@ contract TestRoyaltyModule is BaseTest {
uint256 ipId80IpIdRtBalAfter = IERC20(ipRoyaltyVault80).balanceOf(address(80));

assertFalse(royaltyModule.ipRoyaltyVaults(address(80)) == address(0));
assertEq(royaltyModule.isIpRoyaltyVault(royaltyModule.ipRoyaltyVaults(address(80))), true);
assertEq(ipId80RtLAPBalAfter, 0);
assertEq(ipId80RtLRPBalAfter, 0);
assertEq(ipId80RtLRPParentVaultBalAfter, 0);
Expand Down

0 comments on commit e643f5f

Please sign in to comment.