Skip to content

Commit

Permalink
Check if inputs are from field (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmikolajczyk41 authored Dec 27, 2024
1 parent 1189ddf commit d142b83
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 4 deletions.
24 changes: 23 additions & 1 deletion contracts/Shielder.sol
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ contract Shielder is
/// so we control the sum of balances instead.
uint256 public constant MAX_CONTRACT_BALANCE = MAX_TRANSACTION_AMOUNT;

/// The modulus of the field used in the circuits.
uint256 private constant FIELD_MODULUS =
21888242871839275222246405745257275088548364400416034343698204186575808495617;

// -- Events --
event NewAccountNative(
bytes3 contractVersion,
Expand Down Expand Up @@ -85,6 +89,7 @@ contract Shielder is
error AmountTooHigh();
error ContractBalanceLimitReached();
error WrongContractVersion(bytes3 actual, bytes3 expectedByCaller);
error NotAFieldElement();

modifier restrictContractVersion(bytes3 expectedByCaller) {
if (expectedByCaller != CONTRACT_VERSION) {
Expand Down Expand Up @@ -139,6 +144,8 @@ contract Shielder is
whenNotPaused
withinDepositLimit
restrictContractVersion(expectedContractVersion)
fieldElement(newNote)
fieldElement(idHash)
{
uint256 amount = msg.value;
if (nullifiers(idHash) != 0) revert DuplicatedNullifier();
Expand Down Expand Up @@ -179,6 +186,9 @@ contract Shielder is
whenNotPaused
withinDepositLimit
restrictContractVersion(expectedContractVersion)
fieldElement(idHiding)
fieldElement(oldNullifierHash)
fieldElement(newNote)
{
uint256 amount = msg.value;
if (amount == 0) revert ZeroAmount();
Expand Down Expand Up @@ -221,7 +231,14 @@ contract Shielder is
bytes calldata proof,
address relayerAddress,
uint256 relayerFee
) external whenNotPaused restrictContractVersion(expectedContractVersion) {
)
external
whenNotPaused
restrictContractVersion(expectedContractVersion)
fieldElement(idHiding)
fieldElement(oldNullifierHash)
fieldElement(newNote)
{
if (amount == 0) revert ZeroAmount();
if (amount <= relayerFee) revert FeeHigherThanAmount();
if (amount > MAX_TRANSACTION_AMOUNT) revert AmountTooHigh();
Expand Down Expand Up @@ -283,6 +300,11 @@ contract Shielder is
return uint256(uint160(addr));
}

modifier fieldElement(uint256 x) {
require(x < FIELD_MODULUS, NotAFieldElement());
_;
}

// -- Setters ---

/*
Expand Down
38 changes: 37 additions & 1 deletion crates/integration-tests/src/shielder/calls/deposit_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ pub fn invoke_call(
#[cfg(test)]
mod tests {

use std::assert_matches::assert_matches;
use std::{assert_matches::assert_matches, mem, str::FromStr};

use alloy_primitives::{Bytes, FixedBytes, U256};
use evm_utils::SuccessResult;
use halo2_proofs::halo2curves::ff::PrimeField;
use rstest::rstest;
use shielder_circuits::F;
use shielder_rust_sdk::{
account::ShielderAccount,
contract::ShielderContract::{
Expand All @@ -75,6 +77,7 @@ mod tests {

use crate::{
calls::deposit_native::{invoke_call, prepare_call},
recipient_balance_increased_by, relayer_balance_increased_by,
shielder::{
actor_balance_decreased_by,
calls::new_account_native,
Expand Down Expand Up @@ -249,6 +252,39 @@ mod tests {
assert!(actor_balance_decreased_by(&deployment, U256::from(15)))
}

#[rstest]
fn cannot_use_input_greater_than_field_modulus(mut deployment: Deployment) {
let mut shielder_account = new_account_native::create_account_and_call(
&mut deployment,
U256::from(1),
U256::from(10),
)
.unwrap();

let amount = U256::from(5);
let (mut calldata, _) = prepare_call(&mut deployment, &mut shielder_account, amount);
let mut swap_value = U256::from_str(F::MODULUS).unwrap();

mem::swap(&mut calldata.oldNullifierHash, &mut swap_value);
let result = invoke_call(&mut deployment, &mut shielder_account, amount, &calldata);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.oldNullifierHash, &mut swap_value);

mem::swap(&mut calldata.newNote, &mut swap_value);
let result = invoke_call(&mut deployment, &mut shielder_account, amount, &calldata);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.newNote, &mut swap_value);

mem::swap(&mut calldata.idHiding, &mut swap_value);
let result = invoke_call(&mut deployment, &mut shielder_account, amount, &calldata);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.idHiding, &mut swap_value);

assert!(actor_balance_decreased_by(&deployment, U256::from(10)));
assert!(recipient_balance_increased_by(&deployment, U256::from(0)));
assert!(relayer_balance_increased_by(&deployment, U256::from(0)))
}

#[rstest]
fn fails_if_merkle_root_does_not_exist(mut deployment: Deployment) {
let mut shielder_account = ShielderAccount::default();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ pub fn create_account_and_call(
#[cfg(test)]
mod tests {

use std::assert_matches::assert_matches;
use std::{assert_matches::assert_matches, mem, str::FromStr};

use alloy_primitives::{FixedBytes, U256};
use evm_utils::SuccessResult;
use halo2_proofs::halo2curves::ff::PrimeField;
use rstest::rstest;
use shielder_circuits::F;
use shielder_rust_sdk::{
account::ShielderAccount,
contract::ShielderContract::{
Expand All @@ -68,6 +70,7 @@ mod tests {
use crate::{
calls::new_account_native::{create_account_and_call, invoke_call, prepare_call},
deploy::deployment,
recipient_balance_increased_by, relayer_balance_increased_by,
shielder::{
actor_balance_decreased_by,
limits::{get_deposit_limit, set_deposit_limit},
Expand Down Expand Up @@ -147,6 +150,39 @@ mod tests {
assert!(actor_balance_decreased_by(&deployment, U256::from(10)))
}

#[rstest]
fn cannot_use_input_greater_than_field_modulus(mut deployment: Deployment) {
let mut shielder_account = ShielderAccount::new(U256::from(1));

let initial_amount = U256::from(10);
let mut calldata = prepare_call(&mut deployment, &mut shielder_account, initial_amount);
let mut swap_value = U256::from_str(F::MODULUS).unwrap();

mem::swap(&mut calldata.idHash, &mut swap_value);
let result = invoke_call(
&mut deployment,
&mut shielder_account,
initial_amount,
&calldata,
);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.idHash, &mut swap_value);

mem::swap(&mut calldata.newNote, &mut swap_value);
let result = invoke_call(
&mut deployment,
&mut shielder_account,
initial_amount,
&calldata,
);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.newNote, &mut swap_value);

assert!(actor_balance_decreased_by(&deployment, U256::from(0)));
assert!(recipient_balance_increased_by(&deployment, U256::from(0)));
assert!(relayer_balance_increased_by(&deployment, U256::from(0)))
}

#[rstest]
fn can_consume_entire_contract_balance_limit(mut deployment: Deployment) {
let mut shielder_account = ShielderAccount::default();
Expand Down
40 changes: 39 additions & 1 deletion crates/integration-tests/src/shielder/calls/withdraw_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ pub fn invoke_call(
#[cfg(test)]
mod tests {

use std::{assert_matches::assert_matches, str::FromStr};
use std::{assert_matches::assert_matches, mem, str::FromStr};

use alloy_primitives::{Address, Bytes, FixedBytes, U256};
use evm_utils::SuccessResult;
use halo2_proofs::halo2curves::ff::PrimeField;
use rstest::rstest;
use shielder_circuits::F;
use shielder_rust_sdk::{
account::ShielderAccount,
contract::ShielderContract::{
Expand Down Expand Up @@ -428,6 +430,42 @@ mod tests {
assert!(relayer_balance_increased_by(&deployment, U256::from(1)))
}

#[rstest]
fn cannot_use_input_greater_than_field_modulus(mut deployment: Deployment) {
let mut shielder_account = new_account_native::create_account_and_call(
&mut deployment,
U256::from(1),
U256::from(20),
)
.unwrap();

let (mut calldata, _) = prepare_call(
&mut deployment,
&mut shielder_account,
prepare_args(U256::from(5), U256::from(1)),
);
let mut swap_value = U256::from_str(F::MODULUS).unwrap();

mem::swap(&mut calldata.oldNullifierHash, &mut swap_value);
let result = invoke_call(&mut deployment, &mut shielder_account, &calldata);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.oldNullifierHash, &mut swap_value);

mem::swap(&mut calldata.newNote, &mut swap_value);
let result = invoke_call(&mut deployment, &mut shielder_account, &calldata);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.newNote, &mut swap_value);

mem::swap(&mut calldata.idHiding, &mut swap_value);
let result = invoke_call(&mut deployment, &mut shielder_account, &calldata);
assert_matches!(result, Err(ShielderContractErrors::NotAFieldElement(_)));
mem::swap(&mut calldata.idHiding, &mut swap_value);

assert!(actor_balance_decreased_by(&deployment, U256::from(20)));
assert!(recipient_balance_increased_by(&deployment, U256::from(0)));
assert!(relayer_balance_increased_by(&deployment, U256::from(0)))
}

#[rstest]
fn handles_withdraw_transfer_failure(mut deployment: Deployment) {
let mut shielder_account = new_account_native::create_account_and_call(
Expand Down
1 change: 1 addition & 0 deletions crates/shielder-rust-sdk/src/contract/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ sol! {
error LeafIsNotInTheTree();
error PrecompileCallFailed();
error WrongContractVersion(bytes3 actual, bytes3 expectedByCaller);
error NotAFieldElement();

function depositLimit() external view returns (uint256);

Expand Down

0 comments on commit d142b83

Please sign in to comment.