From 415c3da5c65247b55b5078cc89c8f9605d8be77a Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Mon, 19 Feb 2024 09:52:49 +0000 Subject: [PATCH] refactor(contracts): add audit auggestions Add coordinator pub key check, optimize padAndHashMessage, use >= instead of == for signup/message max value checks --- contracts/contracts/MACI.sol | 5 +---- contracts/contracts/Poll.sol | 18 ++++++++++++++---- contracts/contracts/utilities/Utilities.sol | 6 ++++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/contracts/contracts/MACI.sol b/contracts/contracts/MACI.sol index 06213e1937..df0c61ac3c 100644 --- a/contracts/contracts/MACI.sol +++ b/contracts/contracts/MACI.sol @@ -41,9 +41,6 @@ contract MACI is IMACI, Params, Utilities, Ownable { /// @notice The number of signups uint256 public numSignUps; - /// @notice A mapping of block timestamps to the number of state leaves - mapping(uint256 => uint256) public numStateLeaves; - /// @notice ERC20 contract that hold topup credits TopupCredit public immutable topupCredit; @@ -157,7 +154,7 @@ contract MACI is IMACI, Params, Utilities, Ownable { bytes memory _initialVoiceCreditProxyData ) public virtual { // ensure we do not have more signups than what the circuits support - if (numSignUps == uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups(); + if (numSignUps >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups(); if (_pubKey.x >= SNARK_SCALAR_FIELD || _pubKey.y >= SNARK_SCALAR_FIELD) { revert MaciPubKeyLargerThanSnarkFieldSize(); diff --git a/contracts/contracts/Poll.sol b/contracts/contracts/Poll.sol index d7f163df14..dc0e05a9b2 100644 --- a/contracts/contracts/Poll.sol +++ b/contracts/contracts/Poll.sol @@ -93,12 +93,22 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol PubKey memory _coordinatorPubKey, ExtContracts memory _extContracts ) payable { - extContracts = _extContracts; + // check that the coordinator public key is valid + if (_coordinatorPubKey.x >= SNARK_SCALAR_FIELD || _coordinatorPubKey.y >= SNARK_SCALAR_FIELD) { + revert MaciPubKeyLargerThanSnarkFieldSize(); + } + + // store the pub key as object then calculate the hash coordinatorPubKey = _coordinatorPubKey; - // we hash it ourselves to ensure we record the correct value + // we hash it ourselves to ensure we store the correct value coordinatorPubKeyHash = hashLeftRight(_coordinatorPubKey.x, _coordinatorPubKey.y); + // store the external contracts to interact with + extContracts = _extContracts; + // store duration of the poll duration = _duration; + // store max values maxValues = _maxValues; + // store tree depth treeDepths = _treeDepths; // Record the current timestamp deployTime = block.timestamp; @@ -144,7 +154,7 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol /// @inheritdoc IPoll function topup(uint256 stateIndex, uint256 amount) public virtual isWithinVotingDeadline { // we check that we do not exceed the max number of messages - if (numMessages == maxValues.maxMessages) revert TooManyMessages(); + if (numMessages >= maxValues.maxMessages) revert TooManyMessages(); // cannot realistically overflow unchecked { @@ -165,7 +175,7 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol /// @inheritdoc IPoll function publishMessage(Message memory _message, PubKey calldata _encPubKey) public virtual isWithinVotingDeadline { // we check that we do not exceed the max number of messages - if (numMessages == maxValues.maxMessages) revert TooManyMessages(); + if (numMessages >= maxValues.maxMessages) revert TooManyMessages(); // validate that the public key is valid if (_encPubKey.x >= SNARK_SCALAR_FIELD || _encPubKey.y >= SNARK_SCALAR_FIELD) { diff --git a/contracts/contracts/utilities/Utilities.sol b/contracts/contracts/utilities/Utilities.sol index b0c920ad1e..488bf39896 100644 --- a/contracts/contracts/utilities/Utilities.sol +++ b/contracts/contracts/utilities/Utilities.sol @@ -36,8 +36,10 @@ contract Utilities is SnarkConstants, DomainObjs, Hasher { uint256[2] memory dataToPad, uint256 msgType ) public pure returns (Message memory message, PubKey memory padKey, uint256 msgHash) { - // add data and pad it - uint256[10] memory dat = [dataToPad[0], dataToPad[1], 0, 0, 0, 0, 0, 0, 0, 0]; + // add data and pad it to 10 elements (automatically cause it's the default value) + uint256[10] memory dat; + dat[0] = dataToPad[0]; + dat[1] = dataToPad[1]; padKey = PubKey(PAD_PUBKEY_X, PAD_PUBKEY_Y); message = Message({ msgType: msgType, data: dat });