diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 88594d9..c5946f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,6 +2,9 @@ name: CI on: push: + branches: + - main + - dev pull_request: workflow_dispatch: diff --git a/.solhint.json b/.solhint.json new file mode 100644 index 0000000..dfc6861 --- /dev/null +++ b/.solhint.json @@ -0,0 +1,17 @@ +{ + "extends": "solhint:recommended", + "plugins": [], + "rules": { + "avoid-suicide": "error", + "avoid-sha3": "warn", + "compiler-version": ["error", "^0.8.0"], + "func-visibility": ["warn", { "ignoreConstructors": true }], + "reason-string": ["warn", { "maxLength": 64 }], + "not-rely-on-time": "warn", + "state-visibility": "error", + "max-line-length": ["warn", 122], + "no-console": "off", + "func-name-mixedcase": "off", + "no-inline-assembly": "off" + } +} diff --git a/lib/yieldnest-vault b/lib/yieldnest-vault index 9a86e09..a9764a5 160000 --- a/lib/yieldnest-vault +++ b/lib/yieldnest-vault @@ -1 +1 @@ -Subproject commit 9a86e096e7231611167374a8b6cc1f0a5779732c +Subproject commit a9764a5b7e489f10b92501cd66acb4fa05c347fe diff --git a/src/BaseKeeper.sol b/src/BaseKeeper.sol index 4684f5d..795eb05 100644 --- a/src/BaseKeeper.sol +++ b/src/BaseKeeper.sol @@ -1,14 +1,26 @@ // SPDX-License-Identifier: BSD-3-Clause pragma solidity ^0.8.24; -// import {IVault} from "lib/yieldnest-vault/src/interface/IVault.sol"; +import {Ownable} from "lib/yieldnest-vault/lib/openzeppelin-contracts/contracts/access/Ownable.sol"; +import {IERC20} from "lib/yieldnest-vault/lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; +import {IProvider} from "lib/yieldnest-vault/src/interface/IProvider.sol"; -contract BaseKeeper { +import {Vault} from "lib/yieldnest-vault/src/Vault.sol"; +import {IVault} from "lib/yieldnest-vault/src/interface/IVault.sol"; +import {Math} from "src/libraries/Math.sol"; + +import {console} from "lib/yieldnest-vault/lib/forge-std/src/console.sol"; + +contract BaseKeeper is Ownable { uint256[] public initialRatios; - uint256[] public finalRatios; + uint256[] public targetRatios; // vault[0] is max vault and rest are underlying vaults address[] public vaults; + address public asset; + IVault public maxVault; + + uint256 public tolerance; struct Transfer { uint256 from; @@ -16,46 +28,111 @@ contract BaseKeeper { uint256 amount; } - function setData(uint256[] memory _initialRatios, uint256[] memory _finalRatios, address[] memory _vaults) public { - require(_initialRatios.length > 1, "Array length must be greater than 1"); - require(_initialRatios.length == _finalRatios.length, "Array lengths must match"); - require(_initialRatios.length == _vaults.length, "Array lengths must match"); + constructor() Ownable(msg.sender) {} - initialRatios = _initialRatios; - finalRatios = _finalRatios; + function setTolerance(uint256 _tolerance) public onlyOwner { + tolerance = _tolerance; + } + + function setData(uint256[] memory _targetRatios, address[] memory _vaults) public onlyOwner { + require(_targetRatios.length > 1, "Array length must be greater than 1"); + require(_targetRatios.length == _vaults.length, "Array lengths must match"); + + targetRatios = _targetRatios; + initialRatios = new uint256[](_targetRatios.length); vaults = _vaults; + for (uint256 i = 0; i < _vaults.length; i++) { + require(isVault(_vaults[i]), "Invalid vault"); + } + + maxVault = IVault(payable(_vaults[0])); + asset = maxVault.asset(); + + require(_totalInitialRatios() == _totalTargetRatios(), "Initial and target ratios must match"); } - function totalInitialRatios() public view returns (uint256) { + function _totalInitialRatios() internal returns (uint256) { + uint256 totalAssets = maxVault.totalAssets(); uint256 total = 0; - for (uint256 i = 0; i < initialRatios.length; i++) { + for (uint256 i = 0; i < targetRatios.length; i++) { + initialRatios[i] = calculateCurrentRatio(vaults[i], totalAssets); total += initialRatios[i]; } + require(total == 1e18, "Initial ratios must add up to 100 %"); return total; } - function totalFinalRatios() public view returns (uint256) { + function _totalTargetRatios() internal view returns (uint256) { uint256 total = 0; - for (uint256 i = 0; i < finalRatios.length; i++) { - total += finalRatios[i]; + for (uint256 i = 0; i < targetRatios.length; i++) { + total += targetRatios[i]; } return total; } - function caculateSteps() public view returns (Transfer[] memory) { - uint256 length = initialRatios.length; + function calculateCurrentRatio(address vault, uint256 totalAssets) public view returns (uint256) { + uint256 balance; + if (vault == address(maxVault)) { + balance = IERC20(asset).balanceOf(address(maxVault)); + } else { + balance = IVault(vault).totalAssets(); + } + + uint256 rate = IProvider(maxVault.provider()).getRate(asset); + + // get current percentage in wad: ((wad*X) / Y) / WAD = percentZ (1e18 = 100%) + uint256 adjustedBalance = Math.wmul(balance, rate); + uint256 currentRatio = Math.wdiv(adjustedBalance, totalAssets); + return currentRatio; + } + + function isVault(address target) public view returns (bool) { + try Vault(payable(target)).VAULT_VERSION() returns (string memory version) { + return bytes(version).length > 0; + } catch { + return false; + } + } + + function shouldRebalance() public view returns (bool) { + uint256 totalAssets = maxVault.totalAssets(); + address vault; + + // Check each underlying asset's ratio. + for (uint256 i = 0; i < vaults.length; i++) { + vault = vaults[i]; + uint256 actualRatio = calculateCurrentRatio(vault, totalAssets); // Calculate current ratio + + // Check if the actual ratio deviates from the target ratio + if (!isWithinTolerance(actualRatio, targetRatios[i])) { + return true; // Rebalancing is required + } + } + // All vaults are within target ratios + return false; + } + + function isWithinTolerance(uint256 actualWAD, uint256 targetWAD) public view returns (bool) { + if (actualWAD >= targetWAD) { + // todo: make tolerance a percentage + return (actualWAD - targetWAD) <= tolerance; // Upper bound + } else { + return (targetWAD - actualWAD) <= tolerance; // Lower bound + } + } + + function rebalance() public returns (Transfer[] memory) { + uint256 length = targetRatios.length; require(length > 1, "Array length must be greater than 1"); - require(length == finalRatios.length, "Array lengths must match"); + require(length == targetRatios.length, "Array lengths must match"); require(length == vaults.length, "Array lengths must match"); - uint256 totalInitial = totalInitialRatios(); - uint256 totalFinal = totalFinalRatios(); - require(totalInitial == totalFinal, "Ratios must add up"); + uint256 totalInitial = _totalInitialRatios(); + uint256 totalFinal = _totalTargetRatios(); - // address baseVault = vaults[0]; - // uint256 totalAssets = IVault(baseVault).totalAssets(); + require(totalInitial == totalFinal, "Ratios must add up"); - uint256 totalAssets = 100; // for testing set to 100 + uint256 totalAssets = maxVault.totalAssets(); uint256[] memory initialAmounts = new uint256[](length); for (uint256 i = 0; i < length; i++) { @@ -64,7 +141,7 @@ contract BaseKeeper { uint256[] memory finalAmounts = new uint256[](length); for (uint256 i = 0; i < length; i++) { - finalAmounts[i] = finalRatios[i] * totalAssets / totalFinal; + finalAmounts[i] = targetRatios[i] * totalAssets / totalFinal; } int256[] memory diffs = new int256[](length); diff --git a/src/libraries/Math.sol b/src/libraries/Math.sol new file mode 100644 index 0000000..854477d --- /dev/null +++ b/src/libraries/Math.sol @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.24; + +/// @dev Max uint256 value that a RAD can represent without overflowing +uint256 constant MAX_RAD = type(uint256).max / RAY; +/// @dev Uint256 representation of 1 RAD +uint256 constant RAD = 10 ** 45; +/// @dev Uint256 representation of 1 RAY +uint256 constant RAY = 10 ** 27; +/// @dev Uint256 representation of 1 WAD +uint256 constant WAD = 10 ** 18; +/// @dev Uint256 representation of 1 year in seconds +uint256 constant YEAR = 365 days; +/// @dev Uint256 representation of 1 hour in seconds +uint256 constant HOUR = 3600; + +/** + * @title Math + * @notice This library contains common math functions + */ +library Math { + // --- Errors --- + + /// @dev Throws when trying to cast a uint256 to an int256 that overflows + error IntOverflow(); + + // --- Math --- + + /** + * @notice Calculates the sum of an unsigned integer and a signed integer + * @param _x Unsigned integer + * @param _y Signed integer + * @return _add Unsigned sum of `_x` and `_y` + */ + function add(uint256 _x, int256 _y) internal pure returns (uint256 _add) { + if (_y >= 0) { + return _x + uint256(_y); + } else { + return _x - uint256(-_y); + } + } + + /** + * @notice Calculates the substraction of an unsigned integer and a signed integer + * @param _x Unsigned integer + * @param _y Signed integer + * @return _sub Unsigned substraction of `_x` and `_y` + */ + function sub(uint256 _x, int256 _y) internal pure returns (uint256 _sub) { + if (_y >= 0) { + return _x - uint256(_y); + } else { + return _x + uint256(-_y); + } + } + + /** + * @notice Calculates the substraction of two unsigned integers + * @param _x Unsigned integer + * @param _y Unsigned integer + * @return _sub Signed substraction of `_x` and `_y` + */ + function sub(uint256 _x, uint256 _y) internal pure returns (int256 _sub) { + return toInt(_x) - toInt(_y); + } + + /** + * @notice Calculates the multiplication of an unsigned integer and a signed integer + * @param _x Unsigned integer + * @param _y Signed integer + * @return _mul Signed multiplication of `_x` and `_y` + */ + function mul(uint256 _x, int256 _y) internal pure returns (int256 _mul) { + return toInt(_x) * _y; + } + + /** + * @notice Calculates the multiplication of two unsigned RAY integers + * @param _x Unsigned RAY integer + * @param _y Unsigned RAY integer + * @return _rmul Unsigned multiplication of `_x` and `_y` in RAY precision + */ + function rmul(uint256 _x, uint256 _y) internal pure returns (uint256 _rmul) { + return (_x * _y) / RAY; + } + + /** + * @notice Calculates the multiplication of an unsigned and a signed RAY integers + * @param _x Unsigned RAY integer + * @param _y Signed RAY integer + * @return _rmul Signed multiplication of `_x` and `_y` in RAY precision + */ + function rmul(uint256 _x, int256 _y) internal pure returns (int256 _rmul) { + return (toInt(_x) * _y) / int256(RAY); + } + + /** + * @notice Calculates the multiplication of two unsigned WAD integers + * @param _x Unsigned WAD integer + * @param _y Unsigned WAD integer + * @return _wmul Unsigned multiplication of `_x` and `_y` in WAD precision + */ + function wmul(uint256 _x, uint256 _y) internal pure returns (uint256 _wmul) { + return (_x * _y) / WAD; + } + + /** + * @notice Calculates the multiplication of an unsigned and a signed WAD integers + * @param _x Unsigned WAD integer + * @param _y Signed WAD integer + * @return _wmul Signed multiplication of `_x` and `_y` in WAD precision + */ + function wmul(uint256 _x, int256 _y) internal pure returns (int256 _wmul) { + return (toInt(_x) * _y) / int256(WAD); + } + + /** + * @notice Calculates the multiplication of two signed WAD integers + * @param _x Signed WAD integer + * @param _y Signed WAD integer + * @return _wmul Signed multiplication of `_x` and `_y` in WAD precision + */ + function wmul(int256 _x, int256 _y) internal pure returns (int256 _wmul) { + return (_x * _y) / int256(WAD); + } + + /** + * @notice Calculates the division of two unsigned RAY integers + * @param _x Unsigned RAY integer + * @param _y Unsigned RAY integer + * @return _rdiv Unsigned division of `_x` by `_y` in RAY precision + */ + function rdiv(uint256 _x, uint256 _y) internal pure returns (uint256 _rdiv) { + return (_x * RAY) / _y; + } + + /** + * @notice Calculates the division of two signed RAY integers + * @param _x Signed RAY integer + * @param _y Signed RAY integer + * @return _rdiv Signed division of `_x` by `_y` in RAY precision + */ + function rdiv(int256 _x, int256 _y) internal pure returns (int256 _rdiv) { + return (_x * int256(RAY)) / _y; + } + + /** + * @notice Calculates the division of two unsigned WAD integers + * @param _x Unsigned WAD integer + * @param _y Unsigned WAD integer + * @return _wdiv Unsigned division of `_x` by `_y` in WAD precision + */ + function wdiv(uint256 _x, uint256 _y) internal pure returns (uint256 _wdiv) { + return (_x * WAD) / _y; + } + + /** + * @notice Calculates the power of an unsigned RAY integer to an unsigned integer + * @param _x Unsigned RAY integer + * @param _n Unsigned integer exponent + * @return _rpow Unsigned `_x` to the power of `_n` in RAY precision + */ + function rpow(uint256 _x, uint256 _n) internal pure returns (uint256 _rpow) { + assembly { + switch _x + case 0 { + switch _n + case 0 { _rpow := RAY } + default { _rpow := 0 } + } + default { + switch mod(_n, 2) + case 0 { _rpow := RAY } + default { _rpow := _x } + let half := div(RAY, 2) // for rounding. + for { _n := div(_n, 2) } _n { _n := div(_n, 2) } { + let _xx := mul(_x, _x) + if iszero(eq(div(_xx, _x), _x)) { revert(0, 0) } + let _xxRound := add(_xx, half) + if lt(_xxRound, _xx) { revert(0, 0) } + _x := div(_xxRound, RAY) + if mod(_n, 2) { + let _zx := mul(_rpow, _x) + if and(iszero(iszero(_x)), iszero(eq(div(_zx, _x), _rpow))) { revert(0, 0) } + let _zxRound := add(_zx, half) + if lt(_zxRound, _zx) { revert(0, 0) } + _rpow := div(_zxRound, RAY) + } + } + } + } + } + + /** + * @notice Calculates the maximum of two unsigned integers + * @param _x Unsigned integer + * @param _y Unsigned integer + * @return _max Unsigned maximum of `_x` and `_y` + */ + function max(uint256 _x, uint256 _y) internal pure returns (uint256 _max) { + _max = (_x >= _y) ? _x : _y; + } + + /** + * @notice Calculates the minimum of two unsigned integers + * @param _x Unsigned integer + * @param _y Unsigned integer + * @return _min Unsigned minimum of `_x` and `_y` + */ + function min(uint256 _x, uint256 _y) internal pure returns (uint256 _min) { + _min = (_x <= _y) ? _x : _y; + } + + /** + * @notice Casts an unsigned integer to a signed integer + * @param _x Unsigned integer + * @return _int Signed integer + * @dev Throws if `_x` is too large to fit in an int256 + */ + function toInt(uint256 _x) internal pure returns (int256 _int) { + _int = int256(_x); + if (_int < 0) revert IntOverflow(); + } + + // --- PI Specific Math --- + + /** + * @notice Calculates the Riemann sum of two signed integers + * @param _x Signed integer + * @param _y Signed integer + * @return _riemannSum Riemann sum of `_x` and `_y` + */ + function riemannSum(int256 _x, int256 _y) internal pure returns (int256 _riemannSum) { + return (_x + _y) / 2; + } + + /** + * @notice Calculates the absolute value of a signed integer + * @param _x Signed integer + * @return _z Unsigned absolute value of `_x` + */ + function absolute(int256 _x) internal pure returns (uint256 _z) { + _z = (_x < 0) ? uint256(-_x) : uint256(_x); + } +} diff --git a/test/unit/keeper.t.sol b/test/unit/keeper.t.sol index d044f30..945ff0f 100644 --- a/test/unit/keeper.t.sol +++ b/test/unit/keeper.t.sol @@ -2,220 +2,203 @@ pragma solidity ^0.8.24; import "lib/forge-std/src/Test.sol"; -import "src/BaseKeeper.sol"; -contract BaseKeeperTest is Test { - BaseKeeper baseKeeper; +import {MainnetActors} from "lib/yieldnest-vault/script/Actors.sol"; +import {SetupVault, Vault, WETH9} from "lib/yieldnest-vault/test/unit/helpers/SetupVault.sol"; +import {BaseKeeper} from "src/BaseKeeper.sol"; - function setUp() public { - baseKeeper = new BaseKeeper(); - } +contract BaseKeeperTest is Test, MainnetActors { + BaseKeeper public baseKeeper; + Vault public maxVault; + Vault public underlyingVault1; + Vault public underlyingVault2; - function testSetData() public { - uint256[] memory initialRatios = new uint256[](3); - uint256[] memory finalRatios = new uint256[](3); - address[] memory vaults = new address[](3); + WETH9 public weth; - initialRatios[0] = 50; - initialRatios[1] = 30; - initialRatios[2] = 20; + uint256 public constant INITIAL_BALANCE = 10 ether; - finalRatios[0] = 40; - finalRatios[1] = 40; - finalRatios[2] = 20; + address public alice = address(0xa11ce); - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); - - baseKeeper.setData(initialRatios, finalRatios, vaults); + function setUp() public { + baseKeeper = new BaseKeeper(); - assertEq(baseKeeper.initialRatios(0), 50); - assertEq(baseKeeper.finalRatios(1), 40); - assertEq(baseKeeper.vaults(2), address(3)); - } + SetupVault setupVault = new SetupVault(); + (maxVault, weth) = setupVault.setup(); + (underlyingVault1,) = setupVault.setup(); + (underlyingVault2,) = setupVault.setup(); - function testTotalInitialRatios() public { - uint256[] memory initialRatios = new uint256[](3); - uint256[] memory finalRatios = new uint256[](3); - address[] memory vaults = new address[](3); + vm.startPrank(ASSET_MANAGER); + maxVault.addAsset(address(underlyingVault1), false); + maxVault.addAsset(address(underlyingVault2), false); + vm.stopPrank(); - initialRatios[0] = 50; - initialRatios[1] = 30; - initialRatios[2] = 20; + // Give Alice some tokens + deal(alice, INITIAL_BALANCE); - finalRatios[0] = 40; - finalRatios[1] = 40; - finalRatios[2] = 20; + vm.startPrank(alice); + weth.deposit{value: INITIAL_BALANCE}(); - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); + // deposit assets + weth.approve(address(maxVault), type(uint256).max); + maxVault.depositAsset(address(weth), INITIAL_BALANCE, alice); + vm.stopPrank(); - baseKeeper.setData(initialRatios, finalRatios, vaults); + vm.label(address(weth), "WETH"); + vm.label(address(maxVault), "Max Vault"); + vm.label(address(underlyingVault1), "Underlying Vault 1"); + vm.label(address(underlyingVault2), "Underlying Vault 2"); + vm.label(address(baseKeeper), "BaseKeeper"); + } - assertEq(baseKeeper.totalInitialRatios(), 100); + function test_ViewFunctions() public view { + assertEq(maxVault.totalAssets(), INITIAL_BALANCE); + assertEq(underlyingVault1.totalAssets(), 0); + assertEq(underlyingVault2.totalAssets(), 0); + assertEq(weth.balanceOf(address(maxVault)), INITIAL_BALANCE); } - function testTotalFinalRatios() public { - uint256[] memory initialRatios = new uint256[](3); + function test_SetData() public { uint256[] memory finalRatios = new uint256[](3); address[] memory vaults = new address[](3); - initialRatios[0] = 50; - initialRatios[1] = 30; - initialRatios[2] = 20; + finalRatios[0] = 40 * 1e16; + finalRatios[1] = 40 * 1e16; + finalRatios[2] = 20 * 1e16; - finalRatios[0] = 40; - finalRatios[1] = 40; - finalRatios[2] = 20; + assertEq(finalRatios[0] + finalRatios[1] + finalRatios[2], 1e18); - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); + vaults[0] = address(maxVault); + vaults[1] = address(underlyingVault1); + vaults[2] = address(underlyingVault2); - baseKeeper.setData(initialRatios, finalRatios, vaults); + baseKeeper.setData(finalRatios, vaults); - assertEq(baseKeeper.totalFinalRatios(), 100); + assertEq(baseKeeper.initialRatios(0), 1e18); + assertEq(baseKeeper.targetRatios(1), 40 * 1e16); + assertEq(baseKeeper.vaults(2), address(underlyingVault2)); } - function testCaculateSteps_ExampleOne() public { - uint256[] memory initialRatios = new uint256[](3); - uint256[] memory finalRatios = new uint256[](3); - address[] memory vaults = new address[](3); - - initialRatios[0] = 50; - initialRatios[1] = 30; - initialRatios[2] = 20; - - finalRatios[0] = 40; - finalRatios[1] = 40; - finalRatios[2] = 20; - - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); - - baseKeeper.setData(initialRatios, finalRatios, vaults); + function test_Rebalance_ExampleOne() public { + setData(0.5e18, 0.25e18, 0.25e18); - BaseKeeper.Transfer[] memory steps = baseKeeper.caculateSteps(); + BaseKeeper.Transfer[] memory steps = baseKeeper.rebalance(); - assertEq(steps.length, 1); + assertEq(steps.length, 2, "Expected 2 steps"); // Validate first step - assertEq(steps[0].from, 0); - assertEq(steps[0].to, 1); - assertEq(steps[0].amount, 10); - } - - function testCaculateSteps_ExampleTwo() public { - uint256[] memory initialRatios = new uint256[](3); - uint256[] memory finalRatios = new uint256[](3); - address[] memory vaults = new address[](3); - - initialRatios[0] = 50; - initialRatios[1] = 30; - initialRatios[2] = 20; + assertEq(steps[0].from, 0, "Expected from for step 0"); + assertEq(steps[0].to, 1, "Expected to for step 0"); + assertEq(steps[0].amount, INITIAL_BALANCE / 4, "Expected amount for step 0"); - finalRatios[0] = 20; - finalRatios[1] = 40; - finalRatios[2] = 40; - - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); + // Validate second step + assertEq(steps[1].from, 0, "Expected from for step 1"); + assertEq(steps[1].to, 2, "Expected to for step 1"); + assertEq(steps[1].amount, INITIAL_BALANCE / 4, "Expected amount for step 1"); + } - baseKeeper.setData(initialRatios, finalRatios, vaults); + function test_Rebalance_ExampleTwo() public { + setData(0.5e18, 0.4e18, 0.1e18); - BaseKeeper.Transfer[] memory steps = baseKeeper.caculateSteps(); + BaseKeeper.Transfer[] memory steps = baseKeeper.rebalance(); assertEq(steps.length, 2); // Validate first step - assertEq(steps[0].from, 0); - assertEq(steps[0].to, 1); - assertEq(steps[0].amount, 10); + assertEq(steps[0].from, 0, "Expected from for step 0"); + assertEq(steps[0].to, 1, "Expected to for step 0"); + assertEq(steps[0].amount, 4 * INITIAL_BALANCE / 10, "Expected amount for step 0"); // Validate second step - assertEq(steps[1].from, 0); - assertEq(steps[1].to, 2); - assertEq(steps[1].amount, 20); + assertEq(steps[1].from, 0, "Expected from for step 1"); + assertEq(steps[1].to, 2, "Expected to for step 1"); + assertEq(steps[1].amount, 1 * INITIAL_BALANCE / 10, "Expected amount for step 1"); } - function testCaculateSteps_ExampleThree() public { - uint256[] memory initialRatios = new uint256[](3); - uint256[] memory finalRatios = new uint256[](3); - address[] memory vaults = new address[](3); - - initialRatios[0] = 50; - initialRatios[1] = 50; - initialRatios[2] = 0; + function test_Rebalance_ExampleThree() public { + setData(0.5e18, 0.2e18, 0.3e18); - finalRatios[0] = 33; - finalRatios[1] = 33; - finalRatios[2] = 34; + BaseKeeper.Transfer[] memory steps = baseKeeper.rebalance(); - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); - - baseKeeper.setData(initialRatios, finalRatios, vaults); - - BaseKeeper.Transfer[] memory steps = baseKeeper.caculateSteps(); - - assertEq(steps.length, 2); + assertEq(steps.length, 2, "Expected 2 steps"); // Validate first step - assertEq(steps[0].from, 0); - assertEq(steps[0].to, 2); - assertEq(steps[0].amount, 17); + assertEq(steps[0].from, 0, "Expected from for step 0"); + assertEq(steps[0].to, 1, "Expected to for step 0"); + assertEq(steps[0].amount, 2 * INITIAL_BALANCE / 10, "Expected amount for step 0"); // Validate second step - assertEq(steps[1].from, 1); - assertEq(steps[1].to, 2); - assertEq(steps[1].amount, 17); + assertEq(steps[1].from, 0, "Expected from for step 1"); + assertEq(steps[1].to, 2, "Expected to for step 1"); + assertEq(steps[1].amount, 3 * INITIAL_BALANCE / 10, "Expected amount for step 1"); } - function testSetDataFailsForMismatchedArrayLengths() public { - uint256[] memory initialRatios = new uint256[](2); + function test_SetDataFailsForMismatchedArrayLengths() public { uint256[] memory finalRatios = new uint256[](3); - address[] memory vaults = new address[](3); - - initialRatios[0] = 50; - initialRatios[1] = 50; + address[] memory vaults = new address[](2); finalRatios[0] = 40; finalRatios[1] = 40; finalRatios[2] = 20; - vaults[0] = address(1); + vaults[0] = address(maxVault); vaults[1] = address(2); - vaults[2] = address(3); vm.expectRevert("Array lengths must match"); - baseKeeper.setData(initialRatios, finalRatios, vaults); + baseKeeper.setData(finalRatios, vaults); } - function testCaculateStepsFailsForUnmatchedRatios() public { - uint256[] memory initialRatios = new uint256[](3); + function test_SetDataFailsForInvalidRatios() public { uint256[] memory finalRatios = new uint256[](3); address[] memory vaults = new address[](3); + uint256 ratio1 = 0.5e17; + uint256 ratio2 = 0.25e17; + uint256 ratio3 = 0.25e17; - initialRatios[0] = 50; - initialRatios[1] = 30; - initialRatios[2] = 20; + vaults[0] = address(maxVault); + vaults[1] = address(underlyingVault1); + vaults[2] = address(underlyingVault2); - finalRatios[0] = 50; - finalRatios[1] = 30; - finalRatios[2] = 10; + finalRatios[0] = ratio1; + finalRatios[1] = ratio2; + finalRatios[2] = ratio3; - vaults[0] = address(1); - vaults[1] = address(2); - vaults[2] = address(3); + vm.expectRevert("Initial and target ratios must match"); + baseKeeper.setData(finalRatios, vaults); + } + + function test_CalculateCurrentRatio() public { + setData(0.5e18, 0.25e18, 0.25e18); + + uint256 currentRatio = baseKeeper.calculateCurrentRatio(address(maxVault), maxVault.totalAssets()); + assertEq(currentRatio, 1e18); + } + + function test_ShouldRebalance() public { + setData(0.5e18, 0.25e18, 0.25e18); + bool shouldRebalance = baseKeeper.shouldRebalance(); + assertEq(shouldRebalance, true); + + setData(1e18, 0, 0); + shouldRebalance = baseKeeper.shouldRebalance(); + assertEq(shouldRebalance, false); + } + + function setData(uint256 ratio1, uint256 ratio2, uint256 ratio3) public { + uint256[] memory finalRatios = new uint256[](3); + address[] memory vaults = new address[](3); + + vaults[0] = address(maxVault); + vaults[1] = address(underlyingVault1); + vaults[2] = address(underlyingVault2); + + finalRatios[0] = ratio1; + finalRatios[1] = ratio2; + finalRatios[2] = ratio3; - baseKeeper.setData(initialRatios, finalRatios, vaults); + baseKeeper.setData(finalRatios, vaults); - vm.expectRevert("Ratios must add up"); - baseKeeper.caculateSteps(); + assertEq(baseKeeper.targetRatios(0), ratio1, "Target ratio 0 incorrect"); + assertEq(baseKeeper.targetRatios(1), ratio2, "Target ratio 1 incorrect"); + assertEq(baseKeeper.targetRatios(2), ratio3, "Target ratio 2 incorrect"); } }