Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add update method to ItemAccess #83

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions packages/mocks/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@ use storey_encoding::{Cover, DecodableWithImpl, EncodableWithImpl, Encoding};

pub struct TestEncoding;

#[derive(Debug, PartialEq)]
pub struct MockError;

impl std::fmt::Display for MockError {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}

impl Encoding for TestEncoding {
type DecodeError = ();
type EncodeError = ();
type DecodeError = MockError;
type EncodeError = MockError;
}

// This is how we would implement `EncodableWith` and `DecodableWith` for
Expand Down Expand Up @@ -39,16 +48,16 @@ where
// Imagine `MyTestEncoding` is a third-party trait that we don't control.

trait MyTestEncoding: Sized {
fn my_encode(&self) -> Result<Vec<u8>, ()>;
fn my_decode(data: &[u8]) -> Result<Self, ()>;
fn my_encode(&self) -> Result<Vec<u8>, MockError>;
fn my_decode(data: &[u8]) -> Result<Self, MockError>;
}

impl MyTestEncoding for u64 {
fn my_encode(&self) -> Result<Vec<u8>, ()> {
fn my_encode(&self) -> Result<Vec<u8>, MockError> {
Ok(self.to_le_bytes().to_vec())
}

fn my_decode(data: &[u8]) -> Result<Self, ()> {
fn my_decode(data: &[u8]) -> Result<Self, MockError> {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(data);
Ok(u64::from_le_bytes(bytes))
Expand Down
4 changes: 2 additions & 2 deletions packages/storey-encoding/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub trait Encoding {
/// The error type returned when encoding fails.
type EncodeError;
type EncodeError: std::fmt::Display;

/// The error type returned when decoding fails.
type DecodeError;
type DecodeError: std::fmt::Display;
}

pub trait EncodableWith<E: Encoding>: sealed::SealedE<E> {
Expand Down
88 changes: 77 additions & 11 deletions packages/storey/src/containers/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ where
Ok(id)
}

/// Update the value associated with the given ID.
/// Set the value associated with the given ID.
///
/// # Example
/// ```
Expand All @@ -380,13 +380,11 @@ where
/// access.push(&1337).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1337));
///
/// access.update(1, &9001).unwrap();
/// access.set(1, &9001).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(9001));
/// ```
pub fn update(&mut self, id: u32, value: &T) -> Result<(), UpdateError<E::EncodeError>> {
self.storage
.get(&encode_id(id))
.ok_or(UpdateError::NotFound)?;
pub fn set(&mut self, id: u32, value: &T) -> Result<(), SetError<E::EncodeError>> {
self.storage.get(&encode_id(id)).ok_or(SetError::NotFound)?;

let bytes = value.encode()?;

Expand All @@ -395,6 +393,44 @@ where
Ok(())
}

/// Update the value associated with the given ID by applying a function to it.
///
/// The provided function is called with the current value, if it exists, and should return the
/// new value. If the function returns `None`, the value is removed.
///
/// # Example
/// ```
/// # use mocks::encoding::TestEncoding;
/// # use mocks::backend::TestStorage;
/// use storey::containers::Column;
///
/// let mut storage = TestStorage::new();
/// let column = Column::<u64, TestEncoding>::new(0);
/// let mut access = column.access(&mut storage);
///
/// access.push(&1337).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1337));
///
/// access.update(1, |value| value.map(|v| v + 1)).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1338));
/// ```
pub fn update<F>(
&mut self,
id: u32,
f: F,
) -> Result<(), UpdateError<E::DecodeError, E::EncodeError>>
where
F: FnOnce(Option<T>) -> Option<T>,
{
let new_value = f(self.get(id).map_err(UpdateError::Decode)?);
match new_value {
Some(value) => self.set(id, &value).map_err(UpdateError::Set),
None => self
.remove(id)
.map_err(|_| UpdateError::Set(SetError::NotFound)),
}
}

/// Remove the value associated with the given ID.
///
/// This operation leaves behind an empty slot in the column. The ID is not reused.
Expand Down Expand Up @@ -445,19 +481,27 @@ impl<E> From<E> for PushError<E> {
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum UpdateError<E> {
pub enum SetError<E> {
#[error("not found")]
NotFound,
#[error("{0}")]
EncodingError(E),
}

impl<E> From<E> for UpdateError<E> {
impl<E> From<E> for SetError<E> {
fn from(e: E) -> Self {
UpdateError::EncodingError(e)
SetError::EncodingError(e)
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum UpdateError<D, E> {
#[error("decode error: {0}")]
Decode(D),
#[error("set error: {0}")]
Set(SetError<E>),
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum RemoveError {
#[error("inconsistent state")]
Expand Down Expand Up @@ -497,8 +541,8 @@ mod tests {
assert_eq!(access.len().unwrap(), 2);

access.remove(1).unwrap();
assert_eq!(access.update(1, &9001), Err(UpdateError::NotFound));
access.update(2, &9001).unwrap();
assert_eq!(access.set(1, &9001), Err(SetError::NotFound));
access.set(2, &9001).unwrap();

assert_eq!(access.get(1).unwrap(), None);
assert_eq!(access.get(2).unwrap(), Some(9001));
Expand Down Expand Up @@ -535,6 +579,28 @@ mod tests {
assert_eq!(access.len().unwrap(), 1);
}

#[test]
fn update() {
let mut storage = TestStorage::new();

let column = Column::<u64, TestEncoding>::new(0);
let mut access = column.access(&mut storage);

access.push(&1337).unwrap();
access.push(&42).unwrap();
access.push(&9001).unwrap();
access.remove(2).unwrap();

access.update(1, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(1).unwrap(), Some(1338));

access.update(2, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(2).unwrap(), None);

access.update(3, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(3).unwrap(), Some(9002));
}

#[test]
fn iteration() {
let mut storage = TestStorage::new();
Expand Down
59 changes: 58 additions & 1 deletion packages/storey/src/containers/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl<E, T, S> ItemAccess<E, T, S>
where
E: Encoding,
T: EncodableWith<E> + DecodableWith<E>,
S: StorageMut,
S: Storage + StorageMut,
{
/// Set the value of the item.
///
Expand All @@ -234,6 +234,39 @@ where
Ok(())
}

/// Update the value of the item.
///
/// The function `f` is called with the current value of the item, if it exists.
/// If the function returns `Some`, the item is set to the new value.
/// If the function returns `None`, the item is removed.
///
/// # Example
/// ```
/// # use mocks::encoding::TestEncoding;
/// # use mocks::backend::TestStorage;
/// use storey::containers::Item;
///
/// let mut storage = TestStorage::new();
/// let item = Item::<u64, TestEncoding>::new(0);
///
/// item.access(&mut storage).set(&42).unwrap();
/// item.access(&mut storage).update(|value| value.map(|v| v + 1)).unwrap();
/// assert_eq!(item.access(&storage).get().unwrap(), Some(43));
/// ```
pub fn update<F>(&mut self, f: F) -> Result<(), UpdateError<E::DecodeError, E::EncodeError>>
where
F: FnOnce(Option<T>) -> Option<T>,
{
let new_value = f(self.get().map_err(UpdateError::Decode)?);
match new_value {
Some(value) => self.set(&value).map_err(UpdateError::Encode),
None => {
self.remove();
Ok(())
}
}
}

/// Remove the value of the item.
///
/// # Example
Expand All @@ -254,6 +287,14 @@ where
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)]
pub enum UpdateError<D, E> {
#[error("decode error: {0}")]
Decode(D),
#[error("encode error: {0}")]
Encode(E),
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -276,4 +317,20 @@ mod tests {
assert_eq!(access1.get().unwrap(), None);
assert_eq!(storage.get(&[1]), None);
}

#[test]
fn update() {
let mut storage = TestStorage::new();

let item = Item::<u64, TestEncoding>::new(0);
item.access(&mut storage).set(&42).unwrap();

item.access(&mut storage)
.update(|value| value.map(|v| v + 1))
.unwrap();
assert_eq!(item.access(&storage).get().unwrap(), Some(43));

item.access(&mut storage).update(|_| None).unwrap();
assert_eq!(item.access(&storage).get().unwrap(), None);
}
}
17 changes: 9 additions & 8 deletions packages/storey/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@
//! struct DisplayEncoding;
//!
//! impl Encoding for DisplayEncoding {
//! type DecodeError = ();
//! type EncodeError = ();
//! type DecodeError = String;
//! type EncodeError = String;
//! }
//!
//! impl<T> EncodableWithImpl<DisplayEncoding> for Cover<&T,>
//! where
//! T: std::fmt::Display,
//! {
//! fn encode_impl(self) -> Result<Vec<u8>, ()> {
//! fn encode_impl(self) -> Result<Vec<u8>, String> {
//! Ok(format!("{}", self.0).into_bytes())
//! }
//! }
Expand All @@ -67,17 +67,18 @@
//! struct DisplayEncoding;
//!
//! impl Encoding for DisplayEncoding {
//! type DecodeError = ();
//! type EncodeError = ();
//! type DecodeError = String;
//! type EncodeError = String;
//! }
//!
//! impl<T> DecodableWithImpl<DisplayEncoding> for Cover<T>
//! where
//! T: std::str::FromStr,
//! {
//! fn decode_impl(data: &[u8]) -> Result<Self, ()> {
//! let string = String::from_utf8(data.to_vec()).map_err(|_| ())?;
//! let value = string.parse().map_err(|_| ())?;
//! fn decode_impl(data: &[u8]) -> Result<Self, String> {
//! let string =
//! String::from_utf8(data.to_vec()).map_err(|_| "string isn't UTF-8".to_string())?;
//! let value = string.parse().map_err(|_| "parsing failed".to_string())?;
//! Ok(Cover(value))
//! }
//! }
Expand Down
Loading