Skip to content

Commit

Permalink
[feat] Make selector compression optional (#212)
Browse files Browse the repository at this point in the history
* feat: add `keygen_vk_custom` function to turn off selector compression

* feat(vkey): add `version` and `compress_selectors` bytes to serialization

* feat: do not serialize selectors when `compress_selectors = false`

* feat: make `keygen_vk` have `compress_selectors = true` by default

* chore: fix clippy

* fix: pass empty poly to fake_selectors

* fix: byte_length calculation
  • Loading branch information
jonathanpwang authored Nov 8, 2023
1 parent 9b200c8 commit 9899a5c
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 33 deletions.
8 changes: 8 additions & 0 deletions halo2_proofs/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ pub trait SerdeCurveAffine: CurveAffine + SerdeObject {
_ => self.write_raw(writer),
}
}

/// Byte length of an affine curve element according to `format`.
fn byte_length(format: SerdeFormat) -> usize {
match format {
SerdeFormat::Processed => Self::default().to_bytes().as_ref().len(),
_ => Self::Repr::default().as_ref().len() * 2,
}
}
}
impl<C: CurveAffine + SerdeObject> SerdeCurveAffine for C {}

Expand Down
90 changes: 66 additions & 24 deletions halo2_proofs/src/plonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub struct VerifyingKey<C: CurveAffine> {
/// The representative of this `VerifyingKey` in transcripts.
transcript_repr: C::Scalar,
selectors: Vec<Vec<bool>>,
/// Whether selector compression is turned on or not.
compress_selectors: bool,
}

impl<C: SerdeCurveAffine> VerifyingKey<C>
Expand All @@ -72,13 +74,19 @@ where
/// Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
writer.write_all(&self.domain.k().to_be_bytes())?;
writer.write_all(&(self.fixed_commitments.len() as u32).to_be_bytes())?;
// Version byte that will be checked on read.
writer.write_all(&[0x02])?;
writer.write_all(&self.domain.k().to_le_bytes())?;
writer.write_all(&[self.compress_selectors as u8])?;
writer.write_all(&(self.fixed_commitments.len() as u32).to_le_bytes())?;
for commitment in &self.fixed_commitments {
commitment.write(writer, format)?;
}
self.permutation.write(writer, format)?;

if !self.compress_selectors {
assert!(self.selectors.is_empty());
}
// write self.selectors
for selector in &self.selectors {
// since `selector` is filled with `bool`, we pack them 8 at a time into bytes and then write
Expand All @@ -104,50 +112,76 @@ where
format: SerdeFormat,
#[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params,
) -> io::Result<Self> {
let mut version_byte = [0u8; 1];
reader.read_exact(&mut version_byte)?;
if 0x02 != version_byte[0] {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected version byte",
));
}
let mut k = [0u8; 4];
reader.read_exact(&mut k)?;
let k = u32::from_be_bytes(k);
let k = u32::from_le_bytes(k);
let mut compress_selectors = [0u8; 1];
reader.read_exact(&mut compress_selectors)?;
if compress_selectors[0] != 0 && compress_selectors[0] != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected compress_selectors not boolean",
));
}
let compress_selectors = compress_selectors[0] == 1;
let (domain, cs, _) = keygen::create_domain::<C, ConcreteCircuit>(
k,
#[cfg(feature = "circuit-params")]
params,
);
let mut num_fixed_columns = [0u8; 4];
reader.read_exact(&mut num_fixed_columns)?;
let num_fixed_columns = u32::from_be_bytes(num_fixed_columns);
let num_fixed_columns = u32::from_le_bytes(num_fixed_columns);

let fixed_commitments: Vec<_> = (0..num_fixed_columns)
.map(|_| C::read(reader, format))
.collect::<Result<_, _>>()?;

let permutation = permutation::VerifyingKey::read(reader, &cs.permutation, format)?;

// read selectors
let selectors: Vec<Vec<bool>> = vec![vec![false; 1 << k]; cs.num_selectors]
.into_iter()
.map(|mut selector| {
let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8];
reader.read_exact(&mut selector_bytes)?;
for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) {
crate::helpers::unpack(byte, bits);
}
Ok(selector)
})
.collect::<io::Result<_>>()?;
let (cs, _) = cs.compress_selectors(selectors.clone());
let (cs, selectors) = if compress_selectors {
// read selectors
let selectors: Vec<Vec<bool>> = vec![vec![false; 1 << k]; cs.num_selectors]
.into_iter()
.map(|mut selector| {
let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8];
reader.read_exact(&mut selector_bytes)?;
for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) {
crate::helpers::unpack(byte, bits);
}
Ok(selector)
})
.collect::<io::Result<_>>()?;
let (cs, _) = cs.compress_selectors(selectors.clone());
(cs, selectors)
} else {
// we still need to replace selectors with fixed Expressions in `cs`
let fake_selectors = vec![vec![]; cs.num_selectors];
let (cs, _) = cs.directly_convert_selectors_to_fixed(fake_selectors);
(cs, vec![])
};

Ok(Self::from_parts(
domain,
fixed_commitments,
permutation,
cs,
selectors,
compress_selectors,
))
}

/// Writes a verifying key to a vector of bytes using [`Self::write`].
pub fn to_bytes(&self, format: SerdeFormat) -> Vec<u8> {
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length());
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length(format));
Self::write(self, &mut bytes, format).expect("Writing to vector should not fail");
bytes
}
Expand All @@ -168,9 +202,12 @@ where
}

impl<C: CurveAffine> VerifyingKey<C> {
fn bytes_length(&self) -> usize {
8 + (self.fixed_commitments.len() * C::default().to_bytes().as_ref().len())
+ self.permutation.bytes_length()
fn bytes_length(&self, format: SerdeFormat) -> usize
where
C: SerdeCurveAffine,
{
10 + (self.fixed_commitments.len() * C::byte_length(format))
+ self.permutation.bytes_length(format)
+ self.selectors.len()
* (self
.selectors
Expand All @@ -185,6 +222,7 @@ impl<C: CurveAffine> VerifyingKey<C> {
permutation: permutation::VerifyingKey<C>,
cs: ConstraintSystem<C::Scalar>,
selectors: Vec<Vec<bool>>,
compress_selectors: bool,
) -> Self
where
C::ScalarExt: FromUniformBytes<64>,
Expand All @@ -201,6 +239,7 @@ impl<C: CurveAffine> VerifyingKey<C> {
// Temporary, this is not pinned.
transcript_repr: C::Scalar::ZERO,
selectors,
compress_selectors,
};

let mut hasher = Blake2bParams::new()
Expand Down Expand Up @@ -300,9 +339,12 @@ where
}

/// Gets the total number of bytes in the serialization of `self`
fn bytes_length(&self) -> usize {
fn bytes_length(&self, format: SerdeFormat) -> usize
where
C: SerdeCurveAffine,
{
let scalar_len = C::Scalar::default().to_repr().as_ref().len();
self.vk.bytes_length()
self.vk.bytes_length(format)
+ 12
+ scalar_len * (self.l0.len() + self.l_last.len() + self.l_active_row.len())
+ polynomial_slice_byte_length(&self.fixed_values)
Expand Down Expand Up @@ -383,7 +425,7 @@ where

/// Writes a proving key to a vector of bytes using [`Self::write`].
pub fn to_bytes(&self, format: SerdeFormat) -> Vec<u8> {
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length());
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length(format));
Self::write(self, &mut bytes, format).expect("Writing to vector should not fail");
bytes
}
Expand Down
46 changes: 41 additions & 5 deletions halo2_proofs/src/plonk/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2023,7 +2023,45 @@ impl<F: Field> ConstraintSystem<F> {
.into_iter()
.map(|a| a.unwrap())
.collect::<Vec<_>>();
self.replace_selectors_with_fixed(&selector_replacements);

(self, polys)
}

/// Does not combine selectors and directly replaces them everywhere with fixed columns.
pub fn directly_convert_selectors_to_fixed(
mut self,
selectors: Vec<Vec<bool>>,
) -> (Self, Vec<Vec<F>>) {
// The number of provided selector assignments must be the number we
// counted for this constraint system.
assert_eq!(selectors.len(), self.num_selectors);

let (polys, selector_replacements): (Vec<_>, Vec<_>) = selectors
.into_iter()
.map(|selector| {
let poly = selector
.iter()
.map(|b| if *b { F::ONE } else { F::ZERO })
.collect::<Vec<_>>();
let column = self.fixed_column();
let rotation = Rotation::cur();
let expr = Expression::Fixed(FixedQuery {
index: Some(self.query_fixed_index(column, rotation)),
column_index: column.index,
rotation,
});
(poly, expr)
})
.unzip();

self.replace_selectors_with_fixed(&selector_replacements);
self.num_selectors = 0;

(self, polys)
}

fn replace_selectors_with_fixed(&mut self, selector_replacements: &[Expression<F>]) {
fn replace_selectors<F: Field>(
expr: &mut Expression<F>,
selector_replacements: &[Expression<F>],
Expand Down Expand Up @@ -2054,7 +2092,7 @@ impl<F: Field> ConstraintSystem<F> {

// Substitute selectors for the real fixed columns in all gates
for expr in self.gates.iter_mut().flat_map(|gate| gate.polys.iter_mut()) {
replace_selectors(expr, &selector_replacements, false);
replace_selectors(expr, selector_replacements, false);
}

// Substitute non-simple selectors for the real fixed columns in all
Expand All @@ -2065,7 +2103,7 @@ impl<F: Field> ConstraintSystem<F> {
.iter_mut()
.chain(lookup.table_expressions.iter_mut())
}) {
replace_selectors(expr, &selector_replacements, true);
replace_selectors(expr, selector_replacements, true);
}

for expr in self.shuffles.iter_mut().flat_map(|shuffle| {
Expand All @@ -2074,10 +2112,8 @@ impl<F: Field> ConstraintSystem<F> {
.iter_mut()
.chain(shuffle.shuffle_expressions.iter_mut())
}) {
replace_selectors(expr, &selector_replacements, true);
replace_selectors(expr, selector_replacements, true);
}

(self, polys)
}

/// Allocate a new (simple) selector. Simple selectors cannot be added to
Expand Down
33 changes: 31 additions & 2 deletions halo2_proofs/src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,28 @@ impl<F: Field> Assignment<F> for Assembly<F> {
}

/// Generate a `VerifyingKey` from an instance of `Circuit`.
/// By default, selector compression is turned **off**.
pub fn keygen_vk<'params, C, P, ConcreteCircuit>(
params: &P,
circuit: &ConcreteCircuit,
) -> Result<VerifyingKey<C>, Error>
where
C: CurveAffine,
P: Params<'params, C>,
ConcreteCircuit: Circuit<C::Scalar>,
C::Scalar: FromUniformBytes<64>,
{
keygen_vk_custom(params, circuit, true)
}

/// Generate a `VerifyingKey` from an instance of `Circuit`.
///
/// The selector compression optimization is turned on only if `compress_selectors` is `true`.
pub fn keygen_vk_custom<'params, C, P, ConcreteCircuit>(
params: &P,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> Result<VerifyingKey<C>, Error>
where
C: CurveAffine,
P: Params<'params, C>,
Expand Down Expand Up @@ -241,7 +259,13 @@ where
)?;

let mut fixed = batch_invert_assigned(assembly.fixed);
let (cs, selector_polys) = cs.compress_selectors(assembly.selectors.clone());
let (cs, selector_polys) = if compress_selectors {
cs.compress_selectors(assembly.selectors.clone())
} else {
// After this, the ConstraintSystem should not have any selectors: `verify` does not need them, and `keygen_pk` regenerates `cs` from scratch anyways.
let selectors = std::mem::take(&mut assembly.selectors);
cs.directly_convert_selectors_to_fixed(selectors)
};
fixed.extend(
selector_polys
.into_iter()
Expand All @@ -263,6 +287,7 @@ where
permutation_vk,
cs,
assembly.selectors,
compress_selectors,
))
}

Expand Down Expand Up @@ -307,7 +332,11 @@ where
)?;

let mut fixed = batch_invert_assigned(assembly.fixed);
let (cs, selector_polys) = cs.compress_selectors(assembly.selectors);
let (cs, selector_polys) = if vk.compress_selectors {
cs.compress_selectors(assembly.selectors)
} else {
cs.directly_convert_selectors_to_fixed(assembly.selectors)
};
fixed.extend(
selector_polys
.into_iter()
Expand Down
7 changes: 5 additions & 2 deletions halo2_proofs/src/plonk/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ impl<C: CurveAffine> VerifyingKey<C> {
Ok(VerifyingKey { commitments })
}

pub(crate) fn bytes_length(&self) -> usize {
self.commitments.len() * C::default().to_bytes().as_ref().len()
pub(crate) fn bytes_length(&self, format: SerdeFormat) -> usize
where
C: SerdeCurveAffine,
{
self.commitments.len() * C::byte_length(format)
}
}

Expand Down

0 comments on commit 9899a5c

Please sign in to comment.