Skip to content

Commit

Permalink
Finish implementing HTTP request
Browse files Browse the repository at this point in the history
  • Loading branch information
drogus committed Mar 10, 2024
1 parent ce3238e commit 47f64e1
Show file tree
Hide file tree
Showing 8 changed files with 556 additions and 67 deletions.
404 changes: 404 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.108"
futures = "0.3.21"
uuid = { version = "1.4.1", features = ["serde", "v4"] }
wasmtime = { git = "https://github.com/bytecodealliance/wasmtime.git" }
wasmtime = { git = "https://github.com/bytecodealliance/wasmtime.git", features = ["async"] }
wasmtime-wasi = { git = "https://github.com/bytecodealliance/wasmtime.git" }
wasi-common = { git = "https://github.com/bytecodealliance/wasmtime.git" }
wiggle = { git = "https://github.com/bytecodealliance/wasmtime.git" }
Expand Down
27 changes: 13 additions & 14 deletions bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use borsh::{from_slice, to_vec, BorshDeserialize, BorshSchema, BorshSerialize};
use std::{cell::RefCell, collections::HashMap, mem::MaybeUninit};
use borsh::{BorshSerialize, BorshDeserialize, from_slice, to_vec};

#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
pub enum HTTPMethod {
Expand All @@ -13,14 +13,17 @@ pub enum HTTPMethod {

#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
pub struct HTTPRequest {
url: String,
method: HTTPMethod,
headers: HashMap<String, String>,
body: Option<String>,
// TODO: these should not be public I think, I'd prefer to do a public interface for them
pub url: String,
pub method: HTTPMethod,
pub headers: HashMap<String, String>,
pub body: Option<String>,
}

#[derive(Debug, BorshDeserialize, BorshSerialize)]
pub struct HTTPError {}
pub struct HTTPError {
pub message: String
}

#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
pub struct HTTPResponse {
Expand Down Expand Up @@ -49,7 +52,7 @@ mod bindings {

fn with_buffer<R>(f: impl FnOnce(&mut Vec<u8>) -> R) -> R {
thread_local! {
static BUFFER: RefCell<Vec<u8>> = RefCell::new(Vec::new());
static BUFFER: RefCell<Vec<u8>> = RefCell::new(Vec::with_capacity(1024));
}

BUFFER.with(|r| {
Expand Down Expand Up @@ -86,21 +89,17 @@ where
{
let mut encoded = to_vec(arguments).unwrap();

println!("encoded length: {}", encoded.len());
let (status, length, index) = with_buffer(|mut buf| {
buf.append(&mut encoded);
let response = f(&mut buf);

extract_from_return_value(response)
});

println!("response: {status}, {length}, {index}");
with_buffer(|buf| {
let capacity = buf.capacity();
if capacity < length as usize {
let additional = length as usize - buf.capacity();
buf.reserve_exact(additional);
}
// when using reserve_exact it guarantees capacity to be vector.len() + additional long,
// thus we can just use length for reserving
buf.reserve_exact(length as usize);

unsafe {
bindings::consume_buffer(index, buf.as_mut_ptr(), length as usize);
Expand Down
6 changes: 2 additions & 4 deletions rust-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use crows_bindings::{http_request, HTTPMethod::*};

#[export_name="test"]
pub fn test() {
let response = http_request("foo".into(), GET, HashMap::new(), "".into());
println!("response: {response:?}");
let response = http_request("https://example.com".into(), GET, HashMap::new(), "".into());
println!("response: {:?}", response.unwrap());
}


2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly"
channel = "nightly-2024-01-27"
targets = ["wasm32-wasi", "x86_64-unknown-linux-gnu"]
4 changes: 3 additions & 1 deletion wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ edition = "2021"
crows-utils = { path = "../utils" }
slab = "0.4"
crows-bindings = { path = "../bindings" }
async-trait = "0.1"
reqwest = "0.11"

borsh.workspace = true
tokio.workspace = true
anyhow.workspace = true
thiserror.workspace = true
wasmtime.workspace = true
wasmtime-wasi.workspace = true
wasi-common.workspace = true
wiggle.workspace = true
borsh.workspace = true
176 changes: 131 additions & 45 deletions wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,35 @@
use borsh::to_vec;
use crows_bindings::{HTTPResponse, HTTPError};
use crows_utils::{services::RunId, ModuleId};
use std::mem::MaybeUninit;
use std::{any::Any, collections::HashMap, io::IoSlice, sync::Arc};
use tokio::time::{Duration, Instant};
use anyhow::anyhow;
use crows_bindings::{HTTPError, HTTPMethod, HTTPRequest, HTTPResponse};
use crows_utils::{services::RunId};
use reqwest::header::{HeaderName, HeaderValue};
use reqwest::{Body, Request, Url};
use std::str::FromStr;
use std::{any::Any, collections::HashMap, io::IoSlice};
use wasi_common::WasiFile;
use wasi_common::{
file::{FdFlags, FileType},
pipe::WritePipe,
Table,
};
use wasmtime::{
Caller, Config, Engine, Linker, Memory, MemoryType, Module, Store, StoreContextMut, Val,
ValType,
Caller, Config, Engine, Linker, Memory, MemoryType, Module, Store
};
use wasmtime_wasi::{StdoutStream, WasiCtxBuilder};
use wasmtime_wasi::{StdoutStream};
use borsh::{BorshSerialize, BorshDeserialize, from_slice, to_vec};

#[macro_export]
macro_rules! ok_or_return {
($expr:expr, $store:expr, $err_handler:expr) => {
match $expr {
Ok(value) => value,
Err(err) => {
let err = $err_handler(err);
let encoded = to_vec(&err)?;
let length = encoded.len();
let index = $store.buffers.insert(encoded.into_boxed_slice());
return Ok(create_return_value(1, length as u32, index as u32));
}
}
};
}

#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down Expand Up @@ -63,6 +78,7 @@ struct WasiHostCtx {
preview1_adapter: wasmtime_wasi::preview1::WasiPreview1Adapter,
memory: Option<Memory>,
buffers: slab::Slab<Box<[u8]>>,
client: reqwest::Client,
}

fn create_return_value(status: u8, length: u32, ptr: u32) -> u64 {
Expand All @@ -78,56 +94,119 @@ impl WasiHostCtx {
self.memory = Some(mem);
}

pub fn log(mut caller: Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result<()> {
let memory = get_memory(&mut caller)?;

let row_str = memory
.data(&caller)
.get(ptr as usize..(ptr + len) as usize)
.unwrap();

Ok(())
}
pub async fn http(mut caller: Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result<u64> {
let request: HTTPRequest = Self::fetch_arg(&mut caller, ptr, len)?;

pub fn http(mut caller: Caller<'_, Self>, ptr: u32, len: u32) -> u64 {
let memory = get_memory(&mut caller).unwrap();
let (_, store) = memory.data_and_store_mut(&mut caller);

let str = memory
.data(&caller)
.get(ptr as usize..(ptr + len) as usize)
.unwrap();
let client = &store.client;

let response = HTTPResponse {
headers: HashMap::new(),
body: "foo bar".into(),
status: 200,
let method = match request.method {
HTTPMethod::HEAD => reqwest::Method::HEAD,
HTTPMethod::GET => reqwest::Method::GET,
HTTPMethod::POST => reqwest::Method::POST,
HTTPMethod::PUT => reqwest::Method::PUT,
HTTPMethod::DELETE => reqwest::Method::DELETE,
HTTPMethod::OPTIONS => reqwest::Method::OPTIONS,
};
let err = HTTPError {
let url = ok_or_return!(Url::parse(&request.url), store, |err| HTTPError {
message: format!("Error when parsing the URL: {err:?}"),
});

let mut reqw_req = Request::new(method, url);

for (key, value) in request.headers {
let name = ok_or_return!(HeaderName::from_str(&key), store, |err| HTTPError {
message: format!("Invalid header name: {key}: {err:?}"),
});
let value = ok_or_return!(HeaderValue::from_str(&value), store, |err| HTTPError {
message: format!("Invalid header value: {value}: {err:?}"),
});
reqw_req.headers_mut().insert(name, value);
}

};
let encoded = to_vec(&err).unwrap();
*reqw_req.body_mut() = request.body.map(|b| Body::from(b));

let length = encoded.len();
println!("return len: {length}, {:?}", encoded.to_vec());
let (_, store) = memory.data_and_store_mut(&mut caller);
let index = store.buffers.insert(encoded.into_boxed_slice());
let response = ok_or_return!(client.execute(reqw_req).await, store, |err| HTTPError {
message: format!("Error when sending a request: {err:?}"),
});

let mut headers = HashMap::new();
for (name, value) in response.headers().iter() {
let value = ok_or_return!(value.to_str(), store, |err| HTTPError {
message: format!("Could not parse response header {value:?}: {err:?}"),
});
headers.insert(name.to_string(), value.to_string());
}

println!("returning 0, {length}, {index}");
create_return_value(1, length as u32, index as u32)
let status = response.status().as_u16();
let body = ok_or_return!(response.text().await, store, |err| HTTPError {
message: format!("Problem with fetching the body: {err:?}"),
});

Self::return_result(
&mut caller,
HTTPResponse {
headers,
body,
status,
},
)
}

pub fn consume_buffer(mut caller: Caller<'_, Self>, index: u32, ptr: u32, len: u32) -> anyhow::Result<()> {
let memory = get_memory(&mut caller).unwrap();
pub fn consume_buffer(
mut caller: Caller<'_, Self>,
index: u32,
ptr: u32,
len: u32,
) -> anyhow::Result<()> {
let memory = get_memory(&mut caller)?;
let (mut slice, store) = memory.data_and_store_mut(&mut caller);

let buffer = store.buffers.try_remove(index as usize).unwrap();
anyhow::ensure!(
len as usize == buffer.len(),
"bad length passed to consume_buffer"
);
slice.get_mut((ptr as usize)..((ptr+len) as usize)).unwrap().copy_from_slice(&buffer);
slice
.get_mut((ptr as usize)..((ptr + len) as usize))
.unwrap()
.copy_from_slice(&buffer);

Ok(())
}

pub fn fetch_arg<T>(mut caller: &mut Caller<'_, Self>, ptr: u32, len: u32) -> anyhow::Result<T>
where
T: BorshDeserialize,
{
let memory = get_memory(&mut caller)?;

let slice = memory
.data(&caller)
.get(ptr as usize..(ptr + len) as usize)
.ok_or(anyhow!("Could not get memory slice"))?;

let arg = from_slice(slice)?;

return Ok(arg);
}

pub fn return_result<T>(mut caller: &mut Caller<'_, Self>, ret: T) -> anyhow::Result<u64>
where
T: BorshSerialize,
{
let memory = get_memory(&mut caller)?;
let (_, store) = memory.data_and_store_mut(&mut caller);

let encoded = to_vec(&ret)?;

let length = encoded.len();
let index = store.buffers.insert(encoded.into_boxed_slice());

Ok(create_return_value(0, length as u32, index as u32))
}
}

impl wasmtime_wasi::WasiView for WasiHostCtx {
Expand Down Expand Up @@ -173,10 +252,16 @@ impl Instance {

let mut linker = Linker::new(&engine);

linker.func_wrap("crows", "log", WasiHostCtx::log).unwrap();
linker.func_wrap("crows", "consume_buffer", WasiHostCtx::consume_buffer).unwrap();
// linker.func_wrap("crows", "log", WasiHostCtx::log).unwrap();
linker
.func_wrap("crows", "consume_buffer", WasiHostCtx::consume_buffer)
.unwrap();
linker
.func_wrap("crows", "http", WasiHostCtx::http)
.func_wrap2_async("crows", "http", |caller, ptr, len| {
Box::new(async move {
WasiHostCtx::http(caller, ptr, len).await
})
})
.unwrap();
// let _ = linker.func_new_async(
// "crows",
Expand Down Expand Up @@ -235,6 +320,7 @@ pub async fn run_wasm(instance: &Instance) -> anyhow::Result<()> {
preview1_adapter: wasmtime_wasi::preview1::WasiPreview1Adapter::new(),
buffers: slab::Slab::default(),
memory: None,
client: reqwest::Client::new(),
};
let mut store: Store<WasiHostCtx> = Store::new(&instance.engine, host_ctx);

Expand Down
2 changes: 1 addition & 1 deletion wasm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::time::Instant;

use crows_wasm::{run_wasm, Instance};

#[tokio::main]
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), anyhow::Error> {
let path = std::env::var("MODULE_PATH").expect("MODULE_PATH env var is not set");
let content = std::fs::read(path).unwrap();
Expand Down

0 comments on commit 47f64e1

Please sign in to comment.