Skip to content

Commit

Permalink
chore: simplify Service impls (#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes authored Dec 30, 2024
1 parent 262089c commit 89dbbee
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 111 deletions.
4 changes: 1 addition & 3 deletions crates/json-rpc/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ pub type BorrowedRpcResult<'a, E> = RpcResult<&'a RawValue, E, &'a RawValue>;
/// Transform a transport response into an [`RpcResult`], discarding the [`Id`].
///
/// [`Id`]: crate::Id
pub fn transform_response<T, E, ErrResp>(
response: Response<T, ErrResp>,
) -> Result<T, RpcError<E, ErrResp>>
pub fn transform_response<T, E, ErrResp>(response: Response<T, ErrResp>) -> RpcResult<T, E, ErrResp>
where
ErrResp: RpcReturn,
{
Expand Down
112 changes: 51 additions & 61 deletions crates/transport-http/src/hyper_transport.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_transport::{
utils::guess_local_url, TransportConnect, TransportError, TransportErrorKind, TransportFut,
TransportResult,
};
use http_body_util::{BodyExt, Full};
use hyper::{
Expand Down Expand Up @@ -79,63 +80,46 @@ where
ResBody::Error: std::error::Error + Send + Sync + 'static,
ResBody::Data: Send,
{
/// Make a request to the server using the given service.
fn request_hyper(&self, req: RequestPacket) -> TransportFut<'static> {
let this = self.clone();
let span = debug_span!("HyperClient", url = %this.url);
Box::pin(
async move {
debug!(count = req.len(), "sending request packet to server");
let ser = req.serialize().map_err(TransportError::ser_err)?;
// convert the Box<RawValue> into a hyper request<B>
let body = ser.get().as_bytes().to_owned().into();

let req = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(this.url.as_str())
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.body(body)
.expect("request parts are invalid");

let mut service = this.client.service.clone();
let resp = service.call(req).await.map_err(TransportErrorKind::custom)?;

let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp
.into_body()
.collect()
.await
.map_err(TransportErrorKind::custom)?
.to_bytes();

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != hyper::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body).map_err(|err| {
TransportError::deser_err(err, String::from_utf8_lossy(body.as_ref()))
})
}
.instrument(span),
)
async fn do_hyper(self, req: RequestPacket) -> TransportResult<ResponsePacket> {
debug!(count = req.len(), "sending request packet to server");
let ser = req.serialize().map_err(TransportError::ser_err)?;
// convert the Box<RawValue> into a hyper request<B>
let body = ser.get().as_bytes().to_owned().into();

let req = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(self.url.as_str())
.header(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"))
.body(body)
.expect("request parts are invalid");

let mut service = self.client.service;
let resp = service.call(req).await.map_err(TransportErrorKind::custom)?;

let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp.into_body().collect().await.map_err(TransportErrorKind::custom)?.to_bytes();

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != hyper::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body)
.map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(body.as_ref())))
}
}

Expand Down Expand Up @@ -168,12 +152,14 @@ where
type Error = TransportError;
type Future = TransportFut<'static>;

fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
task::Poll::Ready(Ok(()))
#[inline]
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
(&*self).poll_ready(cx)
}

#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_hyper(req)
(&*self).call(req)
}
}

Expand All @@ -188,11 +174,15 @@ where
type Error = TransportError;
type Future = TransportFut<'static>;

#[inline]
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
// `hyper` always returns `Ok(())`.
task::Poll::Ready(Ok(()))
}

fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_hyper(req)
let this = self.clone();
let span = debug_span!("HyperTransport", url = %this.url);
Box::pin(this.do_hyper(req).instrument(span))
}
}
87 changes: 40 additions & 47 deletions crates/transport-http/src/reqwest_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{Http, HttpConnect};
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_transport::{
utils::guess_local_url, TransportConnect, TransportError, TransportErrorKind, TransportFut,
TransportResult,
};
use std::task;
use tower::Service;
Expand Down Expand Up @@ -37,46 +38,38 @@ impl Http<Client> {
Self { client: Default::default(), url }
}

/// Make a request.
fn request_reqwest(&self, req: RequestPacket) -> TransportFut<'static> {
let this = self.clone();
let span: tracing::Span = debug_span!("ReqwestTransport", url = %self.url);
Box::pin(
async move {
let resp = this
.client
.post(this.url)
.json(&req)
.send()
.await
.map_err(TransportErrorKind::custom)?;
let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp.bytes().await.map_err(TransportErrorKind::custom)?;

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != reqwest::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body)
.map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(&body)))
}
.instrument(span),
)
async fn do_reqwest(self, req: RequestPacket) -> TransportResult<ResponsePacket> {
let resp = self
.client
.post(self.url)
.json(&req)
.send()
.await
.map_err(TransportErrorKind::custom)?;
let status = resp.status();

debug!(%status, "received response from server");

// Unpack data from the response body. We do this regardless of
// the status code, as we want to return the error in the body
// if there is one.
let body = resp.bytes().await.map_err(TransportErrorKind::custom)?;

debug!(bytes = body.len(), "retrieved response body. Use `trace` for full body");
trace!(body = %String::from_utf8_lossy(&body), "response body");

if status != reqwest::StatusCode::OK {
return Err(TransportErrorKind::http_error(
status.as_u16(),
String::from_utf8_lossy(&body).into_owned(),
));
}

// Deserialize a Box<RawValue> from the body. If deserialization fails, return
// the body as a string in the error. The conversion to String
// is lossy and may not cover all the bytes in the body.
serde_json::from_slice(&body)
.map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(&body)))
}
}

Expand All @@ -86,14 +79,13 @@ impl Service<RequestPacket> for Http<reqwest::Client> {
type Future = TransportFut<'static>;

#[inline]
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
// reqwest always returns ok
task::Poll::Ready(Ok(()))
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
(&*self).poll_ready(cx)
}

#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_reqwest(req)
(&*self).call(req)
}
}

Expand All @@ -104,12 +96,13 @@ impl Service<RequestPacket> for &Http<reqwest::Client> {

#[inline]
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
// reqwest always returns ok
// `reqwest` always returns `Ok(())`.
task::Poll::Ready(Ok(()))
}

#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request_reqwest(req)
let this = self.clone();
let span = debug_span!("ReqwestTransport", url = %this.url);
Box::pin(this.do_reqwest(req).instrument(span))
}
}

0 comments on commit 89dbbee

Please sign in to comment.