diff --git a/contracts/contracts/MACI.sol b/contracts/contracts/MACI.sol index c00223c5a1..4b1d31b70f 100644 --- a/contracts/contracts/MACI.sol +++ b/contracts/contracts/MACI.sol @@ -38,6 +38,9 @@ contract MACI is IMACI, Params, Utilities, Ownable { /// @notice A mapping of poll IDs to Poll contracts. mapping(uint256 => address) public polls; + /// @notice Whether the subtrees have been merged (can merge root before new signup) + bool public subtreesMerged; + /// @notice The number of signups uint256 public numSignUps; @@ -93,6 +96,7 @@ contract MACI is IMACI, Params, Utilities, Ownable { error MaciPubKeyLargerThanSnarkFieldSize(); error PreviousPollNotCompleted(uint256 pollId); error PollDoesNotExist(uint256 pollId); + error SignupTemporaryBlocked(); /// @notice Create a new instance of the MACI contract. /// @param _pollFactory The PollFactory contract @@ -153,6 +157,9 @@ contract MACI is IMACI, Params, Utilities, Ownable { bytes memory _signUpGatekeeperData, bytes memory _initialVoiceCreditProxyData ) public virtual { + // prevent new signups until we merge the roots (possible DoS) + if (subtreesMerged) revert SignupTemporaryBlocked(); + // ensure we do not have more signups than what the circuits support if (numSignUps >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups(); @@ -247,10 +254,18 @@ contract MACI is IMACI, Params, Utilities, Ownable { /// @inheritdoc IMACI function mergeStateAqSubRoots(uint256 _numSrQueueOps, uint256 _pollId) public onlyPoll(_pollId) { stateAq.mergeSubRoots(_numSrQueueOps); + + // if we have merged all subtrees then put a block + if (stateAq.subTreesMerged()) { + subtreesMerged = true; + } } /// @inheritdoc IMACI function mergeStateAq(uint256 _pollId) public onlyPoll(_pollId) returns (uint256 root) { + // remove block + subtreesMerged = false; + root = stateAq.merge(stateTreeDepth); } diff --git a/contracts/contracts/Poll.sol b/contracts/contracts/Poll.sol index dc0e05a9b2..a6560f9da7 100644 --- a/contracts/contracts/Poll.sol +++ b/contracts/contracts/Poll.sol @@ -214,6 +214,7 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol // deadline if (stateAqMerged) revert StateAqAlreadyMerged(); + // set merged to true so it cannot be called again stateAqMerged = true; // the subtrees must have been merged first diff --git a/contracts/tests/MACI.test.ts b/contracts/tests/MACI.test.ts index 6343576dc6..854fba5c69 100644 --- a/contracts/tests/MACI.test.ts +++ b/contracts/tests/MACI.test.ts @@ -248,18 +248,30 @@ describe("MACI", () => { ).to.be.revertedWithCustomError(maciContract, "CallerMustBePoll"); }); - it("should allow a Poll contract to merge the signUp AccQueue", async () => { + it("should prevent a new user from signin up after the accQueue subtrees have been merged", async () => { await timeTravel(signer.provider as unknown as EthereumProvider, Number(duration) + 1); - let tx = await pollContract.mergeMaciStateAqSubRoots(0, pollId, { + + const tx = await pollContract.mergeMaciStateAqSubRoots(0, pollId, { gasLimit: 3000000, }); - let receipt = await tx.wait(); + const receipt = await tx.wait(); expect(receipt?.status).to.eq(1); - tx = await pollContract.mergeMaciStateAq(pollId, { + await expect( + maciContract.signUp( + users[0].pubKey.asContractParam(), + AbiCoder.defaultAbiCoder().encode(["uint256"], [1]), + AbiCoder.defaultAbiCoder().encode(["uint256"], [0]), + signUpTxOpts, + ), + ).to.be.revertedWithCustomError(maciContract, "SignupTemporaryBlocked"); + }); + + it("should allow a Poll contract to merge the signUp AccQueue", async () => { + const tx = await pollContract.mergeMaciStateAq(pollId, { gasLimit: 3000000, }); - receipt = await tx.wait(); + const receipt = await tx.wait(); expect(receipt?.status).to.eq(1); }); @@ -268,13 +280,40 @@ describe("MACI", () => { maciState.polls.get(pollId)?.updatePoll(await pollContract.numSignups()); expect(onChainStateRoot.toString()).to.eq(maciState.polls.get(pollId)?.stateTree?.root.toString()); }); - }); - describe("getStateAqRoot", () => { - it("should return the correct state root", async () => { + it("should get the correct state root with getStateAqRoot", async () => { const onChainStateRoot = await maciContract.getStateAqRoot(); expect(onChainStateRoot.toString()).to.eq(maciState.polls.get(pollId)?.stateTree?.root.toString()); }); + + it("should allow a user to signup after the signUp AccQueue was merged", async () => { + const tx = await maciContract.signUp( + users[0].pubKey.asContractParam(), + AbiCoder.defaultAbiCoder().encode(["uint256"], [1]), + AbiCoder.defaultAbiCoder().encode(["uint256"], [0]), + signUpTxOpts, + ); + const receipt = await tx.wait(); + expect(receipt?.status).to.eq(1); + + const iface = maciContract.interface; + + // Store the state index + const log = receipt!.logs[receipt!.logs.length - 1]; + const event = iface.parseLog(log as unknown as { topics: string[]; data: string }) as unknown as { + args: { + _stateIndex: BigNumberish; + _voiceCreditBalance: BigNumberish; + _timestamp: BigNumberish; + }; + }; + + maciState.signUp( + users[0].pubKey, + BigInt(event.args._voiceCreditBalance.toString()), + BigInt(event.args._timestamp.toString()), + ); + }); }); describe("getPoll", () => {