Skip to content

Commit

Permalink
fix: bring back naive parallelize for mv_lookup prover parallel scan
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanpwang committed Feb 1, 2024
1 parent d94d6f3 commit 45e940c
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 33 deletions.
13 changes: 11 additions & 2 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ crossbeam = "0.8"
ff = "0.13"
group = "0.13"
pairing = "0.23"
halo2curves = { package = "halo2curves-axiom", version = "0.5.0", default-features = false, features = ["bits", "bn256-table", "derive_serde"] }
halo2curves = { package = "halo2curves-axiom", version = "0.5.0", default-features = false, features = [
"bits",
"bn256-table",
"derive_serde",
] }
rand = "0.8"
rand_core = { version = "0.6", default-features = false }
tracing = "0.1"
Expand Down Expand Up @@ -105,7 +109,12 @@ getrandom = { version = "0.2", features = ["js"] }
default = ["batch", "multicore", "circuit-params", "logup_skip_inv"]
multicore = ["maybe-rayon/threads"]
dev-graph = ["plotters", "tabbycat"]
test-dev-graph = ["dev-graph", "plotters/bitmap_backend", "plotters/bitmap_encoder", "plotters/ttf"]
test-dev-graph = [
"dev-graph",
"plotters/bitmap_backend",
"plotters/bitmap_encoder",
"plotters/ttf",
]
gadget-traces = ["backtrace"]
# thread-safe-region = []
sanity-checks = []
Expand Down
31 changes: 31 additions & 0 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,37 @@ pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mu
scope.spawn(move |_| f(chunk, offset));
}
}
});
}

/// This simple utility function will parallelize an operation that is to be
/// performed over a mutable slice.
/// This naive version will have all chunks except the last one of the same size.
/// !! This is important for the mv_lookup prover parallel scan implementation at the moment. !!
pub(crate) fn parallelize_naive<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(
v: &mut [T],
f: F,
) -> Vec<usize> {
let n = v.len();
let num_threads = multicore::current_num_threads();
let mut chunk = n / num_threads;
if chunk < num_threads {
chunk = 1;
}

multicore::scope(|scope| {
let mut chunk_starts = vec![];
for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
let f = f.clone();
scope.spawn(move |_| {
let start = chunk_num * chunk;
f(v, start);
});
let start = chunk_num * chunk;
chunk_starts.push(start);
}

chunk_starts
})
}

Expand Down
14 changes: 2 additions & 12 deletions halo2_proofs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,10 @@
#![feature(associated_type_defaults)]

#[cfg(feature = "counter")]
#[macro_use]
extern crate lazy_static;
use std::{collections::BTreeMap, sync::Mutex};

#[cfg(feature = "counter")]
use lazy_static::lazy_static;

#[cfg(feature = "counter")]
use std::sync::Mutex;

#[cfg(feature = "counter")]
use std::collections::BTreeMap;

#[cfg(feature = "counter")]
lazy_static! {
lazy_static::lazy_static! {
static ref FFT_COUNTER: Mutex<BTreeMap<usize, usize>> = Mutex::new(BTreeMap::new());
static ref MSM_COUNTER: Mutex<BTreeMap<usize, usize>> = Mutex::new(BTreeMap::new());
}
Expand Down
2 changes: 1 addition & 1 deletion halo2_proofs/src/plonk/circuit/compress_selectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct SelectorDescription {
/// This describes the assigned combination of a particular selector as well as
/// the expression it should be substituted with.
#[derive(Debug, Clone)]
pub struct SelectorAssignment<F: Field> {
pub struct SelectorAssignment<F> {
/// The selector that this structure references, by index.
pub selector: usize,

Expand Down
21 changes: 8 additions & 13 deletions halo2_proofs/src/plonk/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

#[cfg(feature = "profile")]
use ark_std::{end_timer, start_timer};
#[cfg(not(feature = "logup_skip_inv"))]
use ff::BatchInvert;
use ff::{Field, PrimeField, WithSmallOrderMulGroup};
#[cfg(not(feature = "logup_skip_inv"))]
use rayon::slice::ParallelSlice;

#[cfg(not(feature = "logup_skip_inv"))]
use crate::arithmetic::par_invert;
use crate::multicore::{self, IntoParallelIterator, ParallelIterator};

Check warning on line 11 in halo2_proofs/src/plonk/evaluation.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-wasi

unused import: `ParallelIterator`

Check warning on line 11 in halo2_proofs/src/plonk/evaluation.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-wasi

unused import: `ParallelIterator`
use crate::{
arithmetic::{parallelize, CurveAffine},
Expand Down Expand Up @@ -562,20 +564,13 @@ impl<C: CurveAffine> Evaluator<C> {
.flatten()
.collect();

parallelize(&mut inputs_values_for_extended_domain, |values, _| {
values.batch_invert();
});
par_invert(&mut inputs_values_for_extended_domain);

let inputs_len = inputs_lookup_evaluator.len();

(0..size)
.into_par_iter()
.map(|i| {
inputs_values_for_extended_domain
[i * inputs_len..(i + 1) * inputs_len]
.iter()
.fold(C::Scalar::ZERO, |acc, x| acc + x)
})
inputs_values_for_extended_domain
.par_chunks_exact(inputs_len)
.map(|values| values.iter().sum())

Check failure on line 573 in halo2_proofs/src/plonk/evaluation.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-wasi

`rayon::slice::ChunksExact<'_, <C as CurveAffine>::ScalarExt>` is not an iterator

Check failure on line 573 in halo2_proofs/src/plonk/evaluation.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-wasi

`rayon::slice::ChunksExact<'_, <C as CurveAffine>::ScalarExt>` is not an iterator
.collect::<Vec<_>>()
})
.collect();
Expand Down
11 changes: 6 additions & 5 deletions halo2_proofs/src/plonk/mv_lookup/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rayon::prelude::{
IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator, ParallelSliceMut,
};

use crate::plonk::evaluation::evaluate;
use crate::{arithmetic::parallelize_naive, plonk::evaluation::evaluate};
use crate::{
arithmetic::{eval_polynomial, par_invert, parallelize, CurveAffine},
poly::{
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<F: PrimeField + WithSmallOrderMulGroup<3> + Ord> Argument<F> {
fixed_values: &'a [Polynomial<C::Scalar, LagrangeCoeff>],
instance_values: &'a [Polynomial<C::Scalar, LagrangeCoeff>],
challenges: &'a [C::Scalar],
rng: R, // in case we want to blind (do we actually need zk?)
#[allow(unused_mut)] mut rng: R, // in case we want to blind (do we actually need zk?)
transcript: &mut T,
) -> Result<Prepared<C>, Error>
where
Expand Down Expand Up @@ -336,17 +336,18 @@ impl<C: CurveAffine> Prepared<C> {
.chain(log_derivatives_diff)
.take(active_size)
.collect::<Vec<_>>();
// TODO: remove the implicit assumption that parallelize() split the grand_sum
// TODO: remove the implicit assumption that parallelize_naive() split the grand_sum
// into segments that each has `chunk` elements except the last.
parallelize(&mut grand_sum, |segment_grand_sum, _| {
// !! Do not use `parallelize()` here because it breaks the above assumption. !!
parallelize_naive(&mut grand_sum, |segment_grand_sum, _| {
for i in 1..segment_grand_sum.len() {
segment_grand_sum[i] += segment_grand_sum[i - 1];
}
});
for i in 1..segment_sum.len() {
segment_sum[i] = segment_sum[i - 1] + grand_sum[i * chunk - 1];
}
parallelize(&mut grand_sum, |grand_sum, start| {
parallelize_naive(&mut grand_sum, |grand_sum, start| {
let prefix_sum = segment_sum[start / chunk];
for v in grand_sum.iter_mut() {
*v += prefix_sum;
Expand Down

0 comments on commit 45e940c

Please sign in to comment.