From 510e6ee365958e8c7f72cfca6cea75d15a89d28b Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:21:54 +0000 Subject: [PATCH 1/2] fix(circuits): enforce use of stateIndex from message --- circuits/circom/processMessages.circom | 41 ++++++-- circuits/circom/processMessagesNonQv.circom | 30 ++++-- .../stateLeafAndBallotTransformer.circom | 2 +- circuits/ts/__tests__/ProcessMessages.test.ts | 31 +++++- cli/tests/e2e/e2e.test.ts | 48 ++++++--- cli/ts/commands/genProofs.ts | 4 +- core/ts/Poll.ts | 98 +++++++++++++++---- crypto/ts/index.ts | 2 +- domainobjs/ts/commands/PCommand.ts | 10 +- 9 files changed, 210 insertions(+), 56 deletions(-) diff --git a/circuits/circom/processMessages.circom b/circuits/circom/processMessages.circom index 9022a59121..d1c2f2db2c 100644 --- a/circuits/circom/processMessages.circom +++ b/circuits/circom/processMessages.circom @@ -48,6 +48,8 @@ template ProcessMessages( var STATE_LEAF_PUB_Y_IDX = 1; var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + + var N_BITS = 252; // Note that we sha256 hash some values from the contract, pass in the hash // as a public input, and pass in said values as private inputs. This saves @@ -277,6 +279,7 @@ template ProcessMessages( component processors[batchSize]; // topup type processor component processors2[batchSize]; + for (var i = batchSize - 1; i >= 0; i --) { // process it as vote type message processors[i] = ProcessOne(stateTreeDepth, voteOptionTreeDepth); @@ -349,6 +352,7 @@ template ProcessMessages( <== currentStateLeavesPathElements[i][j][k]; } } + // pick the correct result by msg type tmpStateRoot1[i] <== processors[i].newStateRoot * (2 - msgs[i][0]); tmpStateRoot2[i] <== processors2[i].newStateRoot * (msgs[i][0] - 1); @@ -378,6 +382,8 @@ template ProcessTopup(stateTreeDepth) { var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + var N_BITS = 252; + signal input msgType; signal input stateTreeIndex; signal input amount; @@ -395,9 +401,9 @@ template ProcessTopup(stateTreeDepth) { // msgType of topup command is 2 amt <== amount * (msgType - 1); index <== stateTreeIndex * (msgType - 1); - component validCreditBalance = LessEqThan(252); + component validCreditBalance = LessEqThan(N_BITS); // check stateIndex, if invalid index, set index and amount to zero - component validStateLeafIndex = LessEqThan(252); + component validStateLeafIndex = LessEqThan(N_BITS); validStateLeafIndex.in[0] <== index; validStateLeafIndex.in[1] <== numSignUps; @@ -462,6 +468,8 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + var N_BITS = 252; + signal input msgType; signal input numSignUps; signal input maxVoteOptions; @@ -526,7 +534,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { // ----------------------------------------------------------------------- // 2. If msgType = 0 and isValid is 0, generate indices for leaf 0 - // Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType + // Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType signal indexByType; signal tmpIndex1; signal tmpIndex2; @@ -534,14 +542,26 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { tmpIndex2 <== topupStateIndex * (msgType - 1); indexByType <== tmpIndex1 + tmpIndex2; - component stateIndexMux = Mux1(); - stateIndexMux.s <== transformer.isValid + msgType - 1; - stateIndexMux.c[0] <== 0; - stateIndexMux.c[1] <== indexByType; + // we can validate if the state index is within the numSignups + // if not, we use 0 + // this is because decryption of an invalid message + // might result in random packed vals + component validStateLeafIndex = SafeLessThan(N_BITS); + validStateLeafIndex.in[0] <== indexByType; + validStateLeafIndex.in[1] <== numSignUps; - component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth); - stateLeafPathIndices.in <== stateIndexMux.out; + // use a mux to pick the correct index + component indexMux = Mux1(); + indexMux.s <== validStateLeafIndex.out; + indexMux.c[0] <== 0; + indexMux.c[1] <== indexByType; + // @note that we expect a coordinator to send the state leaf corresponding to a message + // which specifies a valid state index. If this is not the case, the + // proof will fail to generate. + component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth); + stateLeafPathIndices.in <== indexMux.out; + // ----------------------------------------------------------------------- // 3. Verify that the original state leaf exists in the given state root component stateLeafQip = QuinTreeInclusionProof(stateTreeDepth); @@ -572,6 +592,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { ballotQip.path_elements[i][j] <== ballotPathElements[i][j]; } } + ballotQip.root === currentBallotRoot; // ----------------------------------------------------------------------- @@ -583,7 +604,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { b <== currentVoteWeight * currentVoteWeight; c <== cmdNewVoteWeight * cmdNewVoteWeight; - component enoughVoiceCredits = SafeGreaterEqThan(252); + component enoughVoiceCredits = SafeGreaterEqThan(N_BITS); enoughVoiceCredits.in[0] <== stateLeaf[STATE_LEAF_VOICE_CREDIT_BALANCE_IDX] + b; enoughVoiceCredits.in[1] <== c; diff --git a/circuits/circom/processMessagesNonQv.circom b/circuits/circom/processMessagesNonQv.circom index 82d114c43b..25fe5e36b9 100644 --- a/circuits/circom/processMessagesNonQv.circom +++ b/circuits/circom/processMessagesNonQv.circom @@ -49,6 +49,8 @@ template ProcessMessagesNonQv( var STATE_LEAF_PUB_Y_IDX = 1; var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + + var N_BITS = 252; // Note that we sha256 hash some values from the contract, pass in the hash // as a public input, and pass in said values as private inputs. This saves @@ -389,6 +391,8 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + var N_BITS = 252; + signal input msgType; signal input numSignUps; signal input maxVoteOptions; @@ -461,13 +465,25 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { tmpIndex2 <== topupStateIndex * (msgType - 1); indexByType <== tmpIndex1 + tmpIndex2; - component stateIndexMux = Mux1(); - stateIndexMux.s <== transformer.isValid + msgType - 1; - stateIndexMux.c[0] <== 0; - stateIndexMux.c[1] <== indexByType; - + // we can validate if the state index is within the numSignups + // if not, we use 0 + // this is because decryption of an invalid message + // might result in random packed vals + component validStateLeafIndex = SafeLessThan(N_BITS); + validStateLeafIndex.in[0] <== indexByType; + validStateLeafIndex.in[1] <== numSignUps; + + // use a mux to pick the correct index + component indexMux = Mux1(); + indexMux.s <== validStateLeafIndex.out; + indexMux.c[0] <== 0; + indexMux.c[1] <== indexByType; + + // @note that we expect a coordinator to send the state leaf corresponding to a message + // which specifies a valid state index. If this is not the case, the + // proof will fail to generate. component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth); - stateLeafPathIndices.in <== stateIndexMux.out; + stateLeafPathIndices.in <== indexMux.out; // ----------------------------------------------------------------------- // 3. Verify that the original state leaf exists in the given state root @@ -510,7 +526,7 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { b <== currentVoteWeight; c <== cmdNewVoteWeight; - component enoughVoiceCredits = SafeGreaterEqThan(252); + component enoughVoiceCredits = SafeGreaterEqThan(N_BITS); enoughVoiceCredits.in[0] <== stateLeaf[STATE_LEAF_VOICE_CREDIT_BALANCE_IDX] + b; enoughVoiceCredits.in[1] <== c; diff --git a/circuits/circom/stateLeafAndBallotTransformer.circom b/circuits/circom/stateLeafAndBallotTransformer.circom index 609c688f1f..79dde80d33 100644 --- a/circuits/circom/stateLeafAndBallotTransformer.circom +++ b/circuits/circom/stateLeafAndBallotTransformer.circom @@ -86,7 +86,7 @@ template StateLeafAndBallotTransformer() { messageValidator.voteWeight <== cmdNewVoteWeight; // if the message is valid then we swap out the public key - // we have to do this in two Mux one for pucKey[0] + // we have to do this in two Mux one for pubKey[0] // and one for pubKey[1] component newSlPubKey0Mux = Mux1(); newSlPubKey0Mux.s <== messageValidator.isValid; diff --git a/circuits/ts/__tests__/ProcessMessages.test.ts b/circuits/ts/__tests__/ProcessMessages.test.ts index 6959fe5a0b..f5ed57ce5d 100644 --- a/circuits/ts/__tests__/ProcessMessages.test.ts +++ b/circuits/ts/__tests__/ProcessMessages.test.ts @@ -2,7 +2,7 @@ import { expect } from "chai"; import { type WitnessTester } from "circomkit"; import { MaciState, Poll, packProcessMessageSmallVals, STATE_TREE_ARITY } from "maci-core"; import { hash5, IncrementalQuinTree, NOTHING_UP_MY_SLEEVE, AccQueue } from "maci-crypto"; -import { PrivKey, Keypair, PCommand, Message, Ballot } from "maci-domainobjs"; +import { PrivKey, Keypair, PCommand, Message, Ballot, PubKey } from "maci-domainobjs"; import { IProcessMessagesInputs } from "../types"; @@ -76,7 +76,7 @@ describe("ProcessMessage circuit", function test() { describe("1 user, 2 messages", () => { const maciState = new MaciState(STATE_TREE_DEPTH); const voteWeight = BigInt(9); - const voteOptionIndex = BigInt(0); + const voteOptionIndex = BigInt(1); let stateIndex: bigint; let pollId: bigint; let poll: Poll; @@ -101,6 +101,26 @@ describe("ProcessMessage circuit", function test() { poll = maciState.polls.get(pollId)!; poll.updatePoll(BigInt(maciState.stateLeaves.length)); + const nothing = new Message(1n, [ + 8370432830353022751713833565135785980866757267633941821328460903436894336785n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + ]); + + const encP = new PubKey([ + 10457101036533406547632367118273992217979173478358440826365724437999023779287n, + 19824078218392094440610104313265183977899662750282163392862422243483260492317n, + ]); + + poll.publishMessage(nothing, encP); + // First command (valid) const command = new PCommand( stateIndex, // BigInt(1), @@ -144,6 +164,7 @@ describe("ProcessMessage circuit", function test() { STATE_TREE_ARITY, NOTHING_UP_MY_SLEEVE, ); + accumulatorQueue.enqueue(nothing.hash(encP)); accumulatorQueue.enqueue(message.hash(ecdhKeypair.pubKey)); accumulatorQueue.enqueue(message2.hash(ecdhKeypair2.pubKey)); accumulatorQueue.mergeSubRoots(0); @@ -187,7 +208,7 @@ describe("ProcessMessage circuit", function test() { BigInt(maxValues.maxVoteOptions), BigInt(poll.maciStateRef.numSignUps), 0, - 2, + 3, ); // Test the ProcessMessagesInputHasher circuit @@ -554,7 +575,7 @@ describe("ProcessMessage circuit", function test() { // Second batch is not a full batch const numMessages = messageBatchSize * NUM_BATCHES - 1; - for (let i = 0; i < numMessages; i += 1) { + for (let i = 0; i < 6; i += 1) { const command = new PCommand( BigInt(index), userKeypair.pubKey, @@ -572,7 +593,7 @@ describe("ProcessMessage circuit", function test() { selectedPoll?.publishMessage(message, ecdhKeypair.pubKey); } - for (let i = 0; i < NUM_BATCHES; i += 1) { + for (let i = 0; i < 2; i += 1) { const inputs = selectedPoll?.processMessages(id) as unknown as IProcessMessagesInputs; // eslint-disable-next-line no-await-in-loop const witness = await circuit.calculateWitness(inputs); diff --git a/cli/tests/e2e/e2e.test.ts b/cli/tests/e2e/e2e.test.ts index 3cb22d2e87..dae2712f6f 100644 --- a/cli/tests/e2e/e2e.test.ts +++ b/cli/tests/e2e/e2e.test.ts @@ -54,7 +54,7 @@ import { cleanVanilla, isArm } from "../utils"; /** Test scenarios: 1 signup, 1 message - 4 signups, 6 messages + 4 signups, 8 messages 5 signups, 1 message 8 signups, 10 messages 4 signups, 4 messages @@ -195,7 +195,7 @@ describe("e2e tests", function test() { }); }); - describe("4 signups, 6 messages", () => { + describe("4 signups, 8 messages", () => { after(() => { cleanVanilla(); }); @@ -217,7 +217,31 @@ describe("e2e tests", function test() { } }); - it("should publish six messages", async () => { + it("should publish eight messages", async () => { + await publish({ + pubkey: users[0].pubKey.serialize(), + stateIndex: 1n, + voteOptionIndex: 0n, + nonce: 2n, + pollId: 0n, + newVoteWeight: 4n, + maciContractAddress: maciAddresses.maciAddress, + salt: genRandomSalt(), + privateKey: users[0].privKey.serialize(), + signer, + }); + await publish({ + pubkey: users[0].pubKey.serialize(), + stateIndex: 1n, + voteOptionIndex: 0n, + nonce: 2n, + pollId: 0n, + newVoteWeight: 3n, + maciContractAddress: maciAddresses.maciAddress, + salt: genRandomSalt(), + privateKey: users[0].privKey.serialize(), + signer, + }); await publish({ pubkey: users[0].pubKey.serialize(), stateIndex: 1n, @@ -233,7 +257,7 @@ describe("e2e tests", function test() { await publish({ pubkey: users[1].pubKey.serialize(), stateIndex: 2n, - voteOptionIndex: 0n, + voteOptionIndex: 2n, nonce: 1n, pollId: 0n, newVoteWeight: 9n, @@ -245,7 +269,7 @@ describe("e2e tests", function test() { await publish({ pubkey: users[2].pubKey.serialize(), stateIndex: 3n, - voteOptionIndex: 0n, + voteOptionIndex: 2n, nonce: 1n, pollId: 0n, newVoteWeight: 9n, @@ -257,10 +281,10 @@ describe("e2e tests", function test() { await publish({ pubkey: users[3].pubKey.serialize(), stateIndex: 4n, - voteOptionIndex: 0n, - nonce: 1n, + voteOptionIndex: 2n, + nonce: 3n, pollId: 0n, - newVoteWeight: 9n, + newVoteWeight: 3n, maciContractAddress: maciAddresses.maciAddress, salt: genRandomSalt(), privateKey: users[3].privKey.serialize(), @@ -269,10 +293,10 @@ describe("e2e tests", function test() { await publish({ pubkey: users[3].pubKey.serialize(), stateIndex: 4n, - voteOptionIndex: 0n, - nonce: 1n, + voteOptionIndex: 2n, + nonce: 2n, pollId: 0n, - newVoteWeight: 9n, + newVoteWeight: 2n, maciContractAddress: maciAddresses.maciAddress, salt: genRandomSalt(), privateKey: users[3].privKey.serialize(), @@ -281,7 +305,7 @@ describe("e2e tests", function test() { await publish({ pubkey: users[3].pubKey.serialize(), stateIndex: 4n, - voteOptionIndex: 0n, + voteOptionIndex: 1n, nonce: 1n, pollId: 0n, newVoteWeight: 9n, diff --git a/cli/ts/commands/genProofs.ts b/cli/ts/commands/genProofs.ts index d92ea3f982..603b507ede 100644 --- a/cli/ts/commands/genProofs.ts +++ b/cli/ts/commands/genProofs.ts @@ -279,6 +279,7 @@ export const genProofs = async ({ while (poll.hasUnprocessedMessages()) { // process messages in batches const circuitInputs = poll.processMessages(pollId, useQuadraticVoting, quiet) as unknown as CircuitInputs; + try { // generate the proof for this batch // eslint-disable-next-line no-await-in-loop @@ -290,11 +291,12 @@ export const genProofs = async ({ witnessExePath: processWitgen, wasmPath: processWasm, }); + // verify it // eslint-disable-next-line no-await-in-loop const isValid = await verifyProof(r.publicSignals, r.proof, processVk); if (!isValid) { - logError("Error: generated an invalid proof"); + throw new Error("Generated an invalid proof"); } const thisProof = { diff --git a/core/ts/Poll.ts b/core/ts/Poll.ts index 7d8dfddf69..526d85d7d9 100644 --- a/core/ts/Poll.ts +++ b/core/ts/Poll.ts @@ -554,21 +554,62 @@ export class Poll implements IPoll { console.log(`Error at message index ${idx} - ${e.message}`); } - // Since the command is invalid, use a blank state leaf - currentStateLeaves.unshift(this.stateLeaves[0].copy()); - currentStateLeavesPathElements.unshift(this.stateTree!.genProof(0).pathElements); - // since the command is invalid we use the blank ballot - currentBallots.unshift(this.ballots[0].copy()); - currentBallotsPathElements.unshift(this.ballotTree!.genProof(0).pathElements); - - // Since the command is invalid, we use a zero vote weight - currentVoteWeights.unshift(this.ballots[0].votes[0]); - - // create a new quinary tree and add an empty vote - const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); - vt.insert(this.ballots[0].votes[0]); - // get the path elements for this empty vote weight leaf - currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); + // @note we want to send the correct state leaf to the circuit + // even if a message is invalid + // this way if a message is invalid we can still generate a proof of processing + // we also want to prevent a DoS attack by a voter + // which sends a message that when force decrypted on the circuit + // results in a valid state index thus forcing the circuit to look + // for a valid state leaf, and failing to generate a proof + + // gen shared key + const sharedKey = Keypair.genEcdhSharedKey(this.coordinatorKeypair.privKey, encPubKey); + + // force decrypt it + const { command } = PCommand.decrypt(message, sharedKey, true); + + // cache state leaf index + const stateLeafIndex = command.stateIndex; + + // if the state leaf index is valid then use it + if (stateLeafIndex < this.stateLeaves.length) { + currentStateLeaves.unshift(this.stateLeaves[Number(stateLeafIndex)].copy()); + currentStateLeavesPathElements.unshift(this.stateTree!.genProof(Number(stateLeafIndex)).pathElements); + + // copy the ballot + const ballot = this.ballots[Number(stateLeafIndex)].copy(); + currentBallots.unshift(ballot); + currentBallotsPathElements.unshift(this.ballotTree!.genProof(Number(stateLeafIndex)).pathElements); + + // add the first vote of this ballot + currentVoteWeights.unshift(ballot.votes[0]); + + // create a new quinary tree and add all votes we have so far + const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); + + // fill the vote option tree with the votes we have so far + for (let j = 0; j < this.ballots[0].votes.length; j += 1) { + vt.insert(ballot.votes[j]); + } + + // get the path elements for the first vote leaf + currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); + } else { + // just use state leaf index 0 + currentStateLeaves.unshift(this.stateLeaves[0].copy()); + currentStateLeavesPathElements.unshift(this.stateTree!.genProof(0).pathElements); + currentBallots.unshift(this.ballots[0].copy()); + currentBallotsPathElements.unshift(this.ballotTree!.genProof(0).pathElements); + + // Since the command is invalid, we use a zero vote weight + currentVoteWeights.unshift(this.ballots[0].votes[0]); + + // create a new quinary tree and add an empty vote + const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); + vt.insert(this.ballots[0].votes[0]); + // get the path elements for this empty vote weight leaf + currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); + } } else { throw e; } @@ -704,13 +745,33 @@ export class Poll implements IPoll { // fill the msgs array with a copy of the messages we have // plus empty messages to fill the batch + + // @note create a message with state index 0 to add as padding + // this way the message will look for state leaf 0 + // and no effect will take place + + // create a random key + const key = new Keypair(); + // gen ecdh key + const ecdh = Keypair.genEcdhSharedKey(key.privKey, this.coordinatorKeypair.pubKey); + // create an empty command with state index 0n + const emptyCommand = new PCommand(0n, key.pubKey, 0n, 0n, 0n, 0n, 0n); + + // encrypt it + const msg = emptyCommand.encrypt(emptyCommand.sign(key.privKey), ecdh); + + // copy the messages to a new array let msgs = this.messages.map((x) => x.asCircuitInputs()); + + // pad with our state index 0 message while (msgs.length % messageBatchSize > 0) { - msgs.push(msgs[msgs.length - 1]); + msgs.push(msg.asCircuitInputs()); } + // we only take the messages we need for this batch msgs = msgs.slice(index, index + messageBatchSize); + // insert zero value in the message tree as padding while (this.messageTree.nextIndex < index + messageBatchSize) { this.messageTree.insert(this.messageTree.zeroValue); } @@ -718,6 +779,7 @@ export class Poll implements IPoll { // generate the path to the subroot of the message tree for this batch const messageSubrootPath = this.messageTree.genSubrootProof(index, index + messageBatchSize); + // verify it assert(this.messageTree.verifyProof(messageSubrootPath), "The message subroot path is invalid"); // validate that the batch index is correct, if not fix it @@ -730,11 +792,13 @@ export class Poll implements IPoll { // copy the public keys, pad the array with the last keys if needed let encPubKeys = this.encPubKeys.map((x) => x.copy()); while (encPubKeys.length % messageBatchSize > 0) { - encPubKeys.push(encPubKeys[encPubKeys.length - 1]); + // pad with the public key used to encrypt the message with state index 0 (padding) + encPubKeys.push(key.pubKey.copy()); } // then take the ones part of this batch encPubKeys = encPubKeys.slice(index, index + messageBatchSize); + // cache tree roots const msgRoot = this.messageTree.root; const currentStateRoot = this.stateTree!.root; const currentBallotRoot = this.ballotTree!.root; diff --git a/crypto/ts/index.ts b/crypto/ts/index.ts index 060b5f44e9..45b1f5fa04 100644 --- a/crypto/ts/index.ts +++ b/crypto/ts/index.ts @@ -23,7 +23,7 @@ export { G1Point, G2Point, genRandomBabyJubValue } from "./babyjub"; export { sha256Hash, hashLeftRight, hashN, hash2, hash3, hash4, hash5, hash13, hashOne } from "./hashing"; -export { poseidonDecrypt, poseidonEncrypt } from "@zk-kit/poseidon-cipher"; +export { poseidonDecrypt, poseidonDecryptWithoutCheck, poseidonEncrypt } from "@zk-kit/poseidon-cipher"; export { verifySignature, signMessage as sign } from "@zk-kit/eddsa-poseidon"; diff --git a/domainobjs/ts/commands/PCommand.ts b/domainobjs/ts/commands/PCommand.ts index b7798596f2..99194fcfb2 100644 --- a/domainobjs/ts/commands/PCommand.ts +++ b/domainobjs/ts/commands/PCommand.ts @@ -9,6 +9,7 @@ import { type Ciphertext, type EcdhSharedKey, type Point, + poseidonDecryptWithoutCheck, } from "maci-crypto"; import assert from "assert"; @@ -172,11 +173,16 @@ export class PCommand implements ICommand { /** * Decrypts a Message to produce a Command. + * @dev You can force decrypt the message by setting `force` to true. + * This is useful in case you don't want an invalid message to throw an error. * @param {Message} message - the message to decrypt * @param {EcdhSharedKey} sharedKey - the shared key to use for decryption + * @param {boolean} force - whether to force decryption or not */ - static decrypt = (message: Message, sharedKey: EcdhSharedKey): IDecryptMessage => { - const decrypted = poseidonDecrypt(message.data, sharedKey, BigInt(0), 7); + static decrypt = (message: Message, sharedKey: EcdhSharedKey, force = false): IDecryptMessage => { + const decrypted = force + ? poseidonDecryptWithoutCheck(message.data, sharedKey, BigInt(0), 7) + : poseidonDecrypt(message.data, sharedKey, BigInt(0), 7); const p = BigInt(decrypted[0].toString()); From 22e091d2941b5e86ecd7df2f228bd0e10b6c4c47 Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Thu, 15 Feb 2024 11:46:59 +0000 Subject: [PATCH 2/2] fix(circuits): fix coordinator censoring by passing currentVoteWeight = 0 Prevent coordinator censoring a valid second message by passing the currentVoteWeight equal to a number which would result in not enough voice credits in the circuit --- circuits/circom/messageValidator.circom | 1 + circuits/circom/processMessages.circom | 10 +- circuits/circom/processMessagesNonQv.circom | 10 +- .../stateLeafAndBallotTransformerNonQv.circom | 2 +- circuits/ts/__tests__/ProcessMessages.test.ts | 388 +++++++++++++++++- core/ts/Poll.ts | 52 ++- 6 files changed, 447 insertions(+), 16 deletions(-) diff --git a/circuits/circom/messageValidator.circom b/circuits/circom/messageValidator.circom index fae47b6a87..45fa31f3f9 100644 --- a/circuits/circom/messageValidator.circom +++ b/circuits/circom/messageValidator.circom @@ -20,6 +20,7 @@ template MessageValidator() { validStateLeafIndex.in[0] <== stateTreeIndex; validStateLeafIndex.in[1] <== numSignUps; + // @todo check if we need this if we do the check inside processOne // b) Whether the max vote option tree index is correct signal input voteOptionIndex; signal input maxVoteOptions; diff --git a/circuits/circom/processMessages.circom b/circuits/circom/processMessages.circom index d1c2f2db2c..f2db72eca6 100644 --- a/circuits/circom/processMessages.circom +++ b/circuits/circom/processMessages.circom @@ -613,8 +613,16 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { isMessageValid.in[0] <== bothValid; isMessageValid.in[1] <== transformer.isValid + enoughVoiceCredits.out; + // check that the vote option index is < maxVoteOptions (0-indexed) + component validVoteOptionIndex = SafeLessThan(N_BITS); + validVoteOptionIndex.in[0] <== cmdVoteOptionIndex; + validVoteOptionIndex.in[1] <== maxVoteOptions; + + // @note pick the correct vote option index based on whether the index is < max vote options + // @todo can probably add one output to messageValidator and take from there + // or maybe we can remove altogther from messageValidator so we don't double check this component cmdVoteOptionIndexMux = Mux1(); - cmdVoteOptionIndexMux.s <== isMessageValid.out; + cmdVoteOptionIndexMux.s <== validVoteOptionIndex.out; cmdVoteOptionIndexMux.c[0] <== 0; cmdVoteOptionIndexMux.c[1] <== cmdVoteOptionIndex; diff --git a/circuits/circom/processMessagesNonQv.circom b/circuits/circom/processMessagesNonQv.circom index 25fe5e36b9..56d44d8e28 100644 --- a/circuits/circom/processMessagesNonQv.circom +++ b/circuits/circom/processMessagesNonQv.circom @@ -535,8 +535,16 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { isMessageValid.in[0] <== bothValid; isMessageValid.in[1] <== transformer.isValid + enoughVoiceCredits.out; + // check that the vote option index is < maxVoteOptions (0-indexed) + component validVoteOptionIndex = SafeLessThan(N_BITS); + validVoteOptionIndex.in[0] <== cmdVoteOptionIndex; + validVoteOptionIndex.in[1] <== maxVoteOptions; + + // @note pick the correct vote option index based on whether the index is < max vote options + // @todo can probably add one output to messageValidator and take from there + // or maybe we can remove altogther from messageValidator so we don't double check this component cmdVoteOptionIndexMux = Mux1(); - cmdVoteOptionIndexMux.s <== isMessageValid.out; + cmdVoteOptionIndexMux.s <== validVoteOptionIndex.out; cmdVoteOptionIndexMux.c[0] <== 0; cmdVoteOptionIndexMux.c[1] <== cmdVoteOptionIndex; diff --git a/circuits/circom/stateLeafAndBallotTransformerNonQv.circom b/circuits/circom/stateLeafAndBallotTransformerNonQv.circom index 3a3a8b3e6f..0055c84a14 100644 --- a/circuits/circom/stateLeafAndBallotTransformerNonQv.circom +++ b/circuits/circom/stateLeafAndBallotTransformerNonQv.circom @@ -86,7 +86,7 @@ template StateLeafAndBallotTransformerNonQv() { messageValidator.voteWeight <== cmdNewVoteWeight; // if the message is valid then we swap out the public key - // we have to do this in two Mux one for pucKey[0] + // we have to do this in two Mux one for pubKey[0] // and one for pubKey[1] component newSlPubKey0Mux = Mux1(); newSlPubKey0Mux.s <== messageValidator.isValid; diff --git a/circuits/ts/__tests__/ProcessMessages.test.ts b/circuits/ts/__tests__/ProcessMessages.test.ts index f5ed57ce5d..e1b9f6b440 100644 --- a/circuits/ts/__tests__/ProcessMessages.test.ts +++ b/circuits/ts/__tests__/ProcessMessages.test.ts @@ -575,7 +575,7 @@ describe("ProcessMessage circuit", function test() { // Second batch is not a full batch const numMessages = messageBatchSize * NUM_BATCHES - 1; - for (let i = 0; i < 6; i += 1) { + for (let i = 0; i < numMessages; i += 1) { const command = new PCommand( BigInt(index), userKeypair.pubKey, @@ -674,4 +674,390 @@ describe("ProcessMessage circuit", function test() { await circuit.expectConstraintPass(witness); }); }); + + describe("1 user, 2 messages", () => { + const maciState = new MaciState(STATE_TREE_DEPTH); + const voteOptionIndex = 1n; + let stateIndex: bigint; + let pollId: bigint; + let poll: Poll; + const messages: Message[] = []; + const commands: PCommand[] = []; + + before(() => { + // Sign up and publish + const userKeypair = new Keypair(new PrivKey(BigInt(1))); + stateIndex = BigInt( + maciState.signUp(userKeypair.pubKey, voiceCreditBalance, BigInt(Math.floor(Date.now() / 1000))), + ); + + pollId = maciState.deployPoll( + BigInt(Math.floor(Date.now() / 1000) + duration), + maxValues, + treeDepths, + messageBatchSize, + coordinatorKeypair, + ); + + poll = maciState.polls.get(pollId)!; + poll.updatePoll(BigInt(maciState.stateLeaves.length)); + + const nothing = new Message(1n, [ + 8370432830353022751713833565135785980866757267633941821328460903436894336785n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + ]); + + const encP = new PubKey([ + 10457101036533406547632367118273992217979173478358440826365724437999023779287n, + 19824078218392094440610104313265183977899662750282163392862422243483260492317n, + ]); + + poll.publishMessage(nothing, encP); + + // First command (valid) + const command = new PCommand( + stateIndex, // BigInt(1), + userKeypair.pubKey, + 1n, // voteOptionIndex, + 2n, // vote weight + 2n, // nonce + pollId, + ); + + const signature = command.sign(userKeypair.privKey); + + const ecdhKeypair = new Keypair(); + const sharedKey = Keypair.genEcdhSharedKey(ecdhKeypair.privKey, coordinatorKeypair.pubKey); + const message = command.encrypt(signature, sharedKey); + messages.push(message); + commands.push(command); + + poll.publishMessage(message, ecdhKeypair.pubKey); + + // Second command (valid) + const command2 = new PCommand( + stateIndex, + userKeypair.pubKey, + voteOptionIndex, // voteOptionIndex, + 9n, // vote weight 9 ** 2 = 81 + 1n, // nonce + pollId, + ); + const signature2 = command2.sign(userKeypair.privKey); + + const ecdhKeypair2 = new Keypair(); + const sharedKey2 = Keypair.genEcdhSharedKey(ecdhKeypair2.privKey, coordinatorKeypair.pubKey); + const message2 = command2.encrypt(signature2, sharedKey2); + messages.push(message2); + commands.push(command2); + poll.publishMessage(message2, ecdhKeypair2.pubKey); + }); + + it("should produce the correct state root and ballot root", async () => { + // The current roots + const emptyBallot = new Ballot(poll.maxValues.maxVoteOptions, poll.treeDepths.voteOptionTreeDepth); + const emptyBallotHash = emptyBallot.hash(); + const ballotTree = new IncrementalQuinTree(STATE_TREE_DEPTH, emptyBallot.hash(), STATE_TREE_ARITY, hash5); + + ballotTree.insert(emptyBallot.hash()); + + poll.stateLeaves.forEach(() => { + ballotTree.insert(emptyBallotHash); + }); + + const currentStateRoot = poll.stateTree?.root; + const currentBallotRoot = ballotTree.root; + + const inputs = poll.processMessages(pollId) as unknown as IProcessMessagesInputs; + + // Calculate the witness + const witness = await circuit.calculateWitness(inputs); + await circuit.expectConstraintPass(witness); + + // The new roots, which should differ, since at least one of the + // messages modified a Ballot or State Leaf + const newStateRoot = poll.stateTree?.root; + const newBallotRoot = poll.ballotTree?.root; + + expect(newStateRoot?.toString()).not.to.be.eq(currentStateRoot?.toString()); + expect(newBallotRoot?.toString()).not.to.be.eq(currentBallotRoot.toString()); + }); + }); + + describe("1 user, 2 messages in different batches", () => { + const maciState = new MaciState(STATE_TREE_DEPTH); + const voteOptionIndex = 1n; + let stateIndex: bigint; + let pollId: bigint; + let poll: Poll; + const messages: Message[] = []; + const commands: PCommand[] = []; + + before(() => { + // Sign up and publish + const userKeypair = new Keypair(new PrivKey(BigInt(1))); + stateIndex = BigInt( + maciState.signUp(userKeypair.pubKey, voiceCreditBalance, BigInt(Math.floor(Date.now() / 1000))), + ); + + pollId = maciState.deployPoll( + BigInt(Math.floor(Date.now() / 1000) + duration), + maxValues, + treeDepths, + messageBatchSize, + coordinatorKeypair, + ); + + poll = maciState.polls.get(pollId)!; + poll.updatePoll(BigInt(maciState.stateLeaves.length)); + + const nothing = new Message(1n, [ + 8370432830353022751713833565135785980866757267633941821328460903436894336785n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + ]); + + const encP = new PubKey([ + 10457101036533406547632367118273992217979173478358440826365724437999023779287n, + 19824078218392094440610104313265183977899662750282163392862422243483260492317n, + ]); + + poll.publishMessage(nothing, encP); + + // First command (valid) + const command = new PCommand( + stateIndex, // BigInt(1), + userKeypair.pubKey, + 1n, // voteOptionIndex, + 2n, // vote weight + 2n, // nonce + pollId, + ); + + const signature = command.sign(userKeypair.privKey); + + const ecdhKeypair = new Keypair(); + const sharedKey = Keypair.genEcdhSharedKey(ecdhKeypair.privKey, coordinatorKeypair.pubKey); + const message = command.encrypt(signature, sharedKey); + messages.push(message); + commands.push(command); + + poll.publishMessage(message, ecdhKeypair.pubKey); + + // fill the batch with nothing messages + for (let i = 0; i < messageBatchSize - 1; i += 1) { + poll.publishMessage(nothing, encP); + } + + // Second command (valid) in second batch (which is first due to reverse processing) + const command2 = new PCommand( + stateIndex, + userKeypair.pubKey, + voteOptionIndex, // voteOptionIndex, + 9n, // vote weight 9 ** 2 = 81 + 1n, // nonce + pollId, + ); + const signature2 = command2.sign(userKeypair.privKey); + + const ecdhKeypair2 = new Keypair(); + const sharedKey2 = Keypair.genEcdhSharedKey(ecdhKeypair2.privKey, coordinatorKeypair.pubKey); + const message2 = command2.encrypt(signature2, sharedKey2); + messages.push(message2); + commands.push(command2); + poll.publishMessage(message2, ecdhKeypair2.pubKey); + }); + + it("should produce the correct state root and ballot root", async () => { + // The current roots + const emptyBallot = new Ballot(poll.maxValues.maxVoteOptions, poll.treeDepths.voteOptionTreeDepth); + const emptyBallotHash = emptyBallot.hash(); + const ballotTree = new IncrementalQuinTree(STATE_TREE_DEPTH, emptyBallot.hash(), STATE_TREE_ARITY, hash5); + + ballotTree.insert(emptyBallot.hash()); + + poll.stateLeaves.forEach(() => { + ballotTree.insert(emptyBallotHash); + }); + + while (poll.hasUnprocessedMessages()) { + const currentStateRoot = poll.stateTree?.root; + const currentBallotRoot = ballotTree.root; + const inputs = poll.processMessages(pollId) as unknown as IProcessMessagesInputs; + + // Calculate the witness + // eslint-disable-next-line no-await-in-loop + const witness = await circuit.calculateWitness(inputs); + // eslint-disable-next-line no-await-in-loop + await circuit.expectConstraintPass(witness); + + // The new roots, which should differ, since at least one of the + // messages modified a Ballot or State Leaf + const newStateRoot = poll.stateTree?.root; + const newBallotRoot = poll.ballotTree?.root; + + expect(newStateRoot?.toString()).not.to.be.eq(currentStateRoot?.toString()); + expect(newBallotRoot?.toString()).not.to.be.eq(currentBallotRoot.toString()); + } + }); + }); + + describe("1 user, 3 messages in different batches", () => { + const maciState = new MaciState(STATE_TREE_DEPTH); + const voteOptionIndex = 1n; + let stateIndex: bigint; + let pollId: bigint; + let poll: Poll; + const messages: Message[] = []; + const commands: PCommand[] = []; + + before(() => { + // Sign up and publish + const userKeypair = new Keypair(new PrivKey(BigInt(1))); + stateIndex = BigInt( + maciState.signUp(userKeypair.pubKey, voiceCreditBalance, BigInt(Math.floor(Date.now() / 1000))), + ); + + pollId = maciState.deployPoll( + BigInt(Math.floor(Date.now() / 1000) + duration), + maxValues, + treeDepths, + messageBatchSize, + coordinatorKeypair, + ); + + poll = maciState.polls.get(pollId)!; + poll.updatePoll(BigInt(maciState.stateLeaves.length)); + + const nothing = new Message(1n, [ + 8370432830353022751713833565135785980866757267633941821328460903436894336785n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + ]); + + const encP = new PubKey([ + 10457101036533406547632367118273992217979173478358440826365724437999023779287n, + 19824078218392094440610104313265183977899662750282163392862422243483260492317n, + ]); + + poll.publishMessage(nothing, encP); + + const commandFinal = new PCommand( + stateIndex, // BigInt(1), + userKeypair.pubKey, + 1n, // voteOptionIndex, + 1n, // vote weight + 3n, // nonce + pollId, + ); + + const signatureFinal = commandFinal.sign(userKeypair.privKey); + + const ecdhKeypairFinal = new Keypair(); + const sharedKeyFinal = Keypair.genEcdhSharedKey(ecdhKeypairFinal.privKey, coordinatorKeypair.pubKey); + const messageFinal = commandFinal.encrypt(signatureFinal, sharedKeyFinal); + messages.push(messageFinal); + commands.push(commandFinal); + + poll.publishMessage(messageFinal, ecdhKeypairFinal.pubKey); + + // First command (valid) + const command = new PCommand( + stateIndex, // BigInt(1), + userKeypair.pubKey, + 1n, // voteOptionIndex, + 2n, // vote weight + 2n, // nonce + pollId, + ); + + const signature = command.sign(userKeypair.privKey); + + const ecdhKeypair = new Keypair(); + const sharedKey = Keypair.genEcdhSharedKey(ecdhKeypair.privKey, coordinatorKeypair.pubKey); + const message = command.encrypt(signature, sharedKey); + messages.push(message); + commands.push(command); + + poll.publishMessage(message, ecdhKeypair.pubKey); + + // fill the batch with nothing messages + for (let i = 0; i < messageBatchSize - 1; i += 1) { + poll.publishMessage(nothing, encP); + } + + // Second command (valid) in second batch (which is first due to reverse processing) + const command2 = new PCommand( + stateIndex, + userKeypair.pubKey, + voteOptionIndex, // voteOptionIndex, + 9n, // vote weight 9 ** 2 = 81 + 1n, // nonce + pollId, + ); + const signature2 = command2.sign(userKeypair.privKey); + + const ecdhKeypair2 = new Keypair(); + const sharedKey2 = Keypair.genEcdhSharedKey(ecdhKeypair2.privKey, coordinatorKeypair.pubKey); + const message2 = command2.encrypt(signature2, sharedKey2); + messages.push(message2); + commands.push(command2); + poll.publishMessage(message2, ecdhKeypair2.pubKey); + }); + + it("should produce the correct state root and ballot root", async () => { + // The current roots + const emptyBallot = new Ballot(poll.maxValues.maxVoteOptions, poll.treeDepths.voteOptionTreeDepth); + const emptyBallotHash = emptyBallot.hash(); + const ballotTree = new IncrementalQuinTree(STATE_TREE_DEPTH, emptyBallot.hash(), STATE_TREE_ARITY, hash5); + + ballotTree.insert(emptyBallot.hash()); + + poll.stateLeaves.forEach(() => { + ballotTree.insert(emptyBallotHash); + }); + + while (poll.hasUnprocessedMessages()) { + const currentStateRoot = poll.stateTree?.root; + const currentBallotRoot = ballotTree.root; + const inputs = poll.processMessages(pollId) as unknown as IProcessMessagesInputs; + + // Calculate the witness + // eslint-disable-next-line no-await-in-loop + const witness = await circuit.calculateWitness(inputs); + // eslint-disable-next-line no-await-in-loop + await circuit.expectConstraintPass(witness); + + // The new roots, which should differ, since at least one of the + // messages modified a Ballot or State Leaf + const newStateRoot = poll.stateTree?.root; + const newBallotRoot = poll.ballotTree?.root; + + expect(newStateRoot?.toString()).not.to.be.eq(currentStateRoot?.toString()); + expect(newBallotRoot?.toString()).not.to.be.eq(currentBallotRoot.toString()); + } + }); + }); }); diff --git a/core/ts/Poll.ts b/core/ts/Poll.ts index 526d85d7d9..a63aee8998 100644 --- a/core/ts/Poll.ts +++ b/core/ts/Poll.ts @@ -581,19 +581,47 @@ export class Poll implements IPoll { currentBallots.unshift(ballot); currentBallotsPathElements.unshift(this.ballotTree!.genProof(Number(stateLeafIndex)).pathElements); - // add the first vote of this ballot - currentVoteWeights.unshift(ballot.votes[0]); - - // create a new quinary tree and add all votes we have so far - const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); - - // fill the vote option tree with the votes we have so far - for (let j = 0; j < this.ballots[0].votes.length; j += 1) { - vt.insert(ballot.votes[j]); + // @note we check that command.voteOptionIndex is valid so < maxVoteOptions + // this might be unnecessary but we do it to prevent a possible DoS attack + // from voters who could potentially encrypt a message in such as way that + // when decrypted it results in a valid state leaf index but an invalid vote option index + if (command.voteOptionIndex < this.maxValues.maxVoteOptions) { + currentVoteWeights.unshift(ballot.votes[Number(command.voteOptionIndex)]); + + // create a new quinary tree and add all votes we have so far + const vt = new IncrementalQuinTree( + this.treeDepths.voteOptionTreeDepth, + 0n, + STATE_TREE_ARITY, + hash5, + ); + + // fill the vote option tree with the votes we have so far + for (let j = 0; j < this.ballots[0].votes.length; j += 1) { + vt.insert(ballot.votes[j]); + } + + // get the path elements for the first vote leaf + currentVoteWeightsPathElements.unshift(vt.genProof(Number(command.voteOptionIndex)).pathElements); + } else { + currentVoteWeights.unshift(ballot.votes[0]); + + // create a new quinary tree and add all votes we have so far + const vt = new IncrementalQuinTree( + this.treeDepths.voteOptionTreeDepth, + 0n, + STATE_TREE_ARITY, + hash5, + ); + + // fill the vote option tree with the votes we have so far + for (let j = 0; j < this.ballots[0].votes.length; j += 1) { + vt.insert(ballot.votes[j]); + } + + // get the path elements for the first vote leaf + currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); } - - // get the path elements for the first vote leaf - currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); } else { // just use state leaf index 0 currentStateLeaves.unshift(this.stateLeaves[0].copy());