From f76179a732ee4ad95a8a1639699397b2a45025f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20BR=C3=89ZOT?= Date: Tue, 21 Jan 2025 13:33:47 +0100 Subject: [PATCH] some neats, and strengthen the concurrency test --- src/adt/test_utils.rs | 214 ++++++++++++++++---------------------- src/memory/mod.rs | 4 +- src/memory/redis_store.rs | 50 ++++----- 3 files changed, 118 insertions(+), 150 deletions(-) diff --git a/src/adt/test_utils.rs b/src/adt/test_utils.rs index 7dc16f49..cc1fe73a 100644 --- a/src/adt/test_utils.rs +++ b/src/adt/test_utils.rs @@ -1,44 +1,40 @@ -// ! This module defines tests any implementation of the MemoryADT interface -// must pass. -use rand::{Rng, SeedableRng, rngs::StdRng}; +//! This module defines tests any implementation of the MemoryADT interface must +//! pass. +//! +//! The given seeds are used to initialize the random generators used in the +//! test, thus allowing for reproducibility. In particular, all addresses are +//! randomly generated, which should guarantee thread-safety provided a +//! random-enough seed is given. +//! +//! Both addresses and words are 16-byte long. use crate::MemoryADT; +use rand::{Rng, RngCore, SeedableRng, rngs::StdRng}; +use std::fmt::Debug; + +fn gen_bytes(rng: &mut impl RngCore) -> [u8; 16] { + let mut bytes = [0; 16]; + rng.fill_bytes(&mut bytes); + bytes +} /// Tests the basic write and read operations of a Memory ADT implementation. /// -/// This function verifies the memory operations by first checking empty addresses, -/// then performing a guarded write, and finally validating the written value. -/// -/// # Arguments -/// -/// * `memory` - Reference to the Memory ADT implementation -/// * `seed` - 32-byte seed for reproducible random generation -/// -/// # Type Parameters -/// -/// * `T` - The Memory ADT implementation being tested -/// -/// # Requirements -/// -/// The type `T` must implement: -/// * `MemoryADT + Send + Sync` -/// * `T::Address: Debug + PartialEq + From<[u8; 16]> + Send` -/// * `T::Word: Debug + PartialEq + From<[u8; 16]> + Send` -/// * `T::Error: std::error::Error + Send` -pub async fn test_single_write_and_read(memory: &T, seed: [u8; 32]) +/// This function first attempts reading empty addresses, then performing a +/// guarded write, and finally validating the written value. +pub async fn test_single_write_and_read(memory: &Memory, seed: [u8; 32]) where - T: MemoryADT + Send + Sync, - T::Address: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send, - T::Word: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send, - T::Error: std::error::Error + Send, + Memory: Send + Sync + MemoryADT, + Memory::Address: Send + Clone + From<[u8; 16]>, + Memory::Word: Send + Debug + Clone + PartialEq + From<[u8; 16]>, + Memory::Error: std::error::Error, { let mut rng = StdRng::from_seed(seed); - let empty_read_result = memory .batch_read(vec![ - T::Address::from(rng.gen::().to_be_bytes()), - T::Address::from(rng.gen::().to_be_bytes()), - T::Address::from(rng.gen::().to_be_bytes()), + Memory::Address::from(gen_bytes(&mut rng)), + Memory::Address::from(gen_bytes(&mut rng)), + Memory::Address::from(gen_bytes(&mut rng)), ]) .await .unwrap(); @@ -49,21 +45,17 @@ where {empty_read_result:?}. Seed : {seed:?}" ); - let random_address = T::Address::from(rng.gen::().to_be_bytes()); - let random_word = T::Word::from(rng.gen::().to_be_bytes()); + let a = Memory::Address::from(gen_bytes(&mut rng)); + let w = Memory::Word::from(gen_bytes(&mut rng)); let write_result = memory - .guarded_write((random_address.clone(), None), vec![( - random_address.clone(), - random_word.clone(), - )]) + .guarded_write((a.clone(), None), vec![(a.clone(), w.clone())]) .await .unwrap(); assert_eq!(write_result, None); - let read_result: Vec::Word>> = - memory.batch_read(vec![random_address]).await.unwrap(); - let expected_result = vec![Some(random_word)]; + let read_result = memory.batch_read(vec![a]).await.unwrap(); + let expected_result = vec![Some(w)]; assert_eq!( read_result, expected_result, "test_single_write_and_read failed.\nExpected result : {expected_result:?}, got : \ @@ -75,66 +67,47 @@ where /// /// Attempts to write with a None guard to an address containing a value. /// Verifies that the original value is preserved and the write fails. -/// -/// # Arguments -/// -/// * `memory` - Reference to the Memory ADT implementation -/// * `seed` - 32-byte seed for reproducible random generation -/// -/// # Type Parameters -/// -/// * `T` - The Memory ADT implementation being tested -/// -/// # Requirements -/// -/// The type `T` must implement: -/// * `MemoryADT + Send + Sync` -/// * `T::Address: Debug + PartialEq + From<[u8; 16]> + Send` -/// * `T::Word: Debug + PartialEq + From<[u8; 16]> + Send` -/// * `T::Error: std::error::Error + Send` -pub async fn test_wrong_guard(memory: &T, seed: [u8; 32]) +pub async fn test_wrong_guard(memory: &Memory, seed: [u8; 32]) where - T: MemoryADT + Send + Sync, - T::Address: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send, - T::Word: std::fmt::Debug + Clone + PartialEq + From<[u8; 16]> + Send, - T::Error: std::error::Error + Send, + Memory: Send + Sync + MemoryADT, + Memory::Address: Send + Clone + From<[u8; 16]>, + Memory::Word: Send + Debug + Clone + PartialEq + From<[u8; 16]>, + Memory::Error: Send + std::error::Error, { let mut rng = StdRng::from_seed(seed); - let random_address = T::Address::from(rng.gen::().to_be_bytes()); - let word_to_write = T::Word::from(rng.gen::().to_be_bytes()); + + let a = Memory::Address::from(gen_bytes(&mut rng)); + let w = Memory::Word::from(gen_bytes(&mut rng)); memory - .guarded_write((random_address.clone(), None), vec![( - random_address.clone(), - word_to_write.clone(), - )]) + .guarded_write((a.clone(), None), vec![(a.clone(), w.clone())]) .await .unwrap(); let conflict_result = memory - .guarded_write((random_address.clone(), None), vec![( - random_address.clone(), - T::Word::from(rng.gen::().to_be_bytes()), + .guarded_write((a.clone(), None), vec![( + a.clone(), + Memory::Word::from(rng.gen::().to_be_bytes()), )]) .await .unwrap(); assert_eq!( conflict_result, - Some(word_to_write.clone()), + Some(w.clone()), "test_wrong_guard failed.\nExpected value {:?} after write. Got : {:?}.\nDebug seed : {:?}", conflict_result, - Some(word_to_write), + Some(w), seed ); - let read_result = memory.batch_read(vec![random_address]).await.unwrap(); + let read_result = memory.batch_read(vec![a]).await.unwrap(); assert_eq!( - vec![Some(word_to_write.clone()),], + vec![Some(w.clone()),], read_result, "test_wrong_guard failed. Value was overwritten, violating the guard. Expected : {:?}, \ got : {:?}. Debug seed : {:?}", - vec![Some(word_to_write),], + vec![Some(w),], read_result, seed ); @@ -145,58 +118,55 @@ where /// Spawns multiple threads to perform concurrent counter increments. /// Uses retries to handle write contention between threads. /// Verifies the final counter matches the total number of threads. -/// -/// # Arguments -/// -/// * `memory` - Reference to the Memory ADT implementation that can be cloned -/// * `seed` - 32-byte seed for reproducible random generation -/// -/// # Type Parameters -/// -/// * `T` - The Memory ADT implementation being tested -/// -/// # Requirements -/// -/// The type `T` must implement: -/// * `MemoryADT + Send + Sync + 'static + Clone` -/// * `T::Address: Debug + PartialEq + From<[u8; 16]> + Send` -/// * `T::Word: Debug + PartialEq + From<[u8; 16]> + Into<[u8; 16]> + Send + Clone + Default` -/// * `T::Error: std::error:: -pub async fn test_guarded_write_concurrent(memory: &T, seed: [u8; 32]) +pub async fn test_guarded_write_concurrent(memory: &Memory, seed: [u8; 32]) where - T: MemoryADT + Send + Sync + 'static + Clone, - T::Address: std::fmt::Debug + PartialEq + From<[u8; 16]> + Send, - T::Word: std::fmt::Debug + PartialEq + From<[u8; 16]> + Into<[u8; 16]> + Send + Clone + Default, - T::Error: std::error::Error, + Memory: 'static + Send + Sync + MemoryADT + Clone, + Memory::Address: Send + From<[u8; 16]>, + Memory::Word: Send + Debug + PartialEq + From<[u8; 16]> + Into<[u8; 16]> + Clone + Default, + Memory::Error: Send + std::error::Error, { { - const N: usize = 1000; + const N: usize = 100; let mut rng = StdRng::from_seed(seed); - let a = rng.gen::().to_be_bytes(); + let a = gen_bytes(&mut rng); + + // A worker increment N times the counter m[a]. + let worker = |m: Memory, a: [u8; 16]| async move { + let mut cnt = 0u128; + for _ in 0..N { + loop { + let guard = if 0 == cnt { + None + } else { + Some(Memory::Word::from(cnt.to_be_bytes())) + }; + + let new_cnt = cnt + 1; + let cur_cnt = m + .guarded_write((a.into(), guard), vec![( + a.into(), + Memory::Word::from(new_cnt.to_be_bytes()), + )]) + .await + .unwrap() + .map(|w| ::from_be_bytes(w.into())) + .unwrap_or_default(); + + if cnt == cur_cnt { + cnt = new_cnt; + break; + } else { + cnt = cur_cnt; + } + } + } + }; + // Spawn N concurrent workers. let handles: Vec<_> = (0..N) .map(|_| { - let mem = memory.clone(); - std::thread::spawn(move || async move { - let mut old_cnt = None; - loop { - let cur_cnt = mem - .guarded_write((a.into(), old_cnt.clone()), vec![( - a.into(), - (u128::from_be_bytes(old_cnt.clone().unwrap_or_default().into()) - + 1) - .to_be_bytes() - .into(), - )]) - .await - .unwrap(); - if cur_cnt == old_cnt { - return; - } else { - old_cnt = cur_cnt; - } - } - }) + let m = memory.clone(); + std::thread::spawn(move || worker(m, a)) }) .collect(); @@ -210,7 +180,7 @@ where assert_eq!( u128::from_be_bytes(final_count.clone().into()), - N as u128, + (N * N) as u128, "test_guarded_write_concurrent failed. Expected the counter to be at {:?}, found \ {:?}.\nDebug seed : {:?}.", N as u128, diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 6b6d94df..414ed524 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -1,8 +1,10 @@ mod encryption_layer; mod in_memory_store; + +pub use encryption_layer::MemoryEncryptionLayer; + #[cfg(feature = "redis-mem")] pub mod redis_store; -pub use encryption_layer::MemoryEncryptionLayer; #[cfg(any(test, feature = "bench"))] pub use in_memory_store::InMemory; diff --git a/src/memory/redis_store.rs b/src/memory/redis_store.rs index 22fb1a8f..6c6fcd5d 100644 --- a/src/memory/redis_store.rs +++ b/src/memory/redis_store.rs @@ -17,7 +17,7 @@ local length = ARGV[3] local value = redis.call('GET',ARGV[1]) -- compare the value of the guard to the currently stored value -if((value==false) or (not(value == false) and (guard_value == value))) then +if ((value == false) or (guard_value == value)) then -- guard passed, loop over bindings and insert them for i = 4,(length*2)+3,2 do @@ -48,7 +48,7 @@ impl From for MemoryError { #[derive(Clone)] pub struct RedisMemory { - pub manager: ConnectionManager, + manager: ConnectionManager, script_hash: String, a: PhantomData
, w: PhantomData, @@ -65,13 +65,15 @@ impl fmt::Debug for RedisMemory { impl RedisMemory { /// Connects to a Redis server with a `ConnectionManager`. pub async fn connect_with_manager(mut manager: ConnectionManager) -> Result { + let script_hash = redis::cmd("SCRIPT") + .arg("LOAD") + .arg(GUARDED_WRITE_LUA_SCRIPT) + .query_async(&mut manager) + .await?; + Ok(Self { - manager: manager.clone(), - script_hash: redis::cmd("SCRIPT") - .arg("LOAD") - .arg(GUARDED_WRITE_LUA_SCRIPT) - .query_async(&mut manager) - .await?, + manager, + script_hash, a: PhantomData, w: PhantomData, }) @@ -89,8 +91,8 @@ impl MemoryADT for RedisMemory, [u8; WORD_LENGTH]> { type Address = Address; - type Error = MemoryError; type Word = [u8; WORD_LENGTH]; + type Error = MemoryError; async fn batch_read( &self, @@ -98,6 +100,7 @@ impl MemoryADT ) -> Result>, Self::Error> { let mut cmd = redis::cmd("MGET"); let cmd = addresses.iter().fold(&mut cmd, |c, a| c.arg(&**a)); + // Cloning the connection manager is cheap since it is an `Arc`. cmd.query_async(&mut self.manager.clone()) .await .map_err(Self::Error::from) @@ -113,18 +116,19 @@ impl MemoryADT let cmd = cmd .arg(self.script_hash.as_str()) .arg(0) - .arg(&*guard_address); - - let cmd = if let Some(byte_array) = guard_value { - cmd.arg(&byte_array) - } else { - cmd.arg("false") - }; + .arg(&*guard_address) + .arg( + guard_value + .as_ref() + .map(|bytes| bytes.as_slice()) + .unwrap_or(b"false".as_slice()), + ); let cmd = bindings .iter() .fold(cmd.arg(bindings.len()), |cmd, (a, w)| cmd.arg(&**a).arg(w)); + // Cloning the connection manager is cheap since it is an `Arc`. cmd.query_async(&mut self.manager.clone()) .await .map_err(Self::Error::from) @@ -146,31 +150,23 @@ mod tests { ) } - const WORD_LENGTH: usize = 16; - #[tokio::test] async fn test_rw_seq() -> Result<(), MemoryError> { - let m = RedisMemory::<_, [u8; WORD_LENGTH]>::connect(&get_redis_url()) - .await - .unwrap(); + let m = RedisMemory::connect(&get_redis_url()).await.unwrap(); test_single_write_and_read(&m, rand::random()).await; Ok(()) } #[tokio::test] async fn test_guard_seq() -> Result<(), MemoryError> { - let m = RedisMemory::<_, [u8; WORD_LENGTH]>::connect(&get_redis_url()) - .await - .unwrap(); + let m = RedisMemory::connect(&get_redis_url()).await.unwrap(); test_wrong_guard(&m, rand::random()).await; Ok(()) } #[tokio::test] async fn test_rw_ccr() -> Result<(), MemoryError> { - let m = RedisMemory::<_, [u8; WORD_LENGTH]>::connect(&get_redis_url()) - .await - .unwrap(); + let m = RedisMemory::connect(&get_redis_url()).await.unwrap(); test_guarded_write_concurrent(&m, rand::random()).await; Ok(()) }