Merge pull request #78 from Brechtpd/eval-mem

Reduce memory use in h evaluation
This commit is contained in:
Han 2022-08-09 00:40:06 +08:00 committed by GitHub
commit ad425ed3d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 378 additions and 378 deletions

View File

@ -46,6 +46,22 @@ pub enum ValueSource {
Advice(usize, usize),
/// This is an instance (external) column
Instance(usize, usize),
/// beta
Beta(),
/// gamma
Gamma(),
/// theta
Theta(),
/// y
Y(),
/// Previous value
PreviousValue(),
}
impl Default for ValueSource {
fn default() -> Self {
ValueSource::Constant(0)
}
}
impl ValueSource {
@ -58,6 +74,11 @@ impl ValueSource {
fixed_values: &[Polynomial<F, B>],
advice_values: &[Polynomial<F, B>],
instance_values: &[Polynomial<F, B>],
beta: &F,
gamma: &F,
theta: &F,
y: &F,
previous_value: &F,
) -> F {
match self {
ValueSource::Constant(idx) => constants[*idx],
@ -71,6 +92,11 @@ impl ValueSource {
ValueSource::Instance(column_index, rotation) => {
instance_values[*column_index][rotations[*rotation]]
}
ValueSource::Beta() => *beta,
ValueSource::Gamma() => *gamma,
ValueSource::Theta() => *theta,
ValueSource::Y() => *y,
ValueSource::PreviousValue() => *previous_value,
}
}
}
@ -84,14 +110,14 @@ pub enum Calculation {
Sub(ValueSource, ValueSource),
/// This is a product
Mul(ValueSource, ValueSource),
/// This is a square
Square(ValueSource),
/// This is a double
Double(ValueSource),
/// This is a negation
Negate(ValueSource),
/// This is `(a + beta) * b`
LcBeta(ValueSource, ValueSource),
/// This is `a * theta + b`
LcTheta(ValueSource, ValueSource),
/// This is `a + gamma`
AddGamma(ValueSource),
/// This is Horner's rule: `val = a; val = val * c + b[]`
Horner(ValueSource, Vec<ValueSource>, ValueSource),
/// This is a simple assignment
Store(ValueSource),
}
@ -109,146 +135,73 @@ impl Calculation {
beta: &F,
gamma: &F,
theta: &F,
y: &F,
previous_value: &F,
) -> F {
let get_value = |value: &ValueSource| {
value.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
beta,
gamma,
theta,
y,
previous_value,
)
};
match self {
Calculation::Add(a, b) => {
let a = a.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
let b = b.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
a + b
Calculation::Add(a, b) => get_value(a) + get_value(b),
Calculation::Sub(a, b) => get_value(a) - get_value(b),
Calculation::Mul(a, b) => get_value(a) * get_value(b),
Calculation::Square(v) => get_value(v).square(),
Calculation::Double(v) => get_value(v).double(),
Calculation::Negate(v) => -get_value(v),
Calculation::Horner(start_value, parts, factor) => {
let factor = get_value(factor);
let mut value = get_value(start_value);
for part in parts.iter() {
value = value * factor + get_value(part);
}
value
}
Calculation::Sub(a, b) => {
let a = a.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
let b = b.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
a - b
}
Calculation::Mul(a, b) => {
let a = a.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
let b = b.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
a * b
}
Calculation::Negate(v) => -v.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
),
Calculation::LcBeta(a, b) => {
let a = a.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
let b = b.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
(a + beta) * b
}
Calculation::LcTheta(a, b) => {
let a = a.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
let b = b.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
);
a * theta + b
}
Calculation::AddGamma(v) => {
v.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
) + gamma
}
Calculation::Store(v) => v.get(
rotations,
constants,
intermediates,
fixed_values,
advice_values,
instance_values,
),
Calculation::Store(v) => get_value(v),
}
}
}
/// EvaluationData
/// Evaluator
#[derive(Default, Debug)]
pub struct Evaluator<C: CurveAffine> {
/// Custom gates evalution
pub custom_gates: GraphEvaluator<C>,
/// Lookups evalution
pub lookups: Vec<GraphEvaluator<C>>,
}
/// GraphEvaluator
#[derive(Debug)]
pub struct GraphEvaluator<C: CurveAffine> {
/// Constants
pub constants: Vec<C::ScalarExt>,
/// Rotations
pub rotations: Vec<i32>,
/// Calculations
pub calculations: Vec<CalculationInfo>,
/// Value parts
pub value_parts: Vec<ValueSource>,
/// Lookup results
pub lookup_results: Vec<Calculation>,
/// Number of intermediates
pub num_intermediates: usize,
}
/// EvaluationData
#[derive(Default, Debug)]
pub struct EvaluationData<C: CurveAffine> {
/// Intermediates
pub intermediates: Vec<C::ScalarExt>,
/// Rotations
pub rotations: Vec<usize>,
}
/// CaluclationInfo
@ -256,205 +209,67 @@ pub struct Evaluator<C: CurveAffine> {
pub struct CalculationInfo {
/// Calculation
pub calculation: Calculation,
/// How many times this calculation is used
pub counter: usize,
/// Target
pub target: usize,
}
impl<C: CurveAffine> Evaluator<C> {
/// Creates a new evaluation structure
pub fn new(cs: &ConstraintSystem<C::ScalarExt>) -> Self {
let mut ev = Evaluator::default();
ev.add_constant(&C::ScalarExt::zero());
ev.add_constant(&C::ScalarExt::one());
// Custom gates
let mut parts = Vec::new();
for gate in cs.gates.iter() {
for poly in gate.polynomials().iter() {
let vs = ev.add_expression(poly);
ev.value_parts.push(vs);
}
parts.extend(
gate.polynomials()
.iter()
.map(|poly| ev.custom_gates.add_expression(poly)),
);
}
ev.custom_gates.add_calculation(Calculation::Horner(
ValueSource::PreviousValue(),
parts,
ValueSource::Y(),
));
// Lookups
for lookup in cs.lookups.iter() {
let evaluate_lc = |ev: &mut Evaluator<_>, expressions: &Vec<Expression<_>>| {
let mut graph = GraphEvaluator::default();
let mut evaluate_lc = |expressions: &Vec<Expression<_>>| {
let parts = expressions
.iter()
.map(|expr| ev.add_expression(expr))
.collect::<Vec<_>>();
let mut lc = parts[0];
for part in parts.iter().skip(1) {
lc = ev.add_calculation(Calculation::LcTheta(lc, *part));
}
lc
.map(|expr| graph.add_expression(expr))
.collect();
graph.add_calculation(Calculation::Horner(
ValueSource::Constant(0),
parts,
ValueSource::Theta(),
))
};
// Input coset
let compressed_input_coset = evaluate_lc(&mut ev, &lookup.input_expressions);
let compressed_input_coset = evaluate_lc(&lookup.input_expressions);
// table coset
let compressed_table_coset = evaluate_lc(&mut ev, &lookup.table_expressions);
let compressed_table_coset = evaluate_lc(&lookup.table_expressions);
// z(\omega X) (a'(X) + \beta) (s'(X) + \gamma)
let right_gamma = ev.add_calculation(Calculation::AddGamma(compressed_table_coset));
ev.lookup_results
.push(Calculation::LcBeta(compressed_input_coset, right_gamma));
let right_gamma = graph.add_calculation(Calculation::Add(
compressed_table_coset,
ValueSource::Gamma(),
));
let lc = graph.add_calculation(Calculation::Add(
compressed_input_coset,
ValueSource::Beta(),
));
graph.add_calculation(Calculation::Mul(lc, right_gamma));
ev.lookups.push(graph);
}
ev
}
/// Adds a rotation
fn add_rotation(&mut self, rotation: &Rotation) -> usize {
let position = self.rotations.iter().position(|&c| c == rotation.0);
match position {
Some(pos) => pos,
None => {
self.rotations.push(rotation.0);
self.rotations.len() - 1
}
}
}
/// Adds a constant
fn add_constant(&mut self, constant: &C::ScalarExt) -> ValueSource {
let position = self.constants.iter().position(|&c| c == *constant);
ValueSource::Constant(match position {
Some(pos) => pos,
None => {
self.constants.push(*constant);
self.constants.len() - 1
}
})
}
/// Adds a calculation.
/// Currently does the simplest thing possible: just stores the
/// resulting value so the result can be reused when that calculation
/// is done multiple times.
fn add_calculation(&mut self, calculation: Calculation) -> ValueSource {
let position = self
.calculations
.iter()
.position(|c| c.calculation == calculation);
match position {
Some(pos) => {
self.calculations[pos].counter += 1;
ValueSource::Intermediate(pos)
}
None => {
self.calculations.push(CalculationInfo {
counter: 1,
calculation,
});
ValueSource::Intermediate(self.calculations.len() - 1)
}
}
}
/// Generates an optimized evaluation for the expression
fn add_expression(&mut self, expr: &Expression<C::ScalarExt>) -> ValueSource {
match expr {
Expression::Constant(scalar) => self.add_constant(scalar),
Expression::Selector(_selector) => unreachable!(),
Expression::Fixed {
query_index: _,
column_index,
rotation,
} => {
let rot_idx = self.add_rotation(rotation);
self.add_calculation(Calculation::Store(ValueSource::Fixed(
*column_index,
rot_idx,
)))
}
Expression::Advice {
query_index: _,
column_index,
rotation,
} => {
let rot_idx = self.add_rotation(rotation);
self.add_calculation(Calculation::Store(ValueSource::Advice(
*column_index,
rot_idx,
)))
}
Expression::Instance {
query_index: _,
column_index,
rotation,
} => {
let rot_idx = self.add_rotation(rotation);
self.add_calculation(Calculation::Store(ValueSource::Instance(
*column_index,
rot_idx,
)))
}
Expression::Negated(a) => match **a {
Expression::Constant(scalar) => self.add_constant(&-scalar),
_ => {
let result_a = self.add_expression(a);
match result_a {
ValueSource::Constant(0) => result_a,
_ => self.add_calculation(Calculation::Negate(result_a)),
}
}
},
Expression::Sum(a, b) => {
// Undo subtraction stored as a + (-b) in expressions
match &**b {
Expression::Negated(b_int) => {
let result_a = self.add_expression(a);
let result_b = self.add_expression(b_int);
if result_a == ValueSource::Constant(0) {
self.add_calculation(Calculation::Negate(result_b))
} else if result_b == ValueSource::Constant(0) {
result_a
} else {
self.add_calculation(Calculation::Sub(result_a, result_b))
}
}
_ => {
let result_a = self.add_expression(a);
let result_b = self.add_expression(b);
if result_a == ValueSource::Constant(0) {
result_b
} else if result_b == ValueSource::Constant(0) {
result_a
} else if result_a <= result_b {
self.add_calculation(Calculation::Add(result_a, result_b))
} else {
self.add_calculation(Calculation::Add(result_b, result_a))
}
}
}
}
Expression::Product(a, b) => {
let result_a = self.add_expression(a);
let result_b = self.add_expression(b);
if result_a == ValueSource::Constant(0) || result_b == ValueSource::Constant(0) {
ValueSource::Constant(0)
} else if result_a == ValueSource::Constant(1) {
result_b
} else if result_b == ValueSource::Constant(1) {
result_a
} else if result_a <= result_b {
self.add_calculation(Calculation::Mul(result_a, result_b))
} else {
self.add_calculation(Calculation::Mul(result_b, result_a))
}
}
Expression::Scaled(a, f) => {
if *f == C::ScalarExt::zero() {
ValueSource::Constant(0)
} else if *f == C::ScalarExt::one() {
self.add_expression(a)
} else {
let cst = self.add_constant(f);
let result_a = self.add_expression(a);
self.add_calculation(Calculation::Mul(result_a, cst))
}
}
}
}
/// Evaluate h poly
pub(in crate::plonk) fn evaluate_h(
&self,
@ -473,7 +288,6 @@ impl<C: CurveAffine> Evaluator<C> {
let rot_scale = 1 << (domain.extended_k() - domain.k());
let fixed = &pk.fixed_cosets[..];
let extended_omega = domain.get_extended_omega();
let num_lookups = pk.vk.cs.lookups.len();
let isize = size as i32;
let one = C::ScalarExt::one();
let l0 = &pk.l0;
@ -482,76 +296,38 @@ impl<C: CurveAffine> Evaluator<C> {
let p = &pk.vk.cs.permutation;
let mut values = domain.empty_extended();
let mut lookup_values = vec![C::Scalar::zero(); size * num_lookups];
// Core expression evaluations
let num_threads = multicore::current_num_threads();
let mut table_values_box = ThreadBox::wrap(&mut lookup_values);
for (((advice, instance), lookups), permutation) in advice
.iter()
.zip(instance.iter())
.zip(lookups.iter())
.zip(permutations.iter())
{
// Custom gates
multicore::scope(|scope| {
let chunk_size = (size + num_threads - 1) / num_threads;
for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() {
let start = thread_idx * chunk_size;
scope.spawn(move |_| {
let table_values = table_values_box.unwrap();
let mut rotations = vec![0usize; self.rotations.len()];
let mut intermediates: Vec<C::ScalarExt> =
vec![C::ScalarExt::zero(); self.calculations.len()];
let mut eval_data = self.custom_gates.instance();
for (i, value) in values.iter_mut().enumerate() {
let idx = start + i;
// All rotation index values
for (rot_idx, rot) in self.rotations.iter().enumerate() {
rotations[rot_idx] = get_rotation_idx(idx, *rot, rot_scale, isize);
}
// All calculations, with cached intermediate results
for (i_idx, calc) in self.calculations.iter().enumerate() {
intermediates[i_idx] = calc.calculation.evaluate(
&rotations,
&self.constants,
&intermediates,
fixed,
advice,
instance,
&beta,
&gamma,
&theta,
);
}
// Accumulate value parts
for value_part in self.value_parts.iter() {
*value = *value * y
+ value_part.get(
&rotations,
&self.constants,
&intermediates,
fixed,
advice,
instance,
);
}
// Values required for the lookups
for (t, table_result) in self.lookup_results.iter().enumerate() {
table_values[t * size + idx] = table_result.evaluate(
&rotations,
&self.constants,
&intermediates,
fixed,
advice,
instance,
&beta,
&gamma,
&theta,
);
}
*value = self.custom_gates.evaluate(
&mut eval_data,
fixed,
advice,
instance,
&beta,
&gamma,
&theta,
&y,
value,
idx,
rot_scale,
isize,
);
}
});
}
@ -655,11 +431,27 @@ impl<C: CurveAffine> Evaluator<C> {
.coeff_to_extended(lookup.permuted_table_poly.clone());
// Lookup constraints
let table = &lookup_values[n * size..(n + 1) * size];
parallelize(&mut values, |values, start| {
let lookup_evaluator = &self.lookups[n];
let mut eval_data = lookup_evaluator.instance();
for (i, value) in values.iter_mut().enumerate() {
let idx = start + i;
let table_value = lookup_evaluator.evaluate(
&mut eval_data,
fixed,
advice,
instance,
&beta,
&gamma,
&theta,
&y,
&C::ScalarExt::zero(),
idx,
rot_scale,
isize,
);
let r_next = get_rotation_idx(idx, 1, rot_scale, isize);
let r_prev = get_rotation_idx(idx, -1, rot_scale, isize);
@ -679,7 +471,7 @@ impl<C: CurveAffine> Evaluator<C> {
+ ((product_coset[r_next]
* (permuted_input_coset[idx] + beta)
* (permuted_table_coset[idx] + gamma)
- product_coset[idx] * table[idx])
- product_coset[idx] * table_value)
* l_active_row[idx]);
// Check that the first values in the permuted input expression and permuted
// fixed expression are the same.
@ -701,24 +493,232 @@ impl<C: CurveAffine> Evaluator<C> {
}
}
#[derive(Clone, Copy)]
struct ThreadBox<T>(*mut T, usize);
#[allow(unsafe_code)]
unsafe impl<T> Send for ThreadBox<T> {}
#[allow(unsafe_code)]
unsafe impl<T> Sync for ThreadBox<T> {}
impl<C: CurveAffine> Default for GraphEvaluator<C> {
fn default() -> Self {
Self {
// Fixed positions to allow easy access
constants: vec![
C::ScalarExt::zero(),
C::ScalarExt::one(),
C::ScalarExt::from(2u64),
],
rotations: Vec::new(),
calculations: Vec::new(),
num_intermediates: 0,
}
}
}
/// Wraps a mutable slice so it can be passed into a thread without
/// hard to fix borrow checks caused by difficult data access patterns.
impl<T> ThreadBox<T> {
fn wrap(data: &mut [T]) -> Self {
Self(data.as_mut_ptr(), data.len())
impl<C: CurveAffine> GraphEvaluator<C> {
/// Adds a rotation
fn add_rotation(&mut self, rotation: &Rotation) -> usize {
let position = self.rotations.iter().position(|&c| c == rotation.0);
match position {
Some(pos) => pos,
None => {
self.rotations.push(rotation.0);
self.rotations.len() - 1
}
}
}
fn unwrap(&mut self) -> &mut [T] {
#[allow(unsafe_code)]
unsafe {
slice::from_raw_parts_mut(self.0, self.1)
/// Adds a constant
fn add_constant(&mut self, constant: &C::ScalarExt) -> ValueSource {
let position = self.constants.iter().position(|&c| c == *constant);
ValueSource::Constant(match position {
Some(pos) => pos,
None => {
self.constants.push(*constant);
self.constants.len() - 1
}
})
}
/// Adds a calculation.
/// Currently does the simplest thing possible: just stores the
/// resulting value so the result can be reused when that calculation
/// is done multiple times.
fn add_calculation(&mut self, calculation: Calculation) -> ValueSource {
let existing_calculation = self
.calculations
.iter()
.find(|c| c.calculation == calculation);
match existing_calculation {
Some(existing_calculation) => ValueSource::Intermediate(existing_calculation.target),
None => {
let target = self.num_intermediates;
self.calculations.push(CalculationInfo {
calculation,
target,
});
self.num_intermediates += 1;
ValueSource::Intermediate(target)
}
}
}
/// Generates an optimized evaluation for the expression
fn add_expression(&mut self, expr: &Expression<C::ScalarExt>) -> ValueSource {
match expr {
Expression::Constant(scalar) => self.add_constant(scalar),
Expression::Selector(_selector) => unreachable!(),
Expression::Fixed {
query_index: _,
column_index,
rotation,
} => {
let rot_idx = self.add_rotation(rotation);
self.add_calculation(Calculation::Store(ValueSource::Fixed(
*column_index,
rot_idx,
)))
}
Expression::Advice {
query_index: _,
column_index,
rotation,
} => {
let rot_idx = self.add_rotation(rotation);
self.add_calculation(Calculation::Store(ValueSource::Advice(
*column_index,
rot_idx,
)))
}
Expression::Instance {
query_index: _,
column_index,
rotation,
} => {
let rot_idx = self.add_rotation(rotation);
self.add_calculation(Calculation::Store(ValueSource::Instance(
*column_index,
rot_idx,
)))
}
Expression::Negated(a) => match **a {
Expression::Constant(scalar) => self.add_constant(&-scalar),
_ => {
let result_a = self.add_expression(a);
match result_a {
ValueSource::Constant(0) => result_a,
_ => self.add_calculation(Calculation::Negate(result_a)),
}
}
},
Expression::Sum(a, b) => {
// Undo subtraction stored as a + (-b) in expressions
match &**b {
Expression::Negated(b_int) => {
let result_a = self.add_expression(a);
let result_b = self.add_expression(b_int);
if result_a == ValueSource::Constant(0) {
self.add_calculation(Calculation::Negate(result_b))
} else if result_b == ValueSource::Constant(0) {
result_a
} else {
self.add_calculation(Calculation::Sub(result_a, result_b))
}
}
_ => {
let result_a = self.add_expression(a);
let result_b = self.add_expression(b);
if result_a == ValueSource::Constant(0) {
result_b
} else if result_b == ValueSource::Constant(0) {
result_a
} else if result_a <= result_b {
self.add_calculation(Calculation::Add(result_a, result_b))
} else {
self.add_calculation(Calculation::Add(result_b, result_a))
}
}
}
}
Expression::Product(a, b) => {
let result_a = self.add_expression(a);
let result_b = self.add_expression(b);
if result_a == ValueSource::Constant(0) || result_b == ValueSource::Constant(0) {
ValueSource::Constant(0)
} else if result_a == ValueSource::Constant(1) {
result_b
} else if result_b == ValueSource::Constant(1) {
result_a
} else if result_a == ValueSource::Constant(2) {
self.add_calculation(Calculation::Double(result_b))
} else if result_b == ValueSource::Constant(2) {
self.add_calculation(Calculation::Double(result_a))
} else if result_a == result_b {
self.add_calculation(Calculation::Square(result_a))
} else if result_a <= result_b {
self.add_calculation(Calculation::Mul(result_a, result_b))
} else {
self.add_calculation(Calculation::Mul(result_b, result_a))
}
}
Expression::Scaled(a, f) => {
if *f == C::ScalarExt::zero() {
ValueSource::Constant(0)
} else if *f == C::ScalarExt::one() {
self.add_expression(a)
} else {
let cst = self.add_constant(f);
let result_a = self.add_expression(a);
self.add_calculation(Calculation::Mul(result_a, cst))
}
}
}
}
/// Creates a new evaluation structure
pub fn instance(&self) -> EvaluationData<C> {
EvaluationData {
intermediates: vec![C::ScalarExt::zero(); self.num_intermediates],
rotations: vec![0usize; self.rotations.len()],
}
}
pub fn evaluate<B: Basis>(
&self,
data: &mut EvaluationData<C>,
fixed: &[Polynomial<C::ScalarExt, B>],
advice: &[Polynomial<C::ScalarExt, B>],
instance: &[Polynomial<C::ScalarExt, B>],
beta: &C::ScalarExt,
gamma: &C::ScalarExt,
theta: &C::ScalarExt,
y: &C::ScalarExt,
previous_value: &C::ScalarExt,
idx: usize,
rot_scale: i32,
isize: i32,
) -> C::ScalarExt {
// All rotation index values
for (rot_idx, rot) in self.rotations.iter().enumerate() {
data.rotations[rot_idx] = get_rotation_idx(idx, *rot, rot_scale, isize);
}
// All calculations, with cached intermediate results
for calc in self.calculations.iter() {
data.intermediates[calc.target] = calc.calculation.evaluate(
&data.rotations,
&self.constants,
&data.intermediates,
fixed,
advice,
instance,
beta,
gamma,
theta,
y,
previous_value,
);
}
// Return the result of the last calculation (if any)
if let Some(calc) = self.calculations.last() {
data.intermediates[calc.target]
} else {
C::ScalarExt::zero()
}
}
}