diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..9d85feb Binary files /dev/null and b/.DS_Store differ diff --git a/Cargo.toml b/Cargo.toml index e684aea..7845c4e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,9 @@ include = ["Cargo.toml", "src", "README.md", "LICENSE"] [dependencies] futures-util = { version = "0.3.13", optional = true } +[dev-dependencies] +tokio = { version = "1.5.0", features = ["macros", "rt", "sync"] } + [features] async = ["futures-util"] no_std = [] diff --git a/src/compose.rs b/src/compose.rs index 38b891e..be6c4fc 100644 --- a/src/compose.rs +++ b/src/compose.rs @@ -26,14 +26,14 @@ impl Waiter for DelayComposer { self.a.start(); self.b.start(); } - fn wait(&self) -> Result<(), WaiterError> { + fn wait(&mut self) -> Result<(), WaiterError> { self.a.wait()?; self.b.wait()?; Ok(()) } #[cfg(feature = "async")] - fn async_wait(&self) -> Pin> + Send>> { + fn async_wait(&mut self) -> Pin> + Send>> { use futures_util::TryFutureExt; Box::pin( futures_util::future::try_join(self.a.async_wait(), self.b.async_wait()).map_ok(|_| ()), diff --git a/src/lib.rs b/src/lib.rs index e2adc5d..46f97b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ extern crate alloc; use alloc::boxed::Box; -use core::cell::RefCell; +use core::sync::atomic::{AtomicU64, Ordering}; #[cfg(not(feature = "no_std"))] use core::time::Duration; @@ -41,7 +41,7 @@ pub enum WaiterError { /// A waiter trait, that can be used for executing a delay. Waiters need to be /// multi-threaded and cloneable. /// A waiter should not be reused twice. -pub trait Waiter: WaiterClone + Send { +pub trait Waiter: WaiterClone + Send + Sync { /// Restart the wait timer. /// Call after starting the waiter otherwise returns an error. fn restart(&mut self) -> Result<(), WaiterError> { @@ -53,12 +53,12 @@ pub trait Waiter: WaiterClone + Send { /// Called at each cycle of the waiting cycle. /// Call after starting the waiter otherwise returns an error. - fn wait(&self) -> Result<(), WaiterError>; + fn wait(&mut self) -> Result<(), WaiterError>; /// Async version of [wait]. By default call the blocking wait. Should be implemented /// to be non-blocking. #[cfg(feature = "async")] - fn async_wait(&self) -> Pin> + Send>> { + fn async_wait(&mut self) -> Pin> + Send>> { Box::pin(future::ready(self.wait())) } } @@ -99,15 +99,15 @@ impl Waiter for Box { /// Called at each cycle of the waiting cycle. /// Call after starting the waiter otherwise returns an error. - fn wait(&self) -> Result<(), WaiterError> { - self.as_ref().wait() + fn wait(&mut self) -> Result<(), WaiterError> { + self.as_mut().wait() } /// Async version of [wait]. By default call the blocking wait. Should be implemented /// to be non-blocking. #[cfg(feature = "async")] - fn async_wait(&self) -> Pin> + Send>> { - self.as_ref().async_wait() + fn async_wait(&mut self) -> Pin> + Send>> { + self.as_mut().async_wait() } } @@ -193,12 +193,12 @@ impl Waiter for Delay { fn start(&mut self) { self.inner.start() } - fn wait(&self) -> Result<(), WaiterError> { + fn wait(&mut self) -> Result<(), WaiterError> { self.inner.wait() } #[cfg(feature = "async")] - fn async_wait(&self) -> Pin> + Send>> { + fn async_wait(&mut self) -> Pin> + Send>> { self.inner.async_wait() } } @@ -253,15 +253,14 @@ impl DelayBuilder { #[derive(Clone)] struct InstantWaiter {} impl Waiter for InstantWaiter { - fn wait(&self) -> Result<(), WaiterError> { + fn wait(&mut self) -> Result<(), WaiterError> { Ok(()) } } -#[derive(Clone)] struct CountTimeoutWaiter { max_count: u64, - count: Option>, + count: Option, } impl CountTimeoutWaiter { pub fn new(max_count: u64) -> Self { @@ -271,22 +270,35 @@ impl CountTimeoutWaiter { } } } +impl Clone for CountTimeoutWaiter { + fn clone(&self) -> Self { + Self { + max_count: self.max_count, + count: self + .count + .as_ref() + .map(|count| AtomicU64::new(count.load(Ordering::Relaxed))), + } + } +} impl Waiter for CountTimeoutWaiter { fn restart(&mut self) -> Result<(), WaiterError> { - let count = self.count.as_ref().ok_or(WaiterError::NotStarted)?; - count.replace(0); - Ok(()) + if self.count.is_none() { + Err(WaiterError::NotStarted) + } else { + self.count = Some(AtomicU64::new(0)); + Ok(()) + } } fn start(&mut self) { - self.count = Some(RefCell::new(0)); + self.count = Some(AtomicU64::new(0)); } - fn wait(&self) -> Result<(), WaiterError> { - let count = self.count.as_ref().ok_or(WaiterError::NotStarted)?; - let current = *count.borrow() + 1; - count.replace(current); + fn wait(&mut self) -> Result<(), WaiterError> { + let count = self.count.as_mut().ok_or(WaiterError::NotStarted)?; + let current = count.fetch_add(1, Ordering::Relaxed); if current >= self.max_count { Err(WaiterError::Timeout) } else { @@ -315,7 +327,7 @@ impl Waiter for SideEffectWaiter where F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>, { - fn wait(&self) -> Result<(), WaiterError> { + fn wait(&mut self) -> Result<(), WaiterError> { (self.function)() } } diff --git a/src/tests.rs b/src/tests.rs index fe1fadb..8322127 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -32,6 +32,7 @@ fn counter_works() { let mut waiter = Delay::count_timeout(3); waiter.start(); + assert!(waiter.wait().is_ok()); assert!(waiter.wait().is_ok()); assert!(waiter.wait().is_ok()); assert!(waiter.wait().is_err()); @@ -46,11 +47,13 @@ fn clone_works() { waiter1.start(); assert!(waiter1.wait().is_ok()); assert!(waiter1.wait().is_ok()); + assert!(waiter1.wait().is_ok()); assert!(waiter1.wait().is_err()); waiter2.start(); assert!(waiter2.wait().is_ok()); assert!(waiter2.wait().is_ok()); + assert!(waiter2.wait().is_ok()); assert!(waiter2.wait().is_err()); } @@ -115,3 +118,29 @@ fn can_send_between_threads() { rx_end.recv_timeout(Duration::from_millis(100)).unwrap(); } + +#[tokio::test] +async fn works_as_async() { + let mut waiter = Delay::count_timeout(5); + + let (tx, mut rx) = tokio::sync::mpsc::channel(5); + let (tx_end, mut rx_end) = tokio::sync::mpsc::channel(1); + + tokio::task::spawn(async move { + waiter.start(); + + while let Some(x) = rx.recv().await.unwrap_or(None) { + for _i in 1..x { + waiter.async_wait().await.unwrap(); + } + } + + tx_end.send(()).await.unwrap(); + }); + + tx.send(Some(4)).await.unwrap(); + tx.send(Some(1)).await.unwrap(); + tx.send(None).await.unwrap(); + + rx_end.recv().await.unwrap(); +} diff --git a/src/throttle.rs b/src/throttle.rs index 7037629..10af9e1 100644 --- a/src/throttle.rs +++ b/src/throttle.rs @@ -1,8 +1,8 @@ #![cfg(not(feature = "no_std"))] use crate::{Waiter, WaiterError}; -use std::cell::RefCell; use std::time::Duration; +use core::sync::atomic::{AtomicU64, Ordering}; #[cfg(feature = "async")] use std::{future::Future, pin::Pin}; @@ -96,83 +96,89 @@ impl ThrottleWaiter { } } impl Waiter for ThrottleWaiter { - fn wait(&self) -> Result<(), WaiterError> { + fn wait(&mut self) -> Result<(), WaiterError> { std::thread::sleep(self.throttle); Ok(()) } #[cfg(feature = "async")] - fn async_wait(&self) -> Pin> + Send>> { + fn async_wait(&mut self) -> Pin> + Send>> { Box::pin(future::ThrottleTimerFuture::new(self.throttle)) } } -#[derive(Clone)] pub struct ExponentialBackoffWaiter { - next: Option>, - initial: Duration, - multiplier: f32, - cap: Duration, + next_as_micros: Option, + initial_as_micros: u64, + multiplier: f64, + cap_as_micros: u64, } impl ExponentialBackoffWaiter { pub fn new(initial: Duration, multiplier: f32, cap: Duration) -> Self { ExponentialBackoffWaiter { - next: None, - initial, - multiplier, - cap, + next_as_micros: None, + initial_as_micros: initial.as_micros() as u64, + multiplier: multiplier as f64, + cap_as_micros: cap.as_micros() as u64, + } + } + + fn increment(&mut self) -> Result { + let next = self + .next_as_micros + .as_ref() + .ok_or(WaiterError::NotStarted)?; + let current = next.load(Ordering::Relaxed); + + // Find the next throttle. + let next = u64::max( + (current as f64 * self.multiplier) as u64, + self.cap_as_micros, + ); + self.next_as_micros + .as_mut() + .unwrap() + .store(next, Ordering::Relaxed); + Ok(Duration::from_micros(current)) + } +} +impl Clone for ExponentialBackoffWaiter { + fn clone(&self) -> Self { + Self { + next_as_micros: self + .next_as_micros + .as_ref() + .map(|a| AtomicU64::new(a.load(Ordering::Relaxed))), + ..*self } } } impl Waiter for ExponentialBackoffWaiter { fn restart(&mut self) -> Result<(), WaiterError> { - let next = self.next.as_ref().ok_or(WaiterError::NotStarted)?; - next.replace(self.initial); - Ok(()) + if self.next_as_micros.is_none() { + Err(WaiterError::NotStarted) + } else { + self.next_as_micros = Some(AtomicU64::new(self.initial_as_micros)); + Ok(()) + } } fn start(&mut self) { - self.next = Some(RefCell::new(self.initial)); + self.next_as_micros = Some(AtomicU64::new(self.initial_as_micros)); } - fn wait(&self) -> Result<(), WaiterError> { - let next = self.next.as_ref().ok_or(WaiterError::NotStarted)?; - let current = *next.borrow(); - let current_nsec = current.as_nanos() as f32; - - // Find the next throttle. - let mut next_duration = Duration::from_nanos((current_nsec * self.multiplier) as u64); - if next_duration > self.cap { - next_duration = self.cap; - } - - next.replace(next_duration); - + fn wait(&mut self) -> Result<(), WaiterError> { + let current = self.increment()?; std::thread::sleep(current); - Ok(()) } #[cfg(feature = "async")] - fn async_wait(&self) -> Pin> + Send>> { - let next = if let Some(next) = self.next.as_ref() { - next - } else { - return Box::pin(std::future::ready(Err(WaiterError::NotStarted))); - }; - - let current = *next.borrow(); - let current_nsec = current.as_nanos() as f32; - - // Find the next throttle. - let mut next_duration = Duration::from_nanos((current_nsec * self.multiplier) as u64); - if next_duration > self.cap { - next_duration = self.cap; + fn async_wait(&mut self) -> Pin> + Send>> { + match self.increment() { + Ok(current) => Box::pin(future::ThrottleTimerFuture::new(current)), + Err(e) => Box::pin(futures_util::future::err(e)), } - - next.replace(next_duration); - - Box::pin(future::ThrottleTimerFuture::new(current)) } } diff --git a/src/timeout.rs b/src/timeout.rs index 2dff7cd..360d754 100644 --- a/src/timeout.rs +++ b/src/timeout.rs @@ -24,7 +24,7 @@ impl Waiter for TimeoutWaiter { fn start(&mut self) { self.start = Some(Instant::now()); } - fn wait(&self) -> Result<(), WaiterError> { + fn wait(&mut self) -> Result<(), WaiterError> { let start = self.start.ok_or(WaiterError::NotStarted)?; if start.elapsed() > self.timeout { Err(WaiterError::Timeout)