Skip to content

Commit

Permalink
feat: fixed async_wait and moved wait() to take mutable ref
Browse files Browse the repository at this point in the history
  • Loading branch information
hansl committed Apr 21, 2021
1 parent e4cf281 commit f832c37
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 73 deletions.
Binary file added .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ include = ["Cargo.toml", "src", "README.md", "LICENSE"]
[dependencies]
futures-util = { version = "0.3.13", optional = true }

[dev-dependencies]
tokio = { version = "1.5.0", features = ["macros", "rt", "sync"] }

[features]
async = ["futures-util"]
no_std = []
4 changes: 2 additions & 2 deletions src/compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ impl Waiter for DelayComposer {
self.a.start();
self.b.start();
}
fn wait(&self) -> Result<(), WaiterError> {
fn wait(&mut self) -> Result<(), WaiterError> {
self.a.wait()?;
self.b.wait()?;
Ok(())
}

#[cfg(feature = "async")]
fn async_wait(&self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
use futures_util::TryFutureExt;
Box::pin(
futures_util::future::try_join(self.a.async_wait(), self.b.async_wait()).map_ok(|_| ()),
Expand Down
56 changes: 34 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
extern crate alloc;

use alloc::boxed::Box;
use core::cell::RefCell;
use core::sync::atomic::{AtomicU64, Ordering};

#[cfg(not(feature = "no_std"))]
use core::time::Duration;
Expand Down Expand Up @@ -41,7 +41,7 @@ pub enum WaiterError {
/// A waiter trait, that can be used for executing a delay. Waiters need to be
/// multi-threaded and cloneable.
/// A waiter should not be reused twice.
pub trait Waiter: WaiterClone + Send {
pub trait Waiter: WaiterClone + Send + Sync {
/// Restart the wait timer.
/// Call after starting the waiter otherwise returns an error.
fn restart(&mut self) -> Result<(), WaiterError> {
Expand All @@ -53,12 +53,12 @@ pub trait Waiter: WaiterClone + Send {

/// Called at each cycle of the waiting cycle.
/// Call after starting the waiter otherwise returns an error.
fn wait(&self) -> Result<(), WaiterError>;
fn wait(&mut self) -> Result<(), WaiterError>;

/// Async version of [wait]. By default call the blocking wait. Should be implemented
/// to be non-blocking.
#[cfg(feature = "async")]
fn async_wait(&self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
Box::pin(future::ready(self.wait()))
}
}
Expand Down Expand Up @@ -99,15 +99,15 @@ impl Waiter for Box<dyn Waiter> {

/// Called at each cycle of the waiting cycle.
/// Call after starting the waiter otherwise returns an error.
fn wait(&self) -> Result<(), WaiterError> {
self.as_ref().wait()
fn wait(&mut self) -> Result<(), WaiterError> {
self.as_mut().wait()
}

/// Async version of [wait]. By default call the blocking wait. Should be implemented
/// to be non-blocking.
#[cfg(feature = "async")]
fn async_wait(&self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
self.as_ref().async_wait()
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
self.as_mut().async_wait()
}
}

Expand Down Expand Up @@ -193,12 +193,12 @@ impl Waiter for Delay {
fn start(&mut self) {
self.inner.start()
}
fn wait(&self) -> Result<(), WaiterError> {
fn wait(&mut self) -> Result<(), WaiterError> {
self.inner.wait()
}

#[cfg(feature = "async")]
fn async_wait(&self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
self.inner.async_wait()
}
}
Expand Down Expand Up @@ -253,15 +253,14 @@ impl DelayBuilder {
#[derive(Clone)]
struct InstantWaiter {}
impl Waiter for InstantWaiter {
fn wait(&self) -> Result<(), WaiterError> {
fn wait(&mut self) -> Result<(), WaiterError> {
Ok(())
}
}

#[derive(Clone)]
struct CountTimeoutWaiter {
max_count: u64,
count: Option<RefCell<u64>>,
count: Option<AtomicU64>,
}
impl CountTimeoutWaiter {
pub fn new(max_count: u64) -> Self {
Expand All @@ -271,22 +270,35 @@ impl CountTimeoutWaiter {
}
}
}
impl Clone for CountTimeoutWaiter {
fn clone(&self) -> Self {
Self {
max_count: self.max_count,
count: self
.count
.as_ref()
.map(|count| AtomicU64::new(count.load(Ordering::Relaxed))),
}
}
}
impl Waiter for CountTimeoutWaiter {
fn restart(&mut self) -> Result<(), WaiterError> {
let count = self.count.as_ref().ok_or(WaiterError::NotStarted)?;
count.replace(0);
Ok(())
if self.count.is_none() {
Err(WaiterError::NotStarted)
} else {
self.count = Some(AtomicU64::new(0));
Ok(())
}
}

fn start(&mut self) {
self.count = Some(RefCell::new(0));
self.count = Some(AtomicU64::new(0));
}

fn wait(&self) -> Result<(), WaiterError> {
let count = self.count.as_ref().ok_or(WaiterError::NotStarted)?;
let current = *count.borrow() + 1;
count.replace(current);
fn wait(&mut self) -> Result<(), WaiterError> {
let count = self.count.as_mut().ok_or(WaiterError::NotStarted)?;

let current = count.fetch_add(1, Ordering::Relaxed);
if current >= self.max_count {
Err(WaiterError::Timeout)
} else {
Expand Down Expand Up @@ -315,7 +327,7 @@ impl<F> Waiter for SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
fn wait(&self) -> Result<(), WaiterError> {
fn wait(&mut self) -> Result<(), WaiterError> {
(self.function)()
}
}
29 changes: 29 additions & 0 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn counter_works() {
let mut waiter = Delay::count_timeout(3);
waiter.start();

assert!(waiter.wait().is_ok());
assert!(waiter.wait().is_ok());
assert!(waiter.wait().is_ok());
assert!(waiter.wait().is_err());
Expand All @@ -46,11 +47,13 @@ fn clone_works() {
waiter1.start();
assert!(waiter1.wait().is_ok());
assert!(waiter1.wait().is_ok());
assert!(waiter1.wait().is_ok());
assert!(waiter1.wait().is_err());

waiter2.start();
assert!(waiter2.wait().is_ok());
assert!(waiter2.wait().is_ok());
assert!(waiter2.wait().is_ok());
assert!(waiter2.wait().is_err());
}

Expand Down Expand Up @@ -115,3 +118,29 @@ fn can_send_between_threads() {

rx_end.recv_timeout(Duration::from_millis(100)).unwrap();
}

#[tokio::test]
async fn works_as_async() {
let mut waiter = Delay::count_timeout(5);

let (tx, mut rx) = tokio::sync::mpsc::channel(5);
let (tx_end, mut rx_end) = tokio::sync::mpsc::channel(1);

tokio::task::spawn(async move {
waiter.start();

while let Some(x) = rx.recv().await.unwrap_or(None) {
for _i in 1..x {
waiter.async_wait().await.unwrap();
}
}

tx_end.send(()).await.unwrap();
});

tx.send(Some(4)).await.unwrap();
tx.send(Some(1)).await.unwrap();
tx.send(None).await.unwrap();

rx_end.recv().await.unwrap();
}
102 changes: 54 additions & 48 deletions src/throttle.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![cfg(not(feature = "no_std"))]
use crate::{Waiter, WaiterError};
use std::cell::RefCell;
use std::time::Duration;

use core::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "async")]
use std::{future::Future, pin::Pin};

Expand Down Expand Up @@ -96,83 +96,89 @@ impl ThrottleWaiter {
}
}
impl Waiter for ThrottleWaiter {
fn wait(&self) -> Result<(), WaiterError> {
fn wait(&mut self) -> Result<(), WaiterError> {
std::thread::sleep(self.throttle);

Ok(())
}

#[cfg(feature = "async")]
fn async_wait(&self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
Box::pin(future::ThrottleTimerFuture::new(self.throttle))
}
}

#[derive(Clone)]
pub struct ExponentialBackoffWaiter {
next: Option<RefCell<Duration>>,
initial: Duration,
multiplier: f32,
cap: Duration,
next_as_micros: Option<AtomicU64>,
initial_as_micros: u64,
multiplier: f64,
cap_as_micros: u64,
}
impl ExponentialBackoffWaiter {
pub fn new(initial: Duration, multiplier: f32, cap: Duration) -> Self {
ExponentialBackoffWaiter {
next: None,
initial,
multiplier,
cap,
next_as_micros: None,
initial_as_micros: initial.as_micros() as u64,
multiplier: multiplier as f64,
cap_as_micros: cap.as_micros() as u64,
}
}

fn increment(&mut self) -> Result<Duration, WaiterError> {
let next = self
.next_as_micros
.as_ref()
.ok_or(WaiterError::NotStarted)?;
let current = next.load(Ordering::Relaxed);

// Find the next throttle.
let next = u64::max(
(current as f64 * self.multiplier) as u64,
self.cap_as_micros,
);
self.next_as_micros
.as_mut()
.unwrap()
.store(next, Ordering::Relaxed);
Ok(Duration::from_micros(current))
}
}
impl Clone for ExponentialBackoffWaiter {
fn clone(&self) -> Self {
Self {
next_as_micros: self
.next_as_micros
.as_ref()
.map(|a| AtomicU64::new(a.load(Ordering::Relaxed))),
..*self
}
}
}
impl Waiter for ExponentialBackoffWaiter {
fn restart(&mut self) -> Result<(), WaiterError> {
let next = self.next.as_ref().ok_or(WaiterError::NotStarted)?;
next.replace(self.initial);
Ok(())
if self.next_as_micros.is_none() {
Err(WaiterError::NotStarted)
} else {
self.next_as_micros = Some(AtomicU64::new(self.initial_as_micros));
Ok(())
}
}

fn start(&mut self) {
self.next = Some(RefCell::new(self.initial));
self.next_as_micros = Some(AtomicU64::new(self.initial_as_micros));
}

fn wait(&self) -> Result<(), WaiterError> {
let next = self.next.as_ref().ok_or(WaiterError::NotStarted)?;
let current = *next.borrow();
let current_nsec = current.as_nanos() as f32;

// Find the next throttle.
let mut next_duration = Duration::from_nanos((current_nsec * self.multiplier) as u64);
if next_duration > self.cap {
next_duration = self.cap;
}

next.replace(next_duration);

fn wait(&mut self) -> Result<(), WaiterError> {
let current = self.increment()?;
std::thread::sleep(current);

Ok(())
}

#[cfg(feature = "async")]
fn async_wait(&self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
let next = if let Some(next) = self.next.as_ref() {
next
} else {
return Box::pin(std::future::ready(Err(WaiterError::NotStarted)));
};

let current = *next.borrow();
let current_nsec = current.as_nanos() as f32;

// Find the next throttle.
let mut next_duration = Duration::from_nanos((current_nsec * self.multiplier) as u64);
if next_duration > self.cap {
next_duration = self.cap;
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
match self.increment() {
Ok(current) => Box::pin(future::ThrottleTimerFuture::new(current)),
Err(e) => Box::pin(futures_util::future::err(e)),
}

next.replace(next_duration);

Box::pin(future::ThrottleTimerFuture::new(current))
}
}
2 changes: 1 addition & 1 deletion src/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl Waiter for TimeoutWaiter {
fn start(&mut self) {
self.start = Some(Instant::now());
}
fn wait(&self) -> Result<(), WaiterError> {
fn wait(&mut self) -> Result<(), WaiterError> {
let start = self.start.ok_or(WaiterError::NotStarted)?;
if start.elapsed() > self.timeout {
Err(WaiterError::Timeout)
Expand Down

0 comments on commit f832c37

Please sign in to comment.