From ddeb11932f07080be525408384c0bed4b0e331a7 Mon Sep 17 00:00:00 2001 From: Mateusz Chudkowski Date: Tue, 17 Sep 2024 19:19:53 +0200 Subject: [PATCH] error handling and binary update --- Cargo.lock | 1 + Cargo.toml | 1 + bin/cairo-prove/src/fetch.rs | 32 +++++++++++--------------------- bin/cairo-prove/src/main.rs | 3 +-- prover-sdk/src/lib.rs | 1 + prover/Cargo.toml | 3 ++- prover/src/errors.rs | 4 ++++ prover/src/threadpool/prove.rs | 4 ++-- prover/src/utils/job.rs | 14 +++++--------- 9 files changed, 28 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 93c47b1..5959954 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1428,6 +1428,7 @@ dependencies = [ name = "prover" version = "0.1.0" dependencies = [ + "anyhow", "async-stream", "axum", "axum-extra", diff --git a/Cargo.toml b/Cargo.toml index ef9ffa2..8d5ab68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,3 +46,4 @@ futures = "0.3.30" async-stream = "0.3.5" swiftness_proof_parser = { git = "https://github.com/cartridge-gg/swiftness", rev = "1d46e21"} starknet-crypto = "0.7.0" +anyhow = "1.0.89" \ No newline at end of file diff --git a/bin/cairo-prove/src/fetch.rs b/bin/cairo-prove/src/fetch.rs index fc0883b..9ca24c6 100644 --- a/bin/cairo-prove/src/fetch.rs +++ b/bin/cairo-prove/src/fetch.rs @@ -1,34 +1,25 @@ use std::time::Duration; -use prover_sdk::sdk::ProverSDK; +use prover_sdk::{sdk::ProverSDK, JobResponse, ProverResult}; use serde_json::Value; use tokio::time::sleep; use tracing::info; use crate::errors::ProveErrors; -pub async fn fetch_job_sse(sdk: ProverSDK, job: u64) -> Result { +pub async fn fetch_job_sse(sdk: ProverSDK, job: u64) -> Result { info!("Job ID: {}", job); sdk.sse(job).await?; info!("Job completed"); let response = sdk.get_job(job).await?; let response = response.text().await?; - let json_response: Value = serde_json::from_str(&response)?; - if let Some(status) = json_response.get("status").and_then(Value::as_str) { - if status == "Completed" { - return Ok(json_response - .get("result") - .and_then(Value::as_str) - .unwrap_or("No result found") - .to_string()); - } else { - Err(ProveErrors::Custom(json_response.to_string())) - } - } else { - Err(ProveErrors::Custom(json_response.to_string())) + let json_response: JobResponse = serde_json::from_str(&response).unwrap(); + if let JobResponse::Completed { result, .. } = json_response { + return Ok(result); } + Err(ProveErrors::Custom("Job failed".to_string())) } -pub async fn fetch_job_polling(sdk: ProverSDK, job: u64) -> Result { +pub async fn fetch_job_polling(sdk: ProverSDK, job: u64) -> Result { info!("Fetching job: {}", job); let mut counter = 0; loop { @@ -38,11 +29,10 @@ pub async fn fetch_job_polling(sdk: ProverSDK, job: u64) -> Result { - return Ok(json_response - .get("result") - .and_then(Value::as_str) - .unwrap_or("No result found") - .to_string()); + let json_response: JobResponse = serde_json::from_str(&response).unwrap(); + if let JobResponse::Completed { result, .. } = json_response { + return Ok(result); + } } "Pending" | "Running" => { info!("Job is still in progress. Status: {}", status); diff --git a/bin/cairo-prove/src/main.rs b/bin/cairo-prove/src/main.rs index 91c5c02..4d1e4ec 100644 --- a/bin/cairo-prove/src/main.rs +++ b/bin/cairo-prove/src/main.rs @@ -21,8 +21,7 @@ pub async fn main() -> Result<(), ProveErrors> { fetch_job_polling(sdk, job).await? }; let path: std::path::PathBuf = args.program_output; - std::fs::write(path, job)?; + std::fs::write(path, serde_json::to_string_pretty(&job)?)?; } - Ok(()) } diff --git a/prover-sdk/src/lib.rs b/prover-sdk/src/lib.rs index 56fc763..b18ea01 100644 --- a/prover-sdk/src/lib.rs +++ b/prover-sdk/src/lib.rs @@ -3,4 +3,5 @@ pub mod errors; pub mod sdk; pub mod sdk_builder; +pub use common::models::{JobResponse, ProverResult}; pub use common::prover_input::*; diff --git a/prover/Cargo.toml b/prover/Cargo.toml index ffe1076..d5a7aaf 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -28,4 +28,5 @@ starknet-types-core.workspace = true futures.workspace = true async-stream.workspace = true swiftness_proof_parser.workspace = true -starknet-crypto.workspace = true \ No newline at end of file +starknet-crypto.workspace = true +anyhow.workspace = true \ No newline at end of file diff --git a/prover/src/errors.rs b/prover/src/errors.rs index ca245cd..81e3163 100644 --- a/prover/src/errors.rs +++ b/prover/src/errors.rs @@ -1,3 +1,4 @@ +use anyhow::Error as AnyhowError; use axum::{ http::StatusCode, response::{IntoResponse, Response}, @@ -34,6 +35,8 @@ pub enum ProverError { KeyError(#[from] ed25519_dalek::SignatureError), #[error("Failed to send message via SSE{0}")] SseError(String), + #[error(transparent)] + ParserError(#[from] AnyhowError), } impl From> for ProverError { fn from(err: SendError) -> ProverError { @@ -81,6 +84,7 @@ impl IntoResponse for ProverError { ProverError::AddressParse(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), ProverError::KeyError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), ProverError::SseError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + ProverError::ParserError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), }; let body = Json(json!({ "error": error_message })); diff --git a/prover/src/threadpool/prove.rs b/prover/src/threadpool/prove.rs index 3231bd0..a7ee521 100644 --- a/prover/src/threadpool/prove.rs +++ b/prover/src/threadpool/prove.rs @@ -45,8 +45,8 @@ pub async fn prove( let sender = sse_tx.lock().await; if prove_status.success() { - let proof_json = serde_json::from_str::(&final_result).unwrap(); - let stark_proof = stark_proof::StarkProof::try_from(proof_json).unwrap(); + let proof_json = serde_json::from_str::(&final_result)?; + let stark_proof = stark_proof::StarkProof::try_from(proof_json)?; let program_hash = extract_program_hash(stark_proof.clone()); let program_output = extract_program_output(stark_proof.clone()); let program_output_hash = program_output_hash(program_output.clone()); diff --git a/prover/src/utils/job.rs b/prover/src/utils/job.rs index 92d373d..7528584 100644 --- a/prover/src/utils/job.rs +++ b/prover/src/utils/job.rs @@ -13,7 +13,7 @@ use std::{ }; use tokio::sync::Mutex; -use crate::{auth::jwt::Claims, server::AppState}; +use crate::{auth::jwt::Claims, errors::ProverError, server::AppState}; #[derive(Clone)] pub struct Job { @@ -107,7 +107,7 @@ pub async fn get_job( Path(id): Path, State(app_state): State, _claims: Claims, -) -> impl IntoResponse { +) -> Result { if let Some(job) = app_state.job_store.get_job(id).await { let (status, response) = match job.status { JobStatus::Pending | JobStatus::Running => ( @@ -121,7 +121,7 @@ pub async fn get_job( StatusCode::OK, Json(JobResponse::Completed { status: job.status.clone(), - result: serde_json::from_str(&job.result.clone().unwrap()).unwrap(), + result: serde_json::from_str(&job.result.clone().unwrap_or_default())?, }), ), JobStatus::Failed => ( @@ -140,12 +140,8 @@ pub async fn get_job( }), ), }; - (status, response).into_response() + Ok((status, response).into_response()) } else { - ( - StatusCode::NOT_FOUND, - Json(format!("Job with id {} not found", id)), - ) - .into_response() + Err(ProverError::CustomError("Job not found".to_string())) } }