From ba9f228aa0643b8c8adfb921f408ae0be9c67ccb Mon Sep 17 00:00:00 2001 From: zer0dot Date: Fri, 24 May 2024 00:41:07 +0800 Subject: [PATCH] refactor: make create and create2 standard execution functions --- src/common/BaseLightAccount.sol | 19 ++++++++++---- test/LightAccount.t.sol | 44 ++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/common/BaseLightAccount.sol b/src/common/BaseLightAccount.sol index 7b0c7ca..4db6e4d 100644 --- a/src/common/BaseLightAccount.sol +++ b/src/common/BaseLightAccount.sol @@ -80,18 +80,27 @@ abstract contract BaseLightAccount is BaseAccount, TokenCallbackHandler, UUPSUpg /// @notice Creates a contract, this can only be called by this account. /// @param initCode The initCode to deploy. NOTE: This could be replaced with transient storage in the near future, /// depending on gas savings, if any. - function create(bytes calldata initCode) external payable virtual { + function create(bytes calldata initCode, uint256 value) external payable virtual onlyAuthorized { assembly ("memory-safe") { - // Check that the caller is this account, this compiles to the same as inverting the condition - if iszero(eq(caller(), address())) { - mstore(0, 0x913e98f1) // OnlyCallableBySelf() + // Copy the initCode to memory, then deploy the contract + let len := initCode.length + calldatacopy(0, initCode.offset, len) + let succ := create(value, 0, len) + + // If the creation fails, revert + if iszero(succ) { + mstore(0, 0x7e16b8cd) // CreateFailed() revert(28, 4) } + } + } + function create2(bytes calldata initCode, bytes32 salt, uint256 value) external payable virtual onlyAuthorized { + assembly ("memory-safe") { // Copy the initCode to memory, then deploy the contract let len := initCode.length calldatacopy(0, initCode.offset, len) - let succ := create(callvalue(), 0, len) + let succ := create2(value, 0, len, salt) // If the creation fails, revert if iszero(succ) { diff --git a/test/LightAccount.t.sol b/test/LightAccount.t.sol index 5468183..008e268 100644 --- a/test/LightAccount.t.sol +++ b/test/LightAccount.t.sol @@ -477,12 +477,12 @@ contract LightAccountTest is Test { assertEq(initialized, 1); } - function testRevertCreateContract_IncorrectCaller() public { - vm.expectRevert(BaseLightAccount.OnlyCallableBySelf.selector); - account.create(hex"1234"); + function testRevertCreate_IncorrectCaller() public { + vm.expectRevert(abi.encodeWithSelector(BaseLightAccount.NotAuthorized.selector, address(this))); + account.create(hex"1234", 0); } - function testRevertCreateContract_CreateFailed() public { + function testRevertCreate_CreateFailed() public { vm.prank(eoaAddress); vm.expectRevert(BaseLightAccount.CreateFailed.selector); account.execute( @@ -490,22 +490,38 @@ contract LightAccountTest is Test { 0, abi.encodeCall( account.create, - (hex"01") // Attempt to deploy a contract with a single "ADD" opcode as the whole initcode, which will revert. + (hex"01", 0) // Attempt to deploy a contract with a single "ADD" opcode as the whole initcode, which will revert. ) ); } - function testCreateContract() public { + function testRevertCreate2_IncorrectCaller() public { + vm.expectRevert(abi.encodeWithSelector(BaseLightAccount.NotAuthorized.selector, address(this))); + account.create2(hex"1234", bytes32(0), 0); + } + + function testRevertCreate2_CreateFailed() public { + vm.prank(eoaAddress); + vm.expectRevert(BaseLightAccount.CreateFailed.selector); + account.create2(hex"01", bytes32(0), 0); + } + + function testCreate() public { vm.prank(eoaAddress); address expected = vm.computeCreateAddress(address(account), vm.getNonce(address(account))); - account.execute( - address(account), - 0, - abi.encodeCall( - account.create, (abi.encodePacked(type(LightAccount).creationCode, abi.encode(address(0x4546b)))) - ) - ); - assertEq(address(LightAccount(payable(expected)).entryPoint()), address(0x4546b)); + account.create(abi.encodePacked(type(LightAccount).creationCode, abi.encode(address(entryPoint))), 0); + assertEq(address(LightAccount(payable(expected)).entryPoint()), address(entryPoint)); + } + + function testCreate2() public { + vm.prank(eoaAddress); + bytes memory initCode = abi.encodePacked(type(LightAccount).creationCode, abi.encode(address(entryPoint))); + bytes32 initCodeHash = keccak256(initCode); + bytes32 salt = bytes32(hex"04546b"); + address expected = vm.computeCreate2Address(salt, initCodeHash, address(account)); + + account.create2(abi.encodePacked(type(LightAccount).creationCode, abi.encode(address(entryPoint))), salt, 0); + assertEq(address(LightAccount(payable(expected)).entryPoint()), address(entryPoint)); } function _useContractOwner() internal {