Skip to content

Commit

Permalink
Merge pull request iden3#71 from victorcrrd/plonk-custom-gates-rethink
Browse files Browse the repository at this point in the history
plonk's custom gates in circom
  • Loading branch information
clararod9 authored Jun 14, 2022
2 parents bd1a102 + e507d72 commit de876d0
Show file tree
Hide file tree
Showing 16 changed files with 482 additions and 71 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ parser/target/
program_structure/target/
type_analysis/target/
.idea/
.vscode/
.DS_Store
Cargo.lock
8 changes: 4 additions & 4 deletions circom/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ fn start() -> Result<(), ()> {
c_flag: user_input.c_flag(),
wasm_flag: user_input.wasm_flag(),
wat_flag: user_input.wat_flag(),
js_folder: user_input.js_folder().to_string(),
wasm_name: user_input.wasm_name().to_string(),
c_folder: user_input.c_folder().to_string(),
c_run_name: user_input.c_run_name().to_string(),
js_folder: user_input.js_folder().to_string(),
wasm_name: user_input.wasm_name().to_string(),
c_folder: user_input.c_folder().to_string(),
c_run_name: user_input.c_run_name().to_string(),
c_file: user_input.c_file().to_string(),
dat_file: user_input.dat_file().to_string(),
wat_file: user_input.wat_file().to_string(),
Expand Down
5 changes: 4 additions & 1 deletion constraint_generation/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ fn execute_signal_declaration(
) {
use SignalType::*;
if let Option::Some(node) = actual_node {
node.add_ordered_signal(signal_name, dimensions);
match signal_type {
Input => {
environment_shortcut_add_input(environment, signal_name, dimensions);
Expand Down Expand Up @@ -984,6 +985,7 @@ fn execute_template_call(
debug_assert!(runtime.block_type == BlockType::Known);
let is_main = std::mem::replace(&mut runtime.public_inputs, vec![]);
let is_parallel = program_archive.get_template_data(id).is_parallel();
let is_custom_gate = program_archive.get_template_data(id).is_custom_gate();
let args_names = program_archive.get_template_data(id).get_name_of_params();
let template_body = program_archive.get_template_data(id).get_body_as_vec();
let mut args_to_values = BTreeMap::new();
Expand All @@ -1010,7 +1012,8 @@ fn execute_template_call(
instantiation_name,
args_to_values,
code,
is_parallel
is_parallel,
is_custom_gate
));
let ret = execute_sequence_of_statements(
template_body,
Expand Down
49 changes: 46 additions & 3 deletions constraint_generation/src/execution_data/executed_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ pub struct ExecutedTemplate {
pub report_name: String,
pub inputs: SignalCollector,
pub outputs: SignalCollector,
pub constraints: Vec<Constraint>,
pub intermediates: SignalCollector,
pub ordered_signals: Vec<String>,
pub constraints: Vec<Constraint>,
pub components: ComponentCollector,
pub number_of_components: usize,
pub public_inputs: HashSet<String>,
pub parameter_instances: ParameterContext,
pub is_parallel: bool,
pub has_parallel_sub_cmp: bool,
pub is_custom_gate: bool,
connexions: Vec<Connexion>,
}

Expand All @@ -40,22 +42,25 @@ impl ExecutedTemplate {
instance: ParameterContext,
code: Statement,
is_parallel: bool,
is_custom_gate: bool
) -> ExecutedTemplate {
let public_inputs: HashSet<_> = public.iter().cloned().collect();
ExecutedTemplate {
report_name,
public_inputs,
is_parallel,
has_parallel_sub_cmp: false,
is_custom_gate,
code: code.clone(),
template_name: name,
parameter_instances: instance,
inputs: SignalCollector::new(),
outputs: SignalCollector::new(),
intermediates: SignalCollector::new(),
ordered_signals: Vec::new(),
constraints: Vec::new(),
components: ComponentCollector::new(),
number_of_components: 0,
constraints: Vec::new(),
connexions: Vec::new(),
}
}
Expand All @@ -82,6 +87,27 @@ impl ExecutedTemplate {
self.intermediates.push((intermediate_name.to_string(), dimensions.to_vec()));
}

pub fn add_ordered_signal(&mut self, signal_name: &str, dimensions: &[usize]) {
fn generate_symbols(name: String, current: usize, dimensions: &[usize]) -> Vec<String> {
let symbol_name = name.clone();
if current == dimensions.len() {
vec![name]
} else {
let mut generated_symbols = vec![];
let mut index = 0;
while index < dimensions[current] {
let new_name = format!("{}[{}]", symbol_name, index);
generated_symbols.append(&mut generate_symbols(new_name, current + 1, dimensions));
index += 1;
}
generated_symbols
}
}
for signal in generate_symbols(signal_name.to_string(), 0, dimensions) {
self.ordered_signals.push(signal);
}
}

pub fn add_component(&mut self, component_name: &str, dimensions: &[usize]) {
self.components.push((component_name.to_string(), dimensions.to_vec()));
self.number_of_components += dimensions.iter().fold(1, |p, c| p * (*c));
Expand Down Expand Up @@ -112,7 +138,24 @@ impl ExecutedTemplate {
}

pub fn insert_in_dag(&mut self, dag: &mut DAG) {
dag.add_node(self.report_name.clone(), self.is_parallel);
let parameters = {
let mut parameters = vec![];
for (_, data) in self.parameter_instances.clone() {
let (_, values) = data.destruct();
for value in as_big_int(values) {
parameters.push(value);
}
}
parameters
}; // repeated code from function build_arguments in export_to_circuit

dag.add_node(
self.report_name.clone(),
parameters,
self.ordered_signals.clone(), // pensar si calcularlo en este momento para no hacer clone
self.is_parallel,
self.is_custom_gate
);
self.build_signals(dag);
self.build_connexions(dag);
self.build_constraints(dag);
Expand Down
12 changes: 7 additions & 5 deletions constraint_list/src/constraint_simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
relevant
};

let linear_substitutions = if apply_linear {
let linear_substitutions = if remove_unused {
let now = SystemTime::now();
let (subs, mut cons) = linear_simplification(
&mut substitution_log,
Expand Down Expand Up @@ -563,15 +563,17 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
crate::state_utils::empty_encoding_constraints(&mut smp.dag_encoding);
let _dur = now.elapsed().unwrap().as_millis();
// println!("Storages built in {} ms", dur);
no_rounds -= 1;
if remove_unused {
no_rounds -= 1;
}
(with_linear, storage)
};

let mut round_id = 0;
let _ = round_id;
let mut linear = with_linear;
let mut apply_round = apply_linear && no_rounds > 0 && !linear.is_empty();
let mut non_linear_map = if apply_round || remove_unused{
let mut apply_round = remove_unused && no_rounds > 0 && !linear.is_empty();
let mut non_linear_map = if apply_round || remove_unused {
// println!("Building non-linear map");
let now = SystemTime::now();
let non_linear_map = build_non_linear_signal_map(&constraint_storage);
Expand Down Expand Up @@ -615,7 +617,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
}

for constraint in linear {
if remove_unused{
if remove_unused {
let signals = C::take_cloned_signals(&constraint);
let c_id = constraint_storage.add_constraint(constraint);
for signal in signals {
Expand Down
4 changes: 4 additions & 0 deletions constraint_list/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ pub struct SignalInfo {
}
pub struct EncodingNode {
pub id: usize,
pub name: String,
pub parameters: Vec<BigInt>,
pub signals: Vec<SignalInfo>,
pub ordered_signals: Vec<usize>,
pub non_linear: LinkedList<C>,
pub is_custom_gate: bool,
}

pub struct EncodingEdge {
Expand Down
71 changes: 68 additions & 3 deletions constraint_list/src/r1cs_porting.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{ConstraintList, C};
use constraint_writers::r1cs_writer::{ConstraintSection, HeaderData, R1CSWriter, SignalSection};
use super::{ConstraintList, C, EncodingIterator, SignalMap};
use constraint_writers::r1cs_writer::{ConstraintSection, CustomGatesAppliedData, HeaderData, R1CSWriter, SignalSection};

pub fn port_r1cs(list: &ConstraintList, output: &str) -> Result<(), ()> {
use constraint_writers::log_writer::Log;
Expand Down Expand Up @@ -45,7 +45,72 @@ pub fn port_r1cs(list: &ConstraintList, output: &str) -> Result<(), ()> {
for id in list.get_witness_as_vec() {
SignalSection::write_signal_usize(&mut signal_section, id)?;
}
let _r1cs = signal_section.end_section()?;
let r1cs = signal_section.end_section()?;

let mut custom_gates_used_section = R1CSWriter::start_custom_gates_used_section(r1cs)?;
let (usage_data, occurring_order) = {
let mut usage_data = vec![];
let mut occurring_order = vec![];
for node in &list.dag_encoding.nodes {
if node.is_custom_gate {
let mut name = node.name.clone();
occurring_order.push(name.clone());
while name.pop() != Some('(') {};
usage_data.push((name, node.parameters.clone()));
}
}
(usage_data, occurring_order)
};
custom_gates_used_section.write_custom_gates_usages(usage_data)?;
let r1cs = custom_gates_used_section.end_section()?;

let mut custom_gates_applied_section = R1CSWriter::start_custom_gates_applied_section(r1cs)?;
let application_data = {
fn find_indexes(
occurring_order: Vec<String>,
application_data: Vec<(String, Vec<usize>)>
) -> CustomGatesAppliedData {
let mut new_application_data = vec![];
for (custom_gate_name, signals) in application_data {
let mut index = 0;
while occurring_order[index] != custom_gate_name {
index += 1;
}
new_application_data.push((index, signals));
}
new_application_data
}

fn iterate(
iterator: EncodingIterator,
map: &SignalMap,
application_data: &mut Vec<(String, Vec<usize>)>
) {
let node = &iterator.encoding.nodes[iterator.node_id];
if node.is_custom_gate {
let mut signals = vec![];
for signal in &node.ordered_signals {
let new_signal = signal + iterator.offset;
let signal_numbering = map.get(&new_signal).unwrap();
signals.push(*signal_numbering);
}
application_data.push((node.name.clone(), signals));
} else {
for edge in EncodingIterator::edges(&iterator) {
let next = EncodingIterator::next(&iterator, edge);
iterate(next, map, application_data);
}
}
}

let mut application_data = vec![];
let iterator = EncodingIterator::new(&list.dag_encoding);
iterate(iterator, &list.signal_map, &mut application_data);
find_indexes(occurring_order, application_data)
};
custom_gates_applied_section.write_custom_gates_applications(application_data)?;
let _r1cs = custom_gates_applied_section.end_section()?;

Log::print(&log);
Ok(())
}
Loading

0 comments on commit de876d0

Please sign in to comment.