diff --git a/contracts/SupplyVault.sol b/contracts/SupplyVault.sol index 0e6e2c74..a157dfee 100644 --- a/contracts/SupplyVault.sol +++ b/contracts/SupplyVault.sol @@ -56,7 +56,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { uint256 public timelock; /// @dev Stores the total assets owned by this vault when the fee was last accrued. - uint256 lastTotalAssets; + uint256 public lastTotalAssets; ConfigSet private _config; diff --git a/test/forge/BaseTest.sol b/test/forge/BaseTest.sol index f4a71433..382436cc 100644 --- a/test/forge/BaseTest.sol +++ b/test/forge/BaseTest.sol @@ -9,6 +9,8 @@ import {UtilsLib} from "@morpho-blue/libraries/UtilsLib.sol"; import {ERC20Mock as ERC20} from "contracts/mocks/ERC20Mock.sol"; import {OracleMock as Oracle} from "contracts/mocks/OracleMock.sol"; +import {WAD} from "@morpho-blue/libraries/MathLib.sol"; +import {Math} from "@openzeppelin/contracts/token/ERC20/extensions/ERC4626.sol"; import {SupplyVault, IERC20, ErrorsLib, Pending, MarketAllocation} from "contracts/SupplyVault.sol"; import {Morpho, MarketParamsLib, MarketParams, SharesMathLib, Id} from "@morpho-blue/Morpho.sol"; diff --git a/test/forge/FeeTest.sol b/test/forge/FeeTest.sol index f86d4c6d..9c882ec6 100644 --- a/test/forge/FeeTest.sol +++ b/test/forge/FeeTest.sol @@ -4,9 +4,10 @@ pragma solidity ^0.8.0; import "./BaseTest.sol"; contract FeeTest is BaseTest { + using Math for uint256; using MarketParamsLib for MarketParams; - uint256 internal constant FEE = 0.1 ether; // 10% + uint256 constant internal FEE = 0.1 ether; // 10% function setUp() public override { super.setUp(); @@ -24,21 +25,64 @@ contract FeeTest is BaseTest { vm.stopPrank(); } - function testShouldMintSharesToFeeRecipient(uint256 amount) public { + function testLastTotalAssets(uint256 amount) public { + amount = bound(amount, MIN_TEST_AMOUNT, MAX_TEST_AMOUNT); + + _setFee(); + + assertEq(vault.lastTotalAssets(), 0); + + vm.prank(SUPPLIER); + vault.deposit(amount, SUPPLIER); + + // Update lastTotalAssets + vm.prank(SUPPLIER); + vault.withdraw(10, SUPPLIER, SUPPLIER); + + assertEq(vault.lastTotalAssets(), amount); + } + + function testAccounting() public { + uint256 amount = MAX_TEST_AMOUNT; + _setFee(); + assertEq(vault.lastTotalAssets(), 0); + + vm.prank(SUPPLIER); + vault.deposit(amount, SUPPLIER); + + // Update lastTotalAssets + vm.prank(SUPPLIER); + vault.withdraw(10, SUPPLIER, SUPPLIER); + + // supplier balance 9999999999999999999999999988 + // fee recipient balance 1111111111111111111111111111 + console2.log("supplier balance", vault.balanceOf(SUPPLIER)); + console2.log("fee recipient balance", vault.balanceOf(FEE_RECIPIENT)); + } + + function testShouldMintSharesToFeeRecipient(uint256 amount) public { amount = bound(amount, MIN_TEST_AMOUNT, MAX_TEST_AMOUNT); + _setFee(); + vm.prank(SUPPLIER); uint256 shares = vault.deposit(amount, SUPPLIER); - _borrow(allMarkets[0], amount / 2); + uint256 lastTotalAssets = vault.lastTotalAssets(); vm.warp(block.timestamp + 365 days); + uint256 totalAssetsAfter = vault.totalAssets(); + uint256 interest = totalAssetsAfter - lastTotalAssets; + uint256 feeAmount = interest.mulDiv(FEE, WAD); + uint256 feeShares = feeAmount.mulDiv(vault.totalSupply() + 1, totalAssetsAfter - feeAmount + 1, Math.Rounding.Down); + vm.prank(SUPPLIER); - vault.redeem(shares / 3, SUPPLIER, SUPPLIER); + vault.redeem(shares / 10, SUPPLIER, SUPPLIER); assertGt(vault.balanceOf(FEE_RECIPIENT), 0, "fee recipient balance is zero"); + assertEq(vault.balanceOf(FEE_RECIPIENT), feeShares, "fee recipient balance"); } }