Skip to content

Commit

Permalink
some neats, and strengthen the concurrency test
Browse files Browse the repository at this point in the history
  • Loading branch information
tbrezot committed Jan 21, 2025
1 parent 641388e commit f76179a
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 150 deletions.
214 changes: 92 additions & 122 deletions src/adt/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
// ! This module defines tests any implementation of the MemoryADT interface
// must pass.
use rand::{Rng, SeedableRng, rngs::StdRng};
//! This module defines tests any implementation of the MemoryADT interface must
//! pass.
//!
//! The given seeds are used to initialize the random generators used in the
//! test, thus allowing for reproducibility. In particular, all addresses are
//! randomly generated, which should guarantee thread-safety provided a
//! random-enough seed is given.
//!
//! Both addresses and words are 16-byte long.
use crate::MemoryADT;
use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
use std::fmt::Debug;

fn gen_bytes(rng: &mut impl RngCore) -> [u8; 16] {
let mut bytes = [0; 16];
rng.fill_bytes(&mut bytes);
bytes
}

/// Tests the basic write and read operations of a Memory ADT implementation.
///
/// This function verifies the memory operations by first checking empty addresses,
/// then performing a guarded write, and finally validating the written value.
///
/// # Arguments
///
/// * `memory` - Reference to the Memory ADT implementation
/// * `seed` - 32-byte seed for reproducible random generation
///
/// # Type Parameters
///
/// * `T` - The Memory ADT implementation being tested
///
/// # Requirements
///
/// The type `T` must implement:
/// * `MemoryADT + Send + Sync`
/// * `T::Address: Debug + PartialEq + From<[u8; 16]> + Send`
/// * `T::Word: Debug + PartialEq + From<[u8; 16]> + Send`
/// * `T::Error: std::error::Error + Send`
pub async fn test_single_write_and_read<T>(memory: &T, seed: [u8; 32])
/// This function first attempts reading empty addresses, then performing a
/// guarded write, and finally validating the written value.
pub async fn test_single_write_and_read<Memory>(memory: &Memory, seed: [u8; 32])
where
T: MemoryADT + Send + Sync,
T::Address: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send,
T::Word: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send,
T::Error: std::error::Error + Send,
Memory: Send + Sync + MemoryADT,
Memory::Address: Send + Clone + From<[u8; 16]>,
Memory::Word: Send + Debug + Clone + PartialEq + From<[u8; 16]>,
Memory::Error: std::error::Error,
{
let mut rng = StdRng::from_seed(seed);

let empty_read_result = memory
.batch_read(vec![
T::Address::from(rng.gen::<u128>().to_be_bytes()),
T::Address::from(rng.gen::<u128>().to_be_bytes()),
T::Address::from(rng.gen::<u128>().to_be_bytes()),
Memory::Address::from(gen_bytes(&mut rng)),
Memory::Address::from(gen_bytes(&mut rng)),
Memory::Address::from(gen_bytes(&mut rng)),
])
.await
.unwrap();
Expand All @@ -49,21 +45,17 @@ where
{empty_read_result:?}. Seed : {seed:?}"
);

let random_address = T::Address::from(rng.gen::<u128>().to_be_bytes());
let random_word = T::Word::from(rng.gen::<u128>().to_be_bytes());
let a = Memory::Address::from(gen_bytes(&mut rng));
let w = Memory::Word::from(gen_bytes(&mut rng));

let write_result = memory
.guarded_write((random_address.clone(), None), vec![(
random_address.clone(),
random_word.clone(),
)])
.guarded_write((a.clone(), None), vec![(a.clone(), w.clone())])
.await
.unwrap();
assert_eq!(write_result, None);

let read_result: Vec<Option<<T as MemoryADT>::Word>> =
memory.batch_read(vec![random_address]).await.unwrap();
let expected_result = vec![Some(random_word)];
let read_result = memory.batch_read(vec![a]).await.unwrap();
let expected_result = vec![Some(w)];
assert_eq!(
read_result, expected_result,
"test_single_write_and_read failed.\nExpected result : {expected_result:?}, got : \
Expand All @@ -75,66 +67,47 @@ where
///
/// Attempts to write with a None guard to an address containing a value.
/// Verifies that the original value is preserved and the write fails.
///
/// # Arguments
///
/// * `memory` - Reference to the Memory ADT implementation
/// * `seed` - 32-byte seed for reproducible random generation
///
/// # Type Parameters
///
/// * `T` - The Memory ADT implementation being tested
///
/// # Requirements
///
/// The type `T` must implement:
/// * `MemoryADT + Send + Sync`
/// * `T::Address: Debug + PartialEq + From<[u8; 16]> + Send`
/// * `T::Word: Debug + PartialEq + From<[u8; 16]> + Send`
/// * `T::Error: std::error::Error + Send`
pub async fn test_wrong_guard<T>(memory: &T, seed: [u8; 32])
pub async fn test_wrong_guard<Memory>(memory: &Memory, seed: [u8; 32])
where
T: MemoryADT + Send + Sync,
T::Address: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send,
T::Word: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send,
T::Error: std::error::Error + Send,
Memory: Send + Sync + MemoryADT,
Memory::Address: Send + Clone + From<[u8; 16]>,
Memory::Word: Send + Debug + Clone + PartialEq + From<[u8; 16]>,
Memory::Error: Send + std::error::Error,
{
let mut rng = StdRng::from_seed(seed);
let random_address = T::Address::from(rng.gen::<u128>().to_be_bytes());
let word_to_write = T::Word::from(rng.gen::<u128>().to_be_bytes());

let a = Memory::Address::from(gen_bytes(&mut rng));
let w = Memory::Word::from(gen_bytes(&mut rng));

memory
.guarded_write((random_address.clone(), None), vec![(
random_address.clone(),
word_to_write.clone(),
)])
.guarded_write((a.clone(), None), vec![(a.clone(), w.clone())])
.await
.unwrap();

let conflict_result = memory
.guarded_write((random_address.clone(), None), vec![(
random_address.clone(),
T::Word::from(rng.gen::<u128>().to_be_bytes()),
.guarded_write((a.clone(), None), vec![(
a.clone(),
Memory::Word::from(rng.gen::<u128>().to_be_bytes()),
)])
.await
.unwrap();

assert_eq!(
conflict_result,
Some(word_to_write.clone()),
Some(w.clone()),
"test_wrong_guard failed.\nExpected value {:?} after write. Got : {:?}.\nDebug seed : {:?}",
conflict_result,
Some(word_to_write),
Some(w),
seed
);

let read_result = memory.batch_read(vec![random_address]).await.unwrap();
let read_result = memory.batch_read(vec![a]).await.unwrap();
assert_eq!(
vec![Some(word_to_write.clone()),],
vec![Some(w.clone()),],
read_result,
"test_wrong_guard failed. Value was overwritten, violating the guard. Expected : {:?}, \
got : {:?}. Debug seed : {:?}",
vec![Some(word_to_write),],
vec![Some(w),],
read_result,
seed
);
Expand All @@ -145,58 +118,55 @@ where
/// Spawns multiple threads to perform concurrent counter increments.
/// Uses retries to handle write contention between threads.
/// Verifies the final counter matches the total number of threads.
///
/// # Arguments
///
/// * `memory` - Reference to the Memory ADT implementation that can be cloned
/// * `seed` - 32-byte seed for reproducible random generation
///
/// # Type Parameters
///
/// * `T` - The Memory ADT implementation being tested
///
/// # Requirements
///
/// The type `T` must implement:
/// * `MemoryADT + Send + Sync + 'static + Clone`
/// * `T::Address: Debug + PartialEq + From<[u8; 16]> + Send`
/// * `T::Word: Debug + PartialEq + From<[u8; 16]> + Into<[u8; 16]> + Send + Clone + Default`
/// * `T::Error: std::error::
pub async fn test_guarded_write_concurrent<T>(memory: &T, seed: [u8; 32])
pub async fn test_guarded_write_concurrent<Memory>(memory: &Memory, seed: [u8; 32])
where
T: MemoryADT + Send + Sync + 'static + Clone,
T::Address: std::fmt::Debug + PartialEq + From<[u8; 16]> + Send,
T::Word: std::fmt::Debug + PartialEq + From<[u8; 16]> + Into<[u8; 16]> + Send + Clone + Default,
T::Error: std::error::Error,
Memory: 'static + Send + Sync + MemoryADT + Clone,
Memory::Address: Send + From<[u8; 16]>,
Memory::Word: Send + Debug + PartialEq + From<[u8; 16]> + Into<[u8; 16]> + Clone + Default,
Memory::Error: Send + std::error::Error,
{
{
const N: usize = 1000;
const N: usize = 100;
let mut rng = StdRng::from_seed(seed);
let a = rng.gen::<u128>().to_be_bytes();
let a = gen_bytes(&mut rng);

// A worker increment N times the counter m[a].
let worker = |m: Memory, a: [u8; 16]| async move {
let mut cnt = 0u128;
for _ in 0..N {
loop {
let guard = if 0 == cnt {
None
} else {
Some(Memory::Word::from(cnt.to_be_bytes()))
};

let new_cnt = cnt + 1;
let cur_cnt = m
.guarded_write((a.into(), guard), vec![(
a.into(),
Memory::Word::from(new_cnt.to_be_bytes()),
)])
.await
.unwrap()
.map(|w| <u128>::from_be_bytes(w.into()))
.unwrap_or_default();

if cnt == cur_cnt {
cnt = new_cnt;
break;
} else {
cnt = cur_cnt;
}
}
}
};

// Spawn N concurrent workers.
let handles: Vec<_> = (0..N)
.map(|_| {
let mem = memory.clone();
std::thread::spawn(move || async move {
let mut old_cnt = None;
loop {
let cur_cnt = mem
.guarded_write((a.into(), old_cnt.clone()), vec![(
a.into(),
(u128::from_be_bytes(old_cnt.clone().unwrap_or_default().into())
+ 1)
.to_be_bytes()
.into(),
)])
.await
.unwrap();
if cur_cnt == old_cnt {
return;
} else {
old_cnt = cur_cnt;
}
}
})
let m = memory.clone();
std::thread::spawn(move || worker(m, a))
})
.collect();

Expand All @@ -210,7 +180,7 @@ where

assert_eq!(
u128::from_be_bytes(final_count.clone().into()),
N as u128,
(N * N) as u128,
"test_guarded_write_concurrent failed. Expected the counter to be at {:?}, found \
{:?}.\nDebug seed : {:?}.",
N as u128,
Expand Down
4 changes: 3 additions & 1 deletion src/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod encryption_layer;
mod in_memory_store;

pub use encryption_layer::MemoryEncryptionLayer;

#[cfg(feature = "redis-mem")]
pub mod redis_store;

pub use encryption_layer::MemoryEncryptionLayer;
#[cfg(any(test, feature = "bench"))]
pub use in_memory_store::InMemory;
Loading

0 comments on commit f76179a

Please sign in to comment.