diff --git a/bindings/src/lib.rs b/bindings/src/lib.rs index f0fc287..0ab25f5 100644 --- a/bindings/src/lib.rs +++ b/bindings/src/lib.rs @@ -46,7 +46,6 @@ mod bindings { pub fn log(content: *mut u8, content_len: usize); pub fn http(content: *mut u8, content_len: usize) -> u64; pub fn consume_buffer(index: u32, content: *mut u8, content_len: usize); - } } diff --git a/rust-example/Cargo.toml b/rust-example/Cargo.toml index f794927..77ff949 100644 --- a/rust-example/Cargo.toml +++ b/rust-example/Cargo.toml @@ -10,3 +10,6 @@ crate-type = ["cdylib"] [dependencies] crows-bindings = { path = "../bindings" } + +[profile.release] +lto = true diff --git a/wasm/src/lib.rs b/wasm/src/lib.rs index addf80d..435473d 100644 --- a/wasm/src/lib.rs +++ b/wasm/src/lib.rs @@ -2,8 +2,10 @@ use anyhow::anyhow; use borsh::{from_slice, to_vec, BorshDeserialize, BorshSerialize}; use crows_bindings::{HTTPError, HTTPMethod, HTTPRequest, HTTPResponse}; use crows_utils::services::RunId; +use futures::Future; use reqwest::header::{HeaderName, HeaderValue}; use reqwest::{Body, Request, Url}; +use std::pin::Pin; use std::str::FromStr; use std::{any::Any, collections::HashMap, io::IoSlice}; use tokio::sync::mpsc::UnboundedReceiver; @@ -12,22 +14,6 @@ use wasi_common::WasiFile; use wasmtime::{Caller, Config, Engine, Linker, Memory, MemoryType, Module, Store}; use wasmtime_wasi::{StdoutStream, StreamResult}; -#[macro_export] -macro_rules! ok_or_return { - ($expr:expr, $store:expr, $err_handler:expr) => { - match $expr { - Ok(value) => value, - Err(err) => { - let err = $err_handler(err); - let encoded = to_vec(&err)?; - let length = encoded.len(); - let index = $store.buffers.insert(encoded.into_boxed_slice()); - return Ok(create_return_value(1, length as u32, index as u32)); - } - } - }; -} - #[derive(thiserror::Error, Debug)] pub enum Error { #[error("the module with a given name couldn't be found")] @@ -93,65 +79,109 @@ impl WasiHostCtx { self.memory = Some(mem); } - pub async fn http(mut caller: Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result { - let request: HTTPRequest = Self::fetch_arg(&mut caller, ptr, len)?; + pub async fn wrap_async<'a, U, F, E>( + mut caller: Caller<'a, Self>, + ptr: u32, + len: u32, + f: F, + ) -> anyhow::Result + where + F: for<'b> FnOnce(&'b mut Caller<'_, Self>, HTTPRequest) -> Pin> + 'b + Send>>, + U: BorshSerialize, + E: BorshSerialize, + { + let memory = get_memory(&mut caller)?; - let memory = get_memory(&mut caller).unwrap(); - let (_, store) = memory.data_and_store_mut(&mut caller); + let slice = memory + .data(&caller) + .get(ptr as usize..(ptr + len) as usize) + .ok_or(anyhow!("Could not get memory slice"))?; - let client = &store.client; + let arg = from_slice(slice)?; - let method = match request.method { - HTTPMethod::HEAD => reqwest::Method::HEAD, - HTTPMethod::GET => reqwest::Method::GET, - HTTPMethod::POST => reqwest::Method::POST, - HTTPMethod::PUT => reqwest::Method::PUT, - HTTPMethod::DELETE => reqwest::Method::DELETE, - HTTPMethod::OPTIONS => reqwest::Method::OPTIONS, - }; - let url = ok_or_return!(Url::parse(&request.url), store, |err| HTTPError { - message: format!("Error when parsing the URL: {err:?}"), - }); - - let mut reqw_req = Request::new(method, url); - - for (key, value) in request.headers { - let name = ok_or_return!(HeaderName::from_str(&key), store, |err| HTTPError { - message: format!("Invalid header name: {key}: {err:?}"), - }); - let value = ok_or_return!(HeaderValue::from_str(&value), store, |err| HTTPError { - message: format!("Invalid header value: {value}: {err:?}"), - }); - reqw_req.headers_mut().insert(name, value); - } + let result = f(&mut caller, arg).await; + + let (_, store) = { memory.data_and_store_mut(&mut caller) }; + + match result { + Ok(ret) => { + let encoded = to_vec(&ret).unwrap(); + + let length = encoded.len(); + let index = store.buffers.insert(encoded.into_boxed_slice()); - *reqw_req.body_mut() = request.body.map(|b| Body::from(b)); + Ok(create_return_value(0, length as u32, index as u32)) + } + Err(err) => { + let encoded = to_vec(&err).unwrap(); - let response = ok_or_return!(client.execute(reqw_req).await, store, |err| HTTPError { - message: format!("Error when sending a request: {err:?}"), - }); + let length = encoded.len(); + let index = store.buffers.insert(encoded.into_boxed_slice()); - let mut headers = HashMap::new(); - for (name, value) in response.headers().iter() { - let value = ok_or_return!(value.to_str(), store, |err| HTTPError { - message: format!("Could not parse response header {value:?}: {err:?}"), - }); - headers.insert(name.to_string(), value.to_string()); + Ok(create_return_value(0, length as u32, index as u32)) + } } + } + + pub fn http<'a>( + mut caller: &'a mut Caller<'_, Self>, + request: HTTPRequest, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + let memory = get_memory(&mut caller).unwrap(); + let (_, store) = memory.data_and_store_mut(&mut caller); + + let client = &store.client; + + let method = match request.method { + HTTPMethod::HEAD => reqwest::Method::HEAD, + HTTPMethod::GET => reqwest::Method::GET, + HTTPMethod::POST => reqwest::Method::POST, + HTTPMethod::PUT => reqwest::Method::PUT, + HTTPMethod::DELETE => reqwest::Method::DELETE, + HTTPMethod::OPTIONS => reqwest::Method::OPTIONS, + }; + let url = Url::parse(&request.url).map_err(|err| HTTPError { + message: format!("Error when parsing the URL: {err:?}"), + })?; + + let mut reqw_req = Request::new(method, url); + + for (key, value) in request.headers { + let name = HeaderName::from_str(&key).map_err(|err| HTTPError { + message: format!("Invalid header name: {key}: {err:?}"), + })?; + let value = HeaderValue::from_str(&value).map_err(|err| HTTPError { + message: format!("Invalid header value: {value}: {err:?}"), + })?; + reqw_req.headers_mut().insert(name, value); + } - let status = response.status().as_u16(); - let body = ok_or_return!(response.text().await, store, |err| HTTPError { - message: format!("Problem with fetching the body: {err:?}"), - }); + *reqw_req.body_mut() = request.body.map(|b| Body::from(b)); - Self::return_result( - &mut caller, - HTTPResponse { + let response = client.execute(reqw_req).await.map_err(|err| HTTPError { + message: format!("Error when sending a request: {err:?}"), + })?; + + let mut headers = HashMap::new(); + for (name, value) in response.headers().iter() { + let value = value.to_str().map_err(|err| HTTPError { + message: format!("Could not parse response header {value:?}: {err:?}"), + })?; + headers.insert(name.to_string(), value.to_string()); + } + + let status = response.status().as_u16(); + let body = response.text().await.map_err(|err| HTTPError { + message: format!("Problem with fetching the body: {err:?}"), + })?; + + Ok(HTTPResponse { headers, body, status, - }, - ) + }) + }) } pub fn consume_buffer( @@ -175,37 +205,6 @@ impl WasiHostCtx { Ok(()) } - - pub fn fetch_arg(mut caller: &mut Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result - where - T: BorshDeserialize, - { - let memory = get_memory(&mut caller)?; - - let slice = memory - .data(&caller) - .get(ptr as usize..(ptr + len) as usize) - .ok_or(anyhow!("Could not get memory slice"))?; - - let arg = from_slice(slice)?; - - return Ok(arg); - } - - pub fn return_result(mut caller: &mut Caller<'_, Self>, ret: T) -> anyhow::Result - where - T: BorshSerialize, - { - let memory = get_memory(&mut caller)?; - let (_, store) = memory.data_and_store_mut(&mut caller); - - let encoded = to_vec(&ret)?; - - let length = encoded.len(); - let index = store.buffers.insert(encoded.into_boxed_slice()); - - Ok(create_return_value(0, length as u32, index as u32)) - } } impl wasmtime_wasi::WasiView for WasiHostCtx { @@ -256,7 +255,9 @@ impl Environment { .unwrap(); linker .func_wrap2_async("crows", "http", |caller, ptr, len| { - Box::new(async move { WasiHostCtx::http(caller, ptr, len).await }) + Box::new(async move { + WasiHostCtx::wrap_async(caller, ptr, len, WasiHostCtx::http).await + }) }) .unwrap();