Skip to content

Commit

Permalink
fix(circuits): enforce use of stateIndex from message
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrlc03 committed Feb 15, 2024
1 parent e2e1031 commit e37dcb1
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 56 deletions.
41 changes: 31 additions & 10 deletions circuits/circom/processMessages.circom
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ template ProcessMessages(
var STATE_LEAF_PUB_Y_IDX = 1;
var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2;
var STATE_LEAF_TIMESTAMP_IDX = 3;

var N_BITS = 252;

// Note that we sha256 hash some values from the contract, pass in the hash
// as a public input, and pass in said values as private inputs. This saves
Expand Down Expand Up @@ -277,6 +279,7 @@ template ProcessMessages(
component processors[batchSize];
// topup type processor
component processors2[batchSize];

for (var i = batchSize - 1; i >= 0; i --) {
// process it as vote type message
processors[i] = ProcessOne(stateTreeDepth, voteOptionTreeDepth);
Expand Down Expand Up @@ -349,6 +352,7 @@ template ProcessMessages(
<== currentStateLeavesPathElements[i][j][k];
}
}

// pick the correct result by msg type
tmpStateRoot1[i] <== processors[i].newStateRoot * (2 - msgs[i][0]);
tmpStateRoot2[i] <== processors2[i].newStateRoot * (msgs[i][0] - 1);
Expand Down Expand Up @@ -378,6 +382,8 @@ template ProcessTopup(stateTreeDepth) {
var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2;
var STATE_LEAF_TIMESTAMP_IDX = 3;

var N_BITS = 252;

signal input msgType;
signal input stateTreeIndex;
signal input amount;
Expand All @@ -395,9 +401,9 @@ template ProcessTopup(stateTreeDepth) {
// msgType of topup command is 2
amt <== amount * (msgType - 1);
index <== stateTreeIndex * (msgType - 1);
component validCreditBalance = LessEqThan(252);
component validCreditBalance = LessEqThan(N_BITS);
// check stateIndex, if invalid index, set index and amount to zero
component validStateLeafIndex = LessEqThan(252);
component validStateLeafIndex = LessEqThan(N_BITS);
validStateLeafIndex.in[0] <== index;
validStateLeafIndex.in[1] <== numSignUps;

Expand Down Expand Up @@ -462,6 +468,8 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) {
var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2;
var STATE_LEAF_TIMESTAMP_IDX = 3;

var N_BITS = 252;

signal input msgType;
signal input numSignUps;
signal input maxVoteOptions;
Expand Down Expand Up @@ -526,22 +534,34 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) {

// -----------------------------------------------------------------------
// 2. If msgType = 0 and isValid is 0, generate indices for leaf 0
// Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType
// Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType
signal indexByType;
signal tmpIndex1;
signal tmpIndex2;
tmpIndex1 <== cmdStateIndex * (2 - msgType);
tmpIndex2 <== topupStateIndex * (msgType - 1);
indexByType <== tmpIndex1 + tmpIndex2;

component stateIndexMux = Mux1();
stateIndexMux.s <== transformer.isValid + msgType - 1;
stateIndexMux.c[0] <== 0;
stateIndexMux.c[1] <== indexByType;
// we can validate if the state index is within the numSignups
// if not, we use 0
// this is because decryption of an invalid message
// might result in random packed vals
component validStateLeafIndex = SafeLessThan(N_BITS);
validStateLeafIndex.in[0] <== indexByType;
validStateLeafIndex.in[1] <== numSignUps;

component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth);
stateLeafPathIndices.in <== stateIndexMux.out;
// use a mux to pick the correct index
component indexMux = Mux1();
indexMux.s <== validStateLeafIndex.out;
indexMux.c[0] <== 0;
indexMux.c[1] <== indexByType;

// @note that we expect a coordinator to send the state leaf corresponding to a message
// which specifies a valid state index. If this is not the case, the
// proof will fail to generate.
component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth);
stateLeafPathIndices.in <== indexMux.out;

// -----------------------------------------------------------------------
// 3. Verify that the original state leaf exists in the given state root
component stateLeafQip = QuinTreeInclusionProof(stateTreeDepth);
Expand Down Expand Up @@ -572,6 +592,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) {
ballotQip.path_elements[i][j] <== ballotPathElements[i][j];
}
}

ballotQip.root === currentBallotRoot;

// -----------------------------------------------------------------------
Expand All @@ -583,7 +604,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) {
b <== currentVoteWeight * currentVoteWeight;
c <== cmdNewVoteWeight * cmdNewVoteWeight;

component enoughVoiceCredits = SafeGreaterEqThan(252);
component enoughVoiceCredits = SafeGreaterEqThan(N_BITS);
enoughVoiceCredits.in[0] <== stateLeaf[STATE_LEAF_VOICE_CREDIT_BALANCE_IDX] + b;
enoughVoiceCredits.in[1] <== c;

Expand Down
30 changes: 23 additions & 7 deletions circuits/circom/processMessagesNonQv.circom
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ template ProcessMessagesNonQv(
var STATE_LEAF_PUB_Y_IDX = 1;
var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2;
var STATE_LEAF_TIMESTAMP_IDX = 3;

var N_BITS = 252;

// Note that we sha256 hash some values from the contract, pass in the hash
// as a public input, and pass in said values as private inputs. This saves
Expand Down Expand Up @@ -389,6 +391,8 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) {
var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2;
var STATE_LEAF_TIMESTAMP_IDX = 3;

var N_BITS = 252;

signal input msgType;
signal input numSignUps;
signal input maxVoteOptions;
Expand Down Expand Up @@ -461,13 +465,25 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) {
tmpIndex2 <== topupStateIndex * (msgType - 1);
indexByType <== tmpIndex1 + tmpIndex2;

component stateIndexMux = Mux1();
stateIndexMux.s <== transformer.isValid + msgType - 1;
stateIndexMux.c[0] <== 0;
stateIndexMux.c[1] <== indexByType;

// we can validate if the state index is within the numSignups
// if not, we use 0
// this is because decryption of an invalid message
// might result in random packed vals
component validStateLeafIndex = SafeLessThan(N_BITS);
validStateLeafIndex.in[0] <== indexByType;
validStateLeafIndex.in[1] <== numSignUps;

// use a mux to pick the correct index
component indexMux = Mux1();
indexMux.s <== validStateLeafIndex.out;
indexMux.c[0] <== 0;
indexMux.c[1] <== indexByType;

// @note that we expect a coordinator to send the state leaf corresponding to a message
// which specifies a valid state index. If this is not the case, the
// proof will fail to generate.
component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth);
stateLeafPathIndices.in <== stateIndexMux.out;
stateLeafPathIndices.in <== indexMux.out;

// -----------------------------------------------------------------------
// 3. Verify that the original state leaf exists in the given state root
Expand Down Expand Up @@ -510,7 +526,7 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) {
b <== currentVoteWeight;
c <== cmdNewVoteWeight;

component enoughVoiceCredits = SafeGreaterEqThan(252);
component enoughVoiceCredits = SafeGreaterEqThan(N_BITS);
enoughVoiceCredits.in[0] <== stateLeaf[STATE_LEAF_VOICE_CREDIT_BALANCE_IDX] + b;
enoughVoiceCredits.in[1] <== c;

Expand Down
2 changes: 1 addition & 1 deletion circuits/circom/stateLeafAndBallotTransformer.circom
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ template StateLeafAndBallotTransformer() {
messageValidator.voteWeight <== cmdNewVoteWeight;

// if the message is valid then we swap out the public key
// we have to do this in two Mux one for pucKey[0]
// we have to do this in two Mux one for pubKey[0]
// and one for pubKey[1]
component newSlPubKey0Mux = Mux1();
newSlPubKey0Mux.s <== messageValidator.isValid;
Expand Down
31 changes: 26 additions & 5 deletions circuits/ts/__tests__/ProcessMessages.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect } from "chai";
import { type WitnessTester } from "circomkit";
import { MaciState, Poll, packProcessMessageSmallVals, STATE_TREE_ARITY } from "maci-core";
import { hash5, IncrementalQuinTree, NOTHING_UP_MY_SLEEVE, AccQueue } from "maci-crypto";
import { PrivKey, Keypair, PCommand, Message, Ballot } from "maci-domainobjs";
import { PrivKey, Keypair, PCommand, Message, Ballot, PubKey } from "maci-domainobjs";

import { IProcessMessagesInputs } from "../types";

Expand Down Expand Up @@ -76,7 +76,7 @@ describe("ProcessMessage circuit", function test() {
describe("1 user, 2 messages", () => {
const maciState = new MaciState(STATE_TREE_DEPTH);
const voteWeight = BigInt(9);
const voteOptionIndex = BigInt(0);
const voteOptionIndex = BigInt(1);
let stateIndex: bigint;
let pollId: bigint;
let poll: Poll;
Expand All @@ -101,6 +101,26 @@ describe("ProcessMessage circuit", function test() {
poll = maciState.polls.get(pollId)!;
poll.updatePoll(BigInt(maciState.stateLeaves.length));

const nothing = new Message(1n, [
8370432830353022751713833565135785980866757267633941821328460903436894336785n,
0n,
0n,
0n,
0n,
0n,
0n,
0n,
0n,
0n,
]);

const encP = new PubKey([
10457101036533406547632367118273992217979173478358440826365724437999023779287n,
19824078218392094440610104313265183977899662750282163392862422243483260492317n,
]);

poll.publishMessage(nothing, encP);

// First command (valid)
const command = new PCommand(
stateIndex, // BigInt(1),
Expand Down Expand Up @@ -144,6 +164,7 @@ describe("ProcessMessage circuit", function test() {
STATE_TREE_ARITY,
NOTHING_UP_MY_SLEEVE,
);
accumulatorQueue.enqueue(nothing.hash(encP));
accumulatorQueue.enqueue(message.hash(ecdhKeypair.pubKey));
accumulatorQueue.enqueue(message2.hash(ecdhKeypair2.pubKey));
accumulatorQueue.mergeSubRoots(0);
Expand Down Expand Up @@ -187,7 +208,7 @@ describe("ProcessMessage circuit", function test() {
BigInt(maxValues.maxVoteOptions),
BigInt(poll.maciStateRef.numSignUps),
0,
2,
3,
);

// Test the ProcessMessagesInputHasher circuit
Expand Down Expand Up @@ -554,7 +575,7 @@ describe("ProcessMessage circuit", function test() {

// Second batch is not a full batch
const numMessages = messageBatchSize * NUM_BATCHES - 1;
for (let i = 0; i < numMessages; i += 1) {
for (let i = 0; i < 6; i += 1) {
const command = new PCommand(
BigInt(index),
userKeypair.pubKey,
Expand All @@ -572,7 +593,7 @@ describe("ProcessMessage circuit", function test() {
selectedPoll?.publishMessage(message, ecdhKeypair.pubKey);
}

for (let i = 0; i < NUM_BATCHES; i += 1) {
for (let i = 0; i < 2; i += 1) {
const inputs = selectedPoll?.processMessages(id) as unknown as IProcessMessagesInputs;
// eslint-disable-next-line no-await-in-loop
const witness = await circuit.calculateWitness(inputs);
Expand Down
48 changes: 36 additions & 12 deletions cli/tests/e2e/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import { cleanVanilla, isArm } from "../utils";
/**
Test scenarios:
1 signup, 1 message
4 signups, 6 messages
4 signups, 8 messages
5 signups, 1 message
8 signups, 10 messages
4 signups, 4 messages
Expand Down Expand Up @@ -195,7 +195,7 @@ describe("e2e tests", function test() {
});
});

describe("4 signups, 6 messages", () => {
describe("4 signups, 8 messages", () => {
after(() => {
cleanVanilla();
});
Expand All @@ -217,7 +217,31 @@ describe("e2e tests", function test() {
}
});

it("should publish six messages", async () => {
it("should publish eight messages", async () => {
await publish({
pubkey: users[0].pubKey.serialize(),
stateIndex: 1n,
voteOptionIndex: 0n,
nonce: 2n,
pollId: 0n,
newVoteWeight: 4n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[0].privKey.serialize(),
signer,
});
await publish({
pubkey: users[0].pubKey.serialize(),
stateIndex: 1n,
voteOptionIndex: 0n,
nonce: 2n,
pollId: 0n,
newVoteWeight: 3n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[0].privKey.serialize(),
signer,
});
await publish({
pubkey: users[0].pubKey.serialize(),
stateIndex: 1n,
Expand All @@ -233,7 +257,7 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[1].pubKey.serialize(),
stateIndex: 2n,
voteOptionIndex: 0n,
voteOptionIndex: 2n,
nonce: 1n,
pollId: 0n,
newVoteWeight: 9n,
Expand All @@ -245,7 +269,7 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[2].pubKey.serialize(),
stateIndex: 3n,
voteOptionIndex: 0n,
voteOptionIndex: 2n,
nonce: 1n,
pollId: 0n,
newVoteWeight: 9n,
Expand All @@ -257,10 +281,10 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[3].pubKey.serialize(),
stateIndex: 4n,
voteOptionIndex: 0n,
nonce: 1n,
voteOptionIndex: 2n,
nonce: 3n,
pollId: 0n,
newVoteWeight: 9n,
newVoteWeight: 3n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[3].privKey.serialize(),
Expand All @@ -269,10 +293,10 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[3].pubKey.serialize(),
stateIndex: 4n,
voteOptionIndex: 0n,
nonce: 1n,
voteOptionIndex: 2n,
nonce: 2n,
pollId: 0n,
newVoteWeight: 9n,
newVoteWeight: 2n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[3].privKey.serialize(),
Expand All @@ -281,7 +305,7 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[3].pubKey.serialize(),
stateIndex: 4n,
voteOptionIndex: 0n,
voteOptionIndex: 1n,
nonce: 1n,
pollId: 0n,
newVoteWeight: 9n,
Expand Down
4 changes: 3 additions & 1 deletion cli/ts/commands/genProofs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ export const genProofs = async ({
while (poll.hasUnprocessedMessages()) {
// process messages in batches
const circuitInputs = poll.processMessages(pollId, useQuadraticVoting, quiet) as unknown as CircuitInputs;

try {
// generate the proof for this batch
// eslint-disable-next-line no-await-in-loop
Expand All @@ -290,11 +291,12 @@ export const genProofs = async ({
witnessExePath: processWitgen,
wasmPath: processWasm,
});

// verify it
// eslint-disable-next-line no-await-in-loop
const isValid = await verifyProof(r.publicSignals, r.proof, processVk);
if (!isValid) {
logError("Error: generated an invalid proof");
throw new Error("Generated an invalid proof");
}

const thisProof = {
Expand Down
Loading

0 comments on commit e37dcb1

Please sign in to comment.