Skip to content

Commit

Permalink
Configurable pow_bits and n_queries (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
chudkowsky authored Sep 17, 2024
1 parent 2a228de commit ddfa3ee
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 41 deletions.
4 changes: 4 additions & 0 deletions bin/cairo-prove/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ pub struct Args {
pub wait: bool,
#[arg(long, env, default_value = "false")]
pub sse: bool,
#[arg(long, env)]
pub n_queries: Option<u32>,
#[arg(long, env)]
pub pow_bits: Option<u32>,
}

fn validate_input(input: &str) -> Result<Vec<Felt>, ProveErrors> {
Expand Down
4 changes: 4 additions & 0 deletions bin/cairo-prove/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result<u64, ProveErrors> {
program: program_serialized,
layout: args.layout,
program_input,
pow_bits: args.pow_bits,
n_queries: args.n_queries,
};
sdk.prove_cairo0(data).await?
}
Expand All @@ -38,6 +40,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result<u64, ProveErrors> {
program: program_serialized,
layout: args.layout,
program_input: input,
pow_bits: args.pow_bits,
n_queries: args.n_queries,
};
sdk.prove_cairo(data).await?
}
Expand Down
2 changes: 2 additions & 0 deletions common/src/prover_input/cairo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub struct CairoProverInput {
pub program: CairoCompiledProgram,
pub program_input: Vec<Felt>,
pub layout: String,
pub n_queries: Option<u32>,
pub pow_bits: Option<u32>,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
Expand Down
2 changes: 2 additions & 0 deletions common/src/prover_input/cairo0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pub struct Cairo0ProverInput {
pub program: Cairo0CompiledProgram,
pub program_input: serde_json::Value,
pub layout: String,
pub n_queries: Option<u32>,
pub pow_bits: Option<u32>,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
Expand Down
6 changes: 6 additions & 0 deletions prover-sdk/tests/prove_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ async fn test_cairo_prove() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.prove_cairo(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
Expand All @@ -50,6 +52,8 @@ async fn test_cairo0_prove() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.prove_cairo0(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
Expand Down Expand Up @@ -77,6 +81,8 @@ async fn test_cairo_multi_prove() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job1 = sdk.prove_cairo(data.clone()).await.unwrap();
let job2 = sdk.prove_cairo(data.clone()).await.unwrap();
Expand Down
2 changes: 2 additions & 0 deletions prover-sdk/tests/verify_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ async fn test_verify_valid_proof() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.clone().prove_cairo(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
Expand Down
24 changes: 12 additions & 12 deletions prover/src/prove/cairo.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
use crate::auth::jwt::Claims;
use crate::extractors::workdir::TempDirHandle;
use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use crate::threadpool::{CairoVersionedInput, ExecuteParams};
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::prover_input::CairoProverInput;
use serde_json::json;

pub async fn root(
State(app_state): State<AppState>,
TempDirHandle(path): TempDirHandle,
TempDirHandle(dir): TempDirHandle,
_claims: Claims,
Json(program_input): Json<CairoProverInput>,
) -> impl IntoResponse {
let thread_pool = app_state.thread_pool.clone();
let job_store = app_state.job_store.clone();
let job_id = job_store.create_job().await;
let thread = thread_pool.lock().await;
thread
.execute(
job_id,
job_store,
path,
CairoVersionedInput::Cairo(program_input),
app_state.sse_tx.clone(),
)
.await
.into_response();
let execution_params = ExecuteParams {
job_id,
job_store,
dir,
program_input: CairoVersionedInput::Cairo(program_input.clone()),
sse_tx: app_state.sse_tx.clone(),
n_queries: program_input.clone().n_queries,
pow_bits: program_input.pow_bits,
};
thread.execute(execution_params).await.into_response();

let body = json!({
"job_id": job_id
Expand Down
24 changes: 12 additions & 12 deletions prover/src/prove/cairo0.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
use crate::auth::jwt::Claims;
use crate::extractors::workdir::TempDirHandle;
use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use crate::threadpool::{CairoVersionedInput, ExecuteParams};
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::prover_input::Cairo0ProverInput;
use serde_json::json;

pub async fn root(
State(app_state): State<AppState>,
TempDirHandle(path): TempDirHandle,
TempDirHandle(dir): TempDirHandle,
_claims: Claims,
Json(program_input): Json<Cairo0ProverInput>,
) -> impl IntoResponse {
let thread_pool = app_state.thread_pool.clone();
let job_store = app_state.job_store.clone();
let job_id = job_store.create_job().await;
let thread = thread_pool.lock().await;
thread
.execute(
job_id,
job_store,
path,
CairoVersionedInput::Cairo0(program_input),
app_state.sse_tx.clone(),
)
.await
.into_response();
let execution_params = ExecuteParams {
job_id,
job_store,
dir,
program_input: CairoVersionedInput::Cairo0(program_input.clone()),
sse_tx: app_state.sse_tx.clone(),
n_queries: program_input.clone().n_queries,
pow_bits: program_input.pow_bits,
};
thread.execute(execution_params).await.into_response();
let body = json!({
"job_id": job_id
});
Expand Down
46 changes: 35 additions & 11 deletions prover/src/threadpool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type ReceiverType = Arc<
TempDir,
CairoVersionedInput,
Arc<Mutex<Sender<String>>>,
Option<u32>,
Option<u32>,
)>,
>,
>;
Expand All @@ -32,8 +34,19 @@ type SenderType = Option<
TempDir,
CairoVersionedInput,
Arc<Mutex<Sender<String>>>,
Option<u32>,
Option<u32>,
)>,
>;
pub struct ExecuteParams {
pub job_id: u64,
pub job_store: JobStore,
pub dir: TempDir,
pub program_input: CairoVersionedInput,
pub sse_tx: Arc<Mutex<Sender<String>>>,
pub n_queries: Option<u32>,
pub pow_bits: Option<u32>,
}
pub struct ThreadPool {
workers: Vec<Worker>,
sender: SenderType,
Expand All @@ -59,20 +72,21 @@ impl ThreadPool {
}
}

pub async fn execute(
&self,
job_id: u64,
job_store: JobStore,
dir: TempDir,
program_input: CairoVersionedInput,
sse_tx: Arc<Mutex<Sender<String>>>,
) -> Result<(), ProverError> {
pub async fn execute(&self, params: ExecuteParams) -> Result<(), ProverError> {
self.sender
.as_ref()
.ok_or(ProverError::CustomError(
"Thread pool is shutdown".to_string(),
))?
.send((job_id, job_store, dir, program_input, sse_tx))
.send((
params.job_id,
params.job_store,
params.dir,
params.program_input,
params.sse_tx,
params.n_queries,
params.pow_bits,
))
.await?;
Ok(())
}
Expand Down Expand Up @@ -107,10 +121,20 @@ impl Worker {
loop {
let message = receiver.lock().await.recv().await;
match message {
Some((job_id, job_store, dir, program_input, sse_tx)) => {
Some((job_id, job_store, dir, program_input, sse_tx, n_queries, pow_bits)) => {
trace!("Worker {id} got a job; executing.");

if let Err(e) = prove(job_id, job_store, dir, program_input, sse_tx).await {
if let Err(e) = prove(
job_id,
job_store,
dir,
program_input,
sse_tx,
n_queries,
pow_bits,
)
.await
{
eprintln!("Worker {id} encountered an error: {:?}", e);
}

Expand Down
5 changes: 3 additions & 2 deletions prover/src/threadpool/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub async fn prove(
dir: TempDir,
program_input: CairoVersionedInput,
sse_tx: Arc<Mutex<Sender<String>>>,
n_queries: Option<u32>,
pow_bits: Option<u32>,
) -> Result<(), ProverError> {
job_store
.update_job_status(job_id, JobStatus::Running, None)
Expand All @@ -29,8 +31,7 @@ pub async fn prove(
program_input
.prepare_and_run(&RunPaths::from(&paths))
.await?;

Template::generate_from_public_input_file(&paths.public_input_file)?
Template::generate_from_public_input_file(&paths.public_input_file, n_queries, pow_bits)?
.save_to_file(&paths.params_file)?;

let prove_status = paths.prove_command().spawn()?.wait().await?;
Expand Down
20 changes: 18 additions & 2 deletions prover/src/utils/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,16 @@ pub struct Template {
}

impl Template {
pub fn generate_from_public_input_file(file: &PathBuf) -> Result<Self, ProverError> {
Self::generate_from_public_input(ProgramPublicInputAsNSteps::read_from_file(file)?)
pub fn generate_from_public_input_file(
file: &PathBuf,
n_queries: Option<u32>,
pow_bits: Option<u32>,
) -> Result<Self, ProverError> {
Self::generate_from_public_input(
ProgramPublicInputAsNSteps::read_from_file(file)?,
n_queries,
pow_bits,
)
}
pub fn save_to_file(&self, file: &PathBuf) -> Result<(), ProverError> {
let json_string = serde_json::to_string_pretty(self)?;
Expand All @@ -46,8 +54,16 @@ impl Template {
}
fn generate_from_public_input(
public_input: ProgramPublicInputAsNSteps,
n_queries: Option<u32>,
pow_bits: Option<u32>,
) -> Result<Self, ProverError> {
let mut template = Self::default();
if let Some(pow_bits) = pow_bits {
template.stark.fri.proof_of_work_bits = pow_bits;
}
if let Some(n_queries) = n_queries {
template.stark.fri.n_queries = n_queries;
}
let fri_step_list =
public_input.calculate_fri_step_list(template.stark.fri.last_layer_degree_bound);
template.stark.fri.fri_step_list = fri_step_list;
Expand Down
10 changes: 8 additions & 2 deletions scripts/e2e_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
set -eux
IMAGE_NAME="http-prover-test"
CONTAINER_ENGINE="${CONTAINER_ENGINE:-docker}"

# Check if the image already exists
if $CONTAINER_ENGINE images | grep -q "$IMAGE_NAME"; then
echo "Image $IMAGE_NAME already exists. Skipping build step."
Expand Down Expand Up @@ -46,9 +45,16 @@ $CONTAINER_ENGINE run -d --name http_prover_test $REPLACE_FLAG \
--message-expiration-time 3600 \
--session-expiration-time 3600 \
--authorized-keys $PUBLIC_KEY,$ADMIN_PUBLIC_KEY \
--admin-key $ADMIN_PUBLIC_KEY
--admin-key $ADMIN_PUBLIC_KEY

start_time=$(date +%s)

PRIVATE_KEY=$PRIVATE_KEY PROVER_URL="http://localhost:3040" ADMIN_PRIVATE_KEY=$ADMIN_PRIVATE_KEY cargo test --no-fail-fast --workspace --verbose

end_time=$(date +%s)

runtime=$((end_time - start_time))

echo "Total time for running tests: $runtime seconds"
$CONTAINER_ENGINE stop http_prover_test
$CONTAINER_ENGINE rm http_prover_test

0 comments on commit ddfa3ee

Please sign in to comment.