diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 5f4ae7de..3e1917aa 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -370,7 +370,7 @@ contract UpgradeableModularAccount is // Load the next per-hook data segment (signatureSegment, signature) = signature.getNextSegment(); - if (uint8(signatureSegment[0]) <= i) { + if (signatureSegment.getIndex() <= i) { revert SignatureSegmentOutOfOrder(); } } else { @@ -433,7 +433,7 @@ contract UpgradeableModularAccount is // Load the next per-hook data segment (authSegment, authorizationData) = authorizationData.getNextSegment(); - if (uint8(authSegment[0]) <= i) { + if (authSegment.getIndex() <= i) { revert SignatureSegmentOutOfOrder(); } } else { diff --git a/src/helpers/SparseCalldataSegmentLib.sol b/src/helpers/SparseCalldataSegmentLib.sol index cf4a3d0d..e19b2244 100644 --- a/src/helpers/SparseCalldataSegmentLib.sol +++ b/src/helpers/SparseCalldataSegmentLib.sol @@ -11,15 +11,18 @@ library SparseCalldataSegmentLib { pure returns (bytes calldata segment, bytes calldata remainder) { - // The first 8 bytes hold the length of the segment. + // The first 8 bytes hold the length of the segment, excluding the index. uint64 length = uint64(bytes8(source[:8])); - // The segment is the next `length` bytes. - // By convention, the first byte of each segmet is the index of the segment, excluding the 1-byte index. - segment = source[8:8 + length + 1]; + // The offset of the remainder of the calldata. + uint256 remainderOffset = 8 + length + 1; + + // The segment is the next `length` + 1 bytes, to account for the index. + // By convention, the first byte of each segment is the index of the segment. + segment = source[8:remainderOffset]; // The remainder is the rest of the calldata. - remainder = source[8 + length:]; + remainder = source[remainderOffset:]; } /// @notice Extracts the index from a segment diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol new file mode 100644 index 00000000..4b60ca00 --- /dev/null +++ b/test/account/PerHookData.t.sol @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; + +import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol"; +import {Counter} from "../mocks/Counter.sol"; +import {AccountTestBase} from "../utils/AccountTestBase.sol"; + +contract PerHookDataTest is AccountTestBase { + using MessageHashUtils for bytes32; + + MockAccessControlHookPlugin internal accessControlHookPlugin; + + Counter internal counter; + + FunctionReference internal ownerValidation; + + uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; + + function setUp() public { + counter = new Counter(); + + accessControlHookPlugin = new MockAccessControlHookPlugin(); + + // Write over `account1` with a new account proxy, with different initialization. + + address accountImplementation = address(factory.accountImplementation()); + + account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, ""))); + + ownerValidation = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ); + + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the counter + preValidationHookData[0] = abi.encode(counter); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + vm.prank(address(entryPoint)); + account1.installValidation( + ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks + ); + + vm.deal(address(account1), 100 ether); + } + + function test_passAccessControl_userOp() public { + assertEq(counter.number(), 0); + + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)}); + + userOp.signature = + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(counter.number(), 1); + } + + function test_failAccessControl_badSigData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + userOp.signature = + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failAccessControl_noSigData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + userOp.signature = _encodeSignature(ownerValidation, SHARED_VALIDATION, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failAccessControl_badIndexProvided_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)}); + preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(counter)}); + + userOp.signature = + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + // todo: index out of order failure case with 2 pre hooks + + function test_failAccessControl_badTarget_userOp() public { + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)}); + + userOp.signature = + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Target not allowed") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_passAccessControl_runtime() public { + assertEq(counter.number(), 0); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)}); + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "") + ); + + assertEq(counter.number(), 1); + } + + function test_failAccessControl_badSigData_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "") + ); + } + + function test_failAccessControl_noSigData_runtime() public { + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(ownerValidation, SHARED_VALIDATION, "") + ); + } + + function test_failAccessControl_badIndexProvided_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)}); + preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(counter)}); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "") + ); + } + + //todo: index out of order failure case with 2 pre hooks + + function test_failAccessControl_badTarget_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)}); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Target not allowed") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")), + _encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "") + ); + } + + function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) { + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + return (userOp, userOpHash); + } +} diff --git a/test/mocks/plugins/MockAccessControlHookPlugin.sol b/test/mocks/plugins/MockAccessControlHookPlugin.sol new file mode 100644 index 00000000..8235a223 --- /dev/null +++ b/test/mocks/plugins/MockAccessControlHookPlugin.sol @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {PluginMetadata, PluginManifest} from "../../../src/interfaces/IPlugin.sol"; +import {IValidationHook} from "../../../src/interfaces/IValidationHook.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; + +// A pre validaiton hook plugin that uses per-hook data. +// This example enforces that the target of an `execute` call must only be the previously specified address. +// This is just a mock - it does not enforce this over `executeBatch` and other methods of making calls, and should +// not be used in production.. +// solhint-disable custom-errors +contract MockAccessControlHookPlugin is IValidationHook, BasePlugin { + enum FunctionId { + PRE_VALIDATION_HOOK + } + + mapping(address account => address allowedTarget) public allowedTargets; + + function onInstall(bytes calldata data) external override { + address allowedTarget = abi.decode(data, (address)); + allowedTargets[msg.sender] = allowedTarget; + } + + function onUninstall(bytes calldata) external override { + delete allowedTargets[msg.sender]; + } + + function preUserOpValidationHook(uint8 functionId, PackedUserOperation calldata userOp, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + if (bytes4(userOp.callData[:4]) == IStandardExecutor.execute.selector) { + address target = abi.decode(userOp.callData[4:36], (address)); + + // Simulate a merkle proof - require that the target address is also provided in the signature + address proof = address(bytes20(userOp.signature)); + require(proof == target, "Proof doesn't match target"); + require(target == allowedTargets[msg.sender], "Target not allowed"); + return 0; + } + } + revert NotImplemented(); + } + + function preRuntimeValidationHook( + uint8 functionId, + address, + uint256, + bytes calldata data, + bytes calldata authorization + ) external view override { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + if (bytes4(data[:4]) == IStandardExecutor.execute.selector) { + address target = abi.decode(data[4:36], (address)); + + // Simulate a merkle proof - require that the target address is also provided in the authorization + // data + address proof = address(bytes20(authorization)); + require(proof == target, "Proof doesn't match target"); + require(target == allowedTargets[msg.sender], "Target not allowed"); + + return; + } + } + + revert NotImplemented(); + } + + function pluginMetadata() external pure override returns (PluginMetadata memory) {} + + function pluginManifest() external pure override returns (PluginManifest memory) {} +} +// solhint-enable custom-errors diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index d10bb0d2..2d48c6b4 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -82,7 +82,7 @@ abstract contract AccountTestBase is OptimizedTest { ) internal pure returns (bytes memory) { bytes memory sig = abi.encodePacked(validationFunction, defaultOrNot); - for (uint256 i = 0; i < preValidationHookData.length;) { + for (uint256 i = 0; i < preValidationHookData.length; ++i) { sig = abi.encodePacked( sig, _packValidationDataWithIndex( @@ -109,11 +109,11 @@ abstract contract AccountTestBase is OptimizedTest { } // helper function to pack validation data with an index, according to the sparse calldata segment spec. - function _packValidationDataWithIndex(uint256 index, bytes memory validationData) + function _packValidationDataWithIndex(uint8 index, bytes memory validationData) internal pure returns (bytes memory) { - return abi.encodePacked(uint64(validationData.length), uint8(index), validationData); + return abi.encodePacked(uint64(validationData.length), index, validationData); } }