From 9899a5c6fbba293f563dd162eba0468ed3d01c54 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 7 Nov 2023 22:58:25 -0800 Subject: [PATCH] [feat] Make selector compression optional (#212) * feat: add `keygen_vk_custom` function to turn off selector compression * feat(vkey): add `version` and `compress_selectors` bytes to serialization * feat: do not serialize selectors when `compress_selectors = false` * feat: make `keygen_vk` have `compress_selectors = true` by default * chore: fix clippy * fix: pass empty poly to fake_selectors * fix: byte_length calculation --- halo2_proofs/src/helpers.rs | 8 +++ halo2_proofs/src/plonk.rs | 90 ++++++++++++++++++++------- halo2_proofs/src/plonk/circuit.rs | 46 ++++++++++++-- halo2_proofs/src/plonk/keygen.rs | 33 +++++++++- halo2_proofs/src/plonk/permutation.rs | 7 ++- 5 files changed, 151 insertions(+), 33 deletions(-) diff --git a/halo2_proofs/src/helpers.rs b/halo2_proofs/src/helpers.rs index b3f47a2059..faf7351a3e 100644 --- a/halo2_proofs/src/helpers.rs +++ b/halo2_proofs/src/helpers.rs @@ -55,6 +55,14 @@ pub trait SerdeCurveAffine: CurveAffine + SerdeObject { _ => self.write_raw(writer), } } + + /// Byte length of an affine curve element according to `format`. + fn byte_length(format: SerdeFormat) -> usize { + match format { + SerdeFormat::Processed => Self::default().to_bytes().as_ref().len(), + _ => Self::Repr::default().as_ref().len() * 2, + } + } } impl SerdeCurveAffine for C {} diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index d05e7a4005..5506f94a68 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -56,6 +56,8 @@ pub struct VerifyingKey { /// The representative of this `VerifyingKey` in transcripts. transcript_repr: C::Scalar, selectors: Vec>, + /// Whether selector compression is turned on or not. + compress_selectors: bool, } impl VerifyingKey @@ -72,13 +74,19 @@ where /// Writes a field element into raw bytes in its internal Montgomery representation, /// WITHOUT performing the expensive Montgomery reduction. pub fn write(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> { - writer.write_all(&self.domain.k().to_be_bytes())?; - writer.write_all(&(self.fixed_commitments.len() as u32).to_be_bytes())?; + // Version byte that will be checked on read. + writer.write_all(&[0x02])?; + writer.write_all(&self.domain.k().to_le_bytes())?; + writer.write_all(&[self.compress_selectors as u8])?; + writer.write_all(&(self.fixed_commitments.len() as u32).to_le_bytes())?; for commitment in &self.fixed_commitments { commitment.write(writer, format)?; } self.permutation.write(writer, format)?; + if !self.compress_selectors { + assert!(self.selectors.is_empty()); + } // write self.selectors for selector in &self.selectors { // since `selector` is filled with `bool`, we pack them 8 at a time into bytes and then write @@ -104,9 +112,26 @@ where format: SerdeFormat, #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, ) -> io::Result { + let mut version_byte = [0u8; 1]; + reader.read_exact(&mut version_byte)?; + if 0x02 != version_byte[0] { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected version byte", + )); + } let mut k = [0u8; 4]; reader.read_exact(&mut k)?; - let k = u32::from_be_bytes(k); + let k = u32::from_le_bytes(k); + let mut compress_selectors = [0u8; 1]; + reader.read_exact(&mut compress_selectors)?; + if compress_selectors[0] != 0 && compress_selectors[0] != 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected compress_selectors not boolean", + )); + } + let compress_selectors = compress_selectors[0] == 1; let (domain, cs, _) = keygen::create_domain::( k, #[cfg(feature = "circuit-params")] @@ -114,7 +139,7 @@ where ); let mut num_fixed_columns = [0u8; 4]; reader.read_exact(&mut num_fixed_columns)?; - let num_fixed_columns = u32::from_be_bytes(num_fixed_columns); + let num_fixed_columns = u32::from_le_bytes(num_fixed_columns); let fixed_commitments: Vec<_> = (0..num_fixed_columns) .map(|_| C::read(reader, format)) @@ -122,19 +147,27 @@ where let permutation = permutation::VerifyingKey::read(reader, &cs.permutation, format)?; - // read selectors - let selectors: Vec> = vec![vec![false; 1 << k]; cs.num_selectors] - .into_iter() - .map(|mut selector| { - let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8]; - reader.read_exact(&mut selector_bytes)?; - for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) { - crate::helpers::unpack(byte, bits); - } - Ok(selector) - }) - .collect::>()?; - let (cs, _) = cs.compress_selectors(selectors.clone()); + let (cs, selectors) = if compress_selectors { + // read selectors + let selectors: Vec> = vec![vec![false; 1 << k]; cs.num_selectors] + .into_iter() + .map(|mut selector| { + let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8]; + reader.read_exact(&mut selector_bytes)?; + for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) { + crate::helpers::unpack(byte, bits); + } + Ok(selector) + }) + .collect::>()?; + let (cs, _) = cs.compress_selectors(selectors.clone()); + (cs, selectors) + } else { + // we still need to replace selectors with fixed Expressions in `cs` + let fake_selectors = vec![vec![]; cs.num_selectors]; + let (cs, _) = cs.directly_convert_selectors_to_fixed(fake_selectors); + (cs, vec![]) + }; Ok(Self::from_parts( domain, @@ -142,12 +175,13 @@ where permutation, cs, selectors, + compress_selectors, )) } /// Writes a verifying key to a vector of bytes using [`Self::write`]. pub fn to_bytes(&self, format: SerdeFormat) -> Vec { - let mut bytes = Vec::::with_capacity(self.bytes_length()); + let mut bytes = Vec::::with_capacity(self.bytes_length(format)); Self::write(self, &mut bytes, format).expect("Writing to vector should not fail"); bytes } @@ -168,9 +202,12 @@ where } impl VerifyingKey { - fn bytes_length(&self) -> usize { - 8 + (self.fixed_commitments.len() * C::default().to_bytes().as_ref().len()) - + self.permutation.bytes_length() + fn bytes_length(&self, format: SerdeFormat) -> usize + where + C: SerdeCurveAffine, + { + 10 + (self.fixed_commitments.len() * C::byte_length(format)) + + self.permutation.bytes_length(format) + self.selectors.len() * (self .selectors @@ -185,6 +222,7 @@ impl VerifyingKey { permutation: permutation::VerifyingKey, cs: ConstraintSystem, selectors: Vec>, + compress_selectors: bool, ) -> Self where C::ScalarExt: FromUniformBytes<64>, @@ -201,6 +239,7 @@ impl VerifyingKey { // Temporary, this is not pinned. transcript_repr: C::Scalar::ZERO, selectors, + compress_selectors, }; let mut hasher = Blake2bParams::new() @@ -300,9 +339,12 @@ where } /// Gets the total number of bytes in the serialization of `self` - fn bytes_length(&self) -> usize { + fn bytes_length(&self, format: SerdeFormat) -> usize + where + C: SerdeCurveAffine, + { let scalar_len = C::Scalar::default().to_repr().as_ref().len(); - self.vk.bytes_length() + self.vk.bytes_length(format) + 12 + scalar_len * (self.l0.len() + self.l_last.len() + self.l_active_row.len()) + polynomial_slice_byte_length(&self.fixed_values) @@ -383,7 +425,7 @@ where /// Writes a proving key to a vector of bytes using [`Self::write`]. pub fn to_bytes(&self, format: SerdeFormat) -> Vec { - let mut bytes = Vec::::with_capacity(self.bytes_length()); + let mut bytes = Vec::::with_capacity(self.bytes_length(format)); Self::write(self, &mut bytes, format).expect("Writing to vector should not fail"); bytes } diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index cce39c8115..d4d877d9b1 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -2023,7 +2023,45 @@ impl ConstraintSystem { .into_iter() .map(|a| a.unwrap()) .collect::>(); + self.replace_selectors_with_fixed(&selector_replacements); + (self, polys) + } + + /// Does not combine selectors and directly replaces them everywhere with fixed columns. + pub fn directly_convert_selectors_to_fixed( + mut self, + selectors: Vec>, + ) -> (Self, Vec>) { + // The number of provided selector assignments must be the number we + // counted for this constraint system. + assert_eq!(selectors.len(), self.num_selectors); + + let (polys, selector_replacements): (Vec<_>, Vec<_>) = selectors + .into_iter() + .map(|selector| { + let poly = selector + .iter() + .map(|b| if *b { F::ONE } else { F::ZERO }) + .collect::>(); + let column = self.fixed_column(); + let rotation = Rotation::cur(); + let expr = Expression::Fixed(FixedQuery { + index: Some(self.query_fixed_index(column, rotation)), + column_index: column.index, + rotation, + }); + (poly, expr) + }) + .unzip(); + + self.replace_selectors_with_fixed(&selector_replacements); + self.num_selectors = 0; + + (self, polys) + } + + fn replace_selectors_with_fixed(&mut self, selector_replacements: &[Expression]) { fn replace_selectors( expr: &mut Expression, selector_replacements: &[Expression], @@ -2054,7 +2092,7 @@ impl ConstraintSystem { // Substitute selectors for the real fixed columns in all gates for expr in self.gates.iter_mut().flat_map(|gate| gate.polys.iter_mut()) { - replace_selectors(expr, &selector_replacements, false); + replace_selectors(expr, selector_replacements, false); } // Substitute non-simple selectors for the real fixed columns in all @@ -2065,7 +2103,7 @@ impl ConstraintSystem { .iter_mut() .chain(lookup.table_expressions.iter_mut()) }) { - replace_selectors(expr, &selector_replacements, true); + replace_selectors(expr, selector_replacements, true); } for expr in self.shuffles.iter_mut().flat_map(|shuffle| { @@ -2074,10 +2112,8 @@ impl ConstraintSystem { .iter_mut() .chain(shuffle.shuffle_expressions.iter_mut()) }) { - replace_selectors(expr, &selector_replacements, true); + replace_selectors(expr, selector_replacements, true); } - - (self, polys) } /// Allocate a new (simple) selector. Simple selectors cannot be added to diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index bd48b7c96a..984eecb9e8 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -203,10 +203,28 @@ impl Assignment for Assembly { } /// Generate a `VerifyingKey` from an instance of `Circuit`. +/// By default, selector compression is turned **off**. pub fn keygen_vk<'params, C, P, ConcreteCircuit>( params: &P, circuit: &ConcreteCircuit, ) -> Result, Error> +where + C: CurveAffine, + P: Params<'params, C>, + ConcreteCircuit: Circuit, + C::Scalar: FromUniformBytes<64>, +{ + keygen_vk_custom(params, circuit, true) +} + +/// Generate a `VerifyingKey` from an instance of `Circuit`. +/// +/// The selector compression optimization is turned on only if `compress_selectors` is `true`. +pub fn keygen_vk_custom<'params, C, P, ConcreteCircuit>( + params: &P, + circuit: &ConcreteCircuit, + compress_selectors: bool, +) -> Result, Error> where C: CurveAffine, P: Params<'params, C>, @@ -241,7 +259,13 @@ where )?; let mut fixed = batch_invert_assigned(assembly.fixed); - let (cs, selector_polys) = cs.compress_selectors(assembly.selectors.clone()); + let (cs, selector_polys) = if compress_selectors { + cs.compress_selectors(assembly.selectors.clone()) + } else { + // After this, the ConstraintSystem should not have any selectors: `verify` does not need them, and `keygen_pk` regenerates `cs` from scratch anyways. + let selectors = std::mem::take(&mut assembly.selectors); + cs.directly_convert_selectors_to_fixed(selectors) + }; fixed.extend( selector_polys .into_iter() @@ -263,6 +287,7 @@ where permutation_vk, cs, assembly.selectors, + compress_selectors, )) } @@ -307,7 +332,11 @@ where )?; let mut fixed = batch_invert_assigned(assembly.fixed); - let (cs, selector_polys) = cs.compress_selectors(assembly.selectors); + let (cs, selector_polys) = if vk.compress_selectors { + cs.compress_selectors(assembly.selectors) + } else { + cs.directly_convert_selectors_to_fixed(assembly.selectors) + }; fixed.extend( selector_polys .into_iter() diff --git a/halo2_proofs/src/plonk/permutation.rs b/halo2_proofs/src/plonk/permutation.rs index 3c54f51943..22c1fad6c3 100644 --- a/halo2_proofs/src/plonk/permutation.rs +++ b/halo2_proofs/src/plonk/permutation.rs @@ -117,8 +117,11 @@ impl VerifyingKey { Ok(VerifyingKey { commitments }) } - pub(crate) fn bytes_length(&self) -> usize { - self.commitments.len() * C::default().to_bytes().as_ref().len() + pub(crate) fn bytes_length(&self, format: SerdeFormat) -> usize + where + C: SerdeCurveAffine, + { + self.commitments.len() * C::byte_length(format) } }