From d4370d08733763b52a857fd04b034dece28564f3 Mon Sep 17 00:00:00 2001 From: Philippe Gonday Date: Tue, 17 Dec 2024 10:12:01 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB()=20Gas=20optim=20after=20review?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../modules/InitialLockupPeriodModule.sol | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/contracts/compliance/modular/modules/InitialLockupPeriodModule.sol b/contracts/compliance/modular/modules/InitialLockupPeriodModule.sol index 1acf5449..d89fe5ba 100644 --- a/contracts/compliance/modular/modules/InitialLockupPeriodModule.sol +++ b/contracts/compliance/modular/modules/InitialLockupPeriodModule.sol @@ -92,8 +92,13 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { uint256 releaseTimestamp; } + struct LockedDetails { + uint256 totalLocked; + LockedTokens[] lockedTokens; + } + mapping(address compliance => uint256 lockupPeriod) private _lockupPeriods; - mapping(address compliance => mapping(address user => LockedTokens[])) private _lockedTokens; + mapping(address compliance => mapping(address user => LockedDetails)) private _lockedDetails; /// @dev initializes the contract and sets the initial state. function initialize() external initializer { @@ -114,9 +119,9 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { return; } - (uint256 lockedAmount, uint256 unlockedAmount) = _calculateLockedAmount(msg.sender, _from); + LockedDetails storage lockedDetails = _lockedDetails[msg.sender][_from]; uint256 freeAmount = - IToken(IModularCompliance(msg.sender).getTokenBound()).balanceOf(_from) - lockedAmount - unlockedAmount; + IToken(IModularCompliance(msg.sender).getTokenBound()).balanceOf(_from) - lockedDetails.totalLocked; if (_value > freeAmount) { _updateLockedTokens(_from, _value - freeAmount); } @@ -124,7 +129,9 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { /// @inheritdoc IModule function moduleMintAction(address _to, uint256 _value) external override onlyComplianceCall { - _lockedTokens[msg.sender][_to].push( + LockedDetails storage lockedDetails = _lockedDetails[msg.sender][_to]; + lockedDetails.totalLocked += _value; + lockedDetails.lockedTokens.push( LockedTokens({ amount: _value, releaseTimestamp: block.timestamp + _lockupPeriods[msg.sender] @@ -134,15 +141,20 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { /// @inheritdoc IModule function moduleBurnAction(address _from, uint256 _value) external override onlyComplianceCall { - (uint256 lockedAmount, uint256 unlockedAmount) = _calculateLockedAmount(msg.sender, _from); + LockedDetails storage lockedDetails = _lockedDetails[msg.sender][_from]; uint256 previousBalance = IToken(IModularCompliance(msg.sender).getTokenBound()).balanceOf(_from) + _value; + uint256 freeAmount = previousBalance - lockedDetails.totalLocked; + + if (freeAmount < _value) { + // We need to calculate more accurately the free amount, as totalLocked can include now unlocked tokens. + freeAmount = freeAmount + _calculateUnlockedAmount(lockedDetails); + } require( - (previousBalance - lockedAmount) >= _value, - InsufficientBalanceTokensLocked(_from, _value, previousBalance - lockedAmount) + freeAmount >= _value, + InsufficientBalanceTokensLocked(_from, _value, freeAmount) ); - uint256 freeAmount = previousBalance - lockedAmount - unlockedAmount; if (_value > freeAmount) { _updateLockedTokens(_from, _value - freeAmount); } @@ -151,10 +163,15 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { /// @inheritdoc IModule function moduleCheck(address _from, address /*_to*/, uint256 _value, address _compliance) external view override returns (bool) { - (uint256 lockedAmount, ) = _calculateLockedAmount(_compliance, _from); + if (_from == address(0)) { + return true; + } + + LockedDetails storage lockedDetails = _lockedDetails[_compliance][_from]; + uint256 balance = IToken(IModularCompliance(_compliance).getTokenBound()).balanceOf(_from); - return _from == address(0) - || IToken(IModularCompliance(_compliance).getTokenBound()).balanceOf(_from) - lockedAmount >= _value; + return (balance - lockedDetails.totalLocked) >= _value + || (balance - lockedDetails.totalLocked + _calculateUnlockedAmount(lockedDetails)) >= _value; } /// @inheritdoc IModule @@ -176,22 +193,22 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { /// @param _user the address of the user. /// @param _value the amount of tokens to unlock. function _updateLockedTokens(address _user, uint256 _value) internal { - LockedTokens[] storage lockedTokens = _lockedTokens[msg.sender][_user]; - for (uint256 i; _value > 0 && i < lockedTokens.length; ) { - if (lockedTokens[i].releaseTimestamp <= block.timestamp) { - if (_value >= lockedTokens[i].amount) { - _value -= lockedTokens[i].amount; + LockedDetails storage lockedDetails = _lockedDetails[msg.sender][_user]; + for (uint256 i; _value > 0 && i < lockedDetails.lockedTokens.length; ) { + if (lockedDetails.lockedTokens[i].releaseTimestamp <= block.timestamp) { + if (_value >= lockedDetails.lockedTokens[i].amount) { + _value -= lockedDetails.lockedTokens[i].amount; // Remove entry - if (i == lockedTokens.length - 1) { - lockedTokens.pop(); + if (i == lockedDetails.lockedTokens.length - 1) { + lockedDetails.lockedTokens.pop(); break; } else { - lockedTokens[i] = lockedTokens[lockedTokens.length - 1]; - lockedTokens.pop(); + lockedDetails.lockedTokens[i] = lockedDetails.lockedTokens[lockedDetails.lockedTokens.length - 1]; + lockedDetails.lockedTokens.pop(); } } else { - lockedTokens[i].amount -= _value; + lockedDetails.lockedTokens[i].amount -= _value; break; } } @@ -201,20 +218,14 @@ contract InitialLockupPeriodModule is AbstractModuleUpgradeable { } } - /// @dev calculates the locked amount of tokens for a user. - /// @param _compliance the address of the compliance contract. - /// @param _user the address of the user. - /// @return _lockedAmount the locked amount of tokens. + /// @dev calculates the unlocked amount of tokens for a user. + /// @param _details the locked details of the user. /// @return _unlockedAmount the unlocked amount of tokens. - function _calculateLockedAmount(address _compliance, address _user) internal view - returns (uint256 _lockedAmount, uint256 _unlockedAmount) { - uint256 periodsLength = _lockedTokens[_compliance][_user].length; - for (uint256 i; i < periodsLength; i++) { - if (_lockedTokens[_compliance][_user][i].releaseTimestamp > block.timestamp) { - _lockedAmount += _lockedTokens[_compliance][_user][i].amount; - } - else { - _unlockedAmount += _lockedTokens[_compliance][_user][i].amount; + function _calculateUnlockedAmount(LockedDetails storage _details) internal view + returns (uint256 _unlockedAmount) { + for (uint256 i; i < _details.lockedTokens.length; i++) { + if (_details.lockedTokens[i].releaseTimestamp <= block.timestamp) { + _unlockedAmount += _details.lockedTokens[i].amount; } } }