diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 657b7d306..7e04f95ea 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,5 +1,5 @@ use crate::storage::RawDbConnection; -use std::{cell::RefCell, fmt}; +use std::{cell::RefCell, fmt, sync::Mutex}; // Re-implementation of Cow without ToOwned requirement enum RefOrValue<'a, T> { @@ -12,26 +12,37 @@ enum RefOrValue<'a, T> { /// and transaction state can be shared between the OpenMLS Provider and /// native XMTP operations pub struct DbConnection<'a> { - wrapped_conn: RefCell>, + wrapped_conn: Mutex>, } impl<'a> DbConnection<'a> { pub(crate) fn new(conn: &'a mut RawDbConnection) -> Self { Self { - wrapped_conn: RefCell::new(RefOrValue::Ref(conn)), + wrapped_conn: Mutex::new(RefOrValue::Ref(conn)), } } pub(crate) fn held(conn: RawDbConnection) -> Self { Self { - wrapped_conn: RefCell::new(RefOrValue::Value(conn)), + wrapped_conn: Mutex::new(RefOrValue::Value(conn)), } } + // Note: F is a synchronous fn. If it ever becomes async, we need to use + // tokio::sync::mutex instead of std::sync::Mutex pub(crate) fn raw_query(&self, fun: F) -> Result where F: FnOnce(&mut RawDbConnection) -> Result, { - match *self.wrapped_conn.borrow_mut() { + let mut lock = self.wrapped_conn.lock().map_or_else( + |err| { + log::error!( + "Recovering from poisoned mutex - a thread has previously panicked holding this lock" + ); + err.into_inner() + }, + |guard| guard, + ); + match *lock { RefOrValue::Ref(ref mut conn_ref) => fun(conn_ref), RefOrValue::Value(ref mut conn) => fun(conn), }