From 35c7e50a69039739416193e6a43b01260ca8386c Mon Sep 17 00:00:00 2001 From: Tomasz Kurcz Date: Tue, 7 May 2024 18:08:48 +0200 Subject: [PATCH] feat: provide typed bounds for iteration --- packages/storey/src/containers/column.rs | 58 +++++++++++++-- packages/storey/src/containers/map.rs | 45 ++++++----- packages/storey/src/containers/mod.rs | 95 +++++++++++++++++++----- packages/storey/tests/composition.rs | 5 +- packages/storey/tests/iteration.rs | 7 +- 5 files changed, 151 insertions(+), 59 deletions(-) diff --git a/packages/storey/src/containers/column.rs b/packages/storey/src/containers/column.rs index ae04c8a..f015243 100644 --- a/packages/storey/src/containers/column.rs +++ b/packages/storey/src/containers/column.rs @@ -7,7 +7,7 @@ use crate::encoding::{DecodableWith, EncodableWith}; use crate::storage::{IterableStorage, StorageBranch}; use crate::storage::{Storage, StorageMut}; -use super::{IterableAccessor, Storable}; +use super::{BoundFor, BoundedIterableAccessor, IterableAccessor, Storable}; const META_LAST_IX: &[u8] = &[0]; const META_LEN: &[u8] = &[1]; @@ -135,6 +135,20 @@ where } } +impl BoundedIterableAccessor for ColumnAccess +where + E: Encoding, + T: EncodableWith + DecodableWith, + S: IterableStorage, +{ +} + +impl BoundFor> for u32 { + fn into_bytes(self) -> Vec { + self.to_be_bytes().to_vec() + } +} + impl ColumnAccess where E: Encoding, @@ -428,28 +442,58 @@ mod tests { access.push(&9001).unwrap(); access.remove(1).unwrap(); + assert_eq!( + access.pairs().collect::, _>>().unwrap(), + vec![(0, 1337), (2, 9001)] + ); + + assert_eq!( + access.keys().collect::, _>>().unwrap(), + vec![0, 2] + ); + + assert_eq!( + access.values().collect::, _>>().unwrap(), + vec![1337, 9001] + ); + } + + #[test] + fn bounded_iteration() { + let mut storage = TestStorage::new(); + + let column = Column::::new(0); + let mut access = column.access(&mut storage); + + access.push(&1337).unwrap(); + access.push(&42).unwrap(); + access.push(&9001).unwrap(); + access.push(&1).unwrap(); + access.push(&2).unwrap(); + access.remove(2).unwrap(); + assert_eq!( access - .pairs(None, None) + .bounded_pairs(Some(1), Some(4)) .collect::, _>>() .unwrap(), - vec![(0, 1337), (2, 9001)] + vec![(1, 42), (3, 1)] ); assert_eq!( access - .keys(None, None) + .bounded_keys(Some(1), Some(4)) .collect::, _>>() .unwrap(), - vec![0, 2] + vec![1, 3] ); assert_eq!( access - .values(None, None) + .bounded_values(Some(1), Some(4)) .collect::, _>>() .unwrap(), - vec![1337, 9001] + vec![42, 1] ); } } diff --git a/packages/storey/src/containers/map.rs b/packages/storey/src/containers/map.rs index 67eb36e..d72c0ac 100644 --- a/packages/storey/src/containers/map.rs +++ b/packages/storey/src/containers/map.rs @@ -188,12 +188,7 @@ where K: Borrow, Q: Key + ?Sized, { - let len = key.bytes().len(); - let bytes = key.bytes(); - let mut key = Vec::with_capacity(len + 1); - - key.push(len as u8); - key.extend_from_slice(bytes); + let key = length_prefixed_key(key); V::access_impl(StorageBranch::new(&self.storage, key)) } @@ -243,6 +238,17 @@ where } } +fn length_prefixed_key(key: &K) -> Vec { + let len = key.bytes().len(); + let bytes = key.bytes(); + let mut key = Vec::with_capacity(len + 1); + + key.push(len as u8); + key.extend_from_slice(bytes); + + key +} + impl IterableAccessor for MapAccess where K: OwnedKey, @@ -276,6 +282,12 @@ impl Key for String { } } +impl Key for str { + fn bytes(&self) -> &[u8] { + self.as_bytes() + } +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)] #[error("invalid UTF8")] pub struct InvalidUtf8; @@ -293,12 +305,6 @@ impl OwnedKey for String { } } -impl Key for str { - fn bytes(&self) -> &[u8] { - self.as_bytes() - } -} - #[cfg(test)] mod tests { use super::*; @@ -338,10 +344,7 @@ mod tests { access.entry_mut("foo").set(&1337).unwrap(); access.entry_mut("bar").set(&42).unwrap(); - let items = access - .pairs(None, None) - .collect::, _>>() - .unwrap(); + let items = access.pairs().collect::, _>>().unwrap(); assert_eq!( items, vec![ @@ -361,10 +364,7 @@ mod tests { access.entry_mut("foo").set(&1337).unwrap(); access.entry_mut("bar").set(&42).unwrap(); - let keys = access - .keys(None, None) - .collect::, _>>() - .unwrap(); + let keys = access.keys().collect::, _>>().unwrap(); assert_eq!(keys, vec![("bar".to_string(), ()), ("foo".to_string(), ())]) } @@ -378,10 +378,7 @@ mod tests { access.entry_mut("foo").set(&1337).unwrap(); access.entry_mut("bar").set(&42).unwrap(); - let values = access - .values(None, None) - .collect::, _>>() - .unwrap(); + let values = access.values().collect::, _>>().unwrap(); assert_eq!(values, vec![42, 1337]) } } diff --git a/packages/storey/src/containers/mod.rs b/packages/storey/src/containers/mod.rs index 9576d3d..d62d40e 100644 --- a/packages/storey/src/containers/mod.rs +++ b/packages/storey/src/containers/mod.rs @@ -68,7 +68,7 @@ pub enum KVDecodeError { /// A trait for collection accessors (see [`Storable::AccessorT`]) that provide iteration over /// their contents. -pub trait IterableAccessor { +pub trait IterableAccessor: Sized { /// The [`Storable`] type this accessor is associated with. type StorableT: Storable; @@ -81,42 +81,99 @@ pub trait IterableAccessor { fn storage(&self) -> &Self::StorageT; /// Iterate over key-value pairs in this collection. - fn pairs<'s>( - &'s self, - start: Option<&[u8]>, - end: Option<&[u8]>, - ) -> StorableIter<'s, Self::StorableT, Self::StorageT> { + fn pairs(&self) -> StorableIter<'_, Self::StorableT, Self::StorageT> { StorableIter { - inner: self.storage().pairs(start, end), + inner: self.storage().pairs(None, None), phantom: PhantomData, } } /// Iterate over keys in this collection. - fn keys<'s>( - &'s self, - start: Option<&[u8]>, - end: Option<&[u8]>, - ) -> StorableKeys<'s, Self::StorableT, Self::StorageT> { + fn keys(&self) -> StorableKeys<'_, Self::StorableT, Self::StorageT> { StorableKeys { - inner: self.storage().keys(start, end), + inner: self.storage().keys(None, None), phantom: PhantomData, } } /// Iterate over values in this collection. - fn values<'s>( - &'s self, - start: Option<&[u8]>, - end: Option<&[u8]>, - ) -> StorableValues<'s, Self::StorableT, Self::StorageT> { + fn values(&self) -> StorableValues<'_, Self::StorableT, Self::StorageT> { StorableValues { - inner: self.storage().values(start, end), + inner: self.storage().values(None, None), phantom: PhantomData, } } } +pub trait BoundedIterableAccessor: IterableAccessor { + /// Iterate over key-value pairs in this collection, respecting the given bounds. + fn bounded_pairs( + &self, + start: Option, + end: Option, + ) -> StorableIter<'_, Self::StorableT, Self::StorageT> + where + S: BoundFor, + E: BoundFor, + { + let start = start.map(|b| b.into_bytes()); + let end = end.map(|b| b.into_bytes()); + + StorableIter { + inner: self.storage().pairs(start.as_deref(), end.as_deref()), + phantom: PhantomData, + } + } + + /// Iterate over keys in this collection, respecting the given bounds. + fn bounded_keys( + &self, + start: Option, + end: Option, + ) -> StorableKeys<'_, Self::StorableT, Self::StorageT> + where + S: BoundFor, + E: BoundFor, + { + let start = start.map(|b| b.into_bytes()); + let end = end.map(|b| b.into_bytes()); + + StorableKeys { + inner: self.storage().keys(start.as_deref(), end.as_deref()), + phantom: PhantomData, + } + } + + /// Iterate over values in this collection, respecting the given bounds. + fn bounded_values( + &self, + start: Option, + end: Option, + ) -> StorableValues<'_, Self::StorableT, Self::StorageT> + where + S: BoundFor, + E: BoundFor, + { + let start = start.map(|b| b.into_bytes()); + let end = end.map(|b| b.into_bytes()); + + StorableValues { + inner: self.storage().values(start.as_deref(), end.as_deref()), + phantom: PhantomData, + } + } +} + +/// A type that can be used as bounds for iteration over a given collection. +/// +/// As an example, a collection `Foo` with string-y keys can accept both `String` and +/// `&str` bounds by providing these impls: +/// - `impl BoundFor for &str` +/// - `impl BoundFor for String` +pub trait BoundFor { + fn into_bytes(self) -> Vec; +} + /// The iterator over key-value pairs in a collection. pub struct StorableIter<'i, S, B> where diff --git a/packages/storey/tests/composition.rs b/packages/storey/tests/composition.rs index 221df8c..59dd251 100644 --- a/packages/storey/tests/composition.rs +++ b/packages/storey/tests/composition.rs @@ -57,10 +57,7 @@ fn map_of_column() { assert_eq!(access.entry("bar").get(0).unwrap(), Some(9001)); assert_eq!(access.entry("bar").len().unwrap(), 1); - let all = access - .pairs(None, None) - .collect::, _>>() - .unwrap(); + let all = access.pairs().collect::, _>>().unwrap(); assert_eq!( all, vec![ diff --git a/packages/storey/tests/iteration.rs b/packages/storey/tests/iteration.rs index 215f690..4c43d90 100644 --- a/packages/storey/tests/iteration.rs +++ b/packages/storey/tests/iteration.rs @@ -20,10 +20,7 @@ fn map_of_map_iteration() { .unwrap(); // iterate over all items - let items = access - .pairs(None, None) - .collect::, _>>() - .unwrap(); + let items = access.pairs().collect::, _>>().unwrap(); assert_eq!( items, vec![ @@ -36,7 +33,7 @@ fn map_of_map_iteration() { // iterate over items under "foo" let items = access .entry("foo") - .pairs(None, None) + .pairs() .collect::, _>>() .unwrap(); assert_eq!(