From ab1daa95fb020965648bdcc8bff38ce6d53bbc91 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 17:19:01 +0900 Subject: [PATCH] Add stopper option to server --- Cargo.toml | 1 + src/server/mod.rs | 31 ++++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a16922e..e171178 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ log = "0.4.11" pin-project = "1.0.2" async-channel = "1.5.1" async-dup = "1.2.2" +stopper = "0.2.0" [dev-dependencies] pretty_assertions = "0.6.1" diff --git a/src/server/mod.rs b/src/server/mod.rs index 67af9f0..f96d16f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,6 +6,7 @@ use http_types::headers::{CONNECTION, UPGRADE}; use http_types::upgrade::Connection; use http_types::{Request, Response, StatusCode}; use std::{marker::PhantomData, time::Duration}; +use stopper::Stopper; mod body_reader; mod decode; mod encode; @@ -17,13 +18,16 @@ pub use encode::Encoder; #[derive(Debug, Clone)] pub struct ServerOptions { /// Timeout to handle headers. Defaults to 60s. - headers_timeout: Option, + pub headers_timeout: Option, + /// Stopper to shutdown the server. Defaults to None. + pub stopper: Option, } impl Default for ServerOptions { fn default() -> Self { Self { headers_timeout: Some(Duration::from_secs(60)), + stopper: None, } } } @@ -113,17 +117,30 @@ where // Decode a new request, timing out if this takes longer than the timeout duration. let fut = decode(self.io.clone()); - let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout { - match timeout(timeout_duration, fut).await { + let (req, mut body) = match (self.opts.headers_timeout, &self.opts.stopper) { + (Some(timeout_duration), Some(stopper)) => { + match timeout(timeout_duration, stopper.stop_future(fut)).await { + Ok(Some(Ok(Some(r)))) => r, + Ok(Some(Ok(None))) | Err(TimeoutError { .. }) | Ok(None) => { + return Ok(ConnectionStatus::Close); + } /* EOF, timeout, or stopped by stopper */ + Ok(Some(Err(e))) => return Err(e), + } + } + (Some(timeout_duration), None) => match timeout(timeout_duration, fut).await { Ok(Ok(Some(r))) => r, Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */ Ok(Err(e)) => return Err(e), - } - } else { - match fut.await? { + }, + (None, Some(stopper)) => match stopper.stop_future(fut).await { + Some(Ok(Some(r))) => r, + Some(Ok(None)) | None => return Ok(ConnectionStatus::Close), /* EOF or stopped by stopper */ + Some(Err(e)) => return Err(e), + }, + (None, None) => match fut.await? { Some(r) => r, None => return Ok(ConnectionStatus::Close), /* EOF */ - } + }, }; let has_upgrade_header = req.header(UPGRADE).is_some();