From ddfa3ee0b27e172ea1a454eedc85e28ee0fb85eb Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski <120587768+chudkowsky@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:48:09 +0200 Subject: [PATCH] Configurable pow_bits and n_queries (#62) --- bin/cairo-prove/src/lib.rs | 4 +++ bin/cairo-prove/src/prove.rs | 4 +++ common/src/prover_input/cairo.rs | 2 ++ common/src/prover_input/cairo0.rs | 2 ++ prover-sdk/tests/prove_test.rs | 6 ++++ prover-sdk/tests/verify_test.rs | 2 ++ prover/src/prove/cairo.rs | 24 ++++++++-------- prover/src/prove/cairo0.rs | 24 ++++++++-------- prover/src/threadpool/mod.rs | 46 +++++++++++++++++++++++-------- prover/src/threadpool/prove.rs | 5 ++-- prover/src/utils/config.rs | 20 ++++++++++++-- scripts/e2e_test.sh | 10 +++++-- 12 files changed, 108 insertions(+), 41 deletions(-) diff --git a/bin/cairo-prove/src/lib.rs b/bin/cairo-prove/src/lib.rs index 4f0ef05..6615a6d 100644 --- a/bin/cairo-prove/src/lib.rs +++ b/bin/cairo-prove/src/lib.rs @@ -55,6 +55,10 @@ pub struct Args { pub wait: bool, #[arg(long, env, default_value = "false")] pub sse: bool, + #[arg(long, env)] + pub n_queries: Option, + #[arg(long, env)] + pub pow_bits: Option, } fn validate_input(input: &str) -> Result, ProveErrors> { diff --git a/bin/cairo-prove/src/prove.rs b/bin/cairo-prove/src/prove.rs index dc6c1e6..d620764 100644 --- a/bin/cairo-prove/src/prove.rs +++ b/bin/cairo-prove/src/prove.rs @@ -22,6 +22,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result { program: program_serialized, layout: args.layout, program_input, + pow_bits: args.pow_bits, + n_queries: args.n_queries, }; sdk.prove_cairo0(data).await? } @@ -38,6 +40,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result { program: program_serialized, layout: args.layout, program_input: input, + pow_bits: args.pow_bits, + n_queries: args.n_queries, }; sdk.prove_cairo(data).await? } diff --git a/common/src/prover_input/cairo.rs b/common/src/prover_input/cairo.rs index 6c7c370..525a740 100644 --- a/common/src/prover_input/cairo.rs +++ b/common/src/prover_input/cairo.rs @@ -6,6 +6,8 @@ pub struct CairoProverInput { pub program: CairoCompiledProgram, pub program_input: Vec, pub layout: String, + pub n_queries: Option, + pub pow_bits: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/common/src/prover_input/cairo0.rs b/common/src/prover_input/cairo0.rs index 1ca399e..bc34619 100644 --- a/common/src/prover_input/cairo0.rs +++ b/common/src/prover_input/cairo0.rs @@ -5,6 +5,8 @@ pub struct Cairo0ProverInput { pub program: Cairo0CompiledProgram, pub program_input: serde_json::Value, pub layout: String, + pub n_queries: Option, + pub pow_bits: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/prover-sdk/tests/prove_test.rs b/prover-sdk/tests/prove_test.rs index fb0c172..279e849 100644 --- a/prover-sdk/tests/prove_test.rs +++ b/prover-sdk/tests/prove_test.rs @@ -26,6 +26,8 @@ async fn test_cairo_prove() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job = sdk.prove_cairo(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; @@ -50,6 +52,8 @@ async fn test_cairo0_prove() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job = sdk.prove_cairo0(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; @@ -77,6 +81,8 @@ async fn test_cairo_multi_prove() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job1 = sdk.prove_cairo(data.clone()).await.unwrap(); let job2 = sdk.prove_cairo(data.clone()).await.unwrap(); diff --git a/prover-sdk/tests/verify_test.rs b/prover-sdk/tests/verify_test.rs index 72d7d2a..9876e5d 100644 --- a/prover-sdk/tests/verify_test.rs +++ b/prover-sdk/tests/verify_test.rs @@ -42,6 +42,8 @@ async fn test_verify_valid_proof() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job = sdk.clone().prove_cairo(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; diff --git a/prover/src/prove/cairo.rs b/prover/src/prove/cairo.rs index 8a2b145..0477896 100644 --- a/prover/src/prove/cairo.rs +++ b/prover/src/prove/cairo.rs @@ -1,7 +1,7 @@ use crate::auth::jwt::Claims; use crate::extractors::workdir::TempDirHandle; use crate::server::AppState; -use crate::threadpool::CairoVersionedInput; +use crate::threadpool::{CairoVersionedInput, ExecuteParams}; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; use common::prover_input::CairoProverInput; @@ -9,7 +9,7 @@ use serde_json::json; pub async fn root( State(app_state): State, - TempDirHandle(path): TempDirHandle, + TempDirHandle(dir): TempDirHandle, _claims: Claims, Json(program_input): Json, ) -> impl IntoResponse { @@ -17,16 +17,16 @@ pub async fn root( let job_store = app_state.job_store.clone(); let job_id = job_store.create_job().await; let thread = thread_pool.lock().await; - thread - .execute( - job_id, - job_store, - path, - CairoVersionedInput::Cairo(program_input), - app_state.sse_tx.clone(), - ) - .await - .into_response(); + let execution_params = ExecuteParams { + job_id, + job_store, + dir, + program_input: CairoVersionedInput::Cairo(program_input.clone()), + sse_tx: app_state.sse_tx.clone(), + n_queries: program_input.clone().n_queries, + pow_bits: program_input.pow_bits, + }; + thread.execute(execution_params).await.into_response(); let body = json!({ "job_id": job_id diff --git a/prover/src/prove/cairo0.rs b/prover/src/prove/cairo0.rs index b25becb..96f9c15 100644 --- a/prover/src/prove/cairo0.rs +++ b/prover/src/prove/cairo0.rs @@ -1,7 +1,7 @@ use crate::auth::jwt::Claims; use crate::extractors::workdir::TempDirHandle; use crate::server::AppState; -use crate::threadpool::CairoVersionedInput; +use crate::threadpool::{CairoVersionedInput, ExecuteParams}; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; use common::prover_input::Cairo0ProverInput; @@ -9,7 +9,7 @@ use serde_json::json; pub async fn root( State(app_state): State, - TempDirHandle(path): TempDirHandle, + TempDirHandle(dir): TempDirHandle, _claims: Claims, Json(program_input): Json, ) -> impl IntoResponse { @@ -17,16 +17,16 @@ pub async fn root( let job_store = app_state.job_store.clone(); let job_id = job_store.create_job().await; let thread = thread_pool.lock().await; - thread - .execute( - job_id, - job_store, - path, - CairoVersionedInput::Cairo0(program_input), - app_state.sse_tx.clone(), - ) - .await - .into_response(); + let execution_params = ExecuteParams { + job_id, + job_store, + dir, + program_input: CairoVersionedInput::Cairo0(program_input.clone()), + sse_tx: app_state.sse_tx.clone(), + n_queries: program_input.clone().n_queries, + pow_bits: program_input.pow_bits, + }; + thread.execute(execution_params).await.into_response(); let body = json!({ "job_id": job_id }); diff --git a/prover/src/threadpool/mod.rs b/prover/src/threadpool/mod.rs index a2d7ceb..8a9e498 100644 --- a/prover/src/threadpool/mod.rs +++ b/prover/src/threadpool/mod.rs @@ -22,6 +22,8 @@ type ReceiverType = Arc< TempDir, CairoVersionedInput, Arc>>, + Option, + Option, )>, >, >; @@ -32,8 +34,19 @@ type SenderType = Option< TempDir, CairoVersionedInput, Arc>>, + Option, + Option, )>, >; +pub struct ExecuteParams { + pub job_id: u64, + pub job_store: JobStore, + pub dir: TempDir, + pub program_input: CairoVersionedInput, + pub sse_tx: Arc>>, + pub n_queries: Option, + pub pow_bits: Option, +} pub struct ThreadPool { workers: Vec, sender: SenderType, @@ -59,20 +72,21 @@ impl ThreadPool { } } - pub async fn execute( - &self, - job_id: u64, - job_store: JobStore, - dir: TempDir, - program_input: CairoVersionedInput, - sse_tx: Arc>>, - ) -> Result<(), ProverError> { + pub async fn execute(&self, params: ExecuteParams) -> Result<(), ProverError> { self.sender .as_ref() .ok_or(ProverError::CustomError( "Thread pool is shutdown".to_string(), ))? - .send((job_id, job_store, dir, program_input, sse_tx)) + .send(( + params.job_id, + params.job_store, + params.dir, + params.program_input, + params.sse_tx, + params.n_queries, + params.pow_bits, + )) .await?; Ok(()) } @@ -107,10 +121,20 @@ impl Worker { loop { let message = receiver.lock().await.recv().await; match message { - Some((job_id, job_store, dir, program_input, sse_tx)) => { + Some((job_id, job_store, dir, program_input, sse_tx, n_queries, pow_bits)) => { trace!("Worker {id} got a job; executing."); - if let Err(e) = prove(job_id, job_store, dir, program_input, sse_tx).await { + if let Err(e) = prove( + job_id, + job_store, + dir, + program_input, + sse_tx, + n_queries, + pow_bits, + ) + .await + { eprintln!("Worker {id} encountered an error: {:?}", e); } diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index 0920f9f..ad475a0 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -19,6 +19,8 @@ pub async fn prove( dir: TempDir, program_input: CairoVersionedInput, sse_tx: Arc>>, + n_queries: Option, + pow_bits: Option, ) -> Result<(), ProverError> { job_store .update_job_status(job_id, JobStatus::Running, None) @@ -29,8 +31,7 @@ pub async fn prove( program_input .prepare_and_run(&RunPaths::from(&paths)) .await?; - - Template::generate_from_public_input_file(&paths.public_input_file)? + Template::generate_from_public_input_file(&paths.public_input_file, n_queries, pow_bits)? .save_to_file(&paths.params_file)?; let prove_status = paths.prove_command().spawn()?.wait().await?; diff --git a/prover/src/utils/config.rs b/prover/src/utils/config.rs index e4bb348..1456cd0 100644 --- a/prover/src/utils/config.rs +++ b/prover/src/utils/config.rs @@ -35,8 +35,16 @@ pub struct Template { } impl Template { - pub fn generate_from_public_input_file(file: &PathBuf) -> Result { - Self::generate_from_public_input(ProgramPublicInputAsNSteps::read_from_file(file)?) + pub fn generate_from_public_input_file( + file: &PathBuf, + n_queries: Option, + pow_bits: Option, + ) -> Result { + Self::generate_from_public_input( + ProgramPublicInputAsNSteps::read_from_file(file)?, + n_queries, + pow_bits, + ) } pub fn save_to_file(&self, file: &PathBuf) -> Result<(), ProverError> { let json_string = serde_json::to_string_pretty(self)?; @@ -46,8 +54,16 @@ impl Template { } fn generate_from_public_input( public_input: ProgramPublicInputAsNSteps, + n_queries: Option, + pow_bits: Option, ) -> Result { let mut template = Self::default(); + if let Some(pow_bits) = pow_bits { + template.stark.fri.proof_of_work_bits = pow_bits; + } + if let Some(n_queries) = n_queries { + template.stark.fri.n_queries = n_queries; + } let fri_step_list = public_input.calculate_fri_step_list(template.stark.fri.last_layer_degree_bound); template.stark.fri.fri_step_list = fri_step_list; diff --git a/scripts/e2e_test.sh b/scripts/e2e_test.sh index e2cb536..f072517 100755 --- a/scripts/e2e_test.sh +++ b/scripts/e2e_test.sh @@ -3,7 +3,6 @@ set -eux IMAGE_NAME="http-prover-test" CONTAINER_ENGINE="${CONTAINER_ENGINE:-docker}" - # Check if the image already exists if $CONTAINER_ENGINE images | grep -q "$IMAGE_NAME"; then echo "Image $IMAGE_NAME already exists. Skipping build step." @@ -46,9 +45,16 @@ $CONTAINER_ENGINE run -d --name http_prover_test $REPLACE_FLAG \ --message-expiration-time 3600 \ --session-expiration-time 3600 \ --authorized-keys $PUBLIC_KEY,$ADMIN_PUBLIC_KEY \ - --admin-key $ADMIN_PUBLIC_KEY + --admin-key $ADMIN_PUBLIC_KEY + +start_time=$(date +%s) PRIVATE_KEY=$PRIVATE_KEY PROVER_URL="http://localhost:3040" ADMIN_PRIVATE_KEY=$ADMIN_PRIVATE_KEY cargo test --no-fail-fast --workspace --verbose +end_time=$(date +%s) + +runtime=$((end_time - start_time)) + +echo "Total time for running tests: $runtime seconds" $CONTAINER_ENGINE stop http_prover_test $CONTAINER_ENGINE rm http_prover_test