diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 5d2260926..90a920355 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1072,7 +1072,7 @@ pub(crate) mod tests { .unwrap(); let conn = amal.store().conn().unwrap(); - conn.raw_query(|conn| diesel::delete(identity_updates::table).execute(conn)) + conn.raw_query_write(|conn| diesel::delete(identity_updates::table).execute(conn)) .unwrap(); let members = group.members().await.unwrap(); @@ -1424,6 +1424,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(amal_group.members().await.unwrap().len(), 1); tracing::info!("Syncing bolas welcomes"); + // See if Bola can see that they were added to the group bola.sync_welcomes(&bola.mls_provider().unwrap()) .await diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index bbba948d4..273d8419d 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -63,10 +63,12 @@ impl MlsGroup { intent_kind: IntentKind, intent_data: Vec, ) -> Result { - provider.transaction(|provider| { + let res = provider.transaction(|provider| { let conn = provider.conn_ref(); self.queue_intent_with_conn(conn, intent_kind, intent_data) - }) + }); + + res } fn queue_intent_with_conn( diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index a5dd9085a..1cd041861 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -2216,7 +2216,7 @@ pub(crate) mod tests { // The dm shows up let alix_groups = alix_conn - .raw_query(|conn| groups::table.load::(conn)) + .raw_query_read(|conn| groups::table.load::(conn)) .unwrap(); assert_eq!(alix_groups.len(), 2); // They should have the same ID @@ -3787,7 +3787,7 @@ pub(crate) mod tests { let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into(); let conn_2 = bo.store().conn().unwrap(); conn_2 - .raw_query(|c| { + .raw_query_write(|c| { c.batch_execute("BEGIN EXCLUSIVE").unwrap(); Ok::<_, diesel::result::Error>(()) }) @@ -3795,10 +3795,10 @@ pub(crate) mod tests { let process_result = bo_group.process_messages(bo_messages, &conn_1).await; if let Some(GroupError::ReceiveErrors(errors)) = process_result.err() { - assert_eq!(errors.len(), 2); - assert!(errors - .iter() - .any(|err| err.to_string().contains("database is locked"))); + assert_eq!(errors.len(), 1); + assert!(errors.iter().any(|err| err + .to_string() + .contains("cannot start a transaction within a transaction"))); } else { panic!("Expected error") } diff --git a/xmtp_mls/src/storage/encrypted_store/association_state.rs b/xmtp_mls/src/storage/encrypted_store/association_state.rs index b09158de6..9f7583b0f 100644 --- a/xmtp_mls/src/storage/encrypted_store/association_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/association_state.rs @@ -109,7 +109,7 @@ impl StoredAssociationState { ); let association_states = - conn.raw_query(|query_conn| query.load::(query_conn))?; + conn.raw_query_read(|query_conn| query.load::(query_conn))?; association_states .into_iter() diff --git a/xmtp_mls/src/storage/encrypted_store/consent_record.rs b/xmtp_mls/src/storage/encrypted_store/consent_record.rs index a0789d9e3..c70359473 100644 --- a/xmtp_mls/src/storage/encrypted_store/consent_record.rs +++ b/xmtp_mls/src/storage/encrypted_store/consent_record.rs @@ -48,7 +48,7 @@ impl DbConnection { entity: String, entity_type: ConsentType, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| -> diesel::QueryResult<_> { + Ok(self.raw_query_read(|conn| -> diesel::QueryResult<_> { dsl::consent_records .filter(dsl::entity.eq(entity)) .filter(dsl::entity_type.eq(entity_type)) @@ -77,7 +77,7 @@ impl DbConnection { ); } - let changed = self.raw_query(|conn| -> diesel::QueryResult<_> { + let changed = self.raw_query_write(|conn| -> diesel::QueryResult<_> { let existing: Vec = query.load(conn)?; let changed: Vec<_> = records .iter() @@ -107,7 +107,7 @@ impl DbConnection { &self, record: &StoredConsentRecord, ) -> Result, StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { let maybe_inserted_consent_record: Option = diesel::insert_into(dsl::consent_records) .values(record) diff --git a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs index d2fc136ba..30ca8d7d0 100644 --- a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs +++ b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs @@ -139,7 +139,7 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } else { // Only include the specified states let query = query @@ -153,11 +153,11 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } } else { // Handle the case where `consent_states` is `None` - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -165,7 +165,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = conversation_list_dsl::conversation_list .filter(conversation_list_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(|conn| query.load(conn))?; + let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?; conversations.append(&mut sync_groups); } diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 1cb3e9ce9..d42f656d9 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,8 +1,15 @@ +use crate::storage::{xmtp_openmls_provider::XmtpOpenMlsProvider, StorageError}; +use diesel::connection::TransactionManager; use parking_lot::Mutex; -use std::fmt; -use std::sync::Arc; +use std::{ + fmt, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; -use crate::storage::xmtp_openmls_provider::XmtpOpenMlsProvider; +use super::XmtpDb; #[cfg(not(target_arch = "wasm32"))] pub type DbConnection = DbConnectionPrivate; @@ -19,14 +26,30 @@ pub type DbConnection = DbConnectionPrivate { - inner: Arc>, + // Connection with read-only privileges + read: Option>>, + // Connection with write privileges + write: Arc>, + // Is any connection (possibly this one) currently in a transaction? + global_transaction_lock: Arc>, + // Is this particular connection in a transaction? + in_transaction: Arc, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(conn: Arc>) -> Self { - Self { inner: conn } + pub(super) fn from_arc_mutex( + write: Arc>, + read: Option>>, + transaction_lock: Arc>, + ) -> Self { + Self { + read, + write, + global_transaction_lock: transaction_lock, + in_transaction: Arc::new(AtomicBool::new(false)), + } } } @@ -34,22 +57,50 @@ impl DbConnectionPrivate where C: diesel::Connection, { + pub(crate) fn start_transaction>( + &self, + ) -> Result, StorageError> { + let guard = self.global_transaction_lock.lock(); + let mut write = self.write.lock(); + ::TransactionManager::begin_transaction(&mut *write)?; + self.in_transaction.store(true, Ordering::SeqCst); + + Ok(TransactionGuard { + _mutex_guard: guard, + in_transaction: self.in_transaction.clone(), + }) + } + /// Do a scoped query with a mutable [`diesel::Connection`] /// reference - pub(crate) fn raw_query(&self, fun: F) -> Result + pub(crate) fn raw_query_read(&self, fun: F) -> Result where F: FnOnce(&mut C) -> Result, { - let mut lock = self.inner.lock(); + let mut lock = if self.in_transaction.load(Ordering::SeqCst) { + self.write.lock() + } else if let Some(read) = &self.read { + read.lock() + } else { + self.write.lock() + }; + fun(&mut lock) } - /// Internal-only API to get the underlying `diesel::Connection` reference - /// without a scope - /// Must be used with care. holding this reference while calling `raw_query` - /// will cause a deadlock. - pub(super) fn inner_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { - self.inner.lock() + /// Do a scoped query with a mutable [`diesel::Connection`] + /// reference + pub(crate) fn raw_query_write(&self, fun: F) -> Result + where + F: FnOnce(&mut C) -> Result, + { + let _guard; + // If this connection is not in a transaction + if !self.in_transaction.load(Ordering::SeqCst) { + // Make sure another connection isn't + _guard = self.global_transaction_lock.lock(); + } + fun(&mut self.write.lock()) } } @@ -71,3 +122,13 @@ impl fmt::Debug for DbConnectionPrivate { .finish() } } + +pub struct TransactionGuard<'a> { + in_transaction: Arc, + _mutex_guard: parking_lot::MutexGuard<'a, ()>, +} +impl Drop for TransactionGuard<'_> { + fn drop(&mut self) { + self.in_transaction.store(false, Ordering::SeqCst); + } +} diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index ab365d0f6..08c694095 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -302,7 +302,7 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } else { // Only include the specified states let query = query @@ -315,11 +315,11 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } } else { // Handle the case where `consent_states` is `None` - self.raw_query(|conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -327,7 +327,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = groups_dsl::groups.filter(groups_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(|conn| query.load(conn))?; + let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?; groups.append(&mut sync_groups); } @@ -335,7 +335,7 @@ impl DbConnection { } pub fn consent_records(&self) -> Result, StorageError> { - Ok(self.raw_query(|conn| super::schema::consent_records::table.load(conn))?) + Ok(self.raw_query_read(|conn| super::schema::consent_records::table.load(conn))?) } pub fn all_sync_groups(&self) -> Result, StorageError> { @@ -343,7 +343,7 @@ impl DbConnection { .order(dsl::created_at_ns.desc()) .filter(dsl::conversation_type.eq(ConversationType::Sync)); - Ok(self.raw_query(|conn| query.load(conn))?) + Ok(self.raw_query_read(|conn| query.load(conn))?) } pub fn latest_sync_group(&self) -> Result, StorageError> { @@ -352,16 +352,18 @@ impl DbConnection { .filter(dsl::conversation_type.eq(ConversationType::Sync)) .limit(1); - Ok(self.raw_query(|conn| query.load(conn))?.pop()) + Ok(self.raw_query_read(|conn| query.load(conn))?.pop()) } /// Return a single group that matches the given ID + pub fn find_group(&self, id: &[u8]) -> Result, StorageError> { let query = dsl::groups .order(dsl::created_at_ns.asc()) .limit(1) .filter(dsl::id.eq(id)); - let groups = self.raw_query(|conn| query.load(conn))?; + let groups = self.raw_query_read(|conn| query.load(conn))?; + Ok(groups.into_iter().next()) } @@ -374,7 +376,8 @@ impl DbConnection { .order(dsl::created_at_ns.asc()) .filter(dsl::welcome_id.eq(welcome_id)); - let groups = self.raw_query(|conn| query.load(conn))?; + let groups = self.raw_query_read(|conn| query.load(conn))?; + if groups.len() > 1 { tracing::warn!( welcome_id, @@ -394,7 +397,7 @@ impl DbConnection { .filter(dsl::dm_id.eq(Some(dm_id))) .order(dsl::last_message_ns.desc()); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; if groups.len() > 1 { tracing::info!("More than one group found for dm_inbox_id {members:?}"); } @@ -408,7 +411,7 @@ impl DbConnection { group_id: GroupId, state: GroupMembershipState, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::update(dsl::groups.find(group_id.as_ref())) .set(dsl::membership_state.eq(state)) .execute(conn) @@ -418,7 +421,7 @@ impl DbConnection { } pub fn get_rotated_at_ns(&self, group_id: Vec) -> Result { - let last_ts: Option = self.raw_query(|conn| { + let last_ts: Option = self.raw_query_read(|conn| { let ts = dsl::groups .find(&group_id) .select(dsl::rotated_at_ns) @@ -434,7 +437,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_rotated_at_ns(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::rotated_at_ns.eq(now)) @@ -445,7 +448,7 @@ impl DbConnection { } pub fn get_installations_time_checked(&self, group_id: Vec) -> Result { - let last_ts = self.raw_query(|conn| { + let last_ts = self.raw_query_read(|conn| { let ts = dsl::groups .find(&group_id) .select(dsl::installations_last_checked) @@ -459,7 +462,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_installations_time_checked(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::installations_last_checked.eq(now)) @@ -474,7 +477,7 @@ impl DbConnection { group_id: Vec, from_ns: Option, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::update(dsl::groups.find(&group_id)) .set(dsl::message_disappear_from_ns.eq(from_ns)) .execute(conn) @@ -488,7 +491,7 @@ impl DbConnection { group_id: Vec, in_ns: Option, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::update(dsl::groups.find(&group_id)) .set(dsl::message_disappear_in_ns.eq(in_ns)) .execute(conn) @@ -499,7 +502,7 @@ impl DbConnection { pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result { tracing::info!("Trying to insert group"); - let stored_group = self.raw_query(|conn| { + let stored_group = self.raw_query_write(|conn| { let maybe_inserted_group: Option = diesel::insert_into(dsl::groups) .values(&group) .on_conflict_do_nothing() @@ -533,7 +536,7 @@ impl DbConnection { /// Get all the welcome ids turned into groups pub(crate) fn group_welcome_ids(&self) -> Result, StorageError> { - self.raw_query(|conn| { + self.raw_query_read(|conn| { Ok::<_, StorageError>( dsl::groups .filter(dsl::welcome_id.is_not_null()) @@ -747,7 +750,7 @@ pub(crate) mod tests { test_group.store(conn).unwrap(); assert_eq!( - conn.raw_query(|raw_conn| groups.first::(raw_conn)) + conn.raw_query_read(|raw_conn| groups.first::(raw_conn)) .unwrap(), test_group ); @@ -761,7 +764,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(|raw_conn| { + conn.raw_query_write(|raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) @@ -941,7 +944,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(|raw_conn| { + conn.raw_query_write(|raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 743eccc5e..a02d99324 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -119,8 +119,9 @@ impl_fetch!(StoredGroupIntent, group_intents, ID); impl Delete for DbConnection { type Key = ID; fn delete(&self, key: ID) -> Result { - Ok(self - .raw_query(|raw_conn| diesel::delete(dsl::group_intents.find(key)).execute(raw_conn))?) + Ok(self.raw_query_write(|raw_conn| { + diesel::delete(dsl::group_intents.find(key)).execute(raw_conn) + })?) } } @@ -155,7 +156,7 @@ impl DbConnection { &self, to_save: NewGroupIntent, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write(|conn| { diesel::insert_into(dsl::group_intents) .values(to_save) .get_result(conn) @@ -184,7 +185,7 @@ impl DbConnection { query = query.order(dsl::id.asc()); - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query_read(|conn| query.load::(conn))?) } // Set the intent with the given ID to `Published` and set the payload hash. Optionally add @@ -197,7 +198,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -214,7 +215,7 @@ impl DbConnection { })?; if rows_changed == 0 { - let already_published = self.raw_query(|conn| { + let already_published = self.raw_query_read(|conn| { dsl::group_intents .filter(dsl::id.eq(intent_id)) .first::(conn) @@ -231,7 +232,7 @@ impl DbConnection { // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -252,7 +253,7 @@ impl DbConnection { // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -278,7 +279,7 @@ impl DbConnection { /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) @@ -298,7 +299,7 @@ impl DbConnection { &self, payload_hash: Vec, ) -> Result, StorageError> { - let result = self.raw_query(|conn| { + let result = self.raw_query_read(|conn| { dsl::group_intents .filter(dsl::payload_hash.eq(payload_hash)) .first::(conn) @@ -312,7 +313,7 @@ impl DbConnection { &self, intent_id: ID, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::publish_attempts.eq(dsl::publish_attempts + 1)) @@ -431,7 +432,7 @@ pub(crate) mod tests { } fn find_first_intent(conn: &DbConnection, group_id: group::ID) -> StoredGroupIntent { - conn.raw_query(|raw_conn| { + conn.raw_query_read(|raw_conn| { dsl::group_intents .filter(dsl::group_id.eq(group_id)) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index 1dcda5674..55be5e087 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -292,7 +292,7 @@ impl DbConnection { query = query.limit(limit); } - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query_read(|conn| query.load::(conn))?) } /// Query for group messages with their reactions @@ -343,7 +343,7 @@ impl DbConnection { }; let reactions: Vec = - self.raw_query(|conn| reactions_query.load(conn))?; + self.raw_query_read(|conn| reactions_query.load(conn))?; // Group reactions by parent message id let mut reactions_by_reference: HashMap, Vec> = HashMap::new(); @@ -379,7 +379,7 @@ impl DbConnection { &self, id: MessageId, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_read(|conn| { dsl::group_messages .filter(dsl::id.eq(id.as_ref())) .first(conn) @@ -392,7 +392,7 @@ impl DbConnection { group_id: GroupId, timestamp: i64, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_read(|conn| { dsl::group_messages .filter(dsl::group_id.eq(group_id.as_ref())) .filter(dsl::sent_at_ns.eq(timestamp)) @@ -406,7 +406,7 @@ impl DbConnection { msg_id: &MessageId, timestamp: u64, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write(|conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set(( @@ -421,7 +421,7 @@ impl DbConnection { &self, msg_id: &MessageId, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write(|conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set((dsl::delivery_status.eq(DeliveryStatus::Failed),)) @@ -430,7 +430,7 @@ impl DbConnection { } pub fn delete_expired_messages(&self) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write(|conn| { use diesel::prelude::*; let disappear_from_ns = groups_dsl::message_disappear_from_ns .assume_not_null() @@ -565,7 +565,7 @@ pub(crate) mod tests { } let count: i64 = conn - .raw_query(|raw_conn| { + .raw_query_read(|raw_conn| { dsl::group_messages .select(diesel::dsl::count_star()) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/identity_update.rs b/xmtp_mls/src/storage/encrypted_store/identity_update.rs index 0a0c87b0a..2c1e8749a 100644 --- a/xmtp_mls/src/storage/encrypted_store/identity_update.rs +++ b/xmtp_mls/src/storage/encrypted_store/identity_update.rs @@ -72,7 +72,7 @@ impl DbConnection { query = query.filter(dsl::sequence_id.le(sequence_id)); } - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query_read(|conn| query.load::(conn))?) } /// Batch insert identity updates, ignoring duplicates. @@ -81,7 +81,7 @@ impl DbConnection { &self, updates: &[StoredIdentityUpdate], ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::insert_or_ignore_into(dsl::identity_updates) .values(updates) .execute(conn)?; @@ -98,7 +98,7 @@ impl DbConnection { .filter(dsl::inbox_id.eq(inbox_id)) .into_boxed(); - Ok(self.raw_query(|conn| query.first::(conn))?) + Ok(self.raw_query_read(|conn| query.first::(conn))?) } /// Given a list of inbox_ids return a HashMap of each inbox ID -> highest known sequence ID @@ -115,7 +115,7 @@ impl DbConnection { // Get the results as a Vec of (inbox_id, sequence_id) tuples let result_tuples: Vec<(String, i64)> = self - .raw_query(|conn| query.load::<(String, Option)>(conn))? + .raw_query_read(|conn| query.load::<(String, Option)>(conn))? .into_iter() // Diesel needs an Option type for aggregations like max(sequence_id), so we // unwrap the option here diff --git a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs index b32975370..a56d20f9a 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs @@ -39,7 +39,7 @@ impl DbConnection { &self, hash_ref: Vec, ) -> Result { - let result = self.raw_query(|conn| { + let result = self.raw_query_read(|conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::key_package_hash_ref.eq(hash_ref)) .first::(conn) @@ -52,7 +52,7 @@ impl DbConnection { &self, id: i32, ) -> Result, StorageError> { - let result = self.raw_query(|conn| { + let result = self.raw_query_read(|conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)) .load::(conn) @@ -65,7 +65,7 @@ impl DbConnection { &self, id: i32, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::delete( key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)), @@ -77,7 +77,7 @@ impl DbConnection { } pub fn delete_key_package_entry_with_id(&self, id: i32) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::delete( key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.eq(id)), diff --git a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs index 7fc4deb3d..44f1027c9 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs @@ -18,7 +18,7 @@ impl Delete for DbConnection { type Key = Vec; fn delete(&self, key: Vec) -> Result where { use super::schema::openmls_key_store::dsl::*; - Ok(self.raw_query(|conn| { + Ok(self.raw_query_write(|conn| { diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(conn) })?) } @@ -36,7 +36,7 @@ impl DbConnection { value_bytes: value, }; - self.raw_query(|conn| { + self.raw_query_write(|conn| { diesel::replace_into(openmls_key_store) .values(entry) .execute(conn) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index f252d96a1..a615ed102 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -178,7 +178,7 @@ pub mod private { #[tracing::instrument(level = "trace", skip_all)] pub(super) fn init_db(&mut self) -> Result<(), StorageError> { self.db.validate(&self.opts)?; - self.db.conn()?.raw_query(|conn| { + self.db.conn()?.raw_query_write(|conn| { conn.batch_execute("PRAGMA journal_mode = WAL;")?; tracing::info!("Running DB migrations"); conn.run_pending_migrations(MIGRATIONS)?; @@ -241,7 +241,7 @@ macro_rules! impl_fetch { type Key = (); fn fetch(&self, _key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(|conn| $table.first(conn).optional())?) + Ok(self.raw_query_read(|conn| $table.first(conn).optional())?) } } }; @@ -253,7 +253,7 @@ macro_rules! impl_fetch { type Key = $key; fn fetch(&self, key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(|conn| $table.find(key.clone()).first(conn).optional())?) + Ok(self.raw_query_read(|conn| $table.find(key.clone()).first(conn).optional())?) } } }; @@ -285,8 +285,9 @@ macro_rules! impl_fetch_list_with_key { keys: &[Self::Key], ) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::{$column, *}; - Ok(self - .raw_query(|conn| $table.filter($column.eq_any(keys)).load::<$model>(conn))?) + Ok(self.raw_query_read(|conn| { + $table.filter($column.eq_any(keys)).load::<$model>(conn) + })?) } } }; @@ -303,7 +304,7 @@ macro_rules! impl_store { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(|conn| { + into.raw_query_write(|conn| { diesel::insert_into($table::table) .values(self) .execute(conn) @@ -325,7 +326,7 @@ macro_rules! impl_store_or_ignore { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(|conn| { + into.raw_query_write(|conn| { diesel::insert_or_ignore_into($table::table) .values(self) .execute(conn) @@ -383,17 +384,13 @@ where E: From + From, { tracing::debug!("Transaction beginning"); - { - let connection = self.conn_ref(); - let mut connection = connection.inner_mut_ref(); - ::TransactionManager::begin_transaction(&mut *connection)?; - } let conn = self.conn_ref(); + let _guard = conn.start_transaction::()?; match fun(self) { Ok(value) => { - conn.raw_query(|conn| { + conn.raw_query_write(|conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction being committed"); @@ -401,7 +398,7 @@ where } Err(err) => { tracing::debug!("Transaction being rolled back"); - match conn.raw_query(|conn| { + match conn.raw_query_write(|conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -551,7 +548,7 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(|conn| { + .raw_query_write(|conn| { for _ in 0..15 { conn.run_next_migration(MIGRATIONS)?; } @@ -597,14 +594,14 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(|conn| { + .raw_query_write(|conn| { conn.run_pending_migrations(MIGRATIONS)?; Ok::<_, StorageError>(()) }) .unwrap(); let groups = conn - .raw_query(|conn| groups::table.load::(conn)) + .raw_query_read(|conn| groups::table.load::(conn)) .unwrap(); assert_eq!(groups.len(), 1); assert_eq!(&**groups[0].dm_id.as_ref().unwrap(), "dm:98765:inbox_id"); @@ -667,78 +664,4 @@ pub(crate) mod tests { } EncryptedMessageStore::remove_db_files(db_path) } - - // get two connections - // start a transaction - // try to write with second connection - // write should fail & rollback - // first thread succeeds - // wasm does not have threads - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - #[cfg(not(target_arch = "wasm32"))] - async fn test_transaction_rollback() { - use crate::XmtpOpenMlsProvider; - use std::sync::{Arc, Barrier}; - - let db_path = tmp_path(); - let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), - EncryptedMessageStore::generate_enc_key(), - ) - .await - .unwrap(); - - let barrier = Arc::new(Barrier::new(2)); - let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); - let barrier_pointer = barrier.clone(); - let handle = std::thread::spawn(move || { - provider.transaction(|provider| { - let conn1 = provider.conn_ref(); - StoredIdentity::new("correct".to_string(), rand_vec::<24>(), rand_vec::<24>()) - .store(conn1) - .unwrap(); - // wait for second transaction to start - barrier_pointer.wait(); - // wait for second transaction to finish - barrier_pointer.wait(); - Ok::<_, StorageError>(()) - }) - }); - - let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); - let handle2 = std::thread::spawn(move || { - barrier.wait(); - let result = provider.transaction(|provider| -> Result<(), anyhow::Error> { - let connection = provider.conn_ref(); - let group = StoredGroup::new( - b"should not exist".to_vec(), - 0, - GroupMembershipState::Allowed, - "goodbye".to_string(), - None, - ); - group.store(connection)?; - Ok(()) - }); - barrier.wait(); - result - }); - - let result = handle.join().unwrap(); - assert!(result.is_ok()); - - let result = handle2.join().unwrap(); - - // handle 2 errored because the first transaction has precedence - assert_eq!( - result.unwrap_err().to_string(), - "Diesel result error: database is locked" - ); - let groups = store - .conn() - .unwrap() - .find_group(b"should not exist") - .unwrap(); - assert_eq!(groups, None); - } } diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 635b1b4c7..bf3aab9cc 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -7,7 +7,7 @@ use diesel::{ r2d2::{self, CustomizeConnection, PoolTransactionManager, PooledConnection}, Connection, }; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use std::sync::Arc; pub type ConnectionManager = r2d2::ConnectionManager; @@ -64,7 +64,7 @@ impl ValidatedConnection for UnencryptedConnection {} impl CustomizeConnection for UnencryptedConnection { fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> { - conn.batch_execute("PRAGMA busy_timeout = 5000;") + conn.batch_execute("PRAGMA query_only = ON; PRAGMA busy_timeout = 5000;") .map_err(r2d2::Error::QueryError)?; Ok(()) } @@ -89,10 +89,12 @@ impl StorageOption { } } -#[derive(Clone, Debug)] +#[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { pub(super) pool: Arc>>, + pub(super) write_conn: Arc>, + transaction_lock: Arc>, customizer: Option>, opts: StorageOption, } @@ -106,9 +108,9 @@ impl NativeDb { let mut builder = Pool::builder(); let customizer = if let Some(key) = enc_key { - let enc_opts = EncryptedConnection::new(key, opts)?; - builder = builder.connection_customizer(Box::new(enc_opts.clone())); - Some(Box::new(enc_opts) as Box) + let enc_connection = EncryptedConnection::new(key, opts)?; + builder = builder.connection_customizer(Box::new(enc_connection.clone())); + Some(Box::new(enc_connection) as Box) } else if matches!(opts, StorageOption::Persistent(_)) { builder = builder.connection_customizer(Box::new(UnencryptedConnection)); Some(Box::new(UnencryptedConnection) as Box) @@ -125,8 +127,14 @@ impl NativeDb { .build(ConnectionManager::new(path))?, }; + // Take one of the connections and use it as the only writer. + let mut write_conn = pool.get()?; + write_conn.batch_execute("PRAGMA query_only = OFF;")?; + Ok(Self { pool: Arc::new(Some(pool).into()), + write_conn: Arc::new(Mutex::new(write_conn)), + transaction_lock: Arc::new(Mutex::new(())), customizer, opts: opts.clone(), }) @@ -155,10 +163,16 @@ impl XmtpDb for NativeDb { /// Returns the Wrapped [`super::db_connection::DbConnection`] Connection implementation for this Database fn conn(&self) -> Result, StorageError> { - let conn = self.raw_conn()?; - Ok(DbConnectionPrivate::from_arc_mutex(Arc::new( - parking_lot::Mutex::new(conn), - ))) + let conn = match self.opts { + StorageOption::Ephemeral => None, + StorageOption::Persistent(_) => Some(self.raw_conn()?), + }; + + Ok(DbConnectionPrivate::from_arc_mutex( + self.write_conn.clone(), + conn.map(|conn| Arc::new(parking_lot::Mutex::new(conn))), + self.transaction_lock.clone(), + )) } fn validate(&self, opts: &StorageOption) -> Result<(), StorageError> { diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index 44f584232..248248c8d 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -74,7 +74,7 @@ impl DbConnection { entity_kind: EntityKind, ) -> Result, StorageError> { use super::schema::refresh_state::dsl; - let res = self.raw_query(|conn| { + let res = self.raw_query_read(|conn| { dsl::refresh_state .find((entity_id.as_ref(), entity_kind)) .first(conn) @@ -115,7 +115,7 @@ impl DbConnection { NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), )?; - let num_updated = self.raw_query(|conn| { + let num_updated = self.raw_query_write(|conn| { diesel::update(&state) .filter(dsl::cursor.lt(cursor)) .set(dsl::cursor.eq(cursor)) diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 6723f0df1..184a0c4f9 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -281,6 +281,7 @@ impl diesel::r2d2::CustomizeConnection fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), diesel::r2d2::Error> { conn.batch_execute(&format!( "{} + PRAGMA query_only = ON; PRAGMA busy_timeout = 5000;", self.pragmas() )) diff --git a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs index ef273fe76..4a9be10c9 100644 --- a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs +++ b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs @@ -37,7 +37,7 @@ impl<'a> From<&'a StoredUserPreferences> for NewStoredUserPreferences<'a> { impl Store for StoredUserPreferences { fn store(&self, conn: &DbConnection) -> Result<(), StorageError> { - conn.raw_query(|conn| { + conn.raw_query_write(|conn| { diesel::update(dsl::user_preferences) .set(self) .execute(conn) @@ -50,7 +50,7 @@ impl Store for StoredUserPreferences { impl StoredUserPreferences { pub fn load(conn: &DbConnection) -> Result { let query = dsl::user_preferences.order(dsl::id.desc()).limit(1); - let mut result = conn.raw_query(|conn| query.load::(conn))?; + let mut result = conn.raw_query_read(|conn| query.load::(conn))?; Ok(result.pop().unwrap_or_default()) } @@ -73,7 +73,7 @@ impl StoredUserPreferences { ])); let to_insert: NewStoredUserPreferences = (&preferences).into(); - conn.raw_query(|conn| { + conn.raw_query_write(|conn| { diesel::insert_into(dsl::user_preferences) .values(to_insert) .execute(conn) @@ -115,7 +115,7 @@ mod tests { // check that there are two preferences stored let query = dsl::user_preferences.order(dsl::id.desc()); let result = conn - .raw_query(|conn| query.load::(conn)) + .raw_query_read(|conn| query.load::(conn)) .unwrap(); assert_eq!(result.len(), 1); } diff --git a/xmtp_mls/src/storage/encrypted_store/wasm.rs b/xmtp_mls/src/storage/encrypted_store/wasm.rs index 3cc984fde..558dc88f6 100644 --- a/xmtp_mls/src/storage/encrypted_store/wasm.rs +++ b/xmtp_mls/src/storage/encrypted_store/wasm.rs @@ -11,6 +11,7 @@ use super::{db_connection::DbConnectionPrivate, StorageError, StorageOption, Xmt pub struct WasmDb { conn: Arc>, opts: StorageOption, + transaction_lock: Arc>, } impl std::fmt::Debug for WasmDb { @@ -33,6 +34,7 @@ impl WasmDb { Ok(Self { conn: Arc::new(Mutex::new(conn)), opts: opts.clone(), + transaction_lock: Arc::new(Mutex::new(())), }) } } @@ -42,7 +44,11 @@ impl XmtpDb for WasmDb { type TransactionManager = AnsiTransactionManager; fn conn(&self) -> Result, StorageError> { - Ok(DbConnectionPrivate::from_arc_mutex(self.conn.clone())) + Ok(DbConnectionPrivate::from_arc_mutex( + self.conn.clone(), + None, + self.transaction_lock.clone(), + )) } fn validate(&self, _opts: &StorageOption) -> Result<(), StorageError> { diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 99c4f1689..9de05523e 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -11,13 +11,13 @@ pub use errors::*; impl DbConnection { #[allow(unused)] pub(crate) fn enable_readonly(&self) -> Result<(), StorageError> { - self.raw_query(|conn| conn.batch_execute("PRAGMA query_only = ON;"))?; + self.raw_query_write(|conn| conn.batch_execute("PRAGMA query_only = ON;"))?; Ok(()) } #[allow(unused)] pub(crate) fn disable_readonly(&self) -> Result<(), StorageError> { - self.raw_query(|conn| conn.batch_execute("PRAGMA query_only = OFF;"))?; + self.raw_query_write(|conn| conn.batch_execute("PRAGMA query_only = OFF;"))?; Ok(()) } } @@ -72,12 +72,12 @@ pub mod test_util { for query in queries { let query = diesel::sql_query(query); - let _ = self.raw_query(|conn| query.execute(conn)).unwrap(); + let _ = self.raw_query_write(|conn| query.execute(conn)).unwrap(); } } pub fn intents_published(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query( "SELECT intents_published FROM test_metadata WHERE rowid = 1", @@ -93,7 +93,7 @@ pub mod test_util { } pub fn intents_deleted(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query("SELECT intents_deleted FROM test_metadata")) .unwrap(); @@ -107,7 +107,7 @@ pub mod test_util { } pub fn intents_created(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query("SELECT intents_created FROM test_metadata")) .unwrap(); diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 7fc9ed1c5..c9c2b40df 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -49,7 +49,7 @@ where &self, storage_key: &Vec, ) -> Result, diesel::result::Error> { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query_read(|conn| { sql_query(SELECT_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -62,7 +62,7 @@ where storage_key: &Vec, value: &[u8], ) -> Result { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query_write(|conn| { sql_query(REPLACE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -76,7 +76,7 @@ where storage_key: &Vec, modified_data: &Vec, ) -> Result { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query_write(|conn| { sql_query(UPDATE_QUERY) .bind::(&modified_data) .bind::(&storage_key) @@ -224,7 +224,7 @@ where ) -> Result<(), >::Error> { let storage_key = build_key_from_vec::(label, key.to_vec()); - let _ = self.conn_ref().raw_query(|conn| { + let _ = self.conn_ref().raw_query_write(|conn| { sql_query(DELETE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -809,7 +809,7 @@ where let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; - let data: Vec = self.conn_ref().raw_query(|conn| { + let data: Vec = self.conn_ref().raw_query_read(|conn| { sql_query(query) .bind::(&storage_key) .bind::(CURRENT_VERSION as i32)