Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [v0.8-develop] per validation hook data #66

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .solhint-test.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"immutable-vars-naming": ["error"],
"no-unused-import": ["error"],
"compiler-version": ["error", ">=0.8.19"],
"custom-errors": "off",
"func-visibility": ["error", { "ignoreConstructors": true }],
"max-line-length": ["error", 120],
"max-states-count": ["warn", 30],
Expand Down
3 changes: 1 addition & 2 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ abstract contract AccountLoupe is IAccountLoupe {
override
returns (FunctionReference[] memory preValidationHooks)
{
preValidationHooks =
toFunctionReferenceArray(getAccountStorage().validationData[validationFunction].preValidationHooks);
preValidationHooks = getAccountStorage().validationData[validationFunction].preValidationHooks;
}

/// @inheritdoc IAccountLoupe
Expand Down
2 changes: 1 addition & 1 deletion src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ValidationData {
// How many execution hooks require the UO context.
uint8 requireUOHookCount;
// The pre validation hooks for this function selector.
EnumerableSet.Bytes32Set preValidationHooks;
FunctionReference[] preValidationHooks;
// Permission hooks for this validation function.
EnumerableSet.Bytes32Set permissionHooks;
}
Expand Down
36 changes: 22 additions & 14 deletions src/account/PluginManager2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";
abstract contract PluginManager2 {
using EnumerableSet for EnumerableSet.Bytes32Set;

// Index marking the start of the data for the validation function.
uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255;
adamegyed marked this conversation as resolved.
Show resolved Hide resolved

error DefaultValidationAlreadySet(FunctionReference validationFunction);
error PreValidationAlreadySet(FunctionReference validationFunction, FunctionReference preValidationFunction);
error ValidationAlreadySet(bytes4 selector, FunctionReference validationFunction);
error ValidationNotSet(bytes4 selector, FunctionReference validationFunction);
error PermissionAlreadySet(FunctionReference validationFunction, ExecutionHook hook);
error PreValidationHookLimitExceeded();

function _installValidation(
FunctionReference validationFunction,
Expand All @@ -39,19 +43,21 @@ abstract contract PluginManager2 {
for (uint256 i = 0; i < preValidationFunctions.length; ++i) {
FunctionReference preValidationFunction = preValidationFunctions[i];

if (
!_storage.validationData[validationFunction].preValidationHooks.add(
toSetValue(preValidationFunction)
)
) {
revert PreValidationAlreadySet(validationFunction, preValidationFunction);
}
_storage.validationData[validationFunction].preValidationHooks.push(preValidationFunction);
adamegyed marked this conversation as resolved.
Show resolved Hide resolved

if (initDatas[i].length > 0) {
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onInstall(initDatas[i]);
}
}

// Avoid collision between reserved index and actual indices
if (
_storage.validationData[validationFunction].preValidationHooks.length
> _RESERVED_VALIDATION_DATA_INDEX
) {
revert PreValidationHookLimitExceeded();
}
}

if (permissionHooks.length > 0) {
Expand Down Expand Up @@ -110,15 +116,16 @@ abstract contract PluginManager2 {
bytes[] memory preValidationHookUninstallDatas = abi.decode(preValidationHookUninstallData, (bytes[]));

// Clear pre validation hooks
EnumerableSet.Bytes32Set storage preValidationHooks =
FunctionReference[] storage preValidationHooks =
_storage.validationData[validationFunction].preValidationHooks;
uint256 i = 0;
while (preValidationHooks.length() > 0) {
FunctionReference preValidationFunction = toFunctionReference(preValidationHooks.at(0));
preValidationHooks.remove(toSetValue(preValidationFunction));
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[i++]);
for (uint256 i = 0; i < preValidationHooks.length; ++i) {
FunctionReference preValidationFunction = preValidationHooks[i];
if (preValidationHookUninstallDatas[0].length > 0) {
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[0]);
}
}
delete _storage.validationData[validationFunction].preValidationHooks;
}

{
Expand All @@ -135,6 +142,7 @@ abstract contract PluginManager2 {
IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i++]);
}
}
delete _storage.validationData[validationFunction].preValidationHooks;

// Because this function also calls `onUninstall`, and removes the default flag from validation, we must
// assume these selectors passed in to be exhaustive.
Expand Down
113 changes: 78 additions & 35 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol";
import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol";
import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationDataHelpers.sol";
import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol";
import {IValidation} from "../interfaces/IValidation.sol";
Expand All @@ -20,13 +21,7 @@ import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.so
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {AccountExecutor} from "./AccountExecutor.sol";
import {AccountLoupe} from "./AccountLoupe.sol";
import {
AccountStorage,
getAccountStorage,
toSetValue,
toFunctionReference,
toExecutionHook
} from "./AccountStorage.sol";
import {AccountStorage, getAccountStorage, toSetValue, toExecutionHook} from "./AccountStorage.sol";
import {AccountStorageInitializable} from "./AccountStorageInitializable.sol";
import {PluginManagerInternals} from "./PluginManagerInternals.sol";
import {PluginManager2} from "./PluginManager2.sol";
Expand All @@ -46,6 +41,7 @@ contract UpgradeableModularAccount is
{
using EnumerableSet for EnumerableSet.Bytes32Set;
using FunctionReferenceLib for FunctionReference;
using SparseCalldataSegmentLib for bytes;

struct PostExecToRun {
bytes preExecHookReturnData;
Expand All @@ -68,6 +64,7 @@ contract UpgradeableModularAccount is
error ExecFromPluginNotPermitted(address plugin, bytes4 selector);
error ExecFromPluginExternalNotPermitted(address plugin, address target, uint256 value, bytes data);
error NativeTokenSpendingNotPermitted(address plugin);
error NonCanonicalEncoding();
error NotEntryPoint();
error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
Expand All @@ -80,6 +77,8 @@ contract UpgradeableModularAccount is
error UnrecognizedFunction(bytes4 selector);
error UserOpValidationFunctionMissing(bytes4 selector);
error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isDefault);
error ValidationSignatureSegmentMissing();
error SignatureSegmentOutOfOrder();

// Wraps execution of a native function with runtime validation and hooks
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin
Expand Down Expand Up @@ -407,38 +406,50 @@ contract UpgradeableModularAccount is
revert RequireUserOperationContext();
}

validationData =
_doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
validationData = _doUserOpValidation(userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
}

// To support gas estimation, we don't fail early when the failure is caused by a signature failure
function _doUserOpValidation(
bytes4 selector,
FunctionReference userOpValidationFunction,
PackedUserOperation memory userOp,
bytes calldata signature,
bytes32 userOpHash
) internal returns (uint256 validationData) {
userOp.signature = signature;
) internal returns (uint256) {
// Set up the per-hook data tracking fields
bytes calldata signatureSegment;
(signatureSegment, signature) = signature.getNextSegment();

if (userOpValidationFunction.isEmpty()) {
// If the validation function is empty, then the call cannot proceed.
revert UserOpValidationFunctionMissing(selector);
}

uint256 currentValidationData;
uint256 validationData;

// Do preUserOpValidation hooks
EnumerableSet.Bytes32Set storage preUserOpValidationHooks =
FunctionReference[] memory preUserOpValidationHooks =
getAccountStorage().validationData[userOpValidationFunction].preValidationHooks;

uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length();
for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) {
bytes32 key = preUserOpValidationHooks.at(i);
FunctionReference preUserOpValidationHook = toFunctionReference(key);
for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mega smol but could we rename "i" to preValidationHookIndex since we use it for a few comparisons?

// Load per-hook data, if any is present
// The segment index is the first byte of the signature
if (signatureSegment.getIndex() == i) {
// Use the current segment
userOp.signature = signatureSegment.getBody();

if (userOp.signature.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(signatureSegment, signature) = signature.getNextSegment();

if (signatureSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
userOp.signature = "";
}

(address plugin, uint8 functionId) = preUserOpValidationHook.unpack();
currentValidationData = IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);
(address plugin, uint8 functionId) = preUserOpValidationHooks[i].unpack();
uint256 currentValidationData =
IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on introducing an additional parameter here for both preUserOpValidationHook and validateUserOp similar to the runtime path? That would allow us to keep the signature the same. What are the downsides?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just adds to the gas cost, and I'm not sure we have a good use case for the entire sig at the moment. With the runtime path, there wasn't anywhere else we could put it.


if (uint160(currentValidationData) > 1) {
// If the aggregator is not 0 or 1, it is an unexpected value
Expand All @@ -449,35 +460,63 @@ contract UpgradeableModularAccount is

// Run the user op validationFunction
{
if (signatureSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

userOp.signature = signatureSegment.getBody();

(address plugin, uint8 functionId) = userOpValidationFunction.unpack();
currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash);
uint256 currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash);

if (preUserOpValidationHooksLength != 0) {
if (preUserOpValidationHooks.length != 0) {
// If we have other validation data we need to coalesce with
validationData = _coalesceValidation(validationData, currentValidationData);
} else {
validationData = currentValidationData;
}
}

return validationData;
}

function _doRuntimeValidation(
FunctionReference runtimeValidationFunction,
bytes calldata callData,
bytes calldata authorizationData
) internal {
// Set up the per-hook data tracking fields
bytes calldata authSegment;
(authSegment, authorizationData) = authorizationData.getNextSegment();

// run all preRuntimeValidation hooks
EnumerableSet.Bytes32Set storage preRuntimeValidationHooks =
FunctionReference[] memory preRuntimeValidationHooks =
getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks;

uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length();
for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) {
bytes32 key = preRuntimeValidationHooks.at(i);
FunctionReference preRuntimeValidationHook = toFunctionReference(key);
for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echoing my previous comment to potentially rename this to preValiationHookIndex or something-- not a hard req by any means.

bytes memory currentAuthData;

if (authSegment.getIndex() == i) {
// Use the current segment
currentAuthData = authSegment.getBody();

if (currentAuthData.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(authSegment, authorizationData) = authorizationData.getNextSegment();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recursion implemented here is neat! This appears to be the the most efficient approach, aside from key-value structure, which is not feasible within internal functions.


(address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHook.unpack();
if (authSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
currentAuthData = "";
}

(address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHooks[i].unpack();
try IValidationHook(hookPlugin).preRuntimeValidationHook(
hookFunctionId, msg.sender, msg.value, callData
hookFunctionId, msg.sender, msg.value, callData, currentAuthData
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
Expand All @@ -487,9 +526,13 @@ contract UpgradeableModularAccount is
}
}

if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably move _RESERVED_VALIDATION_DATA_INDEX into the library and expose a helper function, something like hasNext()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would that look? The getIndex function is used to determine if the next segment applies for the current function (pre-validation hook or validation function), or for a later function. I think a hasNext function would also need to return the index, so it seems pretty similar to the existing setup.

revert ValidationSignatureSegmentMissing();
}

(address plugin, uint8 functionId) = runtimeValidationFunction.unpack();

try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authorizationData)
try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authSegment.getBody())
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason) {
Expand Down
51 changes: 51 additions & 0 deletions src/helpers/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

/// @title Sparse Calldata Segment Library
/// @notice Library for working with sparsely-packed calldata segments, identified with an index.
/// @dev The first byte of each segment is the index of the segment.
/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather
/// than inline as part of the tuple returned by `getNextSegment`.
library SparseCalldataSegmentLib {
adamegyed marked this conversation as resolved.
Show resolved Hide resolved
/// @notice Splits out a segment of calldata, sparsely-packed.
/// The expected format is:
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
/// @param source The calldata to extract the segment from.
/// @return segment The extracted segment. Using the above example, this would be segment0.
/// @return remainder The remaining calldata. Using the above example,
/// this would start at uint32(len(segment1)) and continue to the end at segmentN.
function getNextSegment(bytes calldata source)
internal
pure
returns (bytes calldata segment, bytes calldata remainder)
{
// The first 4 bytes hold the length of the segment, excluding the index.
uint32 length = uint32(bytes4(source[:4]));

// The offset of the remainder of the calldata.
uint256 remainderOffset = 4 + length;

// The segment is the next `length` + 1 bytes, to account for the index.
// By convention, the first byte of each segment is the index of the segment.
segment = source[4:remainderOffset];

// The remainder is the rest of the calldata.
remainder = source[remainderOffset:];
}

/// @notice Extracts the index from a segment.
/// @dev The first byte of the segment is the index.
/// @param segment The segment to extract the index from
/// @return The index of the segment
function getIndex(bytes calldata segment) internal pure returns (uint8) {
return uint8(segment[0]);
}

/// @notice Extracts the body from a segment.
/// @dev The body is the segment without the index.
/// @param segment The segment to extract the body from
/// @return The body of the segment.
function getBody(bytes calldata segment) internal pure returns (bytes calldata) {
return segment[1:];
}
}
1 change: 1 addition & 0 deletions src/interfaces/IValidation.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ interface IValidation is IPlugin {
/// @param sender The caller address.
/// @param value The call value.
/// @param data The calldata sent.
/// @param authorization Additional data for the validation function to use.
function validateRuntime(
uint8 functionId,
address sender,
Expand Down
9 changes: 7 additions & 2 deletions src/interfaces/IValidationHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ interface IValidationHook is IPlugin {
/// @param sender The caller address.
/// @param value The call value.
/// @param data The calldata sent.
function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data)
external;
function preRuntimeValidationHook(
uint8 functionId,
address sender,
uint256 value,
bytes calldata data,
bytes calldata authorization
) external;
adamegyed marked this conversation as resolved.
Show resolved Hide resolved

// TODO: support this hook type within the account & in the manifest

Expand Down
Loading
Loading