Skip to content

Commit

Permalink
add per hook data test case and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Jun 12, 2024
1 parent ae9fe4b commit c937581
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ contract UpgradeableModularAccount is
// Load the next per-hook data segment
(signatureSegment, signature) = signature.getNextSegment();

if (uint8(signatureSegment[0]) <= i) {
if (signatureSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
Expand Down Expand Up @@ -433,7 +433,7 @@ contract UpgradeableModularAccount is
// Load the next per-hook data segment
(authSegment, authorizationData) = authorizationData.getNextSegment();

if (uint8(authSegment[0]) <= i) {
if (authSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
Expand Down
13 changes: 8 additions & 5 deletions src/helpers/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ library SparseCalldataSegmentLib {
pure
returns (bytes calldata segment, bytes calldata remainder)
{
// The first 8 bytes hold the length of the segment.
// The first 8 bytes hold the length of the segment, excluding the index.
uint64 length = uint64(bytes8(source[:8]));

// The segment is the next `length` bytes.
// By convention, the first byte of each segmet is the index of the segment, excluding the 1-byte index.
segment = source[8:8 + length + 1];
// The offset of the remainder of the calldata.
uint256 remainderOffset = 8 + length + 1;

// The segment is the next `length` + 1 bytes, to account for the index.
// By convention, the first byte of each segment is the index of the segment.
segment = source[8:remainderOffset];

// The remainder is the rest of the calldata.
remainder = source[8 + length:];
remainder = source[remainderOffset:];
}

/// @notice Extracts the index from a segment
Expand Down
313 changes: 313 additions & 0 deletions test/account/PerHookData.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";

import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol";
import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol";
import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";

import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol";
import {Counter} from "../mocks/Counter.sol";
import {AccountTestBase} from "../utils/AccountTestBase.sol";

contract PerHookDataTest is AccountTestBase {
using MessageHashUtils for bytes32;

MockAccessControlHookPlugin internal accessControlHookPlugin;

Counter internal counter;

FunctionReference internal ownerValidation;

uint256 public constant CALL_GAS_LIMIT = 50000;
uint256 public constant VERIFICATION_GAS_LIMIT = 1200000;

function setUp() public {
counter = new Counter();

accessControlHookPlugin = new MockAccessControlHookPlugin();

// Write over `account1` with a new account proxy, with different initialization.

address accountImplementation = address(factory.accountImplementation());

account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, "")));

ownerValidation = FunctionReferenceLib.pack(
address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)
);

FunctionReference accessControlHook = FunctionReferenceLib.pack(
address(accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK)
);

FunctionReference[] memory preValidationHooks = new FunctionReference[](1);
preValidationHooks[0] = accessControlHook;

bytes[] memory preValidationHookData = new bytes[](1);
// Access control is restricted to only the counter
preValidationHookData[0] = abi.encode(counter);

bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData);

vm.prank(address(entryPoint));
account1.installValidation(
ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks
);

vm.deal(address(account1), 100 ether);
}

function test_passAccessControl_userOp() public {
assertEq(counter.number(), 0);

(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();

(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)});

userOp.signature =
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v));

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

entryPoint.handleOps(userOps, beneficiary);

assertEq(counter.number(), 1);
}

function test_failAccessControl_badSigData_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();

(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({
index: 0,
validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234))
});

userOp.signature =
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v));

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSignature("Error(string)", "Proof doesn't match target")
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_failAccessControl_noSigData_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

userOp.signature = _encodeSignature(ownerValidation, SHARED_VALIDATION, abi.encodePacked(r, s, v));

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSignature("Error(string)", "Proof doesn't match target")
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_failAccessControl_badIndexProvided_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)});
preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(counter)});

userOp.signature =
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v));

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
}

// todo: index out of order failure case with 2 pre hooks

function test_failAccessControl_badTarget_userOp() public {
PackedUserOperation memory userOp = PackedUserOperation({
sender: address(account1),
nonce: 0,
initCode: "",
callData: abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")),
accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT),
preVerificationGas: 0,
gasFees: _encodeGas(1, 1),
paymasterAndData: "",
signature: ""
});

bytes32 userOpHash = entryPoint.getUserOpHash(userOp);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)});

userOp.signature =
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v));

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSignature("Error(string)", "Target not allowed")
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_passAccessControl_runtime() public {
assertEq(counter.number(), 0);

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)});

vm.prank(owner1);
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "")
);

assertEq(counter.number(), 1);
}

function test_failAccessControl_badSigData_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({
index: 0,
validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234))
});

vm.prank(owner1);
vm.expectRevert(
abi.encodeWithSelector(
UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector,
accessControlHookPlugin,
uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK),
abi.encodeWithSignature("Error(string)", "Proof doesn't match target")
)
);
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "")
);
}

function test_failAccessControl_noSigData_runtime() public {
vm.prank(owner1);
vm.expectRevert(
abi.encodeWithSelector(
UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector,
accessControlHookPlugin,
uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK),
abi.encodeWithSignature("Error(string)", "Proof doesn't match target")
)
);
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
_encodeSignature(ownerValidation, SHARED_VALIDATION, "")
);
}

function test_failAccessControl_badIndexProvided_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(counter)});
preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(counter)});

vm.prank(owner1);
vm.expectRevert(
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
);
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "")
);
}

//todo: index out of order failure case with 2 pre hooks

function test_failAccessControl_badTarget_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)});

vm.prank(owner1);
vm.expectRevert(
abi.encodeWithSelector(
UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector,
accessControlHookPlugin,
uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK),
abi.encodeWithSignature("Error(string)", "Target not allowed")
)
);
account1.executeWithAuthorization(
abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")),
_encodeSignature(ownerValidation, SHARED_VALIDATION, preValidationHookData, "")
);
}

function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) {
PackedUserOperation memory userOp = PackedUserOperation({
sender: address(account1),
nonce: 0,
initCode: "",
callData: abi.encodeCall(
UpgradeableModularAccount.execute, (address(counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT),
preVerificationGas: 0,
gasFees: _encodeGas(1, 1),
paymasterAndData: "",
signature: ""
});

bytes32 userOpHash = entryPoint.getUserOpHash(userOp);

return (userOp, userOpHash);
}
}
Loading

0 comments on commit c937581

Please sign in to comment.