diff --git a/src/MetaMorpho.sol b/src/MetaMorpho.sol index ad0e6c61..3cbb452d 100644 --- a/src/MetaMorpho.sol +++ b/src/MetaMorpho.sol @@ -234,7 +234,7 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph if (newFee != 0 && feeRecipient == address(0)) revert ErrorsLib.ZeroFeeRecipient(); // Accrue fee using the previous fee set before changing it. - _updateLastTotalAssets(_accrueFee()); + _accrueInterest(); // Safe "unchecked" cast because newFee <= MAX_FEE. fee = uint96(newFee); @@ -248,7 +248,7 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph if (newFeeRecipient == address(0) && fee != 0) revert ErrorsLib.ZeroFeeRecipient(); // Accrue fee to the previous fee recipient set before changing it. - _updateLastTotalAssets(_accrueFee()); + _accrueInterest(); feeRecipient = newFeeRecipient; @@ -529,63 +529,49 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph /// @inheritdoc IERC4626 function deposit(uint256 assets, address receiver) public override returns (uint256 shares) { - uint256 newTotalAssets = _accrueFee(); + _accrueInterest(); - // Update `lastTotalAssets` to avoid an inconsistent state in a re-entrant context. - // It is updated again in `_deposit`. - lastTotalAssets = newTotalAssets; - - shares = _convertToSharesWithTotals(assets, totalSupply(), newTotalAssets, Math.Rounding.Floor); + shares = _convertToShares(assets, Math.Rounding.Floor); _deposit(_msgSender(), receiver, assets, shares); } /// @inheritdoc IERC4626 function mint(uint256 shares, address receiver) public override returns (uint256 assets) { - uint256 newTotalAssets = _accrueFee(); + _accrueInterest(); - // Update `lastTotalAssets` to avoid an inconsistent state in a re-entrant context. - // It is updated again in `_deposit`. - lastTotalAssets = newTotalAssets; - - assets = _convertToAssetsWithTotals(shares, totalSupply(), newTotalAssets, Math.Rounding.Ceil); + assets = _convertToAssets(shares, Math.Rounding.Ceil); _deposit(_msgSender(), receiver, assets, shares); } /// @inheritdoc IERC4626 function withdraw(uint256 assets, address receiver, address owner) public override returns (uint256 shares) { - uint256 newTotalAssets = _accrueFee(); + _accrueInterest(); // Do not call expensive `maxWithdraw` and optimistically withdraw assets. - shares = _convertToSharesWithTotals(assets, totalSupply(), newTotalAssets, Math.Rounding.Ceil); - - // `newTotalAssets - assets` may be a little off from `totalAssets()`. - _updateLastTotalAssets(newTotalAssets.zeroFloorSub(assets)); + shares = _convertToShares(assets, Math.Rounding.Ceil); _withdraw(_msgSender(), receiver, owner, assets, shares); } /// @inheritdoc IERC4626 function redeem(uint256 shares, address receiver, address owner) public override returns (uint256 assets) { - uint256 newTotalAssets = _accrueFee(); + _accrueInterest(); // Do not call expensive `maxRedeem` and optimistically redeem shares. - assets = _convertToAssetsWithTotals(shares, totalSupply(), newTotalAssets, Math.Rounding.Floor); - - // `newTotalAssets - assets` may be a little off from `totalAssets()`. - _updateLastTotalAssets(newTotalAssets.zeroFloorSub(assets)); + assets = _convertToAssets(shares, Math.Rounding.Floor); _withdraw(_msgSender(), receiver, owner, assets, shares); } /// @inheritdoc IERC4626 function totalAssets() public view override returns (uint256 assets) { - for (uint256 i; i < withdrawQueue.length; ++i) { - assets += MORPHO.expectedSupplyAssets(_marketParams(withdrawQueue[i]), address(this)); - } + uint256 realTotalAssets = _realTotalAssets(); + + assets = realTotalAssets > lastTotalAssets ? realTotalAssets : lastTotalAssets; } /* ERC4626 (INTERNAL) */ @@ -672,7 +658,6 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph _supplyMorpho(assets); - // `lastTotalAssets + assets` may be a little off from `totalAssets()`. _updateLastTotalAssets(lastTotalAssets + assets); } @@ -686,6 +671,8 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph internal override { + _updateLastTotalAssets(lastTotalAssets - assets); + _withdrawMorpho(assets); super._withdraw(caller, receiver, owner, assets, shares); @@ -880,21 +867,25 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph /// @dev Accrues the fee and mints the fee shares to the fee recipient. /// @return newTotalAssets The vaults total assets after accruing the interest. - function _accrueFee() internal returns (uint256 newTotalAssets) { + function _accrueInterest() internal returns (uint256 newTotalAssets) { uint256 feeShares; (feeShares, newTotalAssets) = _accruedFeeShares(); if (feeShares != 0) _mint(feeRecipient, feeShares); + lastTotalAssets = newTotalAssets; + emit EventsLib.AccrueInterest(newTotalAssets, feeShares); } /// @dev Computes and returns the fee shares (`feeShares`) to mint and the new vault's total assets /// (`newTotalAssets`). function _accruedFeeShares() internal view returns (uint256 feeShares, uint256 newTotalAssets) { - newTotalAssets = totalAssets(); + uint256 realTotalAssets = _realTotalAssets(); + + newTotalAssets = realTotalAssets > lastTotalAssets ? realTotalAssets : lastTotalAssets; - uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets); + uint256 totalInterest = newTotalAssets - lastTotalAssets; if (totalInterest != 0 && fee != 0) { // It is acknowledged that `feeAssets` may be rounded down to 0 if `totalInterest * fee < WAD`. uint256 feeAssets = totalInterest.mulDiv(fee, WAD); @@ -904,4 +895,10 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph _convertToSharesWithTotals(feeAssets, totalSupply(), newTotalAssets - feeAssets, Math.Rounding.Floor); } } + + function _realTotalAssets() internal view returns (uint256 realTotalAssets) { + for (uint256 i; i < withdrawQueue.length; ++i) { + realTotalAssets += MORPHO.expectedSupplyAssets(_marketParams(withdrawQueue[i]), address(this)); + } + } } diff --git a/test/forge/ERC4626Test.sol b/test/forge/ERC4626Test.sol index a96efccb..2b4b04d8 100644 --- a/test/forge/ERC4626Test.sol +++ b/test/forge/ERC4626Test.sol @@ -279,7 +279,7 @@ contract ERC4626Test is IntegrationTest, IMorphoFlashLoanCallback { assets = bound(assets, deposited + 1, type(uint256).max / (deposited + 1)); vm.prank(ONBEHALF); - vm.expectRevert(ErrorsLib.NotEnoughLiquidity.selector); + vm.expectRevert(stdError.arithmeticError); vault.withdraw(assets, RECEIVER, ONBEHALF); } @@ -301,7 +301,7 @@ contract ERC4626Test is IntegrationTest, IMorphoFlashLoanCallback { morpho.borrow(allMarkets[0], 1, 0, BORROWER, BORROWER); vm.startPrank(ONBEHALF); - vm.expectRevert(ErrorsLib.NotEnoughLiquidity.selector); + vm.expectRevert(stdError.arithmeticError); vault.withdraw(assets, RECEIVER, ONBEHALF); } diff --git a/test/forge/SharePriceTest.sol b/test/forge/SharePriceTest.sol new file mode 100644 index 00000000..7090d097 --- /dev/null +++ b/test/forge/SharePriceTest.sol @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import "./helpers/IntegrationTest.sol"; + +contract ModifiedMorpho { + address public owner; + address public feeRecipient; + mapping(bytes32 => mapping(address => Position)) public position; + mapping(bytes32 => Market) public market; + + function writeTotalSupplyAssets(bytes32 id, uint128 newValue) external { + market[id].totalSupplyAssets = newValue; + } +} + +contract SharePriceTest is IntegrationTest { + using stdStorage for StdStorage; + using MorphoBalancesLib for IMorpho; + using MarketParamsLib for MarketParams; + + bytes modifiedCode = _makeModifiedCode(); + bytes normalCode = address(morpho).code; + + function _makeModifiedCode() internal returns (bytes memory) { + return address(new ModifiedMorpho()).code; + } + + function _writeTotalSupplyAssets(bytes32 id, uint128 newValue) internal { + vm.etch(address(morpho), modifiedCode); + ModifiedMorpho(address(morpho)).writeTotalSupplyAssets(id, newValue); + vm.etch(address(morpho), normalCode); + } + + function setUp() public override { + super.setUp(); + + _setCap(allMarkets[0], CAP); + _sortSupplyQueueIdleLast(); + } + + function test_totalAssetsCannotDecrease(uint256 assets) public { + assets = bound(assets, MIN_TEST_ASSETS, MAX_TEST_ASSETS); + + loanToken.setBalance(SUPPLIER, assets); + + vm.prank(SUPPLIER); + vault.deposit(assets, ONBEHALF); + + uint256 totalAssetsBefore = vault.totalAssets(); + _writeTotalSupplyAssets(Id.unwrap(allMarkets[0].id()), 0); + uint256 totalAssetsAfter = vault.totalAssets(); + + assertGe(totalAssetsAfter, totalAssetsBefore, "totalAssets decreased"); + } + + function invariant_totalAssetsCannotDecrease() public { + uint256 totalAssetsBefore = vault.totalAssets(); + _writeTotalSupplyAssets(Id.unwrap(allMarkets[0].id()), 0); + uint256 totalAssetsAfter = vault.totalAssets(); + + assertGe(totalAssetsAfter, totalAssetsBefore, "totalAssets decreased"); + } +}