Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Dec 6, 2023
1 parent 13fe808 commit 6fa878d
Show file tree
Hide file tree
Showing 12 changed files with 7,905 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ tempfile = "3.6.0"
camino = "1.1.6"
thiserror = "1.0.44"
tracing = "0.1.37"
tracing-texray = "0.2.0"
tracing-texray = { git = "https://github.com/winston-h-zhang/tracing-texray", branch = "shim" }
tracing-subscriber = "0.3.17"

[[bin]]
Expand Down
1,691 changes: 1,691 additions & 0 deletions benches/dev/600.txt

Large diffs are not rendered by default.

3,132 changes: 3,132 additions & 0 deletions benches/dev/900.txt

Large diffs are not rendered by default.

33 changes: 17 additions & 16 deletions benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ use lurk::{
field::LurkField,
lem::{eval::evaluate, multiframe::MultiFrame, pointers::Ptr, store::Store},
proof::nova::NovaProver,
proof::Prover,
public_parameters::{
instance::{Instance, Kind},
public_params,
},
proof::{nova::public_params, Prover},
state::State,
};

use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry};
use tracing_texray::TeXRayLayer;

mod common;
use common::set_bench_config;

Expand Down Expand Up @@ -111,14 +110,8 @@ fn fibonacci_prove<M: measurement::Measurement>(
let lang_pallas = Lang::<pallas::Scalar, Coproc<pallas::Scalar>>::new();
let lang_rc = Arc::new(lang_pallas.clone());

// use cached public params
let instance = Instance::new(
prove_params.reduction_count,
lang_rc.clone(),
true,
Kind::NovaPublicParams,
);
let pp = public_params::<_, _, MultiFrame<'_, _, _>>(&instance).unwrap();
let pp =
public_params::<_, _, MultiFrame<'_, _, _>>(prove_params.reduction_count, lang_rc.clone());

// Track the number of `Lurk frames / sec`
let rc = prove_params.reduction_count as u64;
Expand Down Expand Up @@ -148,7 +141,8 @@ fn fibonacci_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove(&pp, frames, &store);
let result = tracing_texray::examine(tracing::info_span!("bang!"))
.in_scope(|| prover.prove(&pp, frames, &store));
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand All @@ -159,12 +153,17 @@ fn fibonacci_prove<M: measurement::Measurement>(

fn fibonacci_benchmark(c: &mut Criterion) {
// Uncomment to record the logs. May negatively impact performance
//tracing_subscriber::fmt::init();
let subscriber = Registry::default()
.with(fmt::layer().pretty())
.with(EnvFilter::from_default_env())
.with(TeXRayLayer::new().width(120));
tracing::subscriber::set_global_default(subscriber).unwrap();

set_bench_config();
tracing::debug!("{:?}", lurk::config::LURK_CONFIG);

let reduction_counts = rc_env().unwrap_or_else(|_| vec![100]);
let batch_sizes = [100, 200];
let batch_sizes = [249, 374, 499];

let state = State::init_lurk_state().rccell();

Expand All @@ -187,6 +186,8 @@ fn fibonacci_benchmark(c: &mut Criterion) {
}
}

// RUST_LOG=info LURK_RC=600 LURK_PERF=max-parallel-simple cargo criterion --bench fibonacci --features "cuda" 2> ./benches/gpu-spmvm/600.txt
// RUST_LOG=info LURK_RC=900 LURK_PERF=max-parallel-simple cargo criterion --bench fibonacci --features "cuda" 2> ./benches/dev/900.txt
cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
criterion_group! {
Expand Down
2,540 changes: 2,540 additions & 0 deletions benches/gpu-spmvm/1200.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions benches/gpu-spmvm/600.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Compiling sppark v0.1.5
Compiling pasta-msm v0.1.4 (https://github.com/lurk-lab/pasta-msm?branch=dev#182b971d)
423 changes: 423 additions & 0 deletions benches/gpu-spmvm/900.txt

Large diffs are not rendered by default.

126 changes: 91 additions & 35 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,120 @@
use std::{cell::RefCell, rc::Rc, sync::Arc, time::Duration};

use anyhow::anyhow;

use pasta_curves::pallas;

use lurk::{
eval::lang::{Coproc, Lang},
field::LurkField,
lem::{eval::evaluate_simple, pointers::Ptr, store::Store},
{eval::lang::Coproc, state::State},
lem::{eval::evaluate, multiframe::MultiFrame, pointers::Ptr, store::Store},
proof::Prover,
proof::{nova::NovaProver, RecursiveSNARKTrait},
public_parameters::{
instance::{Instance, Kind},
public_params,
},
state::State,
};
use pasta_curves::Fq;

fn fib_expr<F: LurkField>(store: &Store<F>) -> Ptr {
use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry};
use tracing_texray::TeXRayLayer;

fn fib<F: LurkField>(store: &Store<F>, state: Rc<RefCell<State>>, _a: u64) -> Ptr {
let program = r#"
(letrec ((next (lambda (a b) (next b (+ a b))))
(fib (next 0 1)))
(fib))
"#;

store.read_with_default_state(program).unwrap()
store.read(state, program).unwrap()
}

// The env output in the `fib_frame`th frame of the above, infinite Fibonacci computation contains a binding of the
// The env output in the `fib_frame`th frame of the above, infinite Fibonacci computation will contain a binding of the
// nth Fibonacci number to `a`.
// means of computing it.]
fn fib_frame(n: usize) -> usize {
11 + 16 * n
}

// Set the limit so the last step will be filled exactly, since Lurk currently only pads terminal/error continuations.
#[allow(dead_code)]
fn fib_limit(n: usize, rc: usize) -> usize {
let frame = fib_frame(n);
rc * (frame / rc + usize::from(frame % rc != 0))
}

fn lurk_fib(store: &Store<Fq>, n: usize, _rc: usize) -> Ptr {
let frame_idx = fib_frame(n);
// let limit = fib_limit(n, rc);
let limit = frame_idx;
let fib_expr = fib_expr(store);

let (output, ..) = evaluate_simple::<Fq, Coproc<Fq>>(None, fib_expr, store, limit).unwrap();

let target_env = &output[1];

// The result is the value of the second binding (of `A`), in the target env.
// See relevant excerpt of execution trace below:
//
// INFO lurk::eval > Frame: 11
// Expr: (NEXT B (+ A B))
// Env: ((B . 1) (A . 0) ((NEXT . <FUNCTION (A) (LAMBDA (B) (NEXT B (+ A B)))>)))
// Cont: Tail{ saved_env: (((NEXT . <FUNCTION (A) (LAMBDA (B) (NEXT B (+ A B)))>))), continuation: LetRec{var: FIB,
// saved_env: (((NEXT . <FUNCTION (A) (LAMBDA (B) (NEXT B (+ A B)))>))), body: (FIB), continuation: Tail{ saved_env:
// NIL, continuation: Outermost } } }

let (_, rest_bindings) = store.car_cdr(target_env).unwrap();
let (second_binding, _) = store.car_cdr(&rest_bindings).unwrap();
store.car_cdr(&second_binding).unwrap().1
#[derive(Clone, Debug, Copy)]
struct ProveParams {
fib_n: usize,
rc: usize,
}

fn rc_env() -> anyhow::Result<Vec<usize>> {
std::env::var("LURK_RC")
.map_err(|e| anyhow!("Reduction count env var isn't set: {e}"))
.and_then(|rc| {
let vec: anyhow::Result<Vec<usize>> = rc
.split(',')
.map(|rc| {
rc.parse::<usize>()
.map_err(|e| anyhow!("Failed to parse RC: {e}"))
})
.collect();
vec
})
}

fn fibonacci_prove(prove_params: ProveParams, state: &Rc<RefCell<State>>) {
let limit = fib_limit(prove_params.fib_n, prove_params.rc);
let lang_pallas = Lang::<pallas::Scalar, Coproc<pallas::Scalar>>::new();
let lang_rc = Arc::new(lang_pallas.clone());

// use cached public params
let instance = Instance::new(
prove_params.rc,
lang_rc.clone(),
true,
Kind::NovaPublicParams,
);
let pp = public_params::<_, _, MultiFrame<'_, _, _>>(&instance).unwrap();

let store = Store::default();

let ptr = fib::<pasta_curves::Fq>(&store, state.clone(), prove_params.fib_n as u64);
let prover = NovaProver::new(prove_params.rc, lang_rc.clone());

let frames = &evaluate::<pasta_curves::Fq, Coproc<pasta_curves::Fq>>(None, ptr, &store, limit)
.unwrap()
.0;
let (proof, z0, zi, num_steps) = tracing_texray::examine(tracing::info_span!("bang!"))
.in_scope(|| prover.prove(&pp, frames, &store).unwrap());

let res = proof.verify(&pp, &z0, &zi, num_steps).unwrap();
assert!(res);
}

/// RUST_LOG=info LURK_RC=900 LURK_PERF=max-parallel-simple cargo run --release --example fibonacci --features "cuda"
fn main() {
let store = &Store::<Fq>::default();
let n: usize = std::env::args().collect::<Vec<_>>()[1].parse().unwrap();
let state = State::init_lurk_state();
let subscriber = Registry::default()
.with(fmt::layer().pretty())
.with(EnvFilter::from_default_env())
.with(TeXRayLayer::new().width(120));
tracing::subscriber::set_global_default(subscriber).unwrap();

let rcs = rc_env().unwrap_or_else(|_| vec![100]);
let batch_sizes = [249];

let state = State::init_lurk_state().rccell();

let fib = lurk_fib(store, n, 100);
for rc in rcs.iter() {
for fib_n in batch_sizes.iter() {
let prove_params = ProveParams {
fib_n: *fib_n,
rc: *rc,
};
fibonacci_prove(prove_params, &state);
}
}

println!("Fib({n}) = {}", fib.fmt_to_string(store, &state));
println!("success");
}
4 changes: 2 additions & 2 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ fn build_frames<
let mut pc = 0;
let mut frames = vec![];
let mut iterations = 0;
tracing::info!("{}", &log_fmt(0, &input, &[], store));
tracing::debug!("{}", &log_fmt(0, &input, &[], store));
for _ in 0..limit {
let mut emitted = vec![];
let (frame, must_break) =
compute_frame(lurk_step, cprocs_run, &input, store, lang, &mut emitted, pc)?;

iterations += 1;
input = frame.output.clone();
tracing::info!("{}", &log_fmt(iterations, &input, &emitted, store));
tracing::debug!("{}", &log_fmt(iterations, &input, &emitted, store));
let expr = frame.output[0];
frames.push(frame);

Expand Down
1 change: 1 addition & 0 deletions src/lem/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ impl<'a, F: LurkField, C: Coprocessor<F>> nova::traits::circuit::StepCircuit<F>
2 * self.lurk_step.input_params.len()
}

#[tracing::instrument(skip_all, name = "synthesize")]
fn synthesize<CS>(
&self,
cs: &mut CS,
Expand Down
3 changes: 1 addition & 2 deletions src/lem/var_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::hash_map::Entry;

use anyhow::{bail, Result};
use fxhash::FxHashMap;
use tracing::info;

use super::Var;

Expand All @@ -29,7 +28,7 @@ impl<V> VarMap<V> {
}
Entry::Occupied(mut o) => {
let v = o.insert(v);
info!("Variable {} has been overwritten", o.key());
tracing::debug!("Variable {} has been overwritten", o.key());
Some(v)
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/proof/nova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,16 @@ where
assert_eq!(reduction_count, circuit_primary.frames().unwrap().len());

let mut r_snark = recursive_snark.unwrap_or_else(|| {
RecursiveSNARK::new(
let recursive_snark = RecursiveSNARK::new(
&pp.pp,
&circuit_primary,
&circuit_secondary,
z0_primary,
&z0_secondary,
)
.expect("Failed to construct initial recursive snark")
.expect("Failed to construct initial recursive snark");
recursive_snark.write_abomonated(&pp.pp).unwrap();
recursive_snark
});
r_snark
.prove_step(&pp.pp, &circuit_primary, &circuit_secondary)
Expand Down

0 comments on commit 6fa878d

Please sign in to comment.