Skip to content

Commit

Permalink
Refactor wrapping host functions
Browse files Browse the repository at this point in the history
The previous version of the code was using fetch_arg and return_result
functions and because it had to return a Result<u64>, it also couldn't
easily use a regular result with the error type we actually want to
return - HTTPError in this case.

The new code allows passing the host function as an argument to the
wrap_async function, which in turn handles all of the heavy lifting of
serializing and deserializing arguments and results

The only downside is that now host functions have to return
`Pin<Box<...>>`, but I'm OK with that as it doesn't complicate the code
too much. A possible solution to allow returning a regular feature would
be to create a helper trait, like pointed out in here:
https://users.rust-lang.org/t/approaches-to-an-issue-with-higher-rank-trait-bounds-on-another-generic-type/81941/2,
but I don't want to waste too much time on it at the moment.

Another solution might be this crate: https://lib.rs/crates/higher-order-closure
  • Loading branch information
drogus committed Mar 11, 2024
1 parent 05094f9 commit d4ad438
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 96 deletions.
1 change: 0 additions & 1 deletion bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ mod bindings {
pub fn log(content: *mut u8, content_len: usize);
pub fn http(content: *mut u8, content_len: usize) -> u64;
pub fn consume_buffer(index: u32, content: *mut u8, content_len: usize);

}
}

Expand Down
3 changes: 3 additions & 0 deletions rust-example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ crate-type = ["cdylib"]

[dependencies]
crows-bindings = { path = "../bindings" }

[profile.release]
lto = true
191 changes: 96 additions & 95 deletions wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use anyhow::anyhow;
use borsh::{from_slice, to_vec, BorshDeserialize, BorshSerialize};
use crows_bindings::{HTTPError, HTTPMethod, HTTPRequest, HTTPResponse};
use crows_utils::services::RunId;
use futures::Future;
use reqwest::header::{HeaderName, HeaderValue};
use reqwest::{Body, Request, Url};
use std::pin::Pin;
use std::str::FromStr;
use std::{any::Any, collections::HashMap, io::IoSlice};
use tokio::sync::mpsc::UnboundedReceiver;
Expand All @@ -12,22 +14,6 @@ use wasi_common::WasiFile;
use wasmtime::{Caller, Config, Engine, Linker, Memory, MemoryType, Module, Store};
use wasmtime_wasi::{StdoutStream, StreamResult};

#[macro_export]
macro_rules! ok_or_return {
($expr:expr, $store:expr, $err_handler:expr) => {
match $expr {
Ok(value) => value,
Err(err) => {
let err = $err_handler(err);
let encoded = to_vec(&err)?;
let length = encoded.len();
let index = $store.buffers.insert(encoded.into_boxed_slice());
return Ok(create_return_value(1, length as u32, index as u32));
}
}
};
}

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("the module with a given name couldn't be found")]
Expand Down Expand Up @@ -93,65 +79,109 @@ impl WasiHostCtx {
self.memory = Some(mem);
}

pub async fn http(mut caller: Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result<u64> {
let request: HTTPRequest = Self::fetch_arg(&mut caller, ptr, len)?;
pub async fn wrap_async<'a, U, F, E>(
mut caller: Caller<'a, Self>,
ptr: u32,
len: u32,
f: F,
) -> anyhow::Result<u64>
where
F: for<'b> FnOnce(&'b mut Caller<'_, Self>, HTTPRequest) -> Pin<Box<dyn Future<Output = Result<U, E>> + 'b + Send>>,
U: BorshSerialize,
E: BorshSerialize,
{
let memory = get_memory(&mut caller)?;

let memory = get_memory(&mut caller).unwrap();
let (_, store) = memory.data_and_store_mut(&mut caller);
let slice = memory
.data(&caller)
.get(ptr as usize..(ptr + len) as usize)
.ok_or(anyhow!("Could not get memory slice"))?;

let client = &store.client;
let arg = from_slice(slice)?;

let method = match request.method {
HTTPMethod::HEAD => reqwest::Method::HEAD,
HTTPMethod::GET => reqwest::Method::GET,
HTTPMethod::POST => reqwest::Method::POST,
HTTPMethod::PUT => reqwest::Method::PUT,
HTTPMethod::DELETE => reqwest::Method::DELETE,
HTTPMethod::OPTIONS => reqwest::Method::OPTIONS,
};
let url = ok_or_return!(Url::parse(&request.url), store, |err| HTTPError {
message: format!("Error when parsing the URL: {err:?}"),
});

let mut reqw_req = Request::new(method, url);

for (key, value) in request.headers {
let name = ok_or_return!(HeaderName::from_str(&key), store, |err| HTTPError {
message: format!("Invalid header name: {key}: {err:?}"),
});
let value = ok_or_return!(HeaderValue::from_str(&value), store, |err| HTTPError {
message: format!("Invalid header value: {value}: {err:?}"),
});
reqw_req.headers_mut().insert(name, value);
}
let result = f(&mut caller, arg).await;

let (_, store) = { memory.data_and_store_mut(&mut caller) };

match result {
Ok(ret) => {
let encoded = to_vec(&ret).unwrap();

let length = encoded.len();
let index = store.buffers.insert(encoded.into_boxed_slice());

*reqw_req.body_mut() = request.body.map(|b| Body::from(b));
Ok(create_return_value(0, length as u32, index as u32))
}
Err(err) => {
let encoded = to_vec(&err).unwrap();

let response = ok_or_return!(client.execute(reqw_req).await, store, |err| HTTPError {
message: format!("Error when sending a request: {err:?}"),
});
let length = encoded.len();
let index = store.buffers.insert(encoded.into_boxed_slice());

let mut headers = HashMap::new();
for (name, value) in response.headers().iter() {
let value = ok_or_return!(value.to_str(), store, |err| HTTPError {
message: format!("Could not parse response header {value:?}: {err:?}"),
});
headers.insert(name.to_string(), value.to_string());
Ok(create_return_value(0, length as u32, index as u32))
}
}
}

pub fn http<'a>(
mut caller: &'a mut Caller<'_, Self>,
request: HTTPRequest,
) -> Pin<Box<dyn Future<Output = Result<HTTPResponse, HTTPError>> + 'a + Send>> {
Box::pin(async move {
let memory = get_memory(&mut caller).unwrap();
let (_, store) = memory.data_and_store_mut(&mut caller);

let client = &store.client;

let method = match request.method {
HTTPMethod::HEAD => reqwest::Method::HEAD,
HTTPMethod::GET => reqwest::Method::GET,
HTTPMethod::POST => reqwest::Method::POST,
HTTPMethod::PUT => reqwest::Method::PUT,
HTTPMethod::DELETE => reqwest::Method::DELETE,
HTTPMethod::OPTIONS => reqwest::Method::OPTIONS,
};
let url = Url::parse(&request.url).map_err(|err| HTTPError {
message: format!("Error when parsing the URL: {err:?}"),
})?;

let mut reqw_req = Request::new(method, url);

for (key, value) in request.headers {
let name = HeaderName::from_str(&key).map_err(|err| HTTPError {
message: format!("Invalid header name: {key}: {err:?}"),
})?;
let value = HeaderValue::from_str(&value).map_err(|err| HTTPError {
message: format!("Invalid header value: {value}: {err:?}"),
})?;
reqw_req.headers_mut().insert(name, value);
}

let status = response.status().as_u16();
let body = ok_or_return!(response.text().await, store, |err| HTTPError {
message: format!("Problem with fetching the body: {err:?}"),
});
*reqw_req.body_mut() = request.body.map(|b| Body::from(b));

Self::return_result(
&mut caller,
HTTPResponse {
let response = client.execute(reqw_req).await.map_err(|err| HTTPError {
message: format!("Error when sending a request: {err:?}"),
})?;

let mut headers = HashMap::new();
for (name, value) in response.headers().iter() {
let value = value.to_str().map_err(|err| HTTPError {
message: format!("Could not parse response header {value:?}: {err:?}"),
})?;
headers.insert(name.to_string(), value.to_string());
}

let status = response.status().as_u16();
let body = response.text().await.map_err(|err| HTTPError {
message: format!("Problem with fetching the body: {err:?}"),
})?;

Ok(HTTPResponse {
headers,
body,
status,
},
)
})
})
}

pub fn consume_buffer(
Expand All @@ -175,37 +205,6 @@ impl WasiHostCtx {

Ok(())
}

pub fn fetch_arg<T>(mut caller: &mut Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result<T>
where
T: BorshDeserialize,
{
let memory = get_memory(&mut caller)?;

let slice = memory
.data(&caller)
.get(ptr as usize..(ptr + len) as usize)
.ok_or(anyhow!("Could not get memory slice"))?;

let arg = from_slice(slice)?;

return Ok(arg);
}

pub fn return_result<T>(mut caller: &mut Caller<'_, Self>, ret: T) -> anyhow::Result<u64>
where
T: BorshSerialize,
{
let memory = get_memory(&mut caller)?;
let (_, store) = memory.data_and_store_mut(&mut caller);

let encoded = to_vec(&ret)?;

let length = encoded.len();
let index = store.buffers.insert(encoded.into_boxed_slice());

Ok(create_return_value(0, length as u32, index as u32))
}
}

impl wasmtime_wasi::WasiView for WasiHostCtx {
Expand Down Expand Up @@ -256,7 +255,9 @@ impl Environment {
.unwrap();
linker
.func_wrap2_async("crows", "http", |caller, ptr, len| {
Box::new(async move { WasiHostCtx::http(caller, ptr, len).await })
Box::new(async move {
WasiHostCtx::wrap_async(caller, ptr, len, WasiHostCtx::http).await
})
})
.unwrap();

Expand Down

0 comments on commit d4ad438

Please sign in to comment.