From 929cd00c2485aa53da37a8cde28a63c667f496d3 Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Tue, 17 Sep 2024 11:11:46 +0200 Subject: [PATCH 1/6] Configurable pow_bits and n_queries --- bin/cairo-prove/src/lib.rs | 4 +++ bin/cairo-prove/src/prove.rs | 4 +++ common/src/prover_input/cairo.rs | 2 ++ common/src/prover_input/cairo0.rs | 2 ++ prover-sdk/tests/prove_test.rs | 6 ++++ prover-sdk/tests/verify_test.rs | 2 ++ prover/src/prove/cairo.rs | 24 ++++++++-------- prover/src/prove/cairo0.rs | 24 ++++++++-------- prover/src/threadpool/mod.rs | 46 +++++++++++++++++++++++-------- prover/src/threadpool/prove.rs | 5 ++-- prover/src/utils/config.rs | 20 ++++++++++++-- scripts/e2e_test.sh | 10 +++++-- 12 files changed, 108 insertions(+), 41 deletions(-) diff --git a/bin/cairo-prove/src/lib.rs b/bin/cairo-prove/src/lib.rs index 4f0ef05..6615a6d 100644 --- a/bin/cairo-prove/src/lib.rs +++ b/bin/cairo-prove/src/lib.rs @@ -55,6 +55,10 @@ pub struct Args { pub wait: bool, #[arg(long, env, default_value = "false")] pub sse: bool, + #[arg(long, env)] + pub n_queries: Option, + #[arg(long, env)] + pub pow_bits: Option, } fn validate_input(input: &str) -> Result, ProveErrors> { diff --git a/bin/cairo-prove/src/prove.rs b/bin/cairo-prove/src/prove.rs index dc6c1e6..d620764 100644 --- a/bin/cairo-prove/src/prove.rs +++ b/bin/cairo-prove/src/prove.rs @@ -22,6 +22,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result { program: program_serialized, layout: args.layout, program_input, + pow_bits: args.pow_bits, + n_queries: args.n_queries, }; sdk.prove_cairo0(data).await? } @@ -38,6 +40,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result { program: program_serialized, layout: args.layout, program_input: input, + pow_bits: args.pow_bits, + n_queries: args.n_queries, }; sdk.prove_cairo(data).await? } diff --git a/common/src/prover_input/cairo.rs b/common/src/prover_input/cairo.rs index 6c7c370..525a740 100644 --- a/common/src/prover_input/cairo.rs +++ b/common/src/prover_input/cairo.rs @@ -6,6 +6,8 @@ pub struct CairoProverInput { pub program: CairoCompiledProgram, pub program_input: Vec, pub layout: String, + pub n_queries: Option, + pub pow_bits: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/common/src/prover_input/cairo0.rs b/common/src/prover_input/cairo0.rs index 1ca399e..bc34619 100644 --- a/common/src/prover_input/cairo0.rs +++ b/common/src/prover_input/cairo0.rs @@ -5,6 +5,8 @@ pub struct Cairo0ProverInput { pub program: Cairo0CompiledProgram, pub program_input: serde_json::Value, pub layout: String, + pub n_queries: Option, + pub pow_bits: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/prover-sdk/tests/prove_test.rs b/prover-sdk/tests/prove_test.rs index fb0c172..279e849 100644 --- a/prover-sdk/tests/prove_test.rs +++ b/prover-sdk/tests/prove_test.rs @@ -26,6 +26,8 @@ async fn test_cairo_prove() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job = sdk.prove_cairo(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; @@ -50,6 +52,8 @@ async fn test_cairo0_prove() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job = sdk.prove_cairo0(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; @@ -77,6 +81,8 @@ async fn test_cairo_multi_prove() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job1 = sdk.prove_cairo(data.clone()).await.unwrap(); let job2 = sdk.prove_cairo(data.clone()).await.unwrap(); diff --git a/prover-sdk/tests/verify_test.rs b/prover-sdk/tests/verify_test.rs index 72d7d2a..9876e5d 100644 --- a/prover-sdk/tests/verify_test.rs +++ b/prover-sdk/tests/verify_test.rs @@ -42,6 +42,8 @@ async fn test_verify_valid_proof() { program, layout, program_input, + n_queries: Some(16), + pow_bits: Some(20), }; let job = sdk.clone().prove_cairo(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; diff --git a/prover/src/prove/cairo.rs b/prover/src/prove/cairo.rs index 8a2b145..0477896 100644 --- a/prover/src/prove/cairo.rs +++ b/prover/src/prove/cairo.rs @@ -1,7 +1,7 @@ use crate::auth::jwt::Claims; use crate::extractors::workdir::TempDirHandle; use crate::server::AppState; -use crate::threadpool::CairoVersionedInput; +use crate::threadpool::{CairoVersionedInput, ExecuteParams}; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; use common::prover_input::CairoProverInput; @@ -9,7 +9,7 @@ use serde_json::json; pub async fn root( State(app_state): State, - TempDirHandle(path): TempDirHandle, + TempDirHandle(dir): TempDirHandle, _claims: Claims, Json(program_input): Json, ) -> impl IntoResponse { @@ -17,16 +17,16 @@ pub async fn root( let job_store = app_state.job_store.clone(); let job_id = job_store.create_job().await; let thread = thread_pool.lock().await; - thread - .execute( - job_id, - job_store, - path, - CairoVersionedInput::Cairo(program_input), - app_state.sse_tx.clone(), - ) - .await - .into_response(); + let execution_params = ExecuteParams { + job_id, + job_store, + dir, + program_input: CairoVersionedInput::Cairo(program_input.clone()), + sse_tx: app_state.sse_tx.clone(), + n_queries: program_input.clone().n_queries, + pow_bits: program_input.pow_bits, + }; + thread.execute(execution_params).await.into_response(); let body = json!({ "job_id": job_id diff --git a/prover/src/prove/cairo0.rs b/prover/src/prove/cairo0.rs index b25becb..96f9c15 100644 --- a/prover/src/prove/cairo0.rs +++ b/prover/src/prove/cairo0.rs @@ -1,7 +1,7 @@ use crate::auth::jwt::Claims; use crate::extractors::workdir::TempDirHandle; use crate::server::AppState; -use crate::threadpool::CairoVersionedInput; +use crate::threadpool::{CairoVersionedInput, ExecuteParams}; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; use common::prover_input::Cairo0ProverInput; @@ -9,7 +9,7 @@ use serde_json::json; pub async fn root( State(app_state): State, - TempDirHandle(path): TempDirHandle, + TempDirHandle(dir): TempDirHandle, _claims: Claims, Json(program_input): Json, ) -> impl IntoResponse { @@ -17,16 +17,16 @@ pub async fn root( let job_store = app_state.job_store.clone(); let job_id = job_store.create_job().await; let thread = thread_pool.lock().await; - thread - .execute( - job_id, - job_store, - path, - CairoVersionedInput::Cairo0(program_input), - app_state.sse_tx.clone(), - ) - .await - .into_response(); + let execution_params = ExecuteParams { + job_id, + job_store, + dir, + program_input: CairoVersionedInput::Cairo0(program_input.clone()), + sse_tx: app_state.sse_tx.clone(), + n_queries: program_input.clone().n_queries, + pow_bits: program_input.pow_bits, + }; + thread.execute(execution_params).await.into_response(); let body = json!({ "job_id": job_id }); diff --git a/prover/src/threadpool/mod.rs b/prover/src/threadpool/mod.rs index a2d7ceb..8a9e498 100644 --- a/prover/src/threadpool/mod.rs +++ b/prover/src/threadpool/mod.rs @@ -22,6 +22,8 @@ type ReceiverType = Arc< TempDir, CairoVersionedInput, Arc>>, + Option, + Option, )>, >, >; @@ -32,8 +34,19 @@ type SenderType = Option< TempDir, CairoVersionedInput, Arc>>, + Option, + Option, )>, >; +pub struct ExecuteParams { + pub job_id: u64, + pub job_store: JobStore, + pub dir: TempDir, + pub program_input: CairoVersionedInput, + pub sse_tx: Arc>>, + pub n_queries: Option, + pub pow_bits: Option, +} pub struct ThreadPool { workers: Vec, sender: SenderType, @@ -59,20 +72,21 @@ impl ThreadPool { } } - pub async fn execute( - &self, - job_id: u64, - job_store: JobStore, - dir: TempDir, - program_input: CairoVersionedInput, - sse_tx: Arc>>, - ) -> Result<(), ProverError> { + pub async fn execute(&self, params: ExecuteParams) -> Result<(), ProverError> { self.sender .as_ref() .ok_or(ProverError::CustomError( "Thread pool is shutdown".to_string(), ))? - .send((job_id, job_store, dir, program_input, sse_tx)) + .send(( + params.job_id, + params.job_store, + params.dir, + params.program_input, + params.sse_tx, + params.n_queries, + params.pow_bits, + )) .await?; Ok(()) } @@ -107,10 +121,20 @@ impl Worker { loop { let message = receiver.lock().await.recv().await; match message { - Some((job_id, job_store, dir, program_input, sse_tx)) => { + Some((job_id, job_store, dir, program_input, sse_tx, n_queries, pow_bits)) => { trace!("Worker {id} got a job; executing."); - if let Err(e) = prove(job_id, job_store, dir, program_input, sse_tx).await { + if let Err(e) = prove( + job_id, + job_store, + dir, + program_input, + sse_tx, + n_queries, + pow_bits, + ) + .await + { eprintln!("Worker {id} encountered an error: {:?}", e); } diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index 0920f9f..ad475a0 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -19,6 +19,8 @@ pub async fn prove( dir: TempDir, program_input: CairoVersionedInput, sse_tx: Arc>>, + n_queries: Option, + pow_bits: Option, ) -> Result<(), ProverError> { job_store .update_job_status(job_id, JobStatus::Running, None) @@ -29,8 +31,7 @@ pub async fn prove( program_input .prepare_and_run(&RunPaths::from(&paths)) .await?; - - Template::generate_from_public_input_file(&paths.public_input_file)? + Template::generate_from_public_input_file(&paths.public_input_file, n_queries, pow_bits)? .save_to_file(&paths.params_file)?; let prove_status = paths.prove_command().spawn()?.wait().await?; diff --git a/prover/src/utils/config.rs b/prover/src/utils/config.rs index e4bb348..1456cd0 100644 --- a/prover/src/utils/config.rs +++ b/prover/src/utils/config.rs @@ -35,8 +35,16 @@ pub struct Template { } impl Template { - pub fn generate_from_public_input_file(file: &PathBuf) -> Result { - Self::generate_from_public_input(ProgramPublicInputAsNSteps::read_from_file(file)?) + pub fn generate_from_public_input_file( + file: &PathBuf, + n_queries: Option, + pow_bits: Option, + ) -> Result { + Self::generate_from_public_input( + ProgramPublicInputAsNSteps::read_from_file(file)?, + n_queries, + pow_bits, + ) } pub fn save_to_file(&self, file: &PathBuf) -> Result<(), ProverError> { let json_string = serde_json::to_string_pretty(self)?; @@ -46,8 +54,16 @@ impl Template { } fn generate_from_public_input( public_input: ProgramPublicInputAsNSteps, + n_queries: Option, + pow_bits: Option, ) -> Result { let mut template = Self::default(); + if let Some(pow_bits) = pow_bits { + template.stark.fri.proof_of_work_bits = pow_bits; + } + if let Some(n_queries) = n_queries { + template.stark.fri.n_queries = n_queries; + } let fri_step_list = public_input.calculate_fri_step_list(template.stark.fri.last_layer_degree_bound); template.stark.fri.fri_step_list = fri_step_list; diff --git a/scripts/e2e_test.sh b/scripts/e2e_test.sh index e2cb536..f072517 100755 --- a/scripts/e2e_test.sh +++ b/scripts/e2e_test.sh @@ -3,7 +3,6 @@ set -eux IMAGE_NAME="http-prover-test" CONTAINER_ENGINE="${CONTAINER_ENGINE:-docker}" - # Check if the image already exists if $CONTAINER_ENGINE images | grep -q "$IMAGE_NAME"; then echo "Image $IMAGE_NAME already exists. Skipping build step." @@ -46,9 +45,16 @@ $CONTAINER_ENGINE run -d --name http_prover_test $REPLACE_FLAG \ --message-expiration-time 3600 \ --session-expiration-time 3600 \ --authorized-keys $PUBLIC_KEY,$ADMIN_PUBLIC_KEY \ - --admin-key $ADMIN_PUBLIC_KEY + --admin-key $ADMIN_PUBLIC_KEY + +start_time=$(date +%s) PRIVATE_KEY=$PRIVATE_KEY PROVER_URL="http://localhost:3040" ADMIN_PRIVATE_KEY=$ADMIN_PRIVATE_KEY cargo test --no-fail-fast --workspace --verbose +end_time=$(date +%s) + +runtime=$((end_time - start_time)) + +echo "Total time for running tests: $runtime seconds" $CONTAINER_ENGINE stop http_prover_test $CONTAINER_ENGINE rm http_prover_test From bc13c9cd47de046307d9dd96278c906e6a46665b Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Tue, 17 Sep 2024 18:25:28 +0200 Subject: [PATCH 2/6] working version of proof parser --- Cargo.lock | 96 +++++++++++++++++++++ Cargo.toml | 4 +- common/src/models.rs | 18 +++- prover-sdk/src/sdk.rs | 11 +-- prover-sdk/tests/helpers/mod.rs | 18 ++-- prover-sdk/tests/prove_test.rs | 140 ++++++++++++++++--------------- prover-sdk/tests/verify_test.rs | 18 ++-- prover/Cargo.toml | 4 +- prover/src/server.rs | 4 +- prover/src/threadpool/prove.rs | 18 +++- prover/src/utils/job.rs | 12 +-- prover/src/utils/mod.rs | 1 + prover/src/utils/proof_parser.rs | 52 ++++++++++++ prover/src/verifier.rs | 88 +++++-------------- 14 files changed, 309 insertions(+), 175 deletions(-) create mode 100644 prover/src/utils/proof_parser.rs diff --git a/Cargo.lock b/Cargo.lock index fbb1140..93c47b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -90,6 +90,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "anyhow" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" + [[package]] name = "async-stream" version = "0.3.5" @@ -479,6 +485,17 @@ dependencies = [ "libc", ] +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -585,6 +602,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -862,6 +880,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.1.0" @@ -1235,6 +1262,7 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", + "serde", ] [[package]] @@ -1417,7 +1445,9 @@ dependencies = [ "serde", "serde_json", "serde_with", + "starknet-crypto", "starknet-types-core", + "swiftness_proof_parser", "tempfile", "thiserror", "tokio", @@ -1673,6 +1703,16 @@ dependencies = [ "url", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "ring" version = "0.17.8" @@ -1989,6 +2029,46 @@ dependencies = [ "der", ] +[[package]] +name = "starknet-crypto" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2a821ad8d98c6c3e4d0e5097f3fe6e2ed120ada9d32be87cd1330c7923a2f0" +dependencies = [ + "crypto-bigint", + "hex", + "hmac", + "num-bigint", + "num-integer", + "num-traits", + "rfc6979", + "sha2", + "starknet-crypto-codegen", + "starknet-curve", + "starknet-types-core", + "zeroize", +] + +[[package]] +name = "starknet-crypto-codegen" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e179dedc3fa6da064e56811d3e05d446aa2f7459e4eb0e3e49378a337235437" +dependencies = [ + "starknet-curve", + "starknet-types-core", + "syn", +] + +[[package]] +name = "starknet-curve" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56935b306dcf0b8f14bb2a1257164b8478bb8be4801dfae0923f5b266d1b457c" +dependencies = [ + "starknet-types-core", +] + [[package]] name = "starknet-types-core" version = "0.1.5" @@ -2015,6 +2095,22 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "swiftness_proof_parser" +version = "0.0.9" +source = "git+https://github.com/cartridge-gg/swiftness?rev=1d46e21#1d46e218513350ff5e0bda99f153d0eaaa432e3f" +dependencies = [ + "anyhow", + "clap", + "num-bigint", + "regex", + "serde", + "serde_json", + "starknet-crypto", + "starknet-types-core", + "thiserror", +] + [[package]] name = "syn" version = "2.0.75" diff --git a/Cargo.toml b/Cargo.toml index bbe4f45..ef9ffa2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,4 +43,6 @@ chrono = "0.4.38" base64 = "0.22.1" starknet-types-core = "~0.1.4" futures = "0.3.30" -async-stream = "0.3.5" \ No newline at end of file +async-stream = "0.3.5" +swiftness_proof_parser = { git = "https://github.com/cartridge-gg/swiftness", rev = "1d46e21"} +starknet-crypto = "0.7.0" diff --git a/common/src/models.rs b/common/src/models.rs index 8227c53..7d1c229 100644 --- a/common/src/models.rs +++ b/common/src/models.rs @@ -1,6 +1,7 @@ use ed25519_dalek::VerifyingKey; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; +use starknet_types_core::felt::Felt; #[serde_as] #[derive(Debug, Serialize, Deserialize)] @@ -10,7 +11,7 @@ pub struct JWTResponse { pub expiration: u64, pub session_key: Option, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone,Debug)] pub enum JobStatus { Pending, Running, @@ -18,3 +19,18 @@ pub enum JobStatus { Failed, Unknown, } + +#[derive(Clone,Serialize,Deserialize)] +pub struct ProverResult { + pub proof: String, + pub program_hash: Felt, + pub program_output: Vec, + pub program_output_hash: Felt, +} +#[derive(Serialize,Deserialize)] +#[serde(untagged)] +pub enum JobResponse { + InProgress { id: u64, status: JobStatus }, + Completed { result: ProverResult, status: JobStatus }, + Failed { error: String }, +} \ No newline at end of file diff --git a/prover-sdk/src/sdk.rs b/prover-sdk/src/sdk.rs index cd2a992..4db25be 100644 --- a/prover-sdk/src/sdk.rs +++ b/prover-sdk/src/sdk.rs @@ -69,22 +69,15 @@ impl ProverSDK { let job = serde_json::from_str::(&response_data)?; Ok(job.job_id) } - pub async fn verify(self, proof: String) -> Result { + pub async fn verify(self, proof: String) -> Result { let response = self .client .post(self.verify.clone()) .json(&proof) .send() .await?; - if !response.status().is_success() { - let response_data: String = response.text().await?; - tracing::error!("{}", response_data); - return Err(SdkErrors::VerifyResponseError(response_data)); - } let response_data = response.text().await?; - - let job = serde_json::from_str::(&response_data)?; - Ok(job.job_id) + Ok(response_data) } pub async fn get_job(&self, job_id: u64) -> Result { let url = format!("{}/{}", self.get_job.clone().as_str(), job_id); diff --git a/prover-sdk/tests/helpers/mod.rs b/prover-sdk/tests/helpers/mod.rs index 0143812..d9cad2f 100644 --- a/prover-sdk/tests/helpers/mod.rs +++ b/prover-sdk/tests/helpers/mod.rs @@ -1,15 +1,15 @@ +use common::models::{JobResponse, ProverResult}; use prover_sdk::sdk::ProverSDK; -use serde_json::Value; -pub async fn fetch_job(sdk: ProverSDK, job: u64) -> String { +pub async fn fetch_job(sdk: ProverSDK, job: u64) -> Option { println!("Job ID: {}", job); sdk.sse(job).await.unwrap(); let response = sdk.get_job(job).await.unwrap(); let response = response.text().await.unwrap(); - let json_response: Value = serde_json::from_str(&response).unwrap(); - return json_response - .get("result") - .and_then(Value::as_str) - .unwrap_or("No result found") - .to_string(); -} + let json_response: JobResponse = serde_json::from_str(&response).unwrap(); + + if let JobResponse::Completed { result, .. } = json_response { + return Some(result); + } + None +} \ No newline at end of file diff --git a/prover-sdk/tests/prove_test.rs b/prover-sdk/tests/prove_test.rs index 279e849..fbfb0bd 100644 --- a/prover-sdk/tests/prove_test.rs +++ b/prover-sdk/tests/prove_test.rs @@ -1,8 +1,12 @@ +use std::fs; + use common::prover_input::*; use helpers::fetch_job; use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK}; use serde_json::Value; + use starknet_types_core::felt::Felt; + use url::Url; mod helpers; @@ -31,72 +35,74 @@ async fn test_cairo_prove() { }; let job = sdk.prove_cairo(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; - let job = sdk.clone().verify(result).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("true", result); + assert!(result.is_some()); + let result = result.unwrap(); + let result = sdk.clone().verify(result.proof).await; + assert!(result.is_ok(), "Failed to verify proof"); + assert_eq!("true", result.unwrap()); } -#[tokio::test] -async fn test_cairo0_prove() { - let private_key = std::env::var("PRIVATE_KEY").unwrap(); - let url = std::env::var("PROVER_URL").unwrap(); - let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); - let url = Url::parse(&url).unwrap(); - let sdk = ProverSDK::new(url, access_key).await.unwrap(); - let program = std::fs::read_to_string("../examples/cairo0/fibonacci_compiled.json").unwrap(); - let program: Cairo0CompiledProgram = serde_json::from_str(&program).unwrap(); - let program_input_string = std::fs::read_to_string("../examples/cairo0/input.json").unwrap(); - let program_input: Value = serde_json::from_str(&program_input_string).unwrap(); - let layout = "recursive".to_string(); - let data = Cairo0ProverInput { - program, - layout, - program_input, - n_queries: Some(16), - pow_bits: Some(20), - }; - let job = sdk.prove_cairo0(data).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - let job = sdk.clone().verify(result).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("true", result); -} -#[tokio::test] -async fn test_cairo_multi_prove() { - let private_key = std::env::var("PRIVATE_KEY").unwrap(); - let url = std::env::var("PROVER_URL").unwrap(); - let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); - let url = Url::parse(&url).unwrap(); - let sdk = ProverSDK::new(url, access_key).await.unwrap(); - let program = std::fs::read_to_string("../examples/cairo/fibonacci_compiled.json").unwrap(); - let program: CairoCompiledProgram = serde_json::from_str(&program).unwrap(); - let program_input_string = std::fs::read_to_string("../examples/cairo/input.json").unwrap(); - let mut program_input: Vec = Vec::new(); - for part in program_input_string.split(',') { - let felt = Felt::from_dec_str(part).unwrap(); - program_input.push(felt); - } - let layout = "recursive".to_string(); - let data = CairoProverInput { - program, - layout, - program_input, - n_queries: Some(16), - pow_bits: Some(20), - }; - let job1 = sdk.prove_cairo(data.clone()).await.unwrap(); - let job2 = sdk.prove_cairo(data.clone()).await.unwrap(); - let job3 = sdk.prove_cairo(data.clone()).await.unwrap(); - let result = fetch_job(sdk.clone(), job1).await; - let job = sdk.clone().verify(result).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("true", result); - let result = fetch_job(sdk.clone(), job2).await; - let job = sdk.clone().verify(result).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("true", result); - let result = fetch_job(sdk.clone(), job3).await; - let job = sdk.clone().verify(result).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("true", result); -} +// #[tokio::test] +// async fn test_cairo0_prove() { +// let private_key = std::env::var("PRIVATE_KEY").unwrap(); +// let url = std::env::var("PROVER_URL").unwrap(); +// let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); +// let url = Url::parse(&url).unwrap(); +// let sdk = ProverSDK::new(url, access_key).await.unwrap(); +// let program = std::fs::read_to_string("../examples/cairo0/fibonacci_compiled.json").unwrap(); +// let program: Cairo0CompiledProgram = serde_json::from_str(&program).unwrap(); +// let program_input_string = std::fs::read_to_string("../examples/cairo0/input.json").unwrap(); +// let program_input: Value = serde_json::from_str(&program_input_string).unwrap(); +// let layout = "recursive".to_string(); +// let data = Cairo0ProverInput { +// program, +// layout, +// program_input, +// n_queries: Some(16), +// pow_bits: Some(20), +// }; +// let job = sdk.prove_cairo0(data).await.unwrap(); +// let result = fetch_job(sdk.clone(), job).await; +// let job = sdk.clone().verify(result).await.unwrap(); +// let result = fetch_job(sdk.clone(), job).await; +// assert_eq!("true", result); +// } +// #[tokio::test] +// async fn test_cairo_multi_prove() { +// let private_key = std::env::var("PRIVATE_KEY").unwrap(); +// let url = std::env::var("PROVER_URL").unwrap(); +// let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); +// let url = Url::parse(&url).unwrap(); +// let sdk = ProverSDK::new(url, access_key).await.unwrap(); +// let program = std::fs::read_to_string("../examples/cairo/fibonacci_compiled.json").unwrap(); +// let program: CairoCompiledProgram = serde_json::from_str(&program).unwrap(); +// let program_input_string = std::fs::read_to_string("../examples/cairo/input.json").unwrap(); +// let mut program_input: Vec = Vec::new(); +// for part in program_input_string.split(',') { +// let felt = Felt::from_dec_str(part).unwrap(); +// program_input.push(felt); +// } +// let layout = "recursive".to_string(); +// let data = CairoProverInput { +// program, +// layout, +// program_input, +// n_queries: Some(16), +// pow_bits: Some(20), +// }; +// let job1 = sdk.prove_cairo(data.clone()).await.unwrap(); +// let job2 = sdk.prove_cairo(data.clone()).await.unwrap(); +// let job3 = sdk.prove_cairo(data.clone()).await.unwrap(); +// let result = fetch_job(sdk.clone(), job1).await; +// let job = sdk.clone().verify(result).await.unwrap(); +// let result = fetch_job(sdk.clone(), job).await; +// assert_eq!("true", result); +// let result = fetch_job(sdk.clone(), job2).await; +// let job = sdk.clone().verify(result).await.unwrap(); +// let result = fetch_job(sdk.clone(), job).await; +// assert_eq!("true", result); +// let result = fetch_job(sdk.clone(), job3).await; +// let job = sdk.clone().verify(result).await.unwrap(); +// let result = fetch_job(sdk.clone(), job).await; +// assert_eq!("true", result); +// } diff --git a/prover-sdk/tests/verify_test.rs b/prover-sdk/tests/verify_test.rs index 9876e5d..0a59f38 100644 --- a/prover-sdk/tests/verify_test.rs +++ b/prover-sdk/tests/verify_test.rs @@ -13,13 +13,9 @@ async fn test_verify_invalid_proof() { let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); let url = Url::parse(&url).unwrap(); let sdk = ProverSDK::new(url, access_key).await.unwrap(); - let job = sdk - .clone() - .verify("invalid_proof".to_string()) - .await - .unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("false", result); + let result = sdk.clone().verify("wrong proof".to_string()).await; + assert!(result.is_ok(),"Failed to verify proof"); + assert_eq!("false", result.unwrap()); } #[tokio::test] @@ -47,7 +43,9 @@ async fn test_verify_valid_proof() { }; let job = sdk.clone().prove_cairo(data).await.unwrap(); let result = fetch_job(sdk.clone(), job).await; - let job = sdk.clone().verify(result).await.unwrap(); - let result = fetch_job(sdk.clone(), job).await; - assert_eq!("true", result); + assert!(result.is_some()); + let result = result.unwrap(); + let result = sdk.clone().verify(result.proof).await; + assert!(result.is_ok(),"Failed to verify proof"); + assert_eq!("true", result.unwrap()); } diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 1b609ed..ffe1076 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -26,4 +26,6 @@ chrono.workspace = true base64.workspace = true starknet-types-core.workspace = true futures.workspace = true -async-stream.workspace = true \ No newline at end of file +async-stream.workspace = true +swiftness_proof_parser.workspace = true +starknet-crypto.workspace = true \ No newline at end of file diff --git a/prover/src/server.rs b/prover/src/server.rs index 9510802..3122b85 100644 --- a/prover/src/server.rs +++ b/prover/src/server.rs @@ -7,7 +7,7 @@ use crate::sse::sse_handler; use crate::threadpool::ThreadPool; use crate::utils::job::{get_job, JobStore}; use crate::utils::shutdown::shutdown_signal; -use crate::verifier::root; +use crate::verifier::verify_proof; use crate::{prove, Args}; use axum::{ middleware, @@ -80,7 +80,7 @@ pub async fn start(args: Args) -> Result<(), ProverError> { let app = Router::new() .route("/", get(ok_handler)) - .route("/verify", post(root)) + .route("/verify", post(verify_proof)) .route("/get-job/:id", get(get_job)) .route("/sse", get(sse_handler)) .with_state(app_state.clone()) diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index ad475a0..a2f56b2 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -1,13 +1,15 @@ use super::run::RunPaths; use super::CairoVersionedInput; use crate::errors::ProverError; +use crate::utils::proof_parser::{extract_program_hash, extract_program_output, program_output_hash}; use crate::utils::{config::Template, job::JobStore}; -use common::models::JobStatus; +use common::models::{JobStatus, ProverResult}; use serde_json::Value; use std::fs; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; +use swiftness_proof_parser::{json_parser, stark_proof}; use tempfile::TempDir; use tokio::process::Command; use tokio::sync::broadcast::Sender; @@ -41,8 +43,20 @@ pub async fn prove( let sender = sse_tx.lock().await; if prove_status.success() { + let proof_json = serde_json::from_str::(&final_result).unwrap(); + let stark_proof = stark_proof::StarkProof::try_from(proof_json).unwrap(); + let program_hash = extract_program_hash(stark_proof.clone()); + let program_output = extract_program_output(stark_proof.clone()); + let program_output_hash = program_output_hash(program_output.clone()); + let prover_result = ProverResult{ + proof: final_result, + program_hash, + program_output, + program_output_hash, + }; + job_store - .update_job_status(job_id, JobStatus::Completed, Some(final_result)) + .update_job_status(job_id, JobStatus::Completed,serde_json::to_string_pretty(&prover_result).ok()) .await; if sender.receiver_count() > 0 { sender diff --git a/prover/src/utils/job.rs b/prover/src/utils/job.rs index 112e033..b80ce61 100644 --- a/prover/src/utils/job.rs +++ b/prover/src/utils/job.rs @@ -4,8 +4,8 @@ use axum::{ response::IntoResponse, Json, }; -use common::models::JobStatus; -use serde::Serialize; +use common::models::{JobStatus, ProverResult}; +use serde::{Deserialize, Serialize}; use std::{ collections::BTreeMap, sync::Arc, @@ -23,11 +23,11 @@ pub struct Job { pub created: Instant, } -#[derive(Serialize)] +#[derive(Serialize,Deserialize)] #[serde(untagged)] pub enum JobResponse { InProgress { id: u64, status: JobStatus }, - Completed { result: String, status: JobStatus }, + Completed { result: ProverResult, status: JobStatus }, Failed { error: String }, } @@ -113,10 +113,10 @@ pub async fn get_job( StatusCode::OK, Json(JobResponse::Completed { status: job.status.clone(), - result: job + result: serde_json::from_str(&job .result .clone() - .unwrap_or_else(|| "No result available".to_string()), + .unwrap()).unwrap(), }), ), JobStatus::Failed => ( diff --git a/prover/src/utils/mod.rs b/prover/src/utils/mod.rs index e484f01..d3cad6f 100644 --- a/prover/src/utils/mod.rs +++ b/prover/src/utils/mod.rs @@ -1,3 +1,4 @@ pub mod config; pub mod job; pub mod shutdown; +pub mod proof_parser; \ No newline at end of file diff --git a/prover/src/utils/proof_parser.rs b/prover/src/utils/proof_parser.rs new file mode 100644 index 0000000..f1e7c90 --- /dev/null +++ b/prover/src/utils/proof_parser.rs @@ -0,0 +1,52 @@ +use starknet_crypto::poseidon_hash_many; +use starknet_types_core::felt::Felt; +use swiftness_proof_parser::StarkProof; + +pub fn extract_program_hash(stark_proof: StarkProof) -> Felt { + let program_output_range = &stark_proof.public_input.segments[2]; + let main_page_len = stark_proof.public_input.main_page.len(); + let output_len = (program_output_range.stop_ptr - program_output_range.begin_addr) as usize; + let program = stark_proof.public_input.main_page[0..main_page_len - output_len].to_vec(); + + let values: Vec = program + .iter() + .map(|el| { + let number = &el.value; + + let mut padded_bytes = [0u8; 32]; + let bytes = number.to_bytes_be(); + + let bytes_len = bytes.len(); + + padded_bytes[32 - bytes_len..].copy_from_slice(&bytes); + + Felt::from_bytes_be(&padded_bytes) + }) + .collect(); + poseidon_hash_many(&values) +} +pub fn extract_program_output(stark_proof: StarkProof) -> Vec { + let program_output_range = &stark_proof.public_input.segments[2]; + let main_page_len = stark_proof.public_input.main_page.len(); + let output_len = (program_output_range.stop_ptr - program_output_range.begin_addr) as usize; + let program_output = stark_proof.public_input.main_page[main_page_len - output_len..].to_vec(); + let values: Vec = program_output + .iter() + .map(|el| { + let number = &el.value; + + let mut padded_bytes = [0u8; 32]; + let bytes = number.to_bytes_be(); + + let bytes_len = bytes.len(); + + padded_bytes[32 - bytes_len..].copy_from_slice(&bytes); + + Felt::from_bytes_be(&padded_bytes) + }) + .collect(); + values +} +pub fn program_output_hash(felts: Vec) -> Felt { + poseidon_hash_many(&felts) +} diff --git a/prover/src/verifier.rs b/prover/src/verifier.rs index e93765e..a1d7deb 100644 --- a/prover/src/verifier.rs +++ b/prover/src/verifier.rs @@ -1,82 +1,36 @@ -use crate::{ - auth::jwt::Claims, errors::ProverError, extractors::workdir::TempDirHandle, server::AppState, - utils::job::JobStore, -}; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; -use common::models::JobStatus; -use serde_json::json; -use std::{process::Command, sync::Arc}; -use tempfile::TempDir; -use tokio::sync::broadcast::Sender; -use tokio::sync::Mutex; +use crate::{auth::jwt::Claims, errors::ProverError, extractors::workdir::TempDirHandle}; +use axum::{response::IntoResponse, Json}; -pub async fn root( - State(app_state): State, - TempDirHandle(dir): TempDirHandle, - _claims: Claims, - Json(proof): Json, -) -> impl IntoResponse { - let job_store = app_state.job_store.clone(); - let job_id = job_store.create_job().await; - - tokio::spawn({ - async move { - if let Err(e) = - verify_proof(job_id, job_store.clone(), dir, proof, app_state.sse_tx).await - { - job_store - .update_job_status(job_id, JobStatus::Failed, Some(e.to_string())) - .await; - } - } - }); - - let body = json!({ - "job_id": job_id - }); - (StatusCode::ACCEPTED, body.to_string()) -} - -pub async fn verify_proof( - job_id: u64, - job_store: JobStore, - dir: TempDir, - proof: String, - sender: Arc>>, -) -> Result<(), ProverError> { - job_store - .update_job_status(job_id, JobStatus::Running, None) - .await; +use std::process::Command; +pub async fn verify_proof(TempDirHandle(dir):TempDirHandle,_claims:Claims,Json(proof): Json) -> Json { // Define the path for the proof file - let path = dir.into_path(); - let file = path.join("proof"); + let file = dir.into_path().join("proof"); // Write the proof string to the file - std::fs::write(&file, &proof)?; + if let Err(e) = std::fs::write(&file, proof) { + eprintln!("Failed to write proof to file: {}", e); + return Json(false); + } // Create the command to run the verifier let mut command = Command::new("cpu_air_verifier"); command.arg("--in_file").arg(&file); // Execute the command and capture the status - let status = command.status()?; + let status = command.status(); + // Remove the proof file - std::fs::remove_file(&file)?; - // Check if the command was successful + if let Err(e) = std::fs::remove_file(&file) { + eprintln!("Failed to remove proof file: {}", e); + } - job_store - .update_job_status( - job_id, - JobStatus::Completed, - Some(status.success().to_string()), - ) - .await; - let sender = sender.lock().await; - if sender.receiver_count() > 0 { - sender - .send(serde_json::to_string(&(JobStatus::Completed, job_id))?) - .map_err(|e| ProverError::SseError(e.to_string()))?; + // Check if the command was successful + match status { + Ok(exit_status) => Json(exit_status.success()), + Err(e) => { + eprintln!("Failed to execute verifier: {}", e); + Json(false) + } } - Ok(()) } From 3d46b39e4862890588a433932f7691ad2845f656 Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Tue, 17 Sep 2024 18:30:34 +0200 Subject: [PATCH 3/6] clippy and fmt --- common/src/models.rs | 22 ++++-- prover-sdk/tests/helpers/mod.rs | 6 +- prover-sdk/tests/prove_test.rs | 126 +++++++++++++++----------------- prover-sdk/tests/verify_test.rs | 6 +- prover/src/threadpool/prove.rs | 12 ++- prover/src/utils/job.rs | 21 ++++-- prover/src/utils/mod.rs | 2 +- prover/src/verifier.rs | 10 ++- 8 files changed, 111 insertions(+), 94 deletions(-) diff --git a/common/src/models.rs b/common/src/models.rs index 7d1c229..750f81e 100644 --- a/common/src/models.rs +++ b/common/src/models.rs @@ -11,7 +11,7 @@ pub struct JWTResponse { pub expiration: u64, pub session_key: Option, } -#[derive(Serialize, Deserialize, Clone,Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub enum JobStatus { Pending, Running, @@ -20,17 +20,25 @@ pub enum JobStatus { Unknown, } -#[derive(Clone,Serialize,Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct ProverResult { pub proof: String, pub program_hash: Felt, pub program_output: Vec, pub program_output_hash: Felt, } -#[derive(Serialize,Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(untagged)] pub enum JobResponse { - InProgress { id: u64, status: JobStatus }, - Completed { result: ProverResult, status: JobStatus }, - Failed { error: String }, -} \ No newline at end of file + InProgress { + id: u64, + status: JobStatus, + }, + Completed { + result: ProverResult, + status: JobStatus, + }, + Failed { + error: String, + }, +} diff --git a/prover-sdk/tests/helpers/mod.rs b/prover-sdk/tests/helpers/mod.rs index d9cad2f..5e6a30f 100644 --- a/prover-sdk/tests/helpers/mod.rs +++ b/prover-sdk/tests/helpers/mod.rs @@ -7,9 +7,9 @@ pub async fn fetch_job(sdk: ProverSDK, job: u64) -> Option { let response = sdk.get_job(job).await.unwrap(); let response = response.text().await.unwrap(); let json_response: JobResponse = serde_json::from_str(&response).unwrap(); - + if let JobResponse::Completed { result, .. } = json_response { return Some(result); } - None -} \ No newline at end of file + None +} diff --git a/prover-sdk/tests/prove_test.rs b/prover-sdk/tests/prove_test.rs index fbfb0bd..a97eeeb 100644 --- a/prover-sdk/tests/prove_test.rs +++ b/prover-sdk/tests/prove_test.rs @@ -1,5 +1,3 @@ -use std::fs; - use common::prover_input::*; use helpers::fetch_job; use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK}; @@ -42,67 +40,63 @@ async fn test_cairo_prove() { assert_eq!("true", result.unwrap()); } -// #[tokio::test] -// async fn test_cairo0_prove() { -// let private_key = std::env::var("PRIVATE_KEY").unwrap(); -// let url = std::env::var("PROVER_URL").unwrap(); -// let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); -// let url = Url::parse(&url).unwrap(); -// let sdk = ProverSDK::new(url, access_key).await.unwrap(); -// let program = std::fs::read_to_string("../examples/cairo0/fibonacci_compiled.json").unwrap(); -// let program: Cairo0CompiledProgram = serde_json::from_str(&program).unwrap(); -// let program_input_string = std::fs::read_to_string("../examples/cairo0/input.json").unwrap(); -// let program_input: Value = serde_json::from_str(&program_input_string).unwrap(); -// let layout = "recursive".to_string(); -// let data = Cairo0ProverInput { -// program, -// layout, -// program_input, -// n_queries: Some(16), -// pow_bits: Some(20), -// }; -// let job = sdk.prove_cairo0(data).await.unwrap(); -// let result = fetch_job(sdk.clone(), job).await; -// let job = sdk.clone().verify(result).await.unwrap(); -// let result = fetch_job(sdk.clone(), job).await; -// assert_eq!("true", result); -// } -// #[tokio::test] -// async fn test_cairo_multi_prove() { -// let private_key = std::env::var("PRIVATE_KEY").unwrap(); -// let url = std::env::var("PROVER_URL").unwrap(); -// let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); -// let url = Url::parse(&url).unwrap(); -// let sdk = ProverSDK::new(url, access_key).await.unwrap(); -// let program = std::fs::read_to_string("../examples/cairo/fibonacci_compiled.json").unwrap(); -// let program: CairoCompiledProgram = serde_json::from_str(&program).unwrap(); -// let program_input_string = std::fs::read_to_string("../examples/cairo/input.json").unwrap(); -// let mut program_input: Vec = Vec::new(); -// for part in program_input_string.split(',') { -// let felt = Felt::from_dec_str(part).unwrap(); -// program_input.push(felt); -// } -// let layout = "recursive".to_string(); -// let data = CairoProverInput { -// program, -// layout, -// program_input, -// n_queries: Some(16), -// pow_bits: Some(20), -// }; -// let job1 = sdk.prove_cairo(data.clone()).await.unwrap(); -// let job2 = sdk.prove_cairo(data.clone()).await.unwrap(); -// let job3 = sdk.prove_cairo(data.clone()).await.unwrap(); -// let result = fetch_job(sdk.clone(), job1).await; -// let job = sdk.clone().verify(result).await.unwrap(); -// let result = fetch_job(sdk.clone(), job).await; -// assert_eq!("true", result); -// let result = fetch_job(sdk.clone(), job2).await; -// let job = sdk.clone().verify(result).await.unwrap(); -// let result = fetch_job(sdk.clone(), job).await; -// assert_eq!("true", result); -// let result = fetch_job(sdk.clone(), job3).await; -// let job = sdk.clone().verify(result).await.unwrap(); -// let result = fetch_job(sdk.clone(), job).await; -// assert_eq!("true", result); -// } +#[tokio::test] +async fn test_cairo0_prove() { + let private_key = std::env::var("PRIVATE_KEY").unwrap(); + let url = std::env::var("PROVER_URL").unwrap(); + let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); + let url = Url::parse(&url).unwrap(); + let sdk = ProverSDK::new(url, access_key).await.unwrap(); + let program = std::fs::read_to_string("../examples/cairo0/fibonacci_compiled.json").unwrap(); + let program: Cairo0CompiledProgram = serde_json::from_str(&program).unwrap(); + let program_input_string = std::fs::read_to_string("../examples/cairo0/input.json").unwrap(); + let program_input: Value = serde_json::from_str(&program_input_string).unwrap(); + let layout = "recursive".to_string(); + let data = Cairo0ProverInput { + program, + layout, + program_input, + n_queries: Some(16), + pow_bits: Some(20), + }; + let job = sdk.prove_cairo0(data).await.unwrap(); + let result = fetch_job(sdk.clone(), job).await; + let result = sdk.clone().verify(result.unwrap().proof).await.unwrap(); + assert_eq!("true", result); +} +#[tokio::test] +async fn test_cairo_multi_prove() { + let private_key = std::env::var("PRIVATE_KEY").unwrap(); + let url = std::env::var("PROVER_URL").unwrap(); + let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap(); + let url = Url::parse(&url).unwrap(); + let sdk = ProverSDK::new(url, access_key).await.unwrap(); + let program = std::fs::read_to_string("../examples/cairo/fibonacci_compiled.json").unwrap(); + let program: CairoCompiledProgram = serde_json::from_str(&program).unwrap(); + let program_input_string = std::fs::read_to_string("../examples/cairo/input.json").unwrap(); + let mut program_input: Vec = Vec::new(); + for part in program_input_string.split(',') { + let felt = Felt::from_dec_str(part).unwrap(); + program_input.push(felt); + } + let layout = "recursive".to_string(); + let data = CairoProverInput { + program, + layout, + program_input, + n_queries: Some(16), + pow_bits: Some(20), + }; + let job1 = sdk.prove_cairo(data.clone()).await.unwrap(); + let job2 = sdk.prove_cairo(data.clone()).await.unwrap(); + let job3 = sdk.prove_cairo(data.clone()).await.unwrap(); + let result = fetch_job(sdk.clone(), job1).await; + let result = sdk.clone().verify(result.unwrap().proof).await.unwrap(); + assert_eq!("true", result); + let result = fetch_job(sdk.clone(), job2).await; + let result = sdk.clone().verify(result.unwrap().proof).await.unwrap(); + assert_eq!("true", result); + let result = fetch_job(sdk.clone(), job3).await; + let result = sdk.clone().verify(result.unwrap().proof).await.unwrap(); + assert_eq!("true", result); +} diff --git a/prover-sdk/tests/verify_test.rs b/prover-sdk/tests/verify_test.rs index 0a59f38..f0717b6 100644 --- a/prover-sdk/tests/verify_test.rs +++ b/prover-sdk/tests/verify_test.rs @@ -14,8 +14,8 @@ async fn test_verify_invalid_proof() { let url = Url::parse(&url).unwrap(); let sdk = ProverSDK::new(url, access_key).await.unwrap(); let result = sdk.clone().verify("wrong proof".to_string()).await; - assert!(result.is_ok(),"Failed to verify proof"); - assert_eq!("false", result.unwrap()); + assert!(result.is_ok(), "Failed to verify proof"); + assert_eq!("false", result.unwrap()); } #[tokio::test] @@ -46,6 +46,6 @@ async fn test_verify_valid_proof() { assert!(result.is_some()); let result = result.unwrap(); let result = sdk.clone().verify(result.proof).await; - assert!(result.is_ok(),"Failed to verify proof"); + assert!(result.is_ok(), "Failed to verify proof"); assert_eq!("true", result.unwrap()); } diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index a2f56b2..3231bd0 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -1,7 +1,9 @@ use super::run::RunPaths; use super::CairoVersionedInput; use crate::errors::ProverError; -use crate::utils::proof_parser::{extract_program_hash, extract_program_output, program_output_hash}; +use crate::utils::proof_parser::{ + extract_program_hash, extract_program_output, program_output_hash, +}; use crate::utils::{config::Template, job::JobStore}; use common::models::{JobStatus, ProverResult}; use serde_json::Value; @@ -48,7 +50,7 @@ pub async fn prove( let program_hash = extract_program_hash(stark_proof.clone()); let program_output = extract_program_output(stark_proof.clone()); let program_output_hash = program_output_hash(program_output.clone()); - let prover_result = ProverResult{ + let prover_result = ProverResult { proof: final_result, program_hash, program_output, @@ -56,7 +58,11 @@ pub async fn prove( }; job_store - .update_job_status(job_id, JobStatus::Completed,serde_json::to_string_pretty(&prover_result).ok()) + .update_job_status( + job_id, + JobStatus::Completed, + serde_json::to_string_pretty(&prover_result).ok(), + ) .await; if sender.receiver_count() > 0 { sender diff --git a/prover/src/utils/job.rs b/prover/src/utils/job.rs index b80ce61..92d373d 100644 --- a/prover/src/utils/job.rs +++ b/prover/src/utils/job.rs @@ -23,12 +23,20 @@ pub struct Job { pub created: Instant, } -#[derive(Serialize,Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(untagged)] pub enum JobResponse { - InProgress { id: u64, status: JobStatus }, - Completed { result: ProverResult, status: JobStatus }, - Failed { error: String }, + InProgress { + id: u64, + status: JobStatus, + }, + Completed { + result: ProverResult, + status: JobStatus, + }, + Failed { + error: String, + }, } #[derive(Default, Clone)] @@ -113,10 +121,7 @@ pub async fn get_job( StatusCode::OK, Json(JobResponse::Completed { status: job.status.clone(), - result: serde_json::from_str(&job - .result - .clone() - .unwrap()).unwrap(), + result: serde_json::from_str(&job.result.clone().unwrap()).unwrap(), }), ), JobStatus::Failed => ( diff --git a/prover/src/utils/mod.rs b/prover/src/utils/mod.rs index d3cad6f..3da6868 100644 --- a/prover/src/utils/mod.rs +++ b/prover/src/utils/mod.rs @@ -1,4 +1,4 @@ pub mod config; pub mod job; +pub mod proof_parser; pub mod shutdown; -pub mod proof_parser; \ No newline at end of file diff --git a/prover/src/verifier.rs b/prover/src/verifier.rs index a1d7deb..f66bf26 100644 --- a/prover/src/verifier.rs +++ b/prover/src/verifier.rs @@ -1,9 +1,13 @@ -use crate::{auth::jwt::Claims, errors::ProverError, extractors::workdir::TempDirHandle}; -use axum::{response::IntoResponse, Json}; +use crate::{auth::jwt::Claims, extractors::workdir::TempDirHandle}; +use axum::Json; use std::process::Command; -pub async fn verify_proof(TempDirHandle(dir):TempDirHandle,_claims:Claims,Json(proof): Json) -> Json { +pub async fn verify_proof( + TempDirHandle(dir): TempDirHandle, + _claims: Claims, + Json(proof): Json, +) -> Json { // Define the path for the proof file let file = dir.into_path().join("proof"); From ddeb11932f07080be525408384c0bed4b0e331a7 Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Tue, 17 Sep 2024 19:19:53 +0200 Subject: [PATCH 4/6] error handling and binary update --- Cargo.lock | 1 + Cargo.toml | 1 + bin/cairo-prove/src/fetch.rs | 32 +++++++++++--------------------- bin/cairo-prove/src/main.rs | 3 +-- prover-sdk/src/lib.rs | 1 + prover/Cargo.toml | 3 ++- prover/src/errors.rs | 4 ++++ prover/src/threadpool/prove.rs | 4 ++-- prover/src/utils/job.rs | 14 +++++--------- 9 files changed, 28 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 93c47b1..5959954 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1428,6 +1428,7 @@ dependencies = [ name = "prover" version = "0.1.0" dependencies = [ + "anyhow", "async-stream", "axum", "axum-extra", diff --git a/Cargo.toml b/Cargo.toml index ef9ffa2..8d5ab68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,3 +46,4 @@ futures = "0.3.30" async-stream = "0.3.5" swiftness_proof_parser = { git = "https://github.com/cartridge-gg/swiftness", rev = "1d46e21"} starknet-crypto = "0.7.0" +anyhow = "1.0.89" \ No newline at end of file diff --git a/bin/cairo-prove/src/fetch.rs b/bin/cairo-prove/src/fetch.rs index fc0883b..9ca24c6 100644 --- a/bin/cairo-prove/src/fetch.rs +++ b/bin/cairo-prove/src/fetch.rs @@ -1,34 +1,25 @@ use std::time::Duration; -use prover_sdk::sdk::ProverSDK; +use prover_sdk::{sdk::ProverSDK, JobResponse, ProverResult}; use serde_json::Value; use tokio::time::sleep; use tracing::info; use crate::errors::ProveErrors; -pub async fn fetch_job_sse(sdk: ProverSDK, job: u64) -> Result { +pub async fn fetch_job_sse(sdk: ProverSDK, job: u64) -> Result { info!("Job ID: {}", job); sdk.sse(job).await?; info!("Job completed"); let response = sdk.get_job(job).await?; let response = response.text().await?; - let json_response: Value = serde_json::from_str(&response)?; - if let Some(status) = json_response.get("status").and_then(Value::as_str) { - if status == "Completed" { - return Ok(json_response - .get("result") - .and_then(Value::as_str) - .unwrap_or("No result found") - .to_string()); - } else { - Err(ProveErrors::Custom(json_response.to_string())) - } - } else { - Err(ProveErrors::Custom(json_response.to_string())) + let json_response: JobResponse = serde_json::from_str(&response).unwrap(); + if let JobResponse::Completed { result, .. } = json_response { + return Ok(result); } + Err(ProveErrors::Custom("Job failed".to_string())) } -pub async fn fetch_job_polling(sdk: ProverSDK, job: u64) -> Result { +pub async fn fetch_job_polling(sdk: ProverSDK, job: u64) -> Result { info!("Fetching job: {}", job); let mut counter = 0; loop { @@ -38,11 +29,10 @@ pub async fn fetch_job_polling(sdk: ProverSDK, job: u64) -> Result { - return Ok(json_response - .get("result") - .and_then(Value::as_str) - .unwrap_or("No result found") - .to_string()); + let json_response: JobResponse = serde_json::from_str(&response).unwrap(); + if let JobResponse::Completed { result, .. } = json_response { + return Ok(result); + } } "Pending" | "Running" => { info!("Job is still in progress. Status: {}", status); diff --git a/bin/cairo-prove/src/main.rs b/bin/cairo-prove/src/main.rs index 91c5c02..4d1e4ec 100644 --- a/bin/cairo-prove/src/main.rs +++ b/bin/cairo-prove/src/main.rs @@ -21,8 +21,7 @@ pub async fn main() -> Result<(), ProveErrors> { fetch_job_polling(sdk, job).await? }; let path: std::path::PathBuf = args.program_output; - std::fs::write(path, job)?; + std::fs::write(path, serde_json::to_string_pretty(&job)?)?; } - Ok(()) } diff --git a/prover-sdk/src/lib.rs b/prover-sdk/src/lib.rs index 56fc763..b18ea01 100644 --- a/prover-sdk/src/lib.rs +++ b/prover-sdk/src/lib.rs @@ -3,4 +3,5 @@ pub mod errors; pub mod sdk; pub mod sdk_builder; +pub use common::models::{JobResponse, ProverResult}; pub use common::prover_input::*; diff --git a/prover/Cargo.toml b/prover/Cargo.toml index ffe1076..d5a7aaf 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -28,4 +28,5 @@ starknet-types-core.workspace = true futures.workspace = true async-stream.workspace = true swiftness_proof_parser.workspace = true -starknet-crypto.workspace = true \ No newline at end of file +starknet-crypto.workspace = true +anyhow.workspace = true \ No newline at end of file diff --git a/prover/src/errors.rs b/prover/src/errors.rs index ca245cd..81e3163 100644 --- a/prover/src/errors.rs +++ b/prover/src/errors.rs @@ -1,3 +1,4 @@ +use anyhow::Error as AnyhowError; use axum::{ http::StatusCode, response::{IntoResponse, Response}, @@ -34,6 +35,8 @@ pub enum ProverError { KeyError(#[from] ed25519_dalek::SignatureError), #[error("Failed to send message via SSE{0}")] SseError(String), + #[error(transparent)] + ParserError(#[from] AnyhowError), } impl From> for ProverError { fn from(err: SendError) -> ProverError { @@ -81,6 +84,7 @@ impl IntoResponse for ProverError { ProverError::AddressParse(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), ProverError::KeyError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), ProverError::SseError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + ProverError::ParserError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), }; let body = Json(json!({ "error": error_message })); diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index 3231bd0..a7ee521 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -45,8 +45,8 @@ pub async fn prove( let sender = sse_tx.lock().await; if prove_status.success() { - let proof_json = serde_json::from_str::(&final_result).unwrap(); - let stark_proof = stark_proof::StarkProof::try_from(proof_json).unwrap(); + let proof_json = serde_json::from_str::(&final_result)?; + let stark_proof = stark_proof::StarkProof::try_from(proof_json)?; let program_hash = extract_program_hash(stark_proof.clone()); let program_output = extract_program_output(stark_proof.clone()); let program_output_hash = program_output_hash(program_output.clone()); diff --git a/prover/src/utils/job.rs b/prover/src/utils/job.rs index 92d373d..7528584 100644 --- a/prover/src/utils/job.rs +++ b/prover/src/utils/job.rs @@ -13,7 +13,7 @@ use std::{ }; use tokio::sync::Mutex; -use crate::{auth::jwt::Claims, server::AppState}; +use crate::{auth::jwt::Claims, errors::ProverError, server::AppState}; #[derive(Clone)] pub struct Job { @@ -107,7 +107,7 @@ pub async fn get_job( Path(id): Path, State(app_state): State, _claims: Claims, -) -> impl IntoResponse { +) -> Result { if let Some(job) = app_state.job_store.get_job(id).await { let (status, response) = match job.status { JobStatus::Pending | JobStatus::Running => ( @@ -121,7 +121,7 @@ pub async fn get_job( StatusCode::OK, Json(JobResponse::Completed { status: job.status.clone(), - result: serde_json::from_str(&job.result.clone().unwrap()).unwrap(), + result: serde_json::from_str(&job.result.clone().unwrap_or_default())?, }), ), JobStatus::Failed => ( @@ -140,12 +140,8 @@ pub async fn get_job( }), ), }; - (status, response).into_response() + Ok((status, response).into_response()) } else { - ( - StatusCode::NOT_FOUND, - Json(format!("Job with id {} not found", id)), - ) - .into_response() + Err(ProverError::CustomError("Job not found".to_string())) } } From c046e7aadd5cff73361160b8a8374ab52280819a Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Tue, 17 Sep 2024 19:38:24 +0200 Subject: [PATCH 5/6] multiple admin keys --- prover-sdk/tests/register_test.rs | 12 ++++++++++-- prover/src/auth/register.rs | 2 +- prover/src/auth/validation.rs | 8 ++++---- prover/src/lib.rs | 4 ++-- prover/src/server.rs | 17 ++++++++++------- scripts/e2e_test.sh | 15 ++++++++++----- 6 files changed, 37 insertions(+), 21 deletions(-) diff --git a/prover-sdk/tests/register_test.rs b/prover-sdk/tests/register_test.rs index 873078e..21f592f 100644 --- a/prover-sdk/tests/register_test.rs +++ b/prover-sdk/tests/register_test.rs @@ -4,12 +4,20 @@ use url::Url; #[tokio::test] async fn test_register_authorized() { let url = std::env::var("PROVER_URL").unwrap(); - let admin_key = std::env::var("ADMIN_PRIVATE_KEY").unwrap(); - let admin_key = ProverAccessKey::from_hex_string(&admin_key).unwrap(); + let admin_key1 = std::env::var("ADMIN_PRIVATE_KEY_1").unwrap(); + let admin_key2 = std::env::var("ADMIN_PRIVATE_KEY_2").unwrap(); + + let admin_key = ProverAccessKey::from_hex_string(&admin_key1).unwrap(); let random_key = ProverAccessKey::generate(); let url = Url::parse(&url).unwrap(); let mut sdk = ProverSDK::new(url.clone(), admin_key).await.unwrap(); sdk.register(random_key.0.verifying_key()).await.unwrap(); + let new_sdk = ProverSDK::new(url.clone(), random_key).await; + assert!(new_sdk.is_ok()); + let admin_key = ProverAccessKey::from_hex_string(&admin_key2).unwrap(); + let random_key = ProverAccessKey::generate(); + let mut sdk = ProverSDK::new(url.clone(), admin_key).await.unwrap(); + sdk.register(random_key.0.verifying_key()).await.unwrap(); let new_sdk = ProverSDK::new(url, random_key).await; assert!(new_sdk.is_ok()); } diff --git a/prover/src/auth/register.rs b/prover/src/auth/register.rs index 0989e6f..1ed9d0b 100644 --- a/prover/src/auth/register.rs +++ b/prover/src/auth/register.rs @@ -10,7 +10,7 @@ pub async fn register( _claims: Claims, Json(payload): Json, ) -> Result { - if state.admin_key != payload.authority { + if !state.admins_keys.contains(&payload.authority) { return Err(ProverError::Auth(AuthError::Unauthorized)); } payload diff --git a/prover/src/auth/validation.rs b/prover/src/auth/validation.rs index 9f093a7..9b7801d 100644 --- a/prover/src/auth/validation.rs +++ b/prover/src/auth/validation.rs @@ -119,7 +119,7 @@ mod tests { thread_pool: Arc::new(Mutex::new(ThreadPool::new(1))), nonces, authorizer: Authorizer::Open, - admin_key: generate_verifying_key(&generate_signing_key()), + admins_keys: vec![generate_verifying_key(&generate_signing_key())], sse_tx: Arc::new(Mutex::new(tokio::sync::broadcast::channel(100).0)), }; @@ -162,7 +162,7 @@ mod tests { thread_pool: Arc::new(Mutex::new(ThreadPool::new(1))), nonces, authorizer: Authorizer::Open, - admin_key: generate_verifying_key(&generate_signing_key()), + admins_keys: vec![generate_verifying_key(&generate_signing_key())], sse_tx: Arc::new(Mutex::new(tokio::sync::broadcast::channel(100).0)), }; @@ -202,7 +202,7 @@ mod tests { thread_pool: Arc::new(Mutex::new(ThreadPool::new(1))), nonces, authorizer: Authorizer::Open, - admin_key: generate_verifying_key(&generate_signing_key()), + admins_keys: vec![generate_verifying_key(&generate_signing_key())], sse_tx: Arc::new(Mutex::new(tokio::sync::broadcast::channel(100).0)), }; @@ -243,7 +243,7 @@ mod tests { thread_pool: Arc::new(Mutex::new(ThreadPool::new(1))), nonces, authorizer: Authorizer::Open, - admin_key: generate_verifying_key(&generate_signing_key()), + admins_keys: vec![generate_verifying_key(&generate_signing_key())], sse_tx: Arc::new(Mutex::new(tokio::sync::broadcast::channel(100).0)), }; diff --git a/prover/src/lib.rs b/prover/src/lib.rs index b71ecb2..ba4b1c6 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -30,6 +30,6 @@ pub struct Args { pub authorized_keys: Vec, #[arg(long, env, default_value = "4")] pub num_workers: usize, - #[arg(long, env)] - pub admin_key: String, + #[arg(long, env, value_delimiter = ',')] + pub admins_keys: Vec, } diff --git a/prover/src/server.rs b/prover/src/server.rs index 3122b85..199e867 100644 --- a/prover/src/server.rs +++ b/prover/src/server.rs @@ -34,7 +34,7 @@ pub struct AppState { pub jwt_secret_key: String, pub nonces: Arc>>, pub authorizer: Authorizer, - pub admin_key: VerifyingKey, + pub admins_keys: Vec, pub sse_tx: Arc>>, } @@ -49,12 +49,15 @@ pub async fn start(args: Args) -> Result<(), ProverError> { let authorizer = Authorizer::Persistent(FileAuthorizer::new(args.authorized_keys_path.clone()).await?); + let mut admins_keys = Vec::new(); + for key in args.admins_keys { + let verifying_key_bytes = prefix_hex::decode::>(key) + .map_err(|e| AuthorizerError::PrefixHexConversionError(e.to_string()))?; + let verifying_key = VerifyingKey::from_bytes(&verifying_key_bytes.try_into()?)?; + admins_keys.push(verifying_key); + authorizer.authorize(verifying_key).await?; + } - let admin_key_bytes = prefix_hex::decode::>(args.admin_key) - .map_err(|e| AuthorizerError::PrefixHexConversionError(e.to_string()))?; - let admin_key = VerifyingKey::from_bytes(&admin_key_bytes.try_into()?)?; - - authorizer.authorize(admin_key).await?; for key in args.authorized_keys.iter() { let verifying_key_bytes = prefix_hex::decode::>(key) .map_err(|e| AuthorizerError::PrefixHexConversionError(e.to_string()))?; @@ -70,7 +73,7 @@ pub async fn start(args: Args) -> Result<(), ProverError> { authorizer, job_store: JobStore::default(), thread_pool: Arc::new(Mutex::new(ThreadPool::new(args.num_workers))), - admin_key, + admins_keys, sse_tx: Arc::new(Mutex::new(sse_tx)), }; diff --git a/scripts/e2e_test.sh b/scripts/e2e_test.sh index f072517..2f0d10e 100755 --- a/scripts/e2e_test.sh +++ b/scripts/e2e_test.sh @@ -32,8 +32,13 @@ PRIVATE_KEY=$(echo "$KEYGEN_OUTPUT" | grep "Private key" | awk '{print $3}' | tr KEYGEN_OUTPUT=$(cargo run -p keygen) -ADMIN_PUBLIC_KEY=$(echo "$KEYGEN_OUTPUT" | grep "Public key" | awk '{print $3}' | tr -d ',' | tr -d '[:space:]') -ADMIN_PRIVATE_KEY=$(echo "$KEYGEN_OUTPUT" | grep "Private key" | awk '{print $3}' | tr -d ',' | tr -d '[:space:]') +ADMIN_PUBLIC_KEY1=$(echo "$KEYGEN_OUTPUT" | grep "Public key" | awk '{print $3}' | tr -d ',' | tr -d '[:space:]') +ADMIN_PRIVATE_KEY1=$(echo "$KEYGEN_OUTPUT" | grep "Private key" | awk '{print $3}' | tr -d ',' | tr -d '[:space:]') + +KEYGEN_OUTPUT=$(cargo run -p keygen) + +ADMIN_PUBLIC_KEY2=$(echo "$KEYGEN_OUTPUT" | grep "Public key" | awk '{print $3}' | tr -d ',' | tr -d '[:space:]') +ADMIN_PRIVATE_KEY2=$(echo "$KEYGEN_OUTPUT" | grep "Private key" | awk '{print $3}' | tr -d ',' | tr -d '[:space:]') REPLACE_FLAG="" if [ "$CONTAINER_ENGINE" == "podman" ]; then @@ -44,12 +49,12 @@ $CONTAINER_ENGINE run -d --name http_prover_test $REPLACE_FLAG \ --jwt-secret-key "secret" \ --message-expiration-time 3600 \ --session-expiration-time 3600 \ - --authorized-keys $PUBLIC_KEY,$ADMIN_PUBLIC_KEY \ - --admin-key $ADMIN_PUBLIC_KEY + --authorized-keys $PUBLIC_KEY,$ADMIN_PUBLIC_KEY1,$ADMIN_PUBLIC_KEY2 \ + --admins-keys $ADMIN_PUBLIC_KEY1,$ADMIN_PUBLIC_KEY2 start_time=$(date +%s) -PRIVATE_KEY=$PRIVATE_KEY PROVER_URL="http://localhost:3040" ADMIN_PRIVATE_KEY=$ADMIN_PRIVATE_KEY cargo test --no-fail-fast --workspace --verbose +PRIVATE_KEY=$PRIVATE_KEY PROVER_URL="http://localhost:3040" ADMIN_PRIVATE_KEY_1=$ADMIN_PRIVATE_KEY1 ADMIN_PRIVATE_KEY_2=$ADMIN_PRIVATE_KEY2 cargo test --no-fail-fast --workspace --verbose end_time=$(date +%s) From 10d9dbdaebd0cf788aab8536a5181ecdecd26f35 Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Wed, 18 Sep 2024 13:15:47 +0200 Subject: [PATCH 6/6] function to calculate prover result --- prover/src/threadpool/prove.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index a7ee521..daf3b69 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -45,18 +45,7 @@ pub async fn prove( let sender = sse_tx.lock().await; if prove_status.success() { - let proof_json = serde_json::from_str::(&final_result)?; - let stark_proof = stark_proof::StarkProof::try_from(proof_json)?; - let program_hash = extract_program_hash(stark_proof.clone()); - let program_output = extract_program_output(stark_proof.clone()); - let program_output_hash = program_output_hash(program_output.clone()); - let prover_result = ProverResult { - proof: final_result, - program_hash, - program_output, - program_output_hash, - }; - + let prover_result = prover_result(final_result)?; job_store .update_job_status( job_id, @@ -82,6 +71,21 @@ pub async fn prove( Ok(()) } +fn prover_result(final_result: String) -> Result { + let proof_json = serde_json::from_str::(&final_result)?; + let stark_proof = stark_proof::StarkProof::try_from(proof_json)?; + let program_hash = extract_program_hash(stark_proof.clone()); + let program_output = extract_program_output(stark_proof.clone()); + let program_output_hash = program_output_hash(program_output.clone()); + let prover_result = ProverResult { + proof: final_result.clone(), + program_hash, + program_output, + program_output_hash, + }; + Ok(prover_result) +} + #[derive(Debug, Clone)] pub(super) struct ProvePaths { pub(super) program_input: PathBuf,