From e37dcb191b16bf7f859b1fdca69e2af33f676eb7 Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:21:54 +0000 Subject: [PATCH] fix(circuits): enforce use of stateIndex from message --- circuits/circom/processMessages.circom | 41 ++++++-- circuits/circom/processMessagesNonQv.circom | 30 ++++-- .../stateLeafAndBallotTransformer.circom | 2 +- circuits/ts/__tests__/ProcessMessages.test.ts | 31 +++++- cli/tests/e2e/e2e.test.ts | 48 ++++++--- cli/ts/commands/genProofs.ts | 4 +- core/ts/Poll.ts | 98 +++++++++++++++---- crypto/ts/index.ts | 2 +- domainobjs/ts/commands/PCommand.ts | 10 +- 9 files changed, 210 insertions(+), 56 deletions(-) diff --git a/circuits/circom/processMessages.circom b/circuits/circom/processMessages.circom index 9022a59121..d1c2f2db2c 100644 --- a/circuits/circom/processMessages.circom +++ b/circuits/circom/processMessages.circom @@ -48,6 +48,8 @@ template ProcessMessages( var STATE_LEAF_PUB_Y_IDX = 1; var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + + var N_BITS = 252; // Note that we sha256 hash some values from the contract, pass in the hash // as a public input, and pass in said values as private inputs. This saves @@ -277,6 +279,7 @@ template ProcessMessages( component processors[batchSize]; // topup type processor component processors2[batchSize]; + for (var i = batchSize - 1; i >= 0; i --) { // process it as vote type message processors[i] = ProcessOne(stateTreeDepth, voteOptionTreeDepth); @@ -349,6 +352,7 @@ template ProcessMessages( <== currentStateLeavesPathElements[i][j][k]; } } + // pick the correct result by msg type tmpStateRoot1[i] <== processors[i].newStateRoot * (2 - msgs[i][0]); tmpStateRoot2[i] <== processors2[i].newStateRoot * (msgs[i][0] - 1); @@ -378,6 +382,8 @@ template ProcessTopup(stateTreeDepth) { var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + var N_BITS = 252; + signal input msgType; signal input stateTreeIndex; signal input amount; @@ -395,9 +401,9 @@ template ProcessTopup(stateTreeDepth) { // msgType of topup command is 2 amt <== amount * (msgType - 1); index <== stateTreeIndex * (msgType - 1); - component validCreditBalance = LessEqThan(252); + component validCreditBalance = LessEqThan(N_BITS); // check stateIndex, if invalid index, set index and amount to zero - component validStateLeafIndex = LessEqThan(252); + component validStateLeafIndex = LessEqThan(N_BITS); validStateLeafIndex.in[0] <== index; validStateLeafIndex.in[1] <== numSignUps; @@ -462,6 +468,8 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + var N_BITS = 252; + signal input msgType; signal input numSignUps; signal input maxVoteOptions; @@ -526,7 +534,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { // ----------------------------------------------------------------------- // 2. If msgType = 0 and isValid is 0, generate indices for leaf 0 - // Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType + // Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType signal indexByType; signal tmpIndex1; signal tmpIndex2; @@ -534,14 +542,26 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { tmpIndex2 <== topupStateIndex * (msgType - 1); indexByType <== tmpIndex1 + tmpIndex2; - component stateIndexMux = Mux1(); - stateIndexMux.s <== transformer.isValid + msgType - 1; - stateIndexMux.c[0] <== 0; - stateIndexMux.c[1] <== indexByType; + // we can validate if the state index is within the numSignups + // if not, we use 0 + // this is because decryption of an invalid message + // might result in random packed vals + component validStateLeafIndex = SafeLessThan(N_BITS); + validStateLeafIndex.in[0] <== indexByType; + validStateLeafIndex.in[1] <== numSignUps; - component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth); - stateLeafPathIndices.in <== stateIndexMux.out; + // use a mux to pick the correct index + component indexMux = Mux1(); + indexMux.s <== validStateLeafIndex.out; + indexMux.c[0] <== 0; + indexMux.c[1] <== indexByType; + // @note that we expect a coordinator to send the state leaf corresponding to a message + // which specifies a valid state index. If this is not the case, the + // proof will fail to generate. + component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth); + stateLeafPathIndices.in <== indexMux.out; + // ----------------------------------------------------------------------- // 3. Verify that the original state leaf exists in the given state root component stateLeafQip = QuinTreeInclusionProof(stateTreeDepth); @@ -572,6 +592,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { ballotQip.path_elements[i][j] <== ballotPathElements[i][j]; } } + ballotQip.root === currentBallotRoot; // ----------------------------------------------------------------------- @@ -583,7 +604,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { b <== currentVoteWeight * currentVoteWeight; c <== cmdNewVoteWeight * cmdNewVoteWeight; - component enoughVoiceCredits = SafeGreaterEqThan(252); + component enoughVoiceCredits = SafeGreaterEqThan(N_BITS); enoughVoiceCredits.in[0] <== stateLeaf[STATE_LEAF_VOICE_CREDIT_BALANCE_IDX] + b; enoughVoiceCredits.in[1] <== c; diff --git a/circuits/circom/processMessagesNonQv.circom b/circuits/circom/processMessagesNonQv.circom index 82d114c43b..25fe5e36b9 100644 --- a/circuits/circom/processMessagesNonQv.circom +++ b/circuits/circom/processMessagesNonQv.circom @@ -49,6 +49,8 @@ template ProcessMessagesNonQv( var STATE_LEAF_PUB_Y_IDX = 1; var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + + var N_BITS = 252; // Note that we sha256 hash some values from the contract, pass in the hash // as a public input, and pass in said values as private inputs. This saves @@ -389,6 +391,8 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; + var N_BITS = 252; + signal input msgType; signal input numSignUps; signal input maxVoteOptions; @@ -461,13 +465,25 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { tmpIndex2 <== topupStateIndex * (msgType - 1); indexByType <== tmpIndex1 + tmpIndex2; - component stateIndexMux = Mux1(); - stateIndexMux.s <== transformer.isValid + msgType - 1; - stateIndexMux.c[0] <== 0; - stateIndexMux.c[1] <== indexByType; - + // we can validate if the state index is within the numSignups + // if not, we use 0 + // this is because decryption of an invalid message + // might result in random packed vals + component validStateLeafIndex = SafeLessThan(N_BITS); + validStateLeafIndex.in[0] <== indexByType; + validStateLeafIndex.in[1] <== numSignUps; + + // use a mux to pick the correct index + component indexMux = Mux1(); + indexMux.s <== validStateLeafIndex.out; + indexMux.c[0] <== 0; + indexMux.c[1] <== indexByType; + + // @note that we expect a coordinator to send the state leaf corresponding to a message + // which specifies a valid state index. If this is not the case, the + // proof will fail to generate. component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth); - stateLeafPathIndices.in <== stateIndexMux.out; + stateLeafPathIndices.in <== indexMux.out; // ----------------------------------------------------------------------- // 3. Verify that the original state leaf exists in the given state root @@ -510,7 +526,7 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) { b <== currentVoteWeight; c <== cmdNewVoteWeight; - component enoughVoiceCredits = SafeGreaterEqThan(252); + component enoughVoiceCredits = SafeGreaterEqThan(N_BITS); enoughVoiceCredits.in[0] <== stateLeaf[STATE_LEAF_VOICE_CREDIT_BALANCE_IDX] + b; enoughVoiceCredits.in[1] <== c; diff --git a/circuits/circom/stateLeafAndBallotTransformer.circom b/circuits/circom/stateLeafAndBallotTransformer.circom index 609c688f1f..79dde80d33 100644 --- a/circuits/circom/stateLeafAndBallotTransformer.circom +++ b/circuits/circom/stateLeafAndBallotTransformer.circom @@ -86,7 +86,7 @@ template StateLeafAndBallotTransformer() { messageValidator.voteWeight <== cmdNewVoteWeight; // if the message is valid then we swap out the public key - // we have to do this in two Mux one for pucKey[0] + // we have to do this in two Mux one for pubKey[0] // and one for pubKey[1] component newSlPubKey0Mux = Mux1(); newSlPubKey0Mux.s <== messageValidator.isValid; diff --git a/circuits/ts/__tests__/ProcessMessages.test.ts b/circuits/ts/__tests__/ProcessMessages.test.ts index 6959fe5a0b..f5ed57ce5d 100644 --- a/circuits/ts/__tests__/ProcessMessages.test.ts +++ b/circuits/ts/__tests__/ProcessMessages.test.ts @@ -2,7 +2,7 @@ import { expect } from "chai"; import { type WitnessTester } from "circomkit"; import { MaciState, Poll, packProcessMessageSmallVals, STATE_TREE_ARITY } from "maci-core"; import { hash5, IncrementalQuinTree, NOTHING_UP_MY_SLEEVE, AccQueue } from "maci-crypto"; -import { PrivKey, Keypair, PCommand, Message, Ballot } from "maci-domainobjs"; +import { PrivKey, Keypair, PCommand, Message, Ballot, PubKey } from "maci-domainobjs"; import { IProcessMessagesInputs } from "../types"; @@ -76,7 +76,7 @@ describe("ProcessMessage circuit", function test() { describe("1 user, 2 messages", () => { const maciState = new MaciState(STATE_TREE_DEPTH); const voteWeight = BigInt(9); - const voteOptionIndex = BigInt(0); + const voteOptionIndex = BigInt(1); let stateIndex: bigint; let pollId: bigint; let poll: Poll; @@ -101,6 +101,26 @@ describe("ProcessMessage circuit", function test() { poll = maciState.polls.get(pollId)!; poll.updatePoll(BigInt(maciState.stateLeaves.length)); + const nothing = new Message(1n, [ + 8370432830353022751713833565135785980866757267633941821328460903436894336785n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + 0n, + ]); + + const encP = new PubKey([ + 10457101036533406547632367118273992217979173478358440826365724437999023779287n, + 19824078218392094440610104313265183977899662750282163392862422243483260492317n, + ]); + + poll.publishMessage(nothing, encP); + // First command (valid) const command = new PCommand( stateIndex, // BigInt(1), @@ -144,6 +164,7 @@ describe("ProcessMessage circuit", function test() { STATE_TREE_ARITY, NOTHING_UP_MY_SLEEVE, ); + accumulatorQueue.enqueue(nothing.hash(encP)); accumulatorQueue.enqueue(message.hash(ecdhKeypair.pubKey)); accumulatorQueue.enqueue(message2.hash(ecdhKeypair2.pubKey)); accumulatorQueue.mergeSubRoots(0); @@ -187,7 +208,7 @@ describe("ProcessMessage circuit", function test() { BigInt(maxValues.maxVoteOptions), BigInt(poll.maciStateRef.numSignUps), 0, - 2, + 3, ); // Test the ProcessMessagesInputHasher circuit @@ -554,7 +575,7 @@ describe("ProcessMessage circuit", function test() { // Second batch is not a full batch const numMessages = messageBatchSize * NUM_BATCHES - 1; - for (let i = 0; i < numMessages; i += 1) { + for (let i = 0; i < 6; i += 1) { const command = new PCommand( BigInt(index), userKeypair.pubKey, @@ -572,7 +593,7 @@ describe("ProcessMessage circuit", function test() { selectedPoll?.publishMessage(message, ecdhKeypair.pubKey); } - for (let i = 0; i < NUM_BATCHES; i += 1) { + for (let i = 0; i < 2; i += 1) { const inputs = selectedPoll?.processMessages(id) as unknown as IProcessMessagesInputs; // eslint-disable-next-line no-await-in-loop const witness = await circuit.calculateWitness(inputs); diff --git a/cli/tests/e2e/e2e.test.ts b/cli/tests/e2e/e2e.test.ts index 3cb22d2e87..dae2712f6f 100644 --- a/cli/tests/e2e/e2e.test.ts +++ b/cli/tests/e2e/e2e.test.ts @@ -54,7 +54,7 @@ import { cleanVanilla, isArm } from "../utils"; /** Test scenarios: 1 signup, 1 message - 4 signups, 6 messages + 4 signups, 8 messages 5 signups, 1 message 8 signups, 10 messages 4 signups, 4 messages @@ -195,7 +195,7 @@ describe("e2e tests", function test() { }); }); - describe("4 signups, 6 messages", () => { + describe("4 signups, 8 messages", () => { after(() => { cleanVanilla(); }); @@ -217,7 +217,31 @@ describe("e2e tests", function test() { } }); - it("should publish six messages", async () => { + it("should publish eight messages", async () => { + await publish({ + pubkey: users[0].pubKey.serialize(), + stateIndex: 1n, + voteOptionIndex: 0n, + nonce: 2n, + pollId: 0n, + newVoteWeight: 4n, + maciContractAddress: maciAddresses.maciAddress, + salt: genRandomSalt(), + privateKey: users[0].privKey.serialize(), + signer, + }); + await publish({ + pubkey: users[0].pubKey.serialize(), + stateIndex: 1n, + voteOptionIndex: 0n, + nonce: 2n, + pollId: 0n, + newVoteWeight: 3n, + maciContractAddress: maciAddresses.maciAddress, + salt: genRandomSalt(), + privateKey: users[0].privKey.serialize(), + signer, + }); await publish({ pubkey: users[0].pubKey.serialize(), stateIndex: 1n, @@ -233,7 +257,7 @@ describe("e2e tests", function test() { await publish({ pubkey: users[1].pubKey.serialize(), stateIndex: 2n, - voteOptionIndex: 0n, + voteOptionIndex: 2n, nonce: 1n, pollId: 0n, newVoteWeight: 9n, @@ -245,7 +269,7 @@ describe("e2e tests", function test() { await publish({ pubkey: users[2].pubKey.serialize(), stateIndex: 3n, - voteOptionIndex: 0n, + voteOptionIndex: 2n, nonce: 1n, pollId: 0n, newVoteWeight: 9n, @@ -257,10 +281,10 @@ describe("e2e tests", function test() { await publish({ pubkey: users[3].pubKey.serialize(), stateIndex: 4n, - voteOptionIndex: 0n, - nonce: 1n, + voteOptionIndex: 2n, + nonce: 3n, pollId: 0n, - newVoteWeight: 9n, + newVoteWeight: 3n, maciContractAddress: maciAddresses.maciAddress, salt: genRandomSalt(), privateKey: users[3].privKey.serialize(), @@ -269,10 +293,10 @@ describe("e2e tests", function test() { await publish({ pubkey: users[3].pubKey.serialize(), stateIndex: 4n, - voteOptionIndex: 0n, - nonce: 1n, + voteOptionIndex: 2n, + nonce: 2n, pollId: 0n, - newVoteWeight: 9n, + newVoteWeight: 2n, maciContractAddress: maciAddresses.maciAddress, salt: genRandomSalt(), privateKey: users[3].privKey.serialize(), @@ -281,7 +305,7 @@ describe("e2e tests", function test() { await publish({ pubkey: users[3].pubKey.serialize(), stateIndex: 4n, - voteOptionIndex: 0n, + voteOptionIndex: 1n, nonce: 1n, pollId: 0n, newVoteWeight: 9n, diff --git a/cli/ts/commands/genProofs.ts b/cli/ts/commands/genProofs.ts index d92ea3f982..603b507ede 100644 --- a/cli/ts/commands/genProofs.ts +++ b/cli/ts/commands/genProofs.ts @@ -279,6 +279,7 @@ export const genProofs = async ({ while (poll.hasUnprocessedMessages()) { // process messages in batches const circuitInputs = poll.processMessages(pollId, useQuadraticVoting, quiet) as unknown as CircuitInputs; + try { // generate the proof for this batch // eslint-disable-next-line no-await-in-loop @@ -290,11 +291,12 @@ export const genProofs = async ({ witnessExePath: processWitgen, wasmPath: processWasm, }); + // verify it // eslint-disable-next-line no-await-in-loop const isValid = await verifyProof(r.publicSignals, r.proof, processVk); if (!isValid) { - logError("Error: generated an invalid proof"); + throw new Error("Generated an invalid proof"); } const thisProof = { diff --git a/core/ts/Poll.ts b/core/ts/Poll.ts index 7d8dfddf69..526d85d7d9 100644 --- a/core/ts/Poll.ts +++ b/core/ts/Poll.ts @@ -554,21 +554,62 @@ export class Poll implements IPoll { console.log(`Error at message index ${idx} - ${e.message}`); } - // Since the command is invalid, use a blank state leaf - currentStateLeaves.unshift(this.stateLeaves[0].copy()); - currentStateLeavesPathElements.unshift(this.stateTree!.genProof(0).pathElements); - // since the command is invalid we use the blank ballot - currentBallots.unshift(this.ballots[0].copy()); - currentBallotsPathElements.unshift(this.ballotTree!.genProof(0).pathElements); - - // Since the command is invalid, we use a zero vote weight - currentVoteWeights.unshift(this.ballots[0].votes[0]); - - // create a new quinary tree and add an empty vote - const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); - vt.insert(this.ballots[0].votes[0]); - // get the path elements for this empty vote weight leaf - currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); + // @note we want to send the correct state leaf to the circuit + // even if a message is invalid + // this way if a message is invalid we can still generate a proof of processing + // we also want to prevent a DoS attack by a voter + // which sends a message that when force decrypted on the circuit + // results in a valid state index thus forcing the circuit to look + // for a valid state leaf, and failing to generate a proof + + // gen shared key + const sharedKey = Keypair.genEcdhSharedKey(this.coordinatorKeypair.privKey, encPubKey); + + // force decrypt it + const { command } = PCommand.decrypt(message, sharedKey, true); + + // cache state leaf index + const stateLeafIndex = command.stateIndex; + + // if the state leaf index is valid then use it + if (stateLeafIndex < this.stateLeaves.length) { + currentStateLeaves.unshift(this.stateLeaves[Number(stateLeafIndex)].copy()); + currentStateLeavesPathElements.unshift(this.stateTree!.genProof(Number(stateLeafIndex)).pathElements); + + // copy the ballot + const ballot = this.ballots[Number(stateLeafIndex)].copy(); + currentBallots.unshift(ballot); + currentBallotsPathElements.unshift(this.ballotTree!.genProof(Number(stateLeafIndex)).pathElements); + + // add the first vote of this ballot + currentVoteWeights.unshift(ballot.votes[0]); + + // create a new quinary tree and add all votes we have so far + const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); + + // fill the vote option tree with the votes we have so far + for (let j = 0; j < this.ballots[0].votes.length; j += 1) { + vt.insert(ballot.votes[j]); + } + + // get the path elements for the first vote leaf + currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); + } else { + // just use state leaf index 0 + currentStateLeaves.unshift(this.stateLeaves[0].copy()); + currentStateLeavesPathElements.unshift(this.stateTree!.genProof(0).pathElements); + currentBallots.unshift(this.ballots[0].copy()); + currentBallotsPathElements.unshift(this.ballotTree!.genProof(0).pathElements); + + // Since the command is invalid, we use a zero vote weight + currentVoteWeights.unshift(this.ballots[0].votes[0]); + + // create a new quinary tree and add an empty vote + const vt = new IncrementalQuinTree(this.treeDepths.voteOptionTreeDepth, 0n, STATE_TREE_ARITY, hash5); + vt.insert(this.ballots[0].votes[0]); + // get the path elements for this empty vote weight leaf + currentVoteWeightsPathElements.unshift(vt.genProof(0).pathElements); + } } else { throw e; } @@ -704,13 +745,33 @@ export class Poll implements IPoll { // fill the msgs array with a copy of the messages we have // plus empty messages to fill the batch + + // @note create a message with state index 0 to add as padding + // this way the message will look for state leaf 0 + // and no effect will take place + + // create a random key + const key = new Keypair(); + // gen ecdh key + const ecdh = Keypair.genEcdhSharedKey(key.privKey, this.coordinatorKeypair.pubKey); + // create an empty command with state index 0n + const emptyCommand = new PCommand(0n, key.pubKey, 0n, 0n, 0n, 0n, 0n); + + // encrypt it + const msg = emptyCommand.encrypt(emptyCommand.sign(key.privKey), ecdh); + + // copy the messages to a new array let msgs = this.messages.map((x) => x.asCircuitInputs()); + + // pad with our state index 0 message while (msgs.length % messageBatchSize > 0) { - msgs.push(msgs[msgs.length - 1]); + msgs.push(msg.asCircuitInputs()); } + // we only take the messages we need for this batch msgs = msgs.slice(index, index + messageBatchSize); + // insert zero value in the message tree as padding while (this.messageTree.nextIndex < index + messageBatchSize) { this.messageTree.insert(this.messageTree.zeroValue); } @@ -718,6 +779,7 @@ export class Poll implements IPoll { // generate the path to the subroot of the message tree for this batch const messageSubrootPath = this.messageTree.genSubrootProof(index, index + messageBatchSize); + // verify it assert(this.messageTree.verifyProof(messageSubrootPath), "The message subroot path is invalid"); // validate that the batch index is correct, if not fix it @@ -730,11 +792,13 @@ export class Poll implements IPoll { // copy the public keys, pad the array with the last keys if needed let encPubKeys = this.encPubKeys.map((x) => x.copy()); while (encPubKeys.length % messageBatchSize > 0) { - encPubKeys.push(encPubKeys[encPubKeys.length - 1]); + // pad with the public key used to encrypt the message with state index 0 (padding) + encPubKeys.push(key.pubKey.copy()); } // then take the ones part of this batch encPubKeys = encPubKeys.slice(index, index + messageBatchSize); + // cache tree roots const msgRoot = this.messageTree.root; const currentStateRoot = this.stateTree!.root; const currentBallotRoot = this.ballotTree!.root; diff --git a/crypto/ts/index.ts b/crypto/ts/index.ts index 060b5f44e9..45b1f5fa04 100644 --- a/crypto/ts/index.ts +++ b/crypto/ts/index.ts @@ -23,7 +23,7 @@ export { G1Point, G2Point, genRandomBabyJubValue } from "./babyjub"; export { sha256Hash, hashLeftRight, hashN, hash2, hash3, hash4, hash5, hash13, hashOne } from "./hashing"; -export { poseidonDecrypt, poseidonEncrypt } from "@zk-kit/poseidon-cipher"; +export { poseidonDecrypt, poseidonDecryptWithoutCheck, poseidonEncrypt } from "@zk-kit/poseidon-cipher"; export { verifySignature, signMessage as sign } from "@zk-kit/eddsa-poseidon"; diff --git a/domainobjs/ts/commands/PCommand.ts b/domainobjs/ts/commands/PCommand.ts index b7798596f2..99194fcfb2 100644 --- a/domainobjs/ts/commands/PCommand.ts +++ b/domainobjs/ts/commands/PCommand.ts @@ -9,6 +9,7 @@ import { type Ciphertext, type EcdhSharedKey, type Point, + poseidonDecryptWithoutCheck, } from "maci-crypto"; import assert from "assert"; @@ -172,11 +173,16 @@ export class PCommand implements ICommand { /** * Decrypts a Message to produce a Command. + * @dev You can force decrypt the message by setting `force` to true. + * This is useful in case you don't want an invalid message to throw an error. * @param {Message} message - the message to decrypt * @param {EcdhSharedKey} sharedKey - the shared key to use for decryption + * @param {boolean} force - whether to force decryption or not */ - static decrypt = (message: Message, sharedKey: EcdhSharedKey): IDecryptMessage => { - const decrypted = poseidonDecrypt(message.data, sharedKey, BigInt(0), 7); + static decrypt = (message: Message, sharedKey: EcdhSharedKey, force = false): IDecryptMessage => { + const decrypted = force + ? poseidonDecryptWithoutCheck(message.data, sharedKey, BigInt(0), 7) + : poseidonDecrypt(message.data, sharedKey, BigInt(0), 7); const p = BigInt(decrypted[0].toString());