From f23c0b5cde4b7f3300f395c2eeff1b8bae219081 Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Fri, 2 Feb 2024 16:05:14 +0000 Subject: [PATCH] feat(tally): remove ballotsTallied event and add view function --- contracts/contracts/Tally.sol | 28 +++++++++++++++------------- contracts/tests/Tally.test.ts | 22 +++++++++++++--------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/contracts/contracts/Tally.sol b/contracts/contracts/Tally.sol index e0dfa86359..df0999e771 100644 --- a/contracts/contracts/Tally.sol +++ b/contracts/contracts/Tally.sol @@ -49,9 +49,6 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher { error BatchStartIndexTooLarge(); error TallyBatchSizeTooLarge(); - /// @notice events - event BallotsTallied(address poll); - /// @notice Create a new Tally contract /// @param _verifier The Verifier contract /// @param _vkRegistry The VkRegistry contract @@ -81,6 +78,16 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher { result = (_batchStartIndex / _tallyBatchSize) + (_numSignUps << uint256(50)); } + /// @notice Check if all ballots are tallied + /// @return tallied whether all ballots are tallied + function isTallied() external view returns (bool tallied) { + (uint8 intStateTreeDepth, , , ) = poll.treeDepths(); + (uint256 numSignUps, ) = poll.numSignUpsAndMessages(); + + // Require that there are untalied ballots left + tallied = tallyBatchNum * (TREE_ARITY ** intStateTreeDepth) >= numSignUps; + } + /// @notice generate hash of public inputs for tally circuit /// @param _numSignUps: number of signups /// @param _batchStartIndex: the start index of given batch @@ -121,17 +128,16 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher { _votingPeriodOver(poll); updateSbCommitment(); - uint256 cachedBatchNum = tallyBatchNum; + // get the batch size and start index + (uint8 intStateTreeDepth, , , ) = poll.treeDepths(); + uint256 tallyBatchSize = TREE_ARITY ** intStateTreeDepth; + uint256 batchStartIndex = tallyBatchNum * tallyBatchSize; + // save some gas because we won't overflow uint256 unchecked { tallyBatchNum++; } - // get the batch size and start index - (uint8 intStateTreeDepth, , , ) = poll.treeDepths(); - uint256 tallyBatchSize = TREE_ARITY ** intStateTreeDepth; - uint256 batchStartIndex = cachedBatchNum * tallyBatchSize; - (uint256 numSignUps, ) = poll.numSignUpsAndMessages(); // Require that there are untalied ballots left @@ -147,10 +153,6 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher { // Update the tally commitment and the tally batch num tallyCommitment = _newTallyCommitment; - - if ((cachedBatchNum + 1) * tallyBatchSize >= numSignUps) { - emit BallotsTallied(address(poll)); - } } /// @notice Verify the tally proof using the verifying key diff --git a/contracts/tests/Tally.test.ts b/contracts/tests/Tally.test.ts index 8acc905f9a..1e5aa412d2 100644 --- a/contracts/tests/Tally.test.ts +++ b/contracts/tests/Tally.test.ts @@ -184,18 +184,26 @@ describe("TallyVotes", () => { tallyGeneratedInputs = poll.tallyVotes(); }); + it("isTallied should return false", async () => { + const isTallied = await tallyContract.isTallied(); + expect(isTallied).to.eq(false); + }); + it("tallyVotes() should update the tally commitment", async () => { // do the processing on the message processor contract await mpContract.processMessages(generatedInputs.newSbCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); - await expect(tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0])) - .to.emit(tallyContract, "BallotsTallied") - .withArgs(await pollContract.getAddress()); + await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); const onChainNewTallyCommitment = await tallyContract.tallyCommitment(); expect(tallyGeneratedInputs.newTallyCommitment).to.eq(onChainNewTallyCommitment.toString()); }); + it("isTallied should return true", async () => { + const isTallied = await tallyContract.isTallied(); + expect(isTallied).to.eq(true); + }); + it("tallyVotes() should revert when votes have already been tallied", async () => { await expect( tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]), @@ -334,9 +342,7 @@ describe("TallyVotes", () => { it("should tally votes correctly", async () => { const tallyGeneratedInputs = poll.tallyVotes(); - await expect(tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0])) - .to.emit(tallyContract, "BallotsTallied") - .withArgs(await pollContract.getAddress()); + await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); const onChainNewTallyCommitment = await tallyContract.tallyCommitment(); expect(tallyGeneratedInputs.newTallyCommitment).to.eq(onChainNewTallyCommitment.toString()); @@ -488,9 +494,7 @@ describe("TallyVotes", () => { // tally second batch tallyGeneratedInputs = poll.tallyVotes(); - await expect(tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0])) - .to.emit(tallyContract, "BallotsTallied") - .withArgs(await pollContract.getAddress()); + await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); // check that it fails to tally again await expect(