Skip to content

Commit

Permalink
fix token pause (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
Spablob authored Jun 26, 2024
1 parent 3b6f87f commit ca6e753
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 40 deletions.
10 changes: 10 additions & 0 deletions contracts/interfaces/modules/royalty/policies/IIpRoyaltyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ interface IIpRoyaltyVault {
/// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to
function collectRoyaltyTokens(address ancestorIpId) external;

/// @notice Collect the accrued tokens (if any)
/// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to
/// @param tokens The list of revenue tokens to claim
function collectAccruedTokens(address ancestorIpId, address[] calldata tokens) external;

/// @notice The ip id to whom this royalty vault belongs to
/// @return The ip id address
function ipId() external view returns (address);
Expand All @@ -79,6 +84,11 @@ interface IIpRoyaltyVault {
/// @param token The address of the revenue token
function ancestorsVaultAmount(address token) external view returns (uint256);

/// @notice The amount of revenue tokens that can be collected by the ancestor
/// @param ancestorIpId The ancestor ipId address
/// @param token The address of the revenue token
function collectableAmount(address ancestorIpId, address token) external view returns (uint256);

/// @notice Indicates whether the ancestor has collected the royalty tokens
/// @param ancestorIpId The ancestor ipId address
function isCollectedByAncestor(address ancestorIpId) external view returns (bool);
Expand Down
63 changes: 39 additions & 24 deletions contracts/modules/royalty/policies/IpRoyaltyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
/// @param unclaimedRoyaltyTokens The amount of unclaimed royalty tokens
/// @param lastSnapshotTimestamp The last snapshotted timestamp
/// @param ancestorsVaultAmount The amount of revenue token in the ancestors vault
/// @param collectableAmount The amount of revenue tokens that can be collected by the ancestor
/// @param isCollectedByAncestor Indicates whether the ancestor has collected the royalty tokens
/// @param claimVaultAmount Amount of revenue token in the claim vault
/// @param claimableAtSnapshot Amount of revenue token claimable at a given snapshot
Expand All @@ -40,6 +41,7 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
uint32 unclaimedRoyaltyTokens;
uint40 lastSnapshotTimestamp;
mapping(address token => uint256 amount) ancestorsVaultAmount;
mapping(address ancestorIpId => mapping(address token => uint256 amount)) collectableAmount;
mapping(address ancestorIpId => bool) isCollectedByAncestor;
mapping(address token => uint256 amount) claimVaultAmount;
mapping(uint256 snapshotId => mapping(address token => uint256 amount)) claimableAtSnapshot;
Expand Down Expand Up @@ -234,15 +236,43 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
// transfer royalty tokens to the ancestor
IERC20Upgradeable(address(this)).safeTransfer(ancestorIpId, ancestorsRoyalties[index]);

// collect accrued revenue tokens (if any)
_collectAccruedTokens(ancestorsRoyalties[index], ancestorIpId);
// save the amount of revenue tokens that are collectable by the ancestor
address[] memory tokenList = $.tokens.values();
uint256 unclaimedTokens = $.unclaimedRoyaltyTokens;
for (uint256 i = 0; i < tokenList.length; ++i) {
// the only case in which unclaimedRoyaltyTokens can be 0 is when the vault is empty and everyone claimed
// in which case the call will revert upstream with IpRoyaltyVault__AlreadyClaimed error
uint256 collectAmount = ($.ancestorsVaultAmount[tokenList[i]] * ancestorsRoyalties[index]) /
unclaimedTokens;
if (collectAmount == 0) continue;

$.collectableAmount[ancestorIpId][tokenList[i]] += collectAmount;
}

$.isCollectedByAncestor[ancestorIpId] = true;
$.unclaimedRoyaltyTokens -= ancestorsRoyalties[index];

emit RoyaltyTokensCollected(ancestorIpId, ancestorsRoyalties[index]);
}

/// @notice Collect the accrued tokens (if any)
/// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to
/// @param tokens The list of revenue tokens to claim
function collectAccruedTokens(address ancestorIpId, address[] calldata tokens) external nonReentrant whenNotPaused {
IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage();

if (DISPUTE_MODULE.isIpTagged($.ipId)) revert Errors.IpRoyaltyVault__IpTagged();

for (uint256 i = 0; i < tokens.length; ++i) {
uint256 collectAmount = $.collectableAmount[ancestorIpId][tokens[i]];
$.ancestorsVaultAmount[tokens[i]] -= collectAmount;
$.collectableAmount[ancestorIpId][tokens[i]] -= collectAmount;
IERC20Upgradeable(tokens[i]).safeTransfer(ancestorIpId, collectAmount);

emit RevenueTokenClaimed(ancestorIpId, tokens[i], collectAmount);
}
}

/// @notice A function to calculate the amount of revenue token claimable by a token holder at certain snapshot
/// @param account The address of the token holder
/// @param snapshotId The snapshot id
Expand All @@ -260,28 +290,6 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
return $.isClaimedAtSnapshot[snapshotId][account][token] ? 0 : (balance * claimableToken) / totalSupply;
}

/// @dev Collect the accrued tokens (if any)
/// @param royaltyTokensToClaim The amount of royalty tokens being claimed by the ancestor
/// @param ancestorIpId The ip id of the ancestor to whom the royalty tokens belong to
function _collectAccruedTokens(uint256 royaltyTokensToClaim, address ancestorIpId) internal {
IpRoyaltyVaultStorage storage $ = _getIpRoyaltyVaultStorage();

address[] memory tokenList = $.tokens.values();

for (uint256 i = 0; i < tokenList.length; ++i) {
// the only case in which unclaimedRoyaltyTokens can be 0 is when the vault is empty and everyone claimed
// in which case the call will revert upstream with IpRoyaltyVault__AlreadyClaimed error
uint256 collectAmount = ($.ancestorsVaultAmount[tokenList[i]] * royaltyTokensToClaim) /
$.unclaimedRoyaltyTokens;
if (collectAmount == 0) continue;

$.ancestorsVaultAmount[tokenList[i]] -= collectAmount;
IERC20Upgradeable(tokenList[i]).safeTransfer(ancestorIpId, collectAmount);

emit RevenueTokenClaimed(ancestorIpId, tokenList[i], collectAmount);
}
}

/// @notice The ip id to whom this royalty vault belongs to
/// @return The ip id address
function ipId() external view returns (address) {
Expand All @@ -304,6 +312,13 @@ contract IpRoyaltyVault is IIpRoyaltyVault, ERC20SnapshotUpgradeable, Reentrancy
return _getIpRoyaltyVaultStorage().ancestorsVaultAmount[token];
}

/// @notice The amount of revenue tokens that can be collected by the ancestor
/// @param ancestorIpId The ancestor ipId address
/// @param token The address of the revenue token
function collectableAmount(address ancestorIpId, address token) external view returns (uint256) {
return _getIpRoyaltyVaultStorage().collectableAmount[ancestorIpId][token];
}

/// @notice Indicates whether the ancestor has collected the royalty tokens
/// @param ancestorIpId The ancestor ipId address
function isCollectedByAncestor(address ancestorIpId) external view returns (bool) {
Expand Down
87 changes: 71 additions & 16 deletions test/foundry/modules/royalty/IpRoyaltyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ contract TestIpRoyaltyVault is BaseTest {
);
}

function test_IpRoyaltyVault_ClaimRevenueByTokenBatch_revert_Paused() public {
vm.stopPrank();
vm.prank(u.admin);
royaltyPolicyLAP.pause();

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.claimRevenueByTokenBatch(1, new address[](0));
}

function test_IpRoyaltyVault_ClaimRevenueByTokenBatch() public {
// payment is made to vault
uint256 royaltyAmount = 100000 * 10 ** 6;
Expand Down Expand Up @@ -217,6 +226,15 @@ contract TestIpRoyaltyVault is BaseTest {
assertEq(ipRoyaltyVault.isClaimedAtSnapshot(1, address(2), address(LINK)), true);
}

function test_IpRoyaltyVault_ClaimRevenueBySnapshotBatch_revert_Paused() public {
vm.stopPrank();
vm.prank(u.admin);
royaltyPolicyLAP.pause();

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC));
}

function test_IpRoyaltyVault_ClaimRevenueBySnapshotBatch() public {
uint256 royaltyAmount = 100000 * 10 ** 6;
USDC.mint(address(3), royaltyAmount); // 100k USDC
Expand Down Expand Up @@ -268,7 +286,7 @@ contract TestIpRoyaltyVault is BaseTest {
ipRoyaltyVault.snapshot();
}

function test_IpRoyaltyVault_Snapshot_revert_paused() public {
function test_IpRoyaltyVault_Snapshot_revert_Paused() public {
// payment is made to vault
vm.stopPrank();
vm.prank(u.admin);
Expand Down Expand Up @@ -348,11 +366,17 @@ contract TestIpRoyaltyVault is BaseTest {
ipRoyaltyVault.claimRevenueByTokenBatch(1, tokens);

ipRoyaltyVault.collectRoyaltyTokens(address(5));
ipRoyaltyVault.collectAccruedTokens(address(5), tokens);
ipRoyaltyVault.collectRoyaltyTokens(address(11));
ipRoyaltyVault.collectAccruedTokens(address(11), tokens);
ipRoyaltyVault.collectRoyaltyTokens(address(12));
ipRoyaltyVault.collectAccruedTokens(address(12), tokens);
ipRoyaltyVault.collectRoyaltyTokens(address(6));
ipRoyaltyVault.collectAccruedTokens(address(6), tokens);
ipRoyaltyVault.collectRoyaltyTokens(address(13));
ipRoyaltyVault.collectAccruedTokens(address(13), tokens);
ipRoyaltyVault.collectRoyaltyTokens(address(14));
ipRoyaltyVault.collectAccruedTokens(address(14), tokens);

// take snapshot
vm.warp(block.timestamp + 7 days + 1);
Expand All @@ -374,6 +398,15 @@ contract TestIpRoyaltyVault is BaseTest {
ipRoyaltyVault.collectRoyaltyTokens(address(0));
}

function test_IpRoyaltyVault_CollectRoyaltyTokens_revert_Paused() public {
vm.stopPrank();
vm.prank(u.admin);
royaltyPolicyLAP.pause();

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.collectRoyaltyTokens(address(5));
}

function test_IpRoyaltyVault_CollectRoyaltyTokens() public {
uint256 parentRoyalty = 5 * 10 ** 5;
uint256 royaltyAmount = 100000 * 10 ** 6;
Expand Down Expand Up @@ -404,39 +437,61 @@ contract TestIpRoyaltyVault is BaseTest {

ipRoyaltyVault.collectRoyaltyTokens(address(5));

assertEq(USDC.balanceOf(address(5)) - userUsdcBalanceBefore, accruedCollectableRevenue);
assertEq(contractUsdcBalanceBefore - USDC.balanceOf(address(ipRoyaltyVault)), accruedCollectableRevenue);
assertEq(ipRoyaltyVault.isCollectedByAncestor(address(5)), true);
assertEq(
contractRTBalBefore - IERC20(address(ipRoyaltyVault)).balanceOf(address(ipRoyaltyVault)),
parentRoyalty
);
assertEq(IERC20(address(ipRoyaltyVault)).balanceOf(address(5)) - userRTBalBefore, parentRoyalty);
assertEq(unclaimedRoyaltyTokensBefore - ipRoyaltyVault.unclaimedRoyaltyTokens(), parentRoyalty);
assertEq(
ancestorsVaultAmountBefore - ipRoyaltyVault.ancestorsVaultAmount(address(USDC)),
accruedCollectableRevenue
);
}

function test_IpRoyaltyVault_claimRevenue_revert_paused() public {
function test_IpRoyaltyVault_CollectAccruedTokens_revert_Paused() public {
vm.stopPrank();
vm.prank(u.admin);
royaltyPolicyLAP.pause();

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.claimRevenueBySnapshotBatch(new uint256[](0), address(USDC));

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
ipRoyaltyVault.claimRevenueByTokenBatch(1, new address[](0));
ipRoyaltyVault.collectAccruedTokens(address(5), new address[](0));
}

function test_IpRoyaltyVault_collectRoyaltyTokens_revert_paused() public {
function test_IpRoyaltyVault_CollectAccruedTokens() public {
uint256 parentRoyalty = 5 * 10 ** 5;
uint256 royaltyAmount = 100000 * 10 ** 6;
uint256 accruedCollectableRevenue = (royaltyAmount * 5 * 10 ** 5) / royaltyPolicyLAP.TOTAL_RT_SUPPLY();

// payment is made to vault
USDC.mint(address(3), royaltyAmount); // 100k USDC
vm.startPrank(address(3));
USDC.approve(address(royaltyPolicyLAP), royaltyAmount);
royaltyModule.payRoyaltyOnBehalf(address(2), address(3), address(USDC), royaltyAmount);
vm.stopPrank();
vm.prank(u.admin);
royaltyPolicyLAP.pause();

vm.expectRevert(Errors.IpRoyaltyVault__EnforcedPause.selector);
// take snapshot
vm.warp(block.timestamp + 7 days + 1);
ipRoyaltyVault.snapshot();

// collect royalty tokens
ipRoyaltyVault.collectRoyaltyTokens(address(5));

address[] memory tokens = new address[](1);
tokens[0] = address(USDC);

uint256 userUsdcBalanceBefore = USDC.balanceOf(address(5));
uint256 contractUsdcBalanceBefore = USDC.balanceOf(address(ipRoyaltyVault));
uint256 ancestorsVaultAmountBefore = ipRoyaltyVault.ancestorsVaultAmount(address(USDC));
uint256 collectableAmountBefore = ipRoyaltyVault.collectableAmount(address(5), address(USDC));

ipRoyaltyVault.collectAccruedTokens(address(5), tokens);

uint256 userUsdcBalanceAfter = USDC.balanceOf(address(5));
uint256 contractUsdcBalanceAfter = USDC.balanceOf(address(ipRoyaltyVault));
uint256 ancestorsVaultAmountAfter = ipRoyaltyVault.ancestorsVaultAmount(address(USDC));
uint256 collectableAmountAfter = ipRoyaltyVault.collectableAmount(address(5), address(USDC));

assertEq(userUsdcBalanceAfter - userUsdcBalanceBefore, accruedCollectableRevenue);
assertEq(contractUsdcBalanceBefore - contractUsdcBalanceAfter, accruedCollectableRevenue);
assertEq(ancestorsVaultAmountBefore - ancestorsVaultAmountAfter, accruedCollectableRevenue);
assertEq(collectableAmountAfter, 0);
}
}

0 comments on commit ca6e753

Please sign in to comment.