diff --git a/Makefile b/Makefile index 7cf2399..403b28f 100755 --- a/Makefile +++ b/Makefile @@ -27,4 +27,8 @@ pprof: .PHONY: cloc cloc: - cloc --exclude-dir=target . \ No newline at end of file + cloc --exclude-dir=target . + +.PHONY: test +test: + cargo test --all \ No newline at end of file diff --git a/darpi-web/src/json.rs b/darpi-web/src/json.rs index 8e54ff6..42f7d63 100755 --- a/darpi-web/src/json.rs +++ b/darpi-web/src/json.rs @@ -3,6 +3,7 @@ use crate::response::{Responder, ResponderError}; use crate::Response; use async_trait::async_trait; use derive_more::Display; +use http::header::HeaderName; use http::{header, HeaderMap, HeaderValue}; use hyper::Body; use serde::de::DeserializeOwned; @@ -10,11 +11,22 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Error; use std::{fmt, ops}; -pub struct Json(pub T); +pub struct Json { + t: T, + hm: HeaderMap, +} impl Json { - pub fn into_inner(self) -> T { - self.0 + pub fn new(t: T) -> Self { + Self { + t, + hm: Default::default(), + } + } + + pub fn header(mut self, key: HeaderName, value: HeaderValue) -> Self { + self.hm.append(key, value); + self } async fn deserialize_future(b: Body) -> Result, JsonErr> @@ -23,7 +35,7 @@ impl Json { { let full_body = hyper::body::to_bytes(b).await?; let ser: T = serde_json::from_slice(&full_body)?; - Ok(Json(ser)) + Ok(Json::new(ser)) } } @@ -55,7 +67,7 @@ where D: Deserializer<'de>, { let deser = T::deserialize(deserializer)?.into(); - Ok(Json(deser)) + Ok(Json::new(deser)) } } @@ -63,13 +75,13 @@ impl ops::Deref for Json { type Target = T; fn deref(&self) -> &T { - &self.0 + &self.t } } impl ops::DerefMut for Json { fn deref_mut(&mut self) -> &mut T { - &mut self.0 + &mut self.t } } @@ -78,7 +90,7 @@ where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Json: {:?}", self.0) + write!(f, "Json: {:?}", self.t) } } @@ -87,7 +99,7 @@ where T: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, f) + fmt::Display::fmt(&self.t, f) } } @@ -96,12 +108,17 @@ where T: Serialize, { fn respond(self) -> Response { - match serde_json::to_string(&self.0) { - Ok(body) => Response::builder() - .header(header::CONTENT_TYPE, "application/json") - .status(self.status_code()) - .body(Body::from(body)) - .expect("this cannot happen"), + match serde_json::to_string(&self.t) { + Ok(body) => { + let mut rb = Response::builder() + .header(header::CONTENT_TYPE, "application/json") + .status(self.status_code()); + + for (hk, hv) in self.hm.iter() { + rb = rb.header(hk, hv); + } + rb.body(Body::from(body)).expect("this cannot happen") + } Err(e) => e.respond_err(), } } diff --git a/darpi-web/src/xml.rs b/darpi-web/src/xml.rs index 6c7eadc..4d46fb8 100755 --- a/darpi-web/src/xml.rs +++ b/darpi-web/src/xml.rs @@ -4,6 +4,7 @@ use crate::Response; use async_trait::async_trait; use bytes::Buf; use derive_more::Display; +use http::header::HeaderName; use http::{header, HeaderMap, HeaderValue}; use hyper::Body; use serde::de::DeserializeOwned; @@ -11,16 +12,31 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_xml_rs::Error; use std::{fmt, ops}; -pub struct Xml(pub T); +pub struct Xml { + t: T, + hm: HeaderMap, +} impl Xml { + pub fn new(t: T) -> Self { + Self { + t, + hm: Default::default(), + } + } + + pub fn header(mut self, key: HeaderName, value: HeaderValue) -> Self { + self.hm.append(key, value); + self + } + async fn deserialize_future(b: Body) -> Result, XmlErr> where T: DeserializeOwned, { let full_body = hyper::body::to_bytes(b).await?; let ser: T = serde_xml_rs::from_reader(full_body.reader())?; - Ok(Xml(ser)) + Ok(Xml::new(ser)) } } @@ -52,7 +68,7 @@ where D: Deserializer<'de>, { let deser = T::deserialize(deserializer)?.into(); - Ok(Xml(deser)) + Ok(Xml::new(deser)) } } @@ -60,13 +76,13 @@ impl ops::Deref for Xml { type Target = T; fn deref(&self) -> &T { - &self.0 + &self.t } } impl ops::DerefMut for Xml { fn deref_mut(&mut self) -> &mut T { - &mut self.0 + &mut self.t } } @@ -75,7 +91,7 @@ where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Xml: {:?}", self.0) + write!(f, "Xml: {:?}", self.t) } } @@ -84,7 +100,7 @@ where T: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, f) + fmt::Display::fmt(&self.t, f) } } @@ -93,12 +109,17 @@ where T: Serialize, { fn respond(self) -> Response { - match serde_xml_rs::to_string(&self.0) { - Ok(body) => Response::builder() - .header(header::CONTENT_TYPE, "application/xml") - .status(self.status_code()) - .body(Body::from(body)) - .expect("this cannot happen"), + match serde_xml_rs::to_string(&self.t) { + Ok(body) => { + let mut rb = Response::builder() + .header(header::CONTENT_TYPE, "application/xml") + .status(self.status_code()); + + for (hk, hv) in self.hm.iter() { + rb = rb.header(hk, hv); + } + rb.body(Body::from(body)).expect("this cannot happen") + } Err(e) => e.respond_err(), } } diff --git a/darpi-web/src/yaml.rs b/darpi-web/src/yaml.rs index 3d9ecca..e3bc7b6 100755 --- a/darpi-web/src/yaml.rs +++ b/darpi-web/src/yaml.rs @@ -3,6 +3,7 @@ use crate::response::{Responder, ResponderError}; use crate::Response; use async_trait::async_trait; use derive_more::Display; +use http::header::HeaderName; use http::{header, HeaderMap, HeaderValue}; use hyper::Body; use serde::de::DeserializeOwned; @@ -10,16 +11,31 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_yaml::Error; use std::{fmt, ops}; -pub struct Yaml(pub T); +pub struct Yaml { + t: T, + hm: HeaderMap, +} impl Yaml { + pub fn new(t: T) -> Self { + Self { + t, + hm: Default::default(), + } + } + + pub fn header(mut self, key: HeaderName, value: HeaderValue) -> Self { + self.hm.append(key, value); + self + } + async fn deserialize_future(b: Body) -> Result, YamlErr> where T: DeserializeOwned, { let full_body = hyper::body::to_bytes(b).await?; let ser: T = serde_yaml::from_slice(&full_body)?; - Ok(Yaml(ser)) + Ok(Yaml::new(ser)) } } @@ -51,7 +67,7 @@ where D: Deserializer<'de>, { let deser = T::deserialize(deserializer)?.into(); - Ok(Yaml(deser)) + Ok(Yaml::new(deser)) } } @@ -59,13 +75,13 @@ impl ops::Deref for Yaml { type Target = T; fn deref(&self) -> &T { - &self.0 + &self.t } } impl ops::DerefMut for Yaml { fn deref_mut(&mut self) -> &mut T { - &mut self.0 + &mut self.t } } @@ -74,7 +90,7 @@ where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Yaml: {:?}", self.0) + write!(f, "Yaml: {:?}", self.t) } } @@ -83,7 +99,7 @@ where T: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, f) + fmt::Display::fmt(&self.t, f) } } @@ -92,12 +108,18 @@ where T: Serialize, { fn respond(self) -> Response { - match serde_yaml::to_string(&self.0) { - Ok(body) => Response::builder() - .header(header::CONTENT_TYPE, "application/Yaml") - .status(self.status_code()) - .body(Body::from(body)) - .expect("this cannot happen"), + match serde_yaml::to_string(&self.t) { + Ok(body) => { + let mut rb = Response::builder() + .header(header::CONTENT_TYPE, "application/Yaml") + .status(self.status_code()); + + for (hk, hv) in self.hm.iter() { + rb = rb.header(hk, hv); + } + + rb.body(Body::from(body)).expect("this cannot happen") + } Err(e) => e.respond_err(), } } diff --git a/examples/responses.rs b/examples/responses.rs index e555bfd..1a3b73e 100644 --- a/examples/responses.rs +++ b/examples/responses.rs @@ -1,6 +1,9 @@ use darpi::{app, handler, App, Body, Json, Responder, Response, StatusCode}; use env_logger; +use http::header::HeaderName; +use http::HeaderValue; use serde::Serialize; +use std::str::FromStr; pub struct HelloWorldResp; @@ -35,9 +38,13 @@ pub struct Resp { #[handler] async fn json() -> Json { - Json(Resp { + Json::new(Resp { name: "John".to_string(), }) + .header( + HeaderName::from_str("Keep-Alive").unwrap(), + HeaderValue::from_str("timeout=5").unwrap(), + ) } #[darpi::main] diff --git a/tests/test_handler.rs b/tests/test_handler.rs index ddf458d..107709b 100644 --- a/tests/test_handler.rs +++ b/tests/test_handler.rs @@ -2,10 +2,14 @@ use darpi::response::ResponderError; use darpi::{handler, Path}; #[cfg(test)] use darpi::{Args, Body, Handler, Request, StatusCode}; +use darpi_web::Json; use derive_more::Display; +use http::header::HeaderName; +use http::HeaderValue; use serde::{Deserialize, Serialize}; use std::convert::TryInto; use std::num::TryFromIntError; +use std::str::FromStr; #[cfg(test)] use std::sync::Arc; @@ -49,6 +53,46 @@ async fn increment_byte_not_ok() { assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, resp.status()); } +#[tokio::test] +async fn set_correct_header() { + let req = Request::get("http://127.0.0.1:3000/json") + .body(Body::empty()) + .unwrap(); + + let resp = Handler::call( + json, + Args { + request: req, + container: Arc::new(()), + route_args: (), + }, + ) + .await + .unwrap(); + + assert_eq!(StatusCode::OK, resp.status()); + assert_eq!( + b"timeout=5", + resp.headers().get("Keep-Alive").unwrap().as_bytes() + ) +} + +#[derive(Serialize)] +pub struct Resp { + name: String, +} + +#[handler] +async fn json() -> Json { + Json::new(Resp { + name: "John".to_string(), + }) + .header( + HeaderName::from_str("Keep-Alive").unwrap(), + HeaderValue::from_str("timeout=5").unwrap(), + ) +} + #[derive(Display, Debug)] pub enum IncrementError { Overflow,