diff --git a/external-crates/move/crates/move-vm-runtime/src/natives/extensions.rs b/external-crates/move/crates/move-vm-runtime/src/natives/extensions.rs index 254eed928c6a3..77c4a6b067ffb 100644 --- a/external-crates/move/crates/move-vm-runtime/src/natives/extensions.rs +++ b/external-crates/move/crates/move-vm-runtime/src/natives/extensions.rs @@ -3,7 +3,56 @@ // SPDX-License-Identifier: Apache-2.0 use better_any::{Tid, TidAble, TidExt}; -use std::{any::TypeId, collections::HashMap}; +use std::{ + any::TypeId, + cell::{Ref, RefCell, RefMut}, + collections::HashMap, + ops::{Deref, DerefMut}, + rc::Rc, +}; + +/// A helper wrapper around a `Tid`able type that encapsulates interior mutability in a single-threaded +/// manner. +/// +/// Note that this is _not_ threadsafe. If you need threadsafe access to the `T` you will need to +/// handle that within `T'`s type (just like in the previous implementation of the +/// `NativeContextExtensions`). +#[derive(Tid)] +pub struct NativeContextMut<'a, T: Tid<'a>>(pub RefCell, std::marker::PhantomData<&'a ()>); + +impl<'a, T: Tid<'a>> NativeContextMut<'a, T> { + /// Create a new `NativeContextMut` value with the given value. + pub fn new(t: T) -> Self { + NativeContextMut(RefCell::new(t), std::marker::PhantomData) + } + + /// Get the inner value by `&mut`. + pub fn get_mut(&self) -> RefMut { + self.0.borrow_mut() + } + + /// Get the inner value by `&`. + pub fn get(&self) -> Ref { + self.0.borrow() + } + + pub fn into_inner(self) -> T { + self.0.into_inner() + } +} + +impl<'a, T: Tid<'a>> Deref for NativeContextMut<'a, T> { + type Target = RefCell; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Tid<'a>> DerefMut for NativeContextMut<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} /// A data type to represent a heterogeneous collection of extensions which are available to /// native functions. A value to this is passed into the session function execution. @@ -13,15 +62,15 @@ use std::{any::TypeId, collections::HashMap}; /// avoids that extensions need to have `'static` lifetime, which `Any` requires. In order to make a /// struct suitable to be a 'Tid', use `#[derive(Tid)]` in the struct declaration. (See also /// tests at the end of this module.) -#[derive(Default)] +#[derive(Default, Clone)] pub struct NativeContextExtensions<'a> { - map: HashMap>>, + map: HashMap>>, } impl<'a> NativeContextExtensions<'a> { pub fn add>(&mut self, ext: T) { assert!( - self.map.insert(T::id(), Box::new(ext)).is_none(), + self.map.insert(T::id(), Rc::new(ext)).is_none(), "multiple extensions of the same type not allowed" ) } @@ -35,24 +84,15 @@ impl<'a> NativeContextExtensions<'a> { .unwrap() } - pub fn get_mut>(&mut self) -> &mut T { - self.map - .get_mut(&T::id()) - .expect("extension unknown") - .as_mut() - .downcast_mut::() - .unwrap() - } - - pub fn remove>(&mut self) -> T { + pub fn remove>(&mut self) -> Rc { // can't use expect below because it requires `T: Debug`. match self .map .remove(&T::id()) .expect("extension unknown") - .downcast_box::() + .downcast_rc::() { - Ok(val) => *val, + Ok(val) => val, Err(_) => panic!("downcast error"), } } @@ -60,7 +100,7 @@ impl<'a> NativeContextExtensions<'a> { #[cfg(test)] mod tests { - use crate::natives::extensions::NativeContextExtensions; + use super::*; use better_any::{Tid, TidAble}; #[derive(Tid)] @@ -73,11 +113,11 @@ mod tests { let mut v: u64 = 23; let e = Ext { a: &mut v }; let mut exts = NativeContextExtensions::default(); - exts.add(e); - *exts.get_mut::().a += 1; - assert_eq!(*exts.get_mut::().a, 24); - *exts.get_mut::().a += 1; - let e1 = exts.remove::(); - assert_eq!(*e1.a, 25) + exts.add(NativeContextMut::new(e)); + *exts.get::>().get_mut().a += 1; + assert_eq!(*exts.get::>().get_mut().a, 24); + *exts.get::>().get_mut().a += 1; + let e1 = exts.get::>(); + assert_eq!(*e1.get().a, 25); } }