Skip to content

Commit

Permalink
Merge pull request #60 from morpho-labs/refactor/w-mul-down
Browse files Browse the repository at this point in the history
refactor(math): revert to int256
  • Loading branch information
MathisGD authored Nov 9, 2023
2 parents 15b0f33 + bcc4cf1 commit 9614561
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
32 changes: 18 additions & 14 deletions src/SpeedJumpIrm.sol
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ contract AdaptativeCurveIrm is IIrm {
address public immutable MORPHO;
/// @notice Curve steepness (scaled by WAD).
/// @dev Verified to be greater than 1 at construction.
uint256 public immutable CURVE_STEEPNESS;
int256 public immutable CURVE_STEEPNESS;
/// @notice Adjustment speed (scaled by WAD).
/// @dev The speed is per second, so the rate moves at a speed of ADJUSTMENT_SPEED * err each second (while being
/// continuously compounded). A typical value for the ADJUSTMENT_SPEED would be 10 ethers / 365 days.
uint256 public immutable ADJUSTMENT_SPEED;
/// @dev Verified to be non-negative at construction.
int256 public immutable ADJUSTMENT_SPEED;
/// @notice Target utilization (scaled by WAD).
/// @dev Verified to be strictly between 0 and 1 at construction.
uint256 public immutable TARGET_UTILIZATION;
int256 public immutable TARGET_UTILIZATION;
/// @notice Initial rate at target (scaled by WAD).
uint256 public immutable INITIAL_RATE_AT_TARGET;

Expand Down Expand Up @@ -78,9 +79,12 @@ contract AdaptativeCurveIrm is IIrm {
require(initialRateAtTarget <= MAX_RATE_AT_TARGET, ErrorsLib.INPUT_TOO_LARGE);

MORPHO = morpho;
CURVE_STEEPNESS = curveSteepness;
ADJUSTMENT_SPEED = adjustmentSpeed;
TARGET_UTILIZATION = targetUtilization;
// Safe "unchecked" cast.
CURVE_STEEPNESS = int256(curveSteepness);
// Safe "unchecked" cast.
ADJUSTMENT_SPEED = int256(adjustmentSpeed);
// Safe "unchecked" cast.
TARGET_UTILIZATION = int256(targetUtilization);
INITIAL_RATE_AT_TARGET = initialRateAtTarget;
}

Expand Down Expand Up @@ -110,20 +114,19 @@ contract AdaptativeCurveIrm is IIrm {
/// @dev Returns avgBorrowRate and newRateAtTarget.
/// @dev Assumes that the inputs `marketParams` and `id` match.
function _borrowRate(Id id, Market memory market) private view returns (uint256, uint256) {
uint256 utilization =
market.totalSupplyAssets > 0 ? market.totalBorrowAssets.wDivDown(market.totalSupplyAssets) : 0;
// Safe "unchecked" cast because the utilization is smaller than 1 (scaled by WAD).
int256 utilization =
int256(market.totalSupplyAssets > 0 ? market.totalBorrowAssets.wDivDown(market.totalSupplyAssets) : 0);

uint256 errNormFactor = utilization > TARGET_UTILIZATION ? WAD - TARGET_UTILIZATION : TARGET_UTILIZATION;
// Safe "unchecked" int256 casts because utilization <= WAD, TARGET_UTILIZATION < WAD and errNormFactor <= WAD.
int256 err = (int256(utilization) - int256(TARGET_UTILIZATION)).wDivDown(int256(errNormFactor));
int256 errNormFactor = utilization > TARGET_UTILIZATION ? WAD_INT - TARGET_UTILIZATION : TARGET_UTILIZATION;
int256 err = (utilization - TARGET_UTILIZATION).wDivDown(errNormFactor);

uint256 startRateAtTarget = rateAtTarget[id];

// First interaction.
if (startRateAtTarget == 0) {
return (_curve(INITIAL_RATE_AT_TARGET, err), INITIAL_RATE_AT_TARGET);
} else {
// Safe "unchecked" cast because ADJUSTMENT_SPEED <= type(int256).max.
// Note that the speed is assumed constant between two interactions, but in theory it increases because of
// interests. So the rate will be slightly underestimated.
int256 speed = ADJUSTMENT_SPEED.wMulDown(err);
Expand Down Expand Up @@ -164,8 +167,9 @@ contract AdaptativeCurveIrm is IIrm {
/// r = ((1-1/C)*err + 1) * rateAtTarget if err < 0
/// ((C-1)*err + 1) * rateAtTarget else.
function _curve(uint256 _rateAtTarget, int256 err) private view returns (uint256) {
uint256 steeringCoeff =
(err < 0 ? WAD - WAD.wDivDown(CURVE_STEEPNESS) : CURVE_STEEPNESS - WAD).wMulDown(_rateAtTarget);
// Safe "unchecked" cast of _rateAtTarget because _rateAtTarget <= MAX_RATE_AT_TARGET.
int256 steeringCoeff = (err < 0 ? WAD_INT - WAD_INT.wDivDown(CURVE_STEEPNESS) : CURVE_STEEPNESS - WAD_INT)
.wMulDown(int256(_rateAtTarget));
// Safe "unchecked" cast of _rateAtTarget because _rateAtTarget <= MAX_RATE_AT_TARGET.
// Safe "unchecked" cast of the result because r >= 0.
return uint256(steeringCoeff.wMulDown(err) + int256(_rateAtTarget));
Expand Down
5 changes: 2 additions & 3 deletions src/libraries/MathLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ library MathLib {
}
}

function wMulDown(uint256 a, int256 b) internal pure returns (int256) {
require(a <= uint256(type(int256).max));
return int256(a) * b / WAD_INT;
function wMulDown(int256 a, int256 b) internal pure returns (int256) {
return a * b / WAD_INT;
}

function wDivDown(int256 a, int256 b) internal pure returns (int256) {
Expand Down
35 changes: 19 additions & 16 deletions test/SpeedJumpIrmTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ contract AdaptativeCurveIrmTest is Test {

event BorrowRateUpdate(Id indexed id, uint256 avgBorrowRate, uint256 rateAtTarget);

uint256 internal constant CURVE_STEEPNESS = 4 ether;
uint256 internal constant ADJUSTMENT_SPEED = 50 ether / uint256(365 days);
uint256 internal constant TARGET_UTILIZATION = 0.9 ether;
int256 internal constant CURVE_STEEPNESS = 4 ether;
int256 internal constant ADJUSTMENT_SPEED = 50 ether / int256(365 days);
int256 internal constant TARGET_UTILIZATION = 0.9 ether;
uint256 internal constant INITIAL_RATE_AT_TARGET = 0.01 ether / uint256(365 days);

AdaptativeCurveIrm internal irm;
MarketParams internal marketParams = MarketParams(address(0), address(0), address(0), address(0), 0);

function setUp() public {
irm =
new AdaptativeCurveIrm(address(this), CURVE_STEEPNESS, ADJUSTMENT_SPEED, TARGET_UTILIZATION, INITIAL_RATE_AT_TARGET);
new AdaptativeCurveIrm(address(this), uint256(CURVE_STEEPNESS), uint256(ADJUSTMENT_SPEED), uint256(TARGET_UTILIZATION), INITIAL_RATE_AT_TARGET);
vm.warp(90 days);
}

Expand Down Expand Up @@ -176,14 +176,18 @@ contract AdaptativeCurveIrmTest is Test {
Market memory market;
market.totalBorrowAssets = 9 ether;
market.totalSupplyAssets = 10 ether;
assertGt(irm.borrowRate(marketParams, market), irm.MIN_RATE_AT_TARGET().wDivDown(CURVE_STEEPNESS));
assertGt(
irm.borrowRate(marketParams, market), uint256(int256(irm.MIN_RATE_AT_TARGET()).wDivDown(CURVE_STEEPNESS))
);
}

function invariantMaxRateAtTarget() public {
Market memory market;
market.totalBorrowAssets = 9 ether;
market.totalSupplyAssets = 10 ether;
assertLt(irm.borrowRate(marketParams, market), irm.MAX_RATE_AT_TARGET().wMulDown(CURVE_STEEPNESS));
assertLt(
irm.borrowRate(marketParams, market), uint256(int256(irm.MAX_RATE_AT_TARGET()).wMulDown(CURVE_STEEPNESS))
);
}

function _expectedRateAtTarget(Id id, Market memory market) internal view returns (uint256) {
Expand Down Expand Up @@ -221,25 +225,24 @@ contract AdaptativeCurveIrmTest is Test {
// Safe "unchecked" cast because err >= -1 (in WAD).
if (err < 0) {
return uint256(
(WAD - WAD.wDivDown(CURVE_STEEPNESS)).wMulDown(rateAtTarget).wMulDown(err) + int256(rateAtTarget)
(WAD_INT - WAD_INT.wDivDown(CURVE_STEEPNESS)).wMulDown(int256(rateAtTarget)).wMulDown(err)
+ int256(rateAtTarget)
);
} else {
return uint256((CURVE_STEEPNESS - WAD).wMulDown(rateAtTarget).wMulDown(err) + int256(rateAtTarget));
return
uint256((CURVE_STEEPNESS - WAD_INT).wMulDown(int256(rateAtTarget)).wMulDown(err) + int256(rateAtTarget));
}
}

function _err(Market memory market) internal pure returns (int256) {
function _err(Market memory market) internal pure returns (int256 err) {
if (market.totalSupplyAssets == 0) return -1 ether;
uint256 utilization = market.totalBorrowAssets.wDivDown(market.totalSupplyAssets);

int256 err;
int256 utilization = int256(market.totalBorrowAssets.wDivDown(market.totalSupplyAssets));

if (utilization > TARGET_UTILIZATION) {
// Safe "unchecked" cast because |err| <= WAD.
err = int256((utilization - TARGET_UTILIZATION).wDivDown(WAD - TARGET_UTILIZATION));
err = (utilization - TARGET_UTILIZATION).wDivDown(WAD_INT - TARGET_UTILIZATION);
} else {
// Safe "unchecked" casts because utilization <= WAD and TARGET_UTILIZATION <= WAD.
err = (int256(utilization) - int256(TARGET_UTILIZATION)).wDivDown(int256(TARGET_UTILIZATION));
err = (utilization - TARGET_UTILIZATION).wDivDown(TARGET_UTILIZATION);
}
return err;
}
}

0 comments on commit 9614561

Please sign in to comment.