From 4dca047b62c0266e872d2d7be41e14f6bb0e92ea Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 13:29:25 -0500 Subject: [PATCH 01/38] wip --- xmtp_mls/src/builder.rs | 61 ++++++++++++---- xmtp_mls/src/groups/mod.rs | 4 +- .../storage/encrypted_store/db_connection.rs | 23 ++++--- xmtp_mls/src/storage/encrypted_store/mod.rs | 69 ++++++++++++++----- .../src/storage/encrypted_store/native.rs | 15 ++-- .../encrypted_store/sqlcipher_connection.rs | 54 +++++++++++---- xmtp_mls/src/storage/sql_key_store.rs | 15 +++- xmtp_mls/src/utils/test/mod.rs | 5 +- 8 files changed, 182 insertions(+), 64 deletions(-) diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index d767f7573..dcf785b01 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -464,7 +464,10 @@ pub(crate) mod tests { Some(legacy_key.clone()), ); let store = EncryptedMessageStore::new( - StorageOption::Persistent(tmp_path()), + StorageOption::Persistent { + path: tmp_path(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -533,7 +536,10 @@ pub(crate) mod tests { let scw_verifier = MockSmartContractSignatureVerifier::new(true); let store = EncryptedMessageStore::new( - StorageOption::Persistent(tmpdb), + StorageOption::Persistent { + path: tmpdb, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -574,7 +580,10 @@ pub(crate) mod tests { let scw_verifier = MockSmartContractSignatureVerifier::new(true); let store = EncryptedMessageStore::new( - StorageOption::Persistent(tmpdb), + StorageOption::Persistent { + path: tmpdb, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -614,7 +623,10 @@ pub(crate) mod tests { let scw_verifier = MockSmartContractSignatureVerifier::new(true); let store = EncryptedMessageStore::new( - StorageOption::Persistent(tmpdb), + StorageOption::Persistent { + path: tmpdb, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -654,7 +666,10 @@ pub(crate) mod tests { let tmpdb = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent(tmpdb), + StorageOption::Persistent { + path: tmpdb, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -694,9 +709,15 @@ pub(crate) mod tests { let db_key = EncryptedMessageStore::generate_enc_key(); // Generate a new Wallet + Store - let store_a = EncryptedMessageStore::new(StorageOption::Persistent(tmpdb.clone()), db_key) - .await - .unwrap(); + let store_a = EncryptedMessageStore::new( + StorageOption::Persistent { + path: tmpdb.clone(), + read_only: false, + }, + db_key, + ) + .await + .unwrap(); let nonce = 1; let inbox_id = generate_inbox_id(&wallet.get_address(), &nonce).unwrap(); @@ -720,9 +741,15 @@ pub(crate) mod tests { drop(client_a); // Reload the existing store and wallet - let store_b = EncryptedMessageStore::new(StorageOption::Persistent(tmpdb.clone()), db_key) - .await - .unwrap(); + let store_b = EncryptedMessageStore::new( + StorageOption::Persistent { + path: tmpdb.clone(), + read_only: false, + }, + db_key, + ) + .await + .unwrap(); let client_b = Client::builder(IdentityStrategy::new( inbox_id, @@ -759,9 +786,15 @@ pub(crate) mod tests { // .expect_err("Testing expected mismatch error"); // Use cached only strategy - let store_d = EncryptedMessageStore::new(StorageOption::Persistent(tmpdb.clone()), db_key) - .await - .unwrap(); + let store_d = EncryptedMessageStore::new( + StorageOption::Persistent { + path: tmpdb.clone(), + read_only: false, + }, + db_key, + ) + .await + .unwrap(); let client_d = Client::builder(IdentityStrategy::CachedOnly) .api_client(::create_local().await) .store(store_d) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 5a04373d2..59f1d0541 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1250,13 +1250,13 @@ impl MlsGroup { state, hex::encode(self.group_id.clone()), ); - let new_records = conn + let new_records: Vec<_> = conn .insert_or_replace_consent_records(&[consent_record.clone()])? .into_iter() .map(UserPreferenceUpdate::ConsentUpdate) .collect(); - if self.client.history_sync_url().is_some() { + if !new_records.is_empty() && self.client.history_sync_url().is_some() { // Dispatch an update event so it can be synced across devices let _ = self .client diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 045b897a9..dd76ce9c6 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -19,14 +19,15 @@ pub type DbConnection = DbConnectionPrivate { - inner: Arc>, + write: Arc>, + read: Arc>, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(conn: Arc>) -> Self { - Self { inner: conn } + pub(super) fn from_arc_mutex(read: Arc>, write: Arc>) -> Self { + Self { read, write } } } @@ -40,7 +41,7 @@ where where F: FnOnce(&mut C) -> Result, { - let mut lock = self.inner.lock(); + let mut lock = self.read.lock(); fun(&mut lock) } @@ -48,14 +49,20 @@ where /// without a scope /// Must be used with care. holding this reference while calling `raw_query` /// will cause a deadlock. - pub(super) fn inner_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { - self.inner.lock() + pub(super) fn read_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { + self.read.lock() } /// Internal-only API to get the underlying `diesel::Connection` reference /// without a scope - pub(super) fn inner_ref(&self) -> Arc> { - self.inner.clone() + pub(super) fn read_ref(&self) -> Arc> { + self.read.clone() + } + + /// Internal-only API to get the underlying `diesel::Connection` reference + /// without a scope + pub(super) fn write_ref(&self) -> Arc> { + self.write.clone() } } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 9a96252a6..8cda6e602 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -74,7 +74,10 @@ struct SqliteVersion { pub enum StorageOption { #[default] Ephemeral, - Persistent(String), + Persistent { + path: String, + read_only: bool, + }, } #[allow(async_fn_in_trait)] @@ -393,7 +396,7 @@ where tracing::debug!("Transaction beginning"); { let connection = self.conn_ref(); - let mut connection = connection.inner_mut_ref(); + let mut connection = connection.read_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; } @@ -444,13 +447,13 @@ where tracing::debug!("Transaction async beginning"); { let connection = self.conn_ref(); - let mut connection = connection.inner_mut_ref(); + let mut connection = connection.read_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; } // ensuring we have only one strong reference let result = fun(self).await; - let local_connection = self.conn_ref().inner_ref(); + let local_connection = self.conn_ref().read_ref(); if Arc::strong_count(&local_connection) > 1 { tracing::warn!( "More than 1 strong connection references still exist during async transaction" @@ -524,7 +527,10 @@ pub(crate) mod tests { pub async fn new_test() -> Self { let tmp_path = tmp_path(); EncryptedMessageStore::new( - StorageOption::Persistent(tmp_path), + StorageOption::Persistent { + path: tmp_path, + read_only: true, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -556,7 +562,10 @@ pub(crate) mod tests { let db_path = tmp_path(); { let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -579,7 +588,10 @@ pub(crate) mod tests { let db_path = tmp_path(); { let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -609,7 +621,10 @@ pub(crate) mod tests { #[wasm_bindgen_test::wasm_bindgen_test(unsupported = tokio::test)] async fn test_dm_id_migration() { let db_path = tmp_path(); - let opts = StorageOption::Persistent(db_path.clone()); + let opts = StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }; #[cfg(not(target_arch = "wasm32"))] let db = @@ -690,10 +705,15 @@ pub(crate) mod tests { let db_path = tmp_path(); { // Setup a persistent store - let store = - EncryptedMessageStore::new(StorageOption::Persistent(db_path.clone()), enc_key) - .await - .unwrap(); + let store = EncryptedMessageStore::new( + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, + enc_key, + ) + .await + .unwrap(); StoredIdentity::new( "dummy_address".to_string(), @@ -705,8 +725,14 @@ pub(crate) mod tests { } // Drop it enc_key[3] = 145; // Alter the enc_key - let res = - EncryptedMessageStore::new(StorageOption::Persistent(db_path.clone()), enc_key).await; + let res = EncryptedMessageStore::new( + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, + enc_key, + ) + .await; // Ensure it fails assert!( @@ -721,7 +747,10 @@ pub(crate) mod tests { let db_path = tmp_path(); { let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -754,7 +783,10 @@ pub(crate) mod tests { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -820,7 +852,10 @@ pub(crate) mod tests { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), + StorageOption::Persistent { + path: db_path.clone(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 635b1b4c7..b0cf81ce6 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -9,6 +9,7 @@ use diesel::{ }; use parking_lot::RwLock; use std::sync::Arc; +use tokio::sync::Mutex; pub type ConnectionManager = r2d2::ConnectionManager; pub type Pool = r2d2::Pool; @@ -75,7 +76,7 @@ impl StorageOption { pub(super) fn conn(&self) -> Result { use StorageOption::*; match self { - Persistent(path) => SqliteConnection::establish(path), + Persistent { path, .. } => SqliteConnection::establish(path), Ephemeral => SqliteConnection::establish(":memory:"), } } @@ -83,15 +84,16 @@ impl StorageOption { pub(super) fn path(&self) -> Option<&String> { use StorageOption::*; match self { - Persistent(path) => Some(path), + Persistent { path, .. } => Some(path), _ => None, } } } -#[derive(Clone, Debug)] +#[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { + pub(super) write_conn: Arc>>, pub(super) pool: Arc>>, customizer: Option>, opts: StorageOption, @@ -109,7 +111,7 @@ impl NativeDb { let enc_opts = EncryptedConnection::new(key, opts)?; builder = builder.connection_customizer(Box::new(enc_opts.clone())); Some(Box::new(enc_opts) as Box) - } else if matches!(opts, StorageOption::Persistent(_)) { + } else if matches!(opts, StorageOption::Persistent { .. }) { builder = builder.connection_customizer(Box::new(UnencryptedConnection)); Some(Box::new(UnencryptedConnection) as Box) } else { @@ -120,12 +122,13 @@ impl NativeDb { StorageOption::Ephemeral => builder .max_size(1) .build(ConnectionManager::new(":memory:"))?, - StorageOption::Persistent(ref path) => builder + StorageOption::Persistent { ref path, .. } => builder .max_size(crate::configuration::MAX_DB_POOL_SIZE) .build(ConnectionManager::new(path))?, }; Ok(Self { + write_conn: pool: Arc::new(Some(pool).into()), customizer, opts: opts.clone(), @@ -186,7 +189,7 @@ impl XmtpDb for NativeDb { StorageOption::Ephemeral => builder .max_size(1) .build(ConnectionManager::new(":memory:"))?, - StorageOption::Persistent(ref path) => builder + StorageOption::Persistent { ref path, .. } => builder .max_size(crate::configuration::MAX_DB_POOL_SIZE) .build(ConnectionManager::new(path))?, }; diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 6723f0df1..6c7730b32 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -40,6 +40,7 @@ pub struct EncryptedConnection { key: EncryptionKey, /// We don't store the salt for Ephemeral Dbs salt: Option, + read_only: bool, } impl EncryptedConnection { @@ -48,9 +49,12 @@ impl EncryptedConnection { use super::StorageOption::*; Self::check_for_sqlcipher(opts)?; - let salt = match opts { - Ephemeral => None, - Persistent(ref db_path) => { + let (salt, read_only) = match opts { + Ephemeral => (None, false), + Persistent { + path: ref db_path, + read_only, + } => { let mut salt = [0u8; 16]; let db_pathbuf = PathBuf::from(db_path); let salt_path = Self::salt_file(db_path)?; @@ -86,7 +90,7 @@ impl EncryptedConnection { db_pathbuf.display(), salt_path.display() ); - Self::create(db_path, key, &mut salt)?; + Self::create(db_path, key, &mut salt, *read_only)?; } // the db doesn't exist but the salt does // This generally doesn't make sense & shouldn't happen. @@ -98,19 +102,28 @@ impl EncryptedConnection { salt_path.display(), ); std::fs::remove_file(salt_path)?; - Self::create(db_path, key, &mut salt)?; + Self::create(db_path, key, &mut salt, *read_only)?; } } - Some(salt) + (Some(salt), *read_only) } }; - Ok(Self { key, salt }) + Ok(Self { + key, + salt, + read_only, + }) } /// create a new database + salt file. /// writes the 16-bytes hex-encoded salt to `salt` - fn create(path: &String, key: EncryptionKey, salt: &mut [u8]) -> Result<(), StorageError> { + fn create( + path: &String, + key: EncryptionKey, + salt: &mut [u8], + read_only: bool, + ) -> Result<(), StorageError> { let conn = &mut SqliteConnection::establish(path)?; conn.batch_execute(&format!( r#" @@ -122,6 +135,10 @@ impl EncryptedConnection { pragma_plaintext_header() ))?; + if read_only { + conn.batch_execute("PRAGMA query_only = ON;")?; + } + Self::write_salt(path, conn, salt)?; Ok(()) } @@ -204,7 +221,9 @@ impl EncryptedConnection { /// Output the corect order of PRAGMAS to instantiate a connection fn pragmas(&self) -> impl Display { - let Self { ref key, ref salt } = self; + let Self { + ref key, ref salt, .. + } = self; if let Some(s) = salt { format!( @@ -318,7 +337,10 @@ mod tests { let db_path = tmp_path(); { let _ = EncryptedMessageStore::new( - Persistent(db_path.clone()), + Persistent { + path: db_path.clone(), + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -367,9 +389,15 @@ mod tests { file.read_exact(&mut plaintext_header).unwrap(); assert!(String::from_utf8_lossy(&plaintext_header) != SQLITE3_PLAINTEXT_HEADER); - let _ = EncryptedMessageStore::new(Persistent(db_path.clone()), key) - .await - .unwrap(); + let _ = EncryptedMessageStore::new( + Persistent { + path: db_path.clone(), + read_only: false, + }, + key, + ) + .await + .unwrap(); assert!(EncryptedConnection::salt_file(&db_path).unwrap().exists()); let bytes = std::fs::read(EncryptedConnection::salt_file(&db_path).unwrap()).unwrap(); diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 7fc9ed1c5..bdc5076cf 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -1052,7 +1052,10 @@ pub(crate) mod tests { async fn store_read_delete() { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path), + StorageOption::Persistent { + path: db_path, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -1103,7 +1106,10 @@ pub(crate) mod tests { async fn list_append_remove() { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path), + StorageOption::Persistent { + path: db_path, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await @@ -1187,7 +1193,10 @@ pub(crate) mod tests { async fn group_state() { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path), + StorageOption::Persistent { + path: db_path, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await diff --git a/xmtp_mls/src/utils/test/mod.rs b/xmtp_mls/src/utils/test/mod.rs index 636467fc0..7b04abf85 100755 --- a/xmtp_mls/src/utils/test/mod.rs +++ b/xmtp_mls/src/utils/test/mod.rs @@ -66,7 +66,10 @@ impl ClientBuilder { let tmpdb = xmtp_common::tmp_path(); self.store( EncryptedMessageStore::new( - StorageOption::Persistent(tmpdb), + StorageOption::Persistent { + path: tmpdb, + read_only: false, + }, EncryptedMessageStore::generate_enc_key(), ) .await From d1dc212cfeddf3bb4c2c1a190b2d632d2c65e544 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 14:22:35 -0500 Subject: [PATCH 02/38] cleanup --- xmtp_mls/src/builder.rs | 61 ++++------------ xmtp_mls/src/storage/encrypted_store/mod.rs | 73 +++++-------------- .../src/storage/encrypted_store/native.rs | 18 ++--- .../encrypted_store/sqlcipher_connection.rs | 51 ++++--------- xmtp_mls/src/storage/sql_key_store.rs | 15 +--- xmtp_mls/src/utils/test/mod.rs | 5 +- 6 files changed, 60 insertions(+), 163 deletions(-) diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index dcf785b01..d767f7573 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -464,10 +464,7 @@ pub(crate) mod tests { Some(legacy_key.clone()), ); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmp_path(), - read_only: false, - }, + StorageOption::Persistent(tmp_path()), EncryptedMessageStore::generate_enc_key(), ) .await @@ -536,10 +533,7 @@ pub(crate) mod tests { let scw_verifier = MockSmartContractSignatureVerifier::new(true); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb, - read_only: false, - }, + StorageOption::Persistent(tmpdb), EncryptedMessageStore::generate_enc_key(), ) .await @@ -580,10 +574,7 @@ pub(crate) mod tests { let scw_verifier = MockSmartContractSignatureVerifier::new(true); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb, - read_only: false, - }, + StorageOption::Persistent(tmpdb), EncryptedMessageStore::generate_enc_key(), ) .await @@ -623,10 +614,7 @@ pub(crate) mod tests { let scw_verifier = MockSmartContractSignatureVerifier::new(true); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb, - read_only: false, - }, + StorageOption::Persistent(tmpdb), EncryptedMessageStore::generate_enc_key(), ) .await @@ -666,10 +654,7 @@ pub(crate) mod tests { let tmpdb = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb, - read_only: false, - }, + StorageOption::Persistent(tmpdb), EncryptedMessageStore::generate_enc_key(), ) .await @@ -709,15 +694,9 @@ pub(crate) mod tests { let db_key = EncryptedMessageStore::generate_enc_key(); // Generate a new Wallet + Store - let store_a = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb.clone(), - read_only: false, - }, - db_key, - ) - .await - .unwrap(); + let store_a = EncryptedMessageStore::new(StorageOption::Persistent(tmpdb.clone()), db_key) + .await + .unwrap(); let nonce = 1; let inbox_id = generate_inbox_id(&wallet.get_address(), &nonce).unwrap(); @@ -741,15 +720,9 @@ pub(crate) mod tests { drop(client_a); // Reload the existing store and wallet - let store_b = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb.clone(), - read_only: false, - }, - db_key, - ) - .await - .unwrap(); + let store_b = EncryptedMessageStore::new(StorageOption::Persistent(tmpdb.clone()), db_key) + .await + .unwrap(); let client_b = Client::builder(IdentityStrategy::new( inbox_id, @@ -786,15 +759,9 @@ pub(crate) mod tests { // .expect_err("Testing expected mismatch error"); // Use cached only strategy - let store_d = EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb.clone(), - read_only: false, - }, - db_key, - ) - .await - .unwrap(); + let store_d = EncryptedMessageStore::new(StorageOption::Persistent(tmpdb.clone()), db_key) + .await + .unwrap(); let client_d = Client::builder(IdentityStrategy::CachedOnly) .api_client(::create_local().await) .store(store_d) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 8cda6e602..1d4f58fe8 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -74,10 +74,7 @@ struct SqliteVersion { pub enum StorageOption { #[default] Ephemeral, - Persistent { - path: String, - read_only: bool, - }, + Persistent(String), } #[allow(async_fn_in_trait)] @@ -453,20 +450,22 @@ where // ensuring we have only one strong reference let result = fun(self).await; - let local_connection = self.conn_ref().read_ref(); - if Arc::strong_count(&local_connection) > 1 { + let local_read_connection = self.conn_ref().read_ref(); + let local_write_connection = self.conn_ref().write_ref(); + if Arc::strong_count(&local_read_connection) > 1 { tracing::warn!( "More than 1 strong connection references still exist during async transaction" ); } - if Arc::weak_count(&local_connection) > 1 { + if Arc::weak_count(&local_read_connection) > 1 { tracing::warn!("More than 1 weak connection references still exist during transaction"); } // after the closure finishes, `local_provider` should have the only reference ('strong') // to `XmtpOpenMlsProvider` inner `DbConnection`.. - let local_connection = DbConnectionPrivate::from_arc_mutex(local_connection); + let local_connection = + DbConnectionPrivate::from_arc_mutex(local_read_connection, local_write_connection); match result { Ok(value) => { local_connection.raw_query(|conn| { @@ -527,10 +526,7 @@ pub(crate) mod tests { pub async fn new_test() -> Self { let tmp_path = tmp_path(); EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmp_path, - read_only: true, - }, + StorageOption::Persistent(tmp_path), EncryptedMessageStore::generate_enc_key(), ) .await @@ -562,10 +558,7 @@ pub(crate) mod tests { let db_path = tmp_path(); { let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, + StorageOption::Persistent(db_path.clone()), EncryptedMessageStore::generate_enc_key(), ) .await @@ -588,10 +581,7 @@ pub(crate) mod tests { let db_path = tmp_path(); { let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, + StorageOption::Persistent(db_path.clone()), EncryptedMessageStore::generate_enc_key(), ) .await @@ -621,10 +611,7 @@ pub(crate) mod tests { #[wasm_bindgen_test::wasm_bindgen_test(unsupported = tokio::test)] async fn test_dm_id_migration() { let db_path = tmp_path(); - let opts = StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }; + let opts = StorageOption::Persistent(db_path.clone()); #[cfg(not(target_arch = "wasm32"))] let db = @@ -705,15 +692,10 @@ pub(crate) mod tests { let db_path = tmp_path(); { // Setup a persistent store - let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, - enc_key, - ) - .await - .unwrap(); + let store = + EncryptedMessageStore::new(StorageOption::Persistent(db_path.clone()), enc_key) + .await + .unwrap(); StoredIdentity::new( "dummy_address".to_string(), @@ -725,14 +707,8 @@ pub(crate) mod tests { } // Drop it enc_key[3] = 145; // Alter the enc_key - let res = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, - enc_key, - ) - .await; + let res = + EncryptedMessageStore::new(StorageOption::Persistent(db_path.clone()), enc_key).await; // Ensure it fails assert!( @@ -747,10 +723,7 @@ pub(crate) mod tests { let db_path = tmp_path(); { let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, + StorageOption::Persistent(db_path.clone()), EncryptedMessageStore::generate_enc_key(), ) .await @@ -783,10 +756,7 @@ pub(crate) mod tests { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, + StorageOption::Persistent(db_path.clone()), EncryptedMessageStore::generate_enc_key(), ) .await @@ -852,10 +822,7 @@ pub(crate) mod tests { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path.clone(), - read_only: false, - }, + StorageOption::Persistent(db_path.clone()), EncryptedMessageStore::generate_enc_key(), ) .await diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index b0cf81ce6..65c9d090a 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -76,7 +76,7 @@ impl StorageOption { pub(super) fn conn(&self) -> Result { use StorageOption::*; match self { - Persistent { path, .. } => SqliteConnection::establish(path), + Persistent(path) => SqliteConnection::establish(path), Ephemeral => SqliteConnection::establish(":memory:"), } } @@ -84,7 +84,7 @@ impl StorageOption { pub(super) fn path(&self) -> Option<&String> { use StorageOption::*; match self { - Persistent { path, .. } => Some(path), + Persistent(path) => Some(path), _ => None, } } @@ -93,7 +93,7 @@ impl StorageOption { #[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { - pub(super) write_conn: Arc>>, + pub(super) write_conn: Arc>, pub(super) pool: Arc>>, customizer: Option>, opts: StorageOption, @@ -108,9 +108,9 @@ impl NativeDb { let mut builder = Pool::builder(); let customizer = if let Some(key) = enc_key { - let enc_opts = EncryptedConnection::new(key, opts)?; - builder = builder.connection_customizer(Box::new(enc_opts.clone())); - Some(Box::new(enc_opts) as Box) + let enc_connection = EncryptedConnection::new(key, opts)?; + builder = builder.connection_customizer(Box::new(enc_connection.clone())); + Some(Box::new(enc_connection) as Box) } else if matches!(opts, StorageOption::Persistent { .. }) { builder = builder.connection_customizer(Box::new(UnencryptedConnection)); Some(Box::new(UnencryptedConnection) as Box) @@ -122,13 +122,13 @@ impl NativeDb { StorageOption::Ephemeral => builder .max_size(1) .build(ConnectionManager::new(":memory:"))?, - StorageOption::Persistent { ref path, .. } => builder + StorageOption::Persistent(ref path) => builder .max_size(crate::configuration::MAX_DB_POOL_SIZE) .build(ConnectionManager::new(path))?, }; Ok(Self { - write_conn: + // write_conn: pool: Arc::new(Some(pool).into()), customizer, opts: opts.clone(), @@ -189,7 +189,7 @@ impl XmtpDb for NativeDb { StorageOption::Ephemeral => builder .max_size(1) .build(ConnectionManager::new(":memory:"))?, - StorageOption::Persistent { ref path, .. } => builder + StorageOption::Persistent(ref path) => builder .max_size(crate::configuration::MAX_DB_POOL_SIZE) .build(ConnectionManager::new(path))?, }; diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 6c7730b32..07398a0bc 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -40,7 +40,6 @@ pub struct EncryptedConnection { key: EncryptionKey, /// We don't store the salt for Ephemeral Dbs salt: Option, - read_only: bool, } impl EncryptedConnection { @@ -49,12 +48,9 @@ impl EncryptedConnection { use super::StorageOption::*; Self::check_for_sqlcipher(opts)?; - let (salt, read_only) = match opts { - Ephemeral => (None, false), - Persistent { - path: ref db_path, - read_only, - } => { + let salt = match opts { + Ephemeral => None, + Persistent(ref db_path) => { let mut salt = [0u8; 16]; let db_pathbuf = PathBuf::from(db_path); let salt_path = Self::salt_file(db_path)?; @@ -90,7 +86,7 @@ impl EncryptedConnection { db_pathbuf.display(), salt_path.display() ); - Self::create(db_path, key, &mut salt, *read_only)?; + Self::create(db_path, key, &mut salt)?; } // the db doesn't exist but the salt does // This generally doesn't make sense & shouldn't happen. @@ -102,43 +98,31 @@ impl EncryptedConnection { salt_path.display(), ); std::fs::remove_file(salt_path)?; - Self::create(db_path, key, &mut salt, *read_only)?; + Self::create(db_path, key, &mut salt)?; } } - (Some(salt), *read_only) + Some(salt) } }; - Ok(Self { - key, - salt, - read_only, - }) + Ok(Self { key, salt }) } /// create a new database + salt file. /// writes the 16-bytes hex-encoded salt to `salt` - fn create( - path: &String, - key: EncryptionKey, - salt: &mut [u8], - read_only: bool, - ) -> Result<(), StorageError> { + fn create(path: &String, key: EncryptionKey, salt: &mut [u8]) -> Result<(), StorageError> { let conn = &mut SqliteConnection::establish(path)?; conn.batch_execute(&format!( r#" {} {} PRAGMA journal_mode = WAL; + PRAGMA query_only = ON; "#, pragma_key(hex::encode(key)), pragma_plaintext_header() ))?; - if read_only { - conn.batch_execute("PRAGMA query_only = ON;")?; - } - Self::write_salt(path, conn, salt)?; Ok(()) } @@ -337,10 +321,7 @@ mod tests { let db_path = tmp_path(); { let _ = EncryptedMessageStore::new( - Persistent { - path: db_path.clone(), - read_only: false, - }, + Persistent(db_path.clone()), EncryptedMessageStore::generate_enc_key(), ) .await @@ -389,15 +370,9 @@ mod tests { file.read_exact(&mut plaintext_header).unwrap(); assert!(String::from_utf8_lossy(&plaintext_header) != SQLITE3_PLAINTEXT_HEADER); - let _ = EncryptedMessageStore::new( - Persistent { - path: db_path.clone(), - read_only: false, - }, - key, - ) - .await - .unwrap(); + let _ = EncryptedMessageStore::new(Persistent(db_path.clone()), key) + .await + .unwrap(); assert!(EncryptedConnection::salt_file(&db_path).unwrap().exists()); let bytes = std::fs::read(EncryptedConnection::salt_file(&db_path).unwrap()).unwrap(); diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index bdc5076cf..7fc9ed1c5 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -1052,10 +1052,7 @@ pub(crate) mod tests { async fn store_read_delete() { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path, - read_only: false, - }, + StorageOption::Persistent(db_path), EncryptedMessageStore::generate_enc_key(), ) .await @@ -1106,10 +1103,7 @@ pub(crate) mod tests { async fn list_append_remove() { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path, - read_only: false, - }, + StorageOption::Persistent(db_path), EncryptedMessageStore::generate_enc_key(), ) .await @@ -1193,10 +1187,7 @@ pub(crate) mod tests { async fn group_state() { let db_path = tmp_path(); let store = EncryptedMessageStore::new( - StorageOption::Persistent { - path: db_path, - read_only: false, - }, + StorageOption::Persistent(db_path), EncryptedMessageStore::generate_enc_key(), ) .await diff --git a/xmtp_mls/src/utils/test/mod.rs b/xmtp_mls/src/utils/test/mod.rs index 7b04abf85..636467fc0 100755 --- a/xmtp_mls/src/utils/test/mod.rs +++ b/xmtp_mls/src/utils/test/mod.rs @@ -66,10 +66,7 @@ impl ClientBuilder { let tmpdb = xmtp_common::tmp_path(); self.store( EncryptedMessageStore::new( - StorageOption::Persistent { - path: tmpdb, - read_only: false, - }, + StorageOption::Persistent(tmpdb), EncryptedMessageStore::generate_enc_key(), ) .await From b5205c14fba268173bb120ebe69ac9ca94e912e0 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 14:30:24 -0500 Subject: [PATCH 03/38] all should be well --- xmtp_mls/src/storage/encrypted_store/native.rs | 18 +++++++++++------- .../encrypted_store/sqlcipher_connection.rs | 1 - 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 65c9d090a..66bde10e9 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -7,9 +7,8 @@ use diesel::{ r2d2::{self, CustomizeConnection, PoolTransactionManager, PooledConnection}, Connection, }; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use std::sync::Arc; -use tokio::sync::Mutex; pub type ConnectionManager = r2d2::ConnectionManager; pub type Pool = r2d2::Pool; @@ -93,7 +92,7 @@ impl StorageOption { #[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { - pub(super) write_conn: Arc>, + pub(super) write_conn: Arc>, pub(super) pool: Arc>>, customizer: Option>, opts: StorageOption, @@ -127,8 +126,12 @@ impl NativeDb { .build(ConnectionManager::new(path))?, }; + let mut write_conn = pool.get()?; + write_conn.batch_execute("PRAGMA query_only = OFF;")?; + let write_conn = Arc::new(Mutex::new(write_conn)); + Ok(Self { - // write_conn: + write_conn, pool: Arc::new(Some(pool).into()), customizer, opts: opts.clone(), @@ -159,9 +162,10 @@ impl XmtpDb for NativeDb { /// Returns the Wrapped [`super::db_connection::DbConnection`] Connection implementation for this Database fn conn(&self) -> Result, StorageError> { let conn = self.raw_conn()?; - Ok(DbConnectionPrivate::from_arc_mutex(Arc::new( - parking_lot::Mutex::new(conn), - ))) + Ok(DbConnectionPrivate::from_arc_mutex( + Arc::new(parking_lot::Mutex::new(conn)), + self.write_conn.clone(), + )) } fn validate(&self, opts: &StorageOption) -> Result<(), StorageError> { diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 07398a0bc..4d0fc4430 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -117,7 +117,6 @@ impl EncryptedConnection { {} {} PRAGMA journal_mode = WAL; - PRAGMA query_only = ON; "#, pragma_key(hex::encode(key)), pragma_plaintext_header() From 895587cc3c230aeca370deca452884bfab71cf6a Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 14:44:18 -0500 Subject: [PATCH 04/38] all should be well --- .../src/storage/encrypted_store/db_connection.rs | 6 +++--- xmtp_mls/src/storage/encrypted_store/native.rs | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index dd76ce9c6..3a8dd7866 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -19,14 +19,14 @@ pub type DbConnection = DbConnectionPrivate { - write: Arc>, read: Arc>, + write: Option>>, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(read: Arc>, write: Arc>) -> Self { + pub(super) fn from_arc_mutex(read: Arc>, write: Option>>) -> Self { Self { read, write } } } @@ -61,7 +61,7 @@ where /// Internal-only API to get the underlying `diesel::Connection` reference /// without a scope - pub(super) fn write_ref(&self) -> Arc> { + pub(super) fn write_ref(&self) -> Option>> { self.write.clone() } } diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 66bde10e9..22849e6a1 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -92,7 +92,7 @@ impl StorageOption { #[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { - pub(super) write_conn: Arc>, + pub(super) write_conn: Option>>, pub(super) pool: Arc>>, customizer: Option>, opts: StorageOption, @@ -110,7 +110,7 @@ impl NativeDb { let enc_connection = EncryptedConnection::new(key, opts)?; builder = builder.connection_customizer(Box::new(enc_connection.clone())); Some(Box::new(enc_connection) as Box) - } else if matches!(opts, StorageOption::Persistent { .. }) { + } else if matches!(opts, StorageOption::Persistent(_)) { builder = builder.connection_customizer(Box::new(UnencryptedConnection)); Some(Box::new(UnencryptedConnection) as Box) } else { @@ -126,9 +126,13 @@ impl NativeDb { .build(ConnectionManager::new(path))?, }; - let mut write_conn = pool.get()?; - write_conn.batch_execute("PRAGMA query_only = OFF;")?; - let write_conn = Arc::new(Mutex::new(write_conn)); + let write_conn = if matches!(opts, StorageOption::Persistent(_)) { + let mut write_conn = pool.get()?; + write_conn.batch_execute("PRAGMA query_only = OFF;")?; + Some(Arc::new(Mutex::new(write_conn))) + } else { + None + }; Ok(Self { write_conn, From cb4b5a0c880b7c970d97265d0a5473439e8f28f7 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 14:53:45 -0500 Subject: [PATCH 05/38] all should not be well --- xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 4d0fc4430..07398a0bc 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -117,6 +117,7 @@ impl EncryptedConnection { {} {} PRAGMA journal_mode = WAL; + PRAGMA query_only = ON; "#, pragma_key(hex::encode(key)), pragma_plaintext_header() From 0889db820d279a92e6d314dbfda52136844fdef2 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 15:19:49 -0500 Subject: [PATCH 06/38] all should not be well --- xmtp_mls/src/storage/encrypted_store/native.rs | 6 +++++- .../src/storage/encrypted_store/sqlcipher_connection.rs | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 22849e6a1..5b639c3d0 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -155,7 +155,11 @@ impl NativeDb { pool.state().connections ); - Ok(pool.get()?) + // Turn of writitng by default + let mut conn = pool.get()?; + conn.batch_execute("PRAGMA query_only = ON;")?; + + Ok(conn) } } diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 07398a0bc..4d0fc4430 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -117,7 +117,6 @@ impl EncryptedConnection { {} {} PRAGMA journal_mode = WAL; - PRAGMA query_only = ON; "#, pragma_key(hex::encode(key)), pragma_plaintext_header() From 521aed9cf29b6f5bbe6c8fcc0728afe376965c7b Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 15:35:04 -0500 Subject: [PATCH 07/38] all should be better --- xmtp_mls/src/client.rs | 6 ++- xmtp_mls/src/groups/mod.rs | 4 +- .../encrypted_store/association_state.rs | 5 ++- .../storage/encrypted_store/consent_record.rs | 6 +-- .../encrypted_store/conversation_list.rs | 8 ++-- .../storage/encrypted_store/db_connection.rs | 9 ++++- xmtp_mls/src/storage/encrypted_store/group.rs | 40 ++++++++++--------- .../storage/encrypted_store/group_intent.rs | 25 ++++++------ .../storage/encrypted_store/group_message.rs | 14 +++---- .../encrypted_store/identity_update.rs | 8 ++-- .../encrypted_store/key_package_history.rs | 6 +-- .../encrypted_store/key_store_entry.rs | 4 +- xmtp_mls/src/storage/encrypted_store/mod.rs | 31 +++++++------- .../storage/encrypted_store/refresh_state.rs | 4 +- .../encrypted_store/user_preferences.rs | 8 ++-- xmtp_mls/src/storage/mod.rs | 8 ++-- xmtp_mls/src/storage/sql_key_store.rs | 10 ++--- 17 files changed, 106 insertions(+), 90 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 09b626348..3905c5cac 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1078,8 +1078,10 @@ pub(crate) mod tests { .unwrap(); let conn = amal.store().conn().unwrap(); - conn.raw_query(|conn| diesel::delete(identity_updates::table).execute(conn)) - .unwrap(); + conn.raw_query(true, |conn| { + diesel::delete(identity_updates::table).execute(conn) + }) + .unwrap(); let members = group.members().await.unwrap(); // // The three installations should count as two members diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 59f1d0541..2a22487ee 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -2170,7 +2170,7 @@ pub(crate) mod tests { // The dm shows up let alix_groups = alix_conn - .raw_query(|conn| groups::table.load::(conn)) + .raw_query(false, |conn| groups::table.load::(conn)) .unwrap(); assert_eq!(alix_groups.len(), 2); // They should have the same ID @@ -3697,7 +3697,7 @@ pub(crate) mod tests { let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into(); let conn_2 = bo.store().conn().unwrap(); conn_2 - .raw_query(|c| { + .raw_query(false, |c| { c.batch_execute("BEGIN EXCLUSIVE").unwrap(); Ok::<_, diesel::result::Error>(()) }) diff --git a/xmtp_mls/src/storage/encrypted_store/association_state.rs b/xmtp_mls/src/storage/encrypted_store/association_state.rs index 0eb194b26..ebdcf2fae 100644 --- a/xmtp_mls/src/storage/encrypted_store/association_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/association_state.rs @@ -107,8 +107,9 @@ impl StoredAssociationState { .and(dsl::sequence_id.eq_any(sequence_ids)), ); - let association_states = - conn.raw_query(|query_conn| query.load::(query_conn))?; + let association_states = conn.raw_query(false, |query_conn| { + query.load::(query_conn) + })?; association_states .into_iter() diff --git a/xmtp_mls/src/storage/encrypted_store/consent_record.rs b/xmtp_mls/src/storage/encrypted_store/consent_record.rs index a0789d9e3..4600abb2d 100644 --- a/xmtp_mls/src/storage/encrypted_store/consent_record.rs +++ b/xmtp_mls/src/storage/encrypted_store/consent_record.rs @@ -48,7 +48,7 @@ impl DbConnection { entity: String, entity_type: ConsentType, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| -> diesel::QueryResult<_> { + Ok(self.raw_query(false, |conn| -> diesel::QueryResult<_> { dsl::consent_records .filter(dsl::entity.eq(entity)) .filter(dsl::entity_type.eq(entity_type)) @@ -77,7 +77,7 @@ impl DbConnection { ); } - let changed = self.raw_query(|conn| -> diesel::QueryResult<_> { + let changed = self.raw_query(true, |conn| -> diesel::QueryResult<_> { let existing: Vec = query.load(conn)?; let changed: Vec<_> = records .iter() @@ -107,7 +107,7 @@ impl DbConnection { &self, record: &StoredConsentRecord, ) -> Result, StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { let maybe_inserted_consent_record: Option = diesel::insert_into(dsl::consent_records) .values(record) diff --git a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs index f71dcf5ad..8f5cca087 100644 --- a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs +++ b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs @@ -128,7 +128,7 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query(false, |conn| query.load::(conn))? } else { let query = query .inner_join( @@ -141,10 +141,10 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query(false, |conn| query.load::(conn))? } } else { - self.raw_query(|conn| query.load::(conn))? + self.raw_query(false, |conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -152,7 +152,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = conversation_list_dsl::conversation_list .filter(conversation_list_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(|conn| query.load(conn))?; + let mut sync_groups = self.raw_query(false, |conn| query.load(conn))?; conversations.append(&mut sync_groups); } diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 3a8dd7866..a25bcfe4d 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -37,10 +37,17 @@ where { /// Do a scoped query with a mutable [`diesel::Connection`] /// reference - pub(crate) fn raw_query(&self, fun: F) -> Result + pub(crate) fn raw_query(&self, write: bool, fun: F) -> Result where F: FnOnce(&mut C) -> Result, { + if write { + if let Some(write_conn) = &self.write { + let mut lock = write_conn.lock(); + return fun(&mut lock); + } + } + let mut lock = self.read.lock(); fun(&mut lock) } diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index e597ad062..119afe70a 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -281,7 +281,7 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query(false, |conn| query.load::(conn))? } else { let query = query .inner_join( @@ -293,10 +293,10 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(|conn| query.load::(conn))? + self.raw_query(false, |conn| query.load::(conn))? } } else { - self.raw_query(|conn| query.load::(conn))? + self.raw_query(false, |conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -304,7 +304,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = groups_dsl::groups.filter(groups_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(|conn| query.load(conn))?; + let mut sync_groups = self.raw_query(false, |conn| query.load(conn))?; groups.append(&mut sync_groups); } @@ -312,7 +312,9 @@ impl DbConnection { } pub fn consent_records(&self) -> Result, StorageError> { - Ok(self.raw_query(|conn| super::schema::consent_records::table.load(conn))?) + Ok(self.raw_query(false, |conn| { + super::schema::consent_records::table.load(conn) + })?) } pub fn all_sync_groups(&self) -> Result, StorageError> { @@ -320,7 +322,7 @@ impl DbConnection { .order(dsl::created_at_ns.desc()) .filter(dsl::conversation_type.eq(ConversationType::Sync)); - Ok(self.raw_query(|conn| query.load(conn))?) + Ok(self.raw_query(false, |conn| query.load(conn))?) } pub fn latest_sync_group(&self) -> Result, StorageError> { @@ -329,7 +331,7 @@ impl DbConnection { .filter(dsl::conversation_type.eq(ConversationType::Sync)) .limit(1); - Ok(self.raw_query(|conn| query.load(conn))?.pop()) + Ok(self.raw_query(false, |conn| query.load(conn))?.pop()) } /// Return a single group that matches the given ID @@ -337,7 +339,7 @@ impl DbConnection { let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed(); query = query.limit(1).filter(dsl::id.eq(id)); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query(false, |conn| query.load(conn))?; // Manually extract the first element Ok(groups.into_iter().next()) @@ -352,7 +354,7 @@ impl DbConnection { .order(dsl::created_at_ns.asc()) .filter(dsl::welcome_id.eq(welcome_id)); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query(false, |conn| query.load(conn))?; if groups.len() > 1 { tracing::error!("More than one group found for welcome_id {}", welcome_id); } @@ -370,7 +372,7 @@ impl DbConnection { .filter(dsl::dm_id.eq(Some(dm_id))) .order(dsl::last_message_ns.desc()); - let groups: Vec = self.raw_query(|conn| query.load(conn))?; + let groups: Vec = self.raw_query(false, |conn| query.load(conn))?; if groups.len() > 1 { tracing::info!("More than one group found for dm_inbox_id {members:?}"); } @@ -384,7 +386,7 @@ impl DbConnection { group_id: GroupId, state: GroupMembershipState, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { diesel::update(dsl::groups.find(group_id.as_ref())) .set(dsl::membership_state.eq(state)) .execute(conn) @@ -394,7 +396,7 @@ impl DbConnection { } pub fn get_rotated_at_ns(&self, group_id: Vec) -> Result { - let last_ts: Option = self.raw_query(|conn| { + let last_ts: Option = self.raw_query(false, |conn| { let ts = dsl::groups .find(&group_id) .select(dsl::rotated_at_ns) @@ -410,7 +412,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_rotated_at_ns(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::rotated_at_ns.eq(now)) @@ -421,7 +423,7 @@ impl DbConnection { } pub fn get_installations_time_checked(&self, group_id: Vec) -> Result { - let last_ts = self.raw_query(|conn| { + let last_ts = self.raw_query(false, |conn| { let ts = dsl::groups .find(&group_id) .select(dsl::installations_last_checked) @@ -435,7 +437,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_installations_time_checked(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::installations_last_checked.eq(now)) @@ -447,7 +449,7 @@ impl DbConnection { pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result { tracing::info!("Trying to insert group"); - let stored_group = self.raw_query(|conn| { + let stored_group = self.raw_query(true, |conn| { let maybe_inserted_group: Option = diesel::insert_into(dsl::groups) .values(&group) .on_conflict_do_nothing() @@ -660,7 +662,7 @@ pub(crate) mod tests { test_group.store(conn).unwrap(); assert_eq!( - conn.raw_query(|raw_conn| groups.first::(raw_conn)) + conn.raw_query(false, |raw_conn| groups.first::(raw_conn)) .unwrap(), test_group ); @@ -674,7 +676,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(|raw_conn| { + conn.raw_query(true, |raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) @@ -850,7 +852,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(|raw_conn| { + conn.raw_query(true, |raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 743eccc5e..18b356661 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -119,8 +119,9 @@ impl_fetch!(StoredGroupIntent, group_intents, ID); impl Delete for DbConnection { type Key = ID; fn delete(&self, key: ID) -> Result { - Ok(self - .raw_query(|raw_conn| diesel::delete(dsl::group_intents.find(key)).execute(raw_conn))?) + Ok(self.raw_query(true, |raw_conn| { + diesel::delete(dsl::group_intents.find(key)).execute(raw_conn) + })?) } } @@ -155,7 +156,7 @@ impl DbConnection { &self, to_save: NewGroupIntent, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query(true, |conn| { diesel::insert_into(dsl::group_intents) .values(to_save) .get_result(conn) @@ -184,7 +185,7 @@ impl DbConnection { query = query.order(dsl::id.asc()); - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query(false, |conn| query.load::(conn))?) } // Set the intent with the given ID to `Published` and set the payload hash. Optionally add @@ -197,7 +198,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query(true, |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -214,7 +215,7 @@ impl DbConnection { })?; if rows_changed == 0 { - let already_published = self.raw_query(|conn| { + let already_published = self.raw_query(false, |conn| { dsl::group_intents .filter(dsl::id.eq(intent_id)) .first::(conn) @@ -231,7 +232,7 @@ impl DbConnection { // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query(true, |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -252,7 +253,7 @@ impl DbConnection { // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query(true, |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -278,7 +279,7 @@ impl DbConnection { /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(|conn| { + let rows_changed = self.raw_query(true, |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) @@ -298,7 +299,7 @@ impl DbConnection { &self, payload_hash: Vec, ) -> Result, StorageError> { - let result = self.raw_query(|conn| { + let result = self.raw_query(false, |conn| { dsl::group_intents .filter(dsl::payload_hash.eq(payload_hash)) .first::(conn) @@ -312,7 +313,7 @@ impl DbConnection { &self, intent_id: ID, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::publish_attempts.eq(dsl::publish_attempts + 1)) @@ -431,7 +432,7 @@ pub(crate) mod tests { } fn find_first_intent(conn: &DbConnection, group_id: group::ID) -> StoredGroupIntent { - conn.raw_query(|raw_conn| { + conn.raw_query(false, |raw_conn| { dsl::group_intents .filter(dsl::group_id.eq(group_id)) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index 20335e88c..262e071bb 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -290,7 +290,7 @@ impl DbConnection { query = query.limit(limit); } - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query(false, |conn| query.load::(conn))?) } /// Query for group messages with their reactions @@ -341,7 +341,7 @@ impl DbConnection { }; let reactions: Vec = - self.raw_query(|conn| reactions_query.load(conn))?; + self.raw_query(false, |conn| reactions_query.load(conn))?; // Group reactions by parent message id let mut reactions_by_reference: HashMap, Vec> = HashMap::new(); @@ -377,7 +377,7 @@ impl DbConnection { &self, id: MessageId, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| { + Ok(self.raw_query(false, |conn| { dsl::group_messages .filter(dsl::id.eq(id.as_ref())) .first(conn) @@ -390,7 +390,7 @@ impl DbConnection { group_id: GroupId, timestamp: i64, ) -> Result, StorageError> { - Ok(self.raw_query(|conn| { + Ok(self.raw_query(false, |conn| { dsl::group_messages .filter(dsl::group_id.eq(group_id.as_ref())) .filter(dsl::sent_at_ns.eq(timestamp)) @@ -404,7 +404,7 @@ impl DbConnection { msg_id: &MessageId, timestamp: u64, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query(true, |conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set(( @@ -419,7 +419,7 @@ impl DbConnection { &self, msg_id: &MessageId, ) -> Result { - Ok(self.raw_query(|conn| { + Ok(self.raw_query(true, |conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set((dsl::delivery_status.eq(DeliveryStatus::Failed),)) @@ -517,7 +517,7 @@ pub(crate) mod tests { } let count: i64 = conn - .raw_query(|raw_conn| { + .raw_query(false, |raw_conn| { dsl::group_messages .select(diesel::dsl::count_star()) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/identity_update.rs b/xmtp_mls/src/storage/encrypted_store/identity_update.rs index 0a0c87b0a..d08487984 100644 --- a/xmtp_mls/src/storage/encrypted_store/identity_update.rs +++ b/xmtp_mls/src/storage/encrypted_store/identity_update.rs @@ -72,7 +72,7 @@ impl DbConnection { query = query.filter(dsl::sequence_id.le(sequence_id)); } - Ok(self.raw_query(|conn| query.load::(conn))?) + Ok(self.raw_query(false, |conn| query.load::(conn))?) } /// Batch insert identity updates, ignoring duplicates. @@ -81,7 +81,7 @@ impl DbConnection { &self, updates: &[StoredIdentityUpdate], ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { diesel::insert_or_ignore_into(dsl::identity_updates) .values(updates) .execute(conn)?; @@ -98,7 +98,7 @@ impl DbConnection { .filter(dsl::inbox_id.eq(inbox_id)) .into_boxed(); - Ok(self.raw_query(|conn| query.first::(conn))?) + Ok(self.raw_query(false, |conn| query.first::(conn))?) } /// Given a list of inbox_ids return a HashMap of each inbox ID -> highest known sequence ID @@ -115,7 +115,7 @@ impl DbConnection { // Get the results as a Vec of (inbox_id, sequence_id) tuples let result_tuples: Vec<(String, i64)> = self - .raw_query(|conn| query.load::<(String, Option)>(conn))? + .raw_query(false, |conn| query.load::<(String, Option)>(conn))? .into_iter() // Diesel needs an Option type for aggregations like max(sequence_id), so we // unwrap the option here diff --git a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs index 9e43243c0..2e191e078 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs @@ -39,7 +39,7 @@ impl DbConnection { &self, hash_ref: Vec, ) -> Result { - let result = self.raw_query(|conn| { + let result = self.raw_query(false, |conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::key_package_hash_ref.eq(hash_ref)) .first::(conn) @@ -52,7 +52,7 @@ impl DbConnection { &self, id: i32, ) -> Result, StorageError> { - let result = self.raw_query(|conn| { + let result = self.raw_query(false, |conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)) .load::(conn) @@ -65,7 +65,7 @@ impl DbConnection { &self, id: i32, ) -> Result<(), StorageError> { - self.raw_query(|conn| { + self.raw_query(true, |conn| { diesel::delete( key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)), diff --git a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs index 7fc4deb3d..34c11b332 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs @@ -18,7 +18,7 @@ impl Delete for DbConnection { type Key = Vec; fn delete(&self, key: Vec) -> Result where { use super::schema::openmls_key_store::dsl::*; - Ok(self.raw_query(|conn| { + Ok(self.raw_query(true, |conn| { diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(conn) })?) } @@ -36,7 +36,7 @@ impl DbConnection { value_bytes: value, }; - self.raw_query(|conn| { + self.raw_query(true, |conn| { diesel::replace_into(openmls_key_store) .values(entry) .execute(conn) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 1d4f58fe8..1752039a7 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -179,7 +179,7 @@ pub mod private { #[tracing::instrument(level = "trace", skip_all)] pub(super) fn init_db(&mut self) -> Result<(), StorageError> { self.db.validate(&self.opts)?; - self.db.conn()?.raw_query(|conn| { + self.db.conn()?.raw_query(true, |conn| { conn.batch_execute("PRAGMA journal_mode = WAL;")?; tracing::info!("Running DB migrations"); conn.run_pending_migrations(MIGRATIONS)?; @@ -242,7 +242,7 @@ macro_rules! impl_fetch { type Key = (); fn fetch(&self, _key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(|conn| $table.first(conn).optional())?) + Ok(self.raw_query(false, |conn| $table.first(conn).optional())?) } } }; @@ -254,7 +254,9 @@ macro_rules! impl_fetch { type Key = $key; fn fetch(&self, key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(|conn| $table.find(key.clone()).first(conn).optional())?) + Ok(self.raw_query(false, |conn| { + $table.find(key.clone()).first(conn).optional() + })?) } } }; @@ -286,8 +288,9 @@ macro_rules! impl_fetch_list_with_key { keys: &[Self::Key], ) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::{$column, *}; - Ok(self - .raw_query(|conn| $table.filter($column.eq_any(keys)).load::<$model>(conn))?) + Ok(self.raw_query(false, |conn| { + $table.filter($column.eq_any(keys)).load::<$model>(conn) + })?) } } }; @@ -304,7 +307,7 @@ macro_rules! impl_store { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(|conn| { + into.raw_query(true, |conn| { diesel::insert_into($table::table) .values(self) .execute(conn) @@ -326,7 +329,7 @@ macro_rules! impl_store_or_ignore { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(|conn| { + into.raw_query(true, |conn| { diesel::insert_or_ignore_into($table::table) .values(self) .execute(conn) @@ -401,7 +404,7 @@ where match fun(self) { Ok(value) => { - conn.raw_query(|conn| { + conn.raw_query(true, |conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction being committed"); @@ -409,7 +412,7 @@ where } Err(err) => { tracing::debug!("Transaction being rolled back"); - match conn.raw_query(|conn| { + match conn.raw_query(true, |conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -468,7 +471,7 @@ where DbConnectionPrivate::from_arc_mutex(local_read_connection, local_write_connection); match result { Ok(value) => { - local_connection.raw_query(|conn| { + local_connection.raw_query(true, |conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction async being committed"); @@ -476,7 +479,7 @@ where } Err(err) => { tracing::debug!("Transaction async being rolled back"); - match local_connection.raw_query(|conn| { + match local_connection.raw_query(true, |conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -626,7 +629,7 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(|conn| { + .raw_query(true, |conn| { for _ in 0..15 { conn.run_next_migration(MIGRATIONS)?; } @@ -672,14 +675,14 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(|conn| { + .raw_query(true, |conn| { conn.run_pending_migrations(MIGRATIONS)?; Ok::<_, StorageError>(()) }) .unwrap(); let groups = conn - .raw_query(|conn| groups::table.load::(conn)) + .raw_query(false, |conn| groups::table.load::(conn)) .unwrap(); assert_eq!(groups.len(), 1); assert_eq!(&**groups[0].dm_id.as_ref().unwrap(), "dm:98765:inbox_id"); diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index b1cfefcb0..788f29586 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -74,7 +74,7 @@ impl DbConnection { entity_kind: EntityKind, ) -> Result, StorageError> { use super::schema::refresh_state::dsl; - let res = self.raw_query(|conn| { + let res = self.raw_query(false, |conn| { dsl::refresh_state .find((entity_id.as_ref(), entity_kind)) .first(conn) @@ -115,7 +115,7 @@ impl DbConnection { NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), )?; - let num_updated = self.raw_query(|conn| { + let num_updated = self.raw_query(true, |conn| { diesel::update(&state) .filter(dsl::cursor.lt(cursor)) .set(dsl::cursor.eq(cursor)) diff --git a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs index 98170d52f..4da0c66dd 100644 --- a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs +++ b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs @@ -37,7 +37,7 @@ impl<'a> From<&'a StoredUserPreferences> for NewStoredUserPreferences<'a> { impl Store for StoredUserPreferences { fn store(&self, conn: &DbConnection) -> Result<(), StorageError> { - conn.raw_query(|conn| { + conn.raw_query(true, |conn| { diesel::update(dsl::user_preferences) .set(self) .execute(conn) @@ -50,7 +50,7 @@ impl Store for StoredUserPreferences { impl StoredUserPreferences { pub fn load(conn: &DbConnection) -> Result { let query = dsl::user_preferences.order(dsl::id.desc()).limit(1); - let mut result = conn.raw_query(|conn| query.load::(conn))?; + let mut result = conn.raw_query(false, |conn| query.load::(conn))?; Ok(result.pop().unwrap_or_default()) } @@ -73,7 +73,7 @@ impl StoredUserPreferences { ])); let to_insert: NewStoredUserPreferences = (&preferences).into(); - conn.raw_query(|conn| { + conn.raw_query(true, |conn| { diesel::insert_into(dsl::user_preferences) .values(to_insert) .execute(conn) @@ -115,7 +115,7 @@ mod tests { // check that there are two preferences stored let query = dsl::user_preferences.order(dsl::id.desc()); let result = conn - .raw_query(|conn| query.load::(conn)) + .raw_query(false, |conn| query.load::(conn)) .unwrap(); assert_eq!(result.len(), 1); } diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 4cda41805..2945f0656 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -57,12 +57,12 @@ pub mod test_util { for query in queries { let query = diesel::sql_query(query); - let _ = self.raw_query(|conn| query.execute(conn)).unwrap(); + let _ = self.raw_query(true, |conn| query.execute(conn)).unwrap(); } } pub fn intents_published(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query(false, |conn| { let mut row = conn .load(sql_query( "SELECT intents_published FROM test_metadata WHERE rowid = 1", @@ -78,7 +78,7 @@ pub mod test_util { } pub fn intents_deleted(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query(false, |conn| { let mut row = conn .load(sql_query("SELECT intents_deleted FROM test_metadata")) .unwrap(); @@ -92,7 +92,7 @@ pub mod test_util { } pub fn intents_created(&self) -> i32 { - self.raw_query(|conn| { + self.raw_query(false, |conn| { let mut row = conn .load(sql_query("SELECT intents_created FROM test_metadata")) .unwrap(); diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 7fc9ed1c5..b0a81299b 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -49,7 +49,7 @@ where &self, storage_key: &Vec, ) -> Result, diesel::result::Error> { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query(false, |conn| { sql_query(SELECT_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -62,7 +62,7 @@ where storage_key: &Vec, value: &[u8], ) -> Result { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query(true, |conn| { sql_query(REPLACE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -76,7 +76,7 @@ where storage_key: &Vec, modified_data: &Vec, ) -> Result { - self.conn_ref().raw_query(|conn| { + self.conn_ref().raw_query(true, |conn| { sql_query(UPDATE_QUERY) .bind::(&modified_data) .bind::(&storage_key) @@ -224,7 +224,7 @@ where ) -> Result<(), >::Error> { let storage_key = build_key_from_vec::(label, key.to_vec()); - let _ = self.conn_ref().raw_query(|conn| { + let _ = self.conn_ref().raw_query(true, |conn| { sql_query(DELETE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -809,7 +809,7 @@ where let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; - let data: Vec = self.conn_ref().raw_query(|conn| { + let data: Vec = self.conn_ref().raw_query(false, |conn| { sql_query(query) .bind::(&storage_key) .bind::(CURRENT_VERSION as i32) From cf0143937e4daa6cf13126ba467f46f144dae699 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 15:51:35 -0500 Subject: [PATCH 08/38] move the pragma --- xmtp_mls/src/storage/encrypted_store/native.rs | 8 ++------ .../src/storage/encrypted_store/sqlcipher_connection.rs | 1 + 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 5b639c3d0..8644cee09 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -64,7 +64,7 @@ impl ValidatedConnection for UnencryptedConnection {} impl CustomizeConnection for UnencryptedConnection { fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> { - conn.batch_execute("PRAGMA busy_timeout = 5000;") + conn.batch_execute("PRAGMA query_only = ON; PRAGMA busy_timeout = 5000;") .map_err(r2d2::Error::QueryError)?; Ok(()) } @@ -155,11 +155,7 @@ impl NativeDb { pool.state().connections ); - // Turn of writitng by default - let mut conn = pool.get()?; - conn.batch_execute("PRAGMA query_only = ON;")?; - - Ok(conn) + Ok(pool.get()?) } } diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index 4d0fc4430..b5cae60f3 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -283,6 +283,7 @@ impl diesel::r2d2::CustomizeConnection fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), diesel::r2d2::Error> { conn.batch_execute(&format!( "{} + PRAGMA query_only = ON; PRAGMA busy_timeout = 5000;", self.pragmas() )) From d982f5d608db6b7d3a0931cc0e1289f5b33d6a71 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 16:43:56 -0500 Subject: [PATCH 09/38] guard --- xmtp_mls/src/groups/intents.rs | 6 ++- xmtp_mls/src/groups/mod.rs | 2 + .../storage/encrypted_store/db_connection.rs | 49 ++++++++++++++++++- xmtp_mls/src/storage/encrypted_store/mod.rs | 19 ++++--- 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 369cbad10..883bd77c6 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -64,10 +64,12 @@ impl MlsGroup { intent_kind: IntentKind, intent_data: Vec, ) -> Result { - provider.transaction(|provider| { + let res = provider.transaction(|provider| { let conn = provider.conn_ref(); self.queue_intent_with_conn(conn, intent_kind, intent_data) - }) + }); + + res } fn queue_intent_with_conn( diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 2a22487ee..356a3478b 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -935,6 +935,8 @@ impl MlsGroup { intent_data.into(), )?; + tracing::warn!("This makes it here?"); + self.sync_until_intent_resolved(provider, intent.id).await } diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index a25bcfe4d..a78f0370d 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,5 +1,6 @@ use parking_lot::Mutex; use std::fmt; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use crate::storage::xmtp_openmls_provider::XmtpOpenMlsProvider; @@ -21,13 +22,18 @@ pub type DbConnection = DbConnectionPrivate { read: Arc>, write: Option>>, + pub(super) in_transaction: Arc, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> pub(super) fn from_arc_mutex(read: Arc>, write: Option>>) -> Self { - Self { read, write } + Self { + read, + write, + in_transaction: Arc::new(AtomicBool::new(false)), + } } } @@ -35,6 +41,17 @@ impl DbConnectionPrivate where C: diesel::Connection, { + fn in_transaction(&self) -> bool { + self.in_transaction.load(Ordering::SeqCst) + } + + pub(crate) fn start_transaction(&self) -> TransactionGuard { + self.in_transaction.store(true, Ordering::SeqCst); + TransactionGuard { + in_transaction: self.in_transaction.clone(), + } + } + /// Do a scoped query with a mutable [`diesel::Connection`] /// reference pub(crate) fn raw_query(&self, write: bool, fun: F) -> Result @@ -57,15 +74,36 @@ where /// Must be used with care. holding this reference while calling `raw_query` /// will cause a deadlock. pub(super) fn read_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { + if self.in_transaction() { + if let Some(write) = &self.write { + return write.lock(); + } + } self.read.lock() } /// Internal-only API to get the underlying `diesel::Connection` reference /// without a scope pub(super) fn read_ref(&self) -> Arc> { + if self.in_transaction() { + if let Some(write) = &self.write { + return write.clone(); + }; + } self.read.clone() } + /// Internal-only API to get the underlying `diesel::Connection` reference + /// without a scope + /// Must be used with care. holding this reference while calling `raw_query` + /// will cause a deadlock. + pub(super) fn write_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { + let Some(write) = &self.write else { + return self.read_mut_ref(); + }; + write.lock() + } + /// Internal-only API to get the underlying `diesel::Connection` reference /// without a scope pub(super) fn write_ref(&self) -> Option>> { @@ -91,3 +129,12 @@ impl fmt::Debug for DbConnectionPrivate { .finish() } } + +pub struct TransactionGuard { + in_transaction: Arc, +} +impl Drop for TransactionGuard { + fn drop(&mut self) { + self.in_transaction.store(false, Ordering::SeqCst); + } +} diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 1752039a7..11577eeb1 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -394,11 +394,13 @@ where E: From + From, { tracing::debug!("Transaction beginning"); - { - let connection = self.conn_ref(); - let mut connection = connection.read_mut_ref(); + + let _guard = { + let wrapper = self.conn_ref(); + let mut connection = wrapper.write_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; - } + wrapper.start_transaction() + }; let conn = self.conn_ref(); @@ -445,11 +447,12 @@ where Db: 'a, { tracing::debug!("Transaction async beginning"); - { - let connection = self.conn_ref(); - let mut connection = connection.read_mut_ref(); + let _guard = { + let wrapper = self.conn_ref(); + let mut connection = wrapper.write_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; - } + wrapper.start_transaction() + }; // ensuring we have only one strong reference let result = fun(self).await; From efd639e54f08d5c2601f8746d294ad24e6610064 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 16:53:09 -0500 Subject: [PATCH 10/38] cleanup --- xmtp_mls/src/client.rs | 6 +-- xmtp_mls/src/groups/mod.rs | 4 +- .../encrypted_store/association_state.rs | 2 +- .../storage/encrypted_store/consent_record.rs | 6 +-- .../encrypted_store/conversation_list.rs | 8 ++-- .../storage/encrypted_store/db_connection.rs | 20 +++++++--- xmtp_mls/src/storage/encrypted_store/group.rs | 40 +++++++++---------- .../storage/encrypted_store/group_intent.rs | 22 +++++----- .../storage/encrypted_store/group_message.rs | 14 +++---- .../encrypted_store/identity_update.rs | 8 ++-- .../encrypted_store/key_package_history.rs | 6 +-- .../encrypted_store/key_store_entry.rs | 4 +- xmtp_mls/src/storage/encrypted_store/mod.rs | 28 ++++++------- .../storage/encrypted_store/refresh_state.rs | 4 +- .../encrypted_store/user_preferences.rs | 8 ++-- xmtp_mls/src/storage/mod.rs | 8 ++-- xmtp_mls/src/storage/sql_key_store.rs | 10 ++--- 17 files changed, 100 insertions(+), 98 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index a4fc9949e..0d43b4eb0 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1093,10 +1093,8 @@ pub(crate) mod tests { .unwrap(); let conn = amal.store().conn().unwrap(); - conn.raw_query(true, |conn| { - diesel::delete(identity_updates::table).execute(conn) - }) - .unwrap(); + conn.raw_query_write(|conn| diesel::delete(identity_updates::table).execute(conn)) + .unwrap(); let members = group.members().await.unwrap(); // // The three installations should count as two members diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index b84263ac2..337db6c72 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -2171,7 +2171,7 @@ pub(crate) mod tests { // The dm shows up let alix_groups = alix_conn - .raw_query(false, |conn| groups::table.load::(conn)) + .raw_query_read( |conn| groups::table.load::(conn)) .unwrap(); assert_eq!(alix_groups.len(), 2); // They should have the same ID @@ -3698,7 +3698,7 @@ pub(crate) mod tests { let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into(); let conn_2 = bo.store().conn().unwrap(); conn_2 - .raw_query(false, |c| { + .raw_query_read( |c| { c.batch_execute("BEGIN EXCLUSIVE").unwrap(); Ok::<_, diesel::result::Error>(()) }) diff --git a/xmtp_mls/src/storage/encrypted_store/association_state.rs b/xmtp_mls/src/storage/encrypted_store/association_state.rs index ebdcf2fae..f29d05341 100644 --- a/xmtp_mls/src/storage/encrypted_store/association_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/association_state.rs @@ -107,7 +107,7 @@ impl StoredAssociationState { .and(dsl::sequence_id.eq_any(sequence_ids)), ); - let association_states = conn.raw_query(false, |query_conn| { + let association_states = conn.raw_query_read( |query_conn| { query.load::(query_conn) })?; diff --git a/xmtp_mls/src/storage/encrypted_store/consent_record.rs b/xmtp_mls/src/storage/encrypted_store/consent_record.rs index 4600abb2d..b70552a3b 100644 --- a/xmtp_mls/src/storage/encrypted_store/consent_record.rs +++ b/xmtp_mls/src/storage/encrypted_store/consent_record.rs @@ -48,7 +48,7 @@ impl DbConnection { entity: String, entity_type: ConsentType, ) -> Result, StorageError> { - Ok(self.raw_query(false, |conn| -> diesel::QueryResult<_> { + Ok(self.raw_query_read( |conn| -> diesel::QueryResult<_> { dsl::consent_records .filter(dsl::entity.eq(entity)) .filter(dsl::entity_type.eq(entity_type)) @@ -77,7 +77,7 @@ impl DbConnection { ); } - let changed = self.raw_query(true, |conn| -> diesel::QueryResult<_> { + let changed = self.raw_query_write( |conn| -> diesel::QueryResult<_> { let existing: Vec = query.load(conn)?; let changed: Vec<_> = records .iter() @@ -107,7 +107,7 @@ impl DbConnection { &self, record: &StoredConsentRecord, ) -> Result, StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write( |conn| { let maybe_inserted_consent_record: Option = diesel::insert_into(dsl::consent_records) .values(record) diff --git a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs index 7a33517cc..76e807d51 100644 --- a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs +++ b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs @@ -139,7 +139,7 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(false, |conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } else { // Only include the specified states let query = query @@ -153,11 +153,11 @@ impl DbConnection { .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); - self.raw_query(false, |conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } } else { // Handle the case where `consent_states` is `None` - self.raw_query(false, |conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -165,7 +165,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = conversation_list_dsl::conversation_list .filter(conversation_list_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(false, |conn| query.load(conn))?; + let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?; conversations.append(&mut sync_groups); } diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index a78f0370d..8fe640d76 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -54,15 +54,23 @@ where /// Do a scoped query with a mutable [`diesel::Connection`] /// reference - pub(crate) fn raw_query(&self, write: bool, fun: F) -> Result + pub(crate) fn raw_query_read(&self, fun: F) -> Result where F: FnOnce(&mut C) -> Result, { - if write { - if let Some(write_conn) = &self.write { - let mut lock = write_conn.lock(); - return fun(&mut lock); - } + let mut lock = self.read.lock(); + fun(&mut lock) + } + + /// Do a scoped query with a mutable [`diesel::Connection`] + /// reference + pub(crate) fn raw_query_write(&self, fun: F) -> Result + where + F: FnOnce(&mut C) -> Result, + { + if let Some(write_conn) = &self.write { + let mut lock = write_conn.lock(); + return fun(&mut lock); } let mut lock = self.read.lock(); diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index c7a23af67..fa79224ac 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -292,7 +292,7 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(false, |conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } else { // Only include the specified states let query = query @@ -305,11 +305,11 @@ impl DbConnection { .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); - self.raw_query(false, |conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? } } else { // Handle the case where `consent_states` is `None` - self.raw_query(false, |conn| query.load::(conn))? + self.raw_query_read(|conn| query.load::(conn))? }; // Were sync groups explicitly asked for? Was the include_sync_groups flag set to true? @@ -317,7 +317,7 @@ impl DbConnection { if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups { let query = groups_dsl::groups.filter(groups_dsl::conversation_type.eq(ConversationType::Sync)); - let mut sync_groups = self.raw_query(false, |conn| query.load(conn))?; + let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?; groups.append(&mut sync_groups); } @@ -325,9 +325,7 @@ impl DbConnection { } pub fn consent_records(&self) -> Result, StorageError> { - Ok(self.raw_query(false, |conn| { - super::schema::consent_records::table.load(conn) - })?) + Ok(self.raw_query_read(|conn| super::schema::consent_records::table.load(conn))?) } pub fn all_sync_groups(&self) -> Result, StorageError> { @@ -335,7 +333,7 @@ impl DbConnection { .order(dsl::created_at_ns.desc()) .filter(dsl::conversation_type.eq(ConversationType::Sync)); - Ok(self.raw_query(false, |conn| query.load(conn))?) + Ok(self.raw_query_read(|conn| query.load(conn))?) } pub fn latest_sync_group(&self) -> Result, StorageError> { @@ -344,7 +342,7 @@ impl DbConnection { .filter(dsl::conversation_type.eq(ConversationType::Sync)) .limit(1); - Ok(self.raw_query(false, |conn| query.load(conn))?.pop()) + Ok(self.raw_query_read(|conn| query.load(conn))?.pop()) } /// Return a single group that matches the given ID @@ -352,7 +350,7 @@ impl DbConnection { let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed(); query = query.limit(1).filter(dsl::id.eq(id)); - let groups: Vec = self.raw_query(false, |conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; // Manually extract the first element Ok(groups.into_iter().next()) @@ -367,7 +365,7 @@ impl DbConnection { .order(dsl::created_at_ns.asc()) .filter(dsl::welcome_id.eq(welcome_id)); - let groups: Vec = self.raw_query(false, |conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; if groups.len() > 1 { tracing::error!("More than one group found for welcome_id {}", welcome_id); } @@ -385,7 +383,7 @@ impl DbConnection { .filter(dsl::dm_id.eq(Some(dm_id))) .order(dsl::last_message_ns.desc()); - let groups: Vec = self.raw_query(false, |conn| query.load(conn))?; + let groups: Vec = self.raw_query_read(|conn| query.load(conn))?; if groups.len() > 1 { tracing::info!("More than one group found for dm_inbox_id {members:?}"); } @@ -399,7 +397,7 @@ impl DbConnection { group_id: GroupId, state: GroupMembershipState, ) -> Result<(), StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write(|conn| { diesel::update(dsl::groups.find(group_id.as_ref())) .set(dsl::membership_state.eq(state)) .execute(conn) @@ -409,7 +407,7 @@ impl DbConnection { } pub fn get_rotated_at_ns(&self, group_id: Vec) -> Result { - let last_ts: Option = self.raw_query(false, |conn| { + let last_ts: Option = self.raw_query_read(|conn| { let ts = dsl::groups .find(&group_id) .select(dsl::rotated_at_ns) @@ -425,7 +423,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_rotated_at_ns(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write(|conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::rotated_at_ns.eq(now)) @@ -436,7 +434,7 @@ impl DbConnection { } pub fn get_installations_time_checked(&self, group_id: Vec) -> Result { - let last_ts = self.raw_query(false, |conn| { + let last_ts = self.raw_query_read(|conn| { let ts = dsl::groups .find(&group_id) .select(dsl::installations_last_checked) @@ -450,7 +448,7 @@ impl DbConnection { /// Updates the 'last time checked' we checked for new installations. pub fn update_installations_time_checked(&self, group_id: Vec) -> Result<(), StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write(|conn| { let now = xmtp_common::time::now_ns(); diesel::update(dsl::groups.find(&group_id)) .set(dsl::installations_last_checked.eq(now)) @@ -462,7 +460,7 @@ impl DbConnection { pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result { tracing::info!("Trying to insert group"); - let stored_group = self.raw_query(true, |conn| { + let stored_group = self.raw_query_write(|conn| { let maybe_inserted_group: Option = diesel::insert_into(dsl::groups) .values(&group) .on_conflict_do_nothing() @@ -675,7 +673,7 @@ pub(crate) mod tests { test_group.store(conn).unwrap(); assert_eq!( - conn.raw_query(false, |raw_conn| groups.first::(raw_conn)) + conn.raw_query_read(|raw_conn| groups.first::(raw_conn)) .unwrap(), test_group ); @@ -689,7 +687,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(true, |raw_conn| { + conn.raw_query_write(|raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) @@ -865,7 +863,7 @@ pub(crate) mod tests { with_connection(|conn| { let test_group = generate_group(None); - conn.raw_query(true, |raw_conn| { + conn.raw_query_write(|raw_conn| { diesel::insert_into(groups) .values(test_group.clone()) .execute(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 18b356661..464f7b9f6 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -119,7 +119,7 @@ impl_fetch!(StoredGroupIntent, group_intents, ID); impl Delete for DbConnection { type Key = ID; fn delete(&self, key: ID) -> Result { - Ok(self.raw_query(true, |raw_conn| { + Ok(self.raw_query_write( |raw_conn| { diesel::delete(dsl::group_intents.find(key)).execute(raw_conn) })?) } @@ -156,7 +156,7 @@ impl DbConnection { &self, to_save: NewGroupIntent, ) -> Result { - Ok(self.raw_query(true, |conn| { + Ok(self.raw_query_write( |conn| { diesel::insert_into(dsl::group_intents) .values(to_save) .get_result(conn) @@ -185,7 +185,7 @@ impl DbConnection { query = query.order(dsl::id.asc()); - Ok(self.raw_query(false, |conn| query.load::(conn))?) + Ok(self.raw_query_read( |conn| query.load::(conn))?) } // Set the intent with the given ID to `Published` and set the payload hash. Optionally add @@ -198,7 +198,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let rows_changed = self.raw_query(true, |conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -215,7 +215,7 @@ impl DbConnection { })?; if rows_changed == 0 { - let already_published = self.raw_query(false, |conn| { + let already_published = self.raw_query_read( |conn| { dsl::group_intents .filter(dsl::id.eq(intent_id)) .first::(conn) @@ -232,7 +232,7 @@ impl DbConnection { // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(true, |conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -253,7 +253,7 @@ impl DbConnection { // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(true, |conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -279,7 +279,7 @@ impl DbConnection { /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query(true, |conn| { + let rows_changed = self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) @@ -299,7 +299,7 @@ impl DbConnection { &self, payload_hash: Vec, ) -> Result, StorageError> { - let result = self.raw_query(false, |conn| { + let result = self.raw_query_read( |conn| { dsl::group_intents .filter(dsl::payload_hash.eq(payload_hash)) .first::(conn) @@ -313,7 +313,7 @@ impl DbConnection { &self, intent_id: ID, ) -> Result<(), StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write( |conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::publish_attempts.eq(dsl::publish_attempts + 1)) @@ -432,7 +432,7 @@ pub(crate) mod tests { } fn find_first_intent(conn: &DbConnection, group_id: group::ID) -> StoredGroupIntent { - conn.raw_query(false, |raw_conn| { + conn.raw_query_read( |raw_conn| { dsl::group_intents .filter(dsl::group_id.eq(group_id)) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index 262e071bb..ba6254dfe 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -290,7 +290,7 @@ impl DbConnection { query = query.limit(limit); } - Ok(self.raw_query(false, |conn| query.load::(conn))?) + Ok(self.raw_query_read( |conn| query.load::(conn))?) } /// Query for group messages with their reactions @@ -341,7 +341,7 @@ impl DbConnection { }; let reactions: Vec = - self.raw_query(false, |conn| reactions_query.load(conn))?; + self.raw_query_read( |conn| reactions_query.load(conn))?; // Group reactions by parent message id let mut reactions_by_reference: HashMap, Vec> = HashMap::new(); @@ -377,7 +377,7 @@ impl DbConnection { &self, id: MessageId, ) -> Result, StorageError> { - Ok(self.raw_query(false, |conn| { + Ok(self.raw_query_read( |conn| { dsl::group_messages .filter(dsl::id.eq(id.as_ref())) .first(conn) @@ -390,7 +390,7 @@ impl DbConnection { group_id: GroupId, timestamp: i64, ) -> Result, StorageError> { - Ok(self.raw_query(false, |conn| { + Ok(self.raw_query_read( |conn| { dsl::group_messages .filter(dsl::group_id.eq(group_id.as_ref())) .filter(dsl::sent_at_ns.eq(timestamp)) @@ -404,7 +404,7 @@ impl DbConnection { msg_id: &MessageId, timestamp: u64, ) -> Result { - Ok(self.raw_query(true, |conn| { + Ok(self.raw_query_write( |conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set(( @@ -419,7 +419,7 @@ impl DbConnection { &self, msg_id: &MessageId, ) -> Result { - Ok(self.raw_query(true, |conn| { + Ok(self.raw_query_write( |conn| { diesel::update(dsl::group_messages) .filter(dsl::id.eq(msg_id.as_ref())) .set((dsl::delivery_status.eq(DeliveryStatus::Failed),)) @@ -517,7 +517,7 @@ pub(crate) mod tests { } let count: i64 = conn - .raw_query(false, |raw_conn| { + .raw_query_read( |raw_conn| { dsl::group_messages .select(diesel::dsl::count_star()) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/identity_update.rs b/xmtp_mls/src/storage/encrypted_store/identity_update.rs index d08487984..bb37dd774 100644 --- a/xmtp_mls/src/storage/encrypted_store/identity_update.rs +++ b/xmtp_mls/src/storage/encrypted_store/identity_update.rs @@ -72,7 +72,7 @@ impl DbConnection { query = query.filter(dsl::sequence_id.le(sequence_id)); } - Ok(self.raw_query(false, |conn| query.load::(conn))?) + Ok(self.raw_query_read( |conn| query.load::(conn))?) } /// Batch insert identity updates, ignoring duplicates. @@ -81,7 +81,7 @@ impl DbConnection { &self, updates: &[StoredIdentityUpdate], ) -> Result<(), StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write( |conn| { diesel::insert_or_ignore_into(dsl::identity_updates) .values(updates) .execute(conn)?; @@ -98,7 +98,7 @@ impl DbConnection { .filter(dsl::inbox_id.eq(inbox_id)) .into_boxed(); - Ok(self.raw_query(false, |conn| query.first::(conn))?) + Ok(self.raw_query_read( |conn| query.first::(conn))?) } /// Given a list of inbox_ids return a HashMap of each inbox ID -> highest known sequence ID @@ -115,7 +115,7 @@ impl DbConnection { // Get the results as a Vec of (inbox_id, sequence_id) tuples let result_tuples: Vec<(String, i64)> = self - .raw_query(false, |conn| query.load::<(String, Option)>(conn))? + .raw_query_read( |conn| query.load::<(String, Option)>(conn))? .into_iter() // Diesel needs an Option type for aggregations like max(sequence_id), so we // unwrap the option here diff --git a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs index 2e191e078..253894761 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs @@ -39,7 +39,7 @@ impl DbConnection { &self, hash_ref: Vec, ) -> Result { - let result = self.raw_query(false, |conn| { + let result = self.raw_query_read( |conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::key_package_hash_ref.eq(hash_ref)) .first::(conn) @@ -52,7 +52,7 @@ impl DbConnection { &self, id: i32, ) -> Result, StorageError> { - let result = self.raw_query(false, |conn| { + let result = self.raw_query_read( |conn| { key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)) .load::(conn) @@ -65,7 +65,7 @@ impl DbConnection { &self, id: i32, ) -> Result<(), StorageError> { - self.raw_query(true, |conn| { + self.raw_query_write( |conn| { diesel::delete( key_package_history::dsl::key_package_history .filter(key_package_history::dsl::id.lt(id)), diff --git a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs index 34c11b332..44f1027c9 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs @@ -18,7 +18,7 @@ impl Delete for DbConnection { type Key = Vec; fn delete(&self, key: Vec) -> Result where { use super::schema::openmls_key_store::dsl::*; - Ok(self.raw_query(true, |conn| { + Ok(self.raw_query_write(|conn| { diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(conn) })?) } @@ -36,7 +36,7 @@ impl DbConnection { value_bytes: value, }; - self.raw_query(true, |conn| { + self.raw_query_write(|conn| { diesel::replace_into(openmls_key_store) .values(entry) .execute(conn) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 11577eeb1..c16ae8b16 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -179,7 +179,7 @@ pub mod private { #[tracing::instrument(level = "trace", skip_all)] pub(super) fn init_db(&mut self) -> Result<(), StorageError> { self.db.validate(&self.opts)?; - self.db.conn()?.raw_query(true, |conn| { + self.db.conn()?.raw_query_write( |conn| { conn.batch_execute("PRAGMA journal_mode = WAL;")?; tracing::info!("Running DB migrations"); conn.run_pending_migrations(MIGRATIONS)?; @@ -242,7 +242,7 @@ macro_rules! impl_fetch { type Key = (); fn fetch(&self, _key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(false, |conn| $table.first(conn).optional())?) + Ok(self.raw_query_read(|conn| $table.first(conn).optional())?) } } }; @@ -254,9 +254,7 @@ macro_rules! impl_fetch { type Key = $key; fn fetch(&self, key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok(self.raw_query(false, |conn| { - $table.find(key.clone()).first(conn).optional() - })?) + Ok(self.raw_query_read(|conn| $table.find(key.clone()).first(conn).optional())?) } } }; @@ -288,7 +286,7 @@ macro_rules! impl_fetch_list_with_key { keys: &[Self::Key], ) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::{$column, *}; - Ok(self.raw_query(false, |conn| { + Ok(self.raw_query_read(|conn| { $table.filter($column.eq_any(keys)).load::<$model>(conn) })?) } @@ -307,7 +305,7 @@ macro_rules! impl_store { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(true, |conn| { + into.raw_query_write( |conn| { diesel::insert_into($table::table) .values(self) .execute(conn) @@ -329,7 +327,7 @@ macro_rules! impl_store_or_ignore { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query(true, |conn| { + into.raw_query_write( |conn| { diesel::insert_or_ignore_into($table::table) .values(self) .execute(conn) @@ -406,7 +404,7 @@ where match fun(self) { Ok(value) => { - conn.raw_query(true, |conn| { + conn.raw_query_write( |conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction being committed"); @@ -414,7 +412,7 @@ where } Err(err) => { tracing::debug!("Transaction being rolled back"); - match conn.raw_query(true, |conn| { + match conn.raw_query_write( |conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -474,7 +472,7 @@ where DbConnectionPrivate::from_arc_mutex(local_read_connection, local_write_connection); match result { Ok(value) => { - local_connection.raw_query(true, |conn| { + local_connection.raw_query_write( |conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction async being committed"); @@ -482,7 +480,7 @@ where } Err(err) => { tracing::debug!("Transaction async being rolled back"); - match local_connection.raw_query(true, |conn| { + match local_connection.raw_query_write( |conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -632,7 +630,7 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(true, |conn| { + .raw_query_write( |conn| { for _ in 0..15 { conn.run_next_migration(MIGRATIONS)?; } @@ -678,14 +676,14 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query(true, |conn| { + .raw_query_write( |conn| { conn.run_pending_migrations(MIGRATIONS)?; Ok::<_, StorageError>(()) }) .unwrap(); let groups = conn - .raw_query(false, |conn| groups::table.load::(conn)) + .raw_query_read(|conn| groups::table.load::(conn)) .unwrap(); assert_eq!(groups.len(), 1); assert_eq!(&**groups[0].dm_id.as_ref().unwrap(), "dm:98765:inbox_id"); diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index 788f29586..2fb7787a1 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -74,7 +74,7 @@ impl DbConnection { entity_kind: EntityKind, ) -> Result, StorageError> { use super::schema::refresh_state::dsl; - let res = self.raw_query(false, |conn| { + let res = self.raw_query_read( |conn| { dsl::refresh_state .find((entity_id.as_ref(), entity_kind)) .first(conn) @@ -115,7 +115,7 @@ impl DbConnection { NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), )?; - let num_updated = self.raw_query(true, |conn| { + let num_updated = self.raw_query_write( |conn| { diesel::update(&state) .filter(dsl::cursor.lt(cursor)) .set(dsl::cursor.eq(cursor)) diff --git a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs index 4da0c66dd..ed534ce79 100644 --- a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs +++ b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs @@ -37,7 +37,7 @@ impl<'a> From<&'a StoredUserPreferences> for NewStoredUserPreferences<'a> { impl Store for StoredUserPreferences { fn store(&self, conn: &DbConnection) -> Result<(), StorageError> { - conn.raw_query(true, |conn| { + conn.raw_query_write( |conn| { diesel::update(dsl::user_preferences) .set(self) .execute(conn) @@ -50,7 +50,7 @@ impl Store for StoredUserPreferences { impl StoredUserPreferences { pub fn load(conn: &DbConnection) -> Result { let query = dsl::user_preferences.order(dsl::id.desc()).limit(1); - let mut result = conn.raw_query(false, |conn| query.load::(conn))?; + let mut result = conn.raw_query_read( |conn| query.load::(conn))?; Ok(result.pop().unwrap_or_default()) } @@ -73,7 +73,7 @@ impl StoredUserPreferences { ])); let to_insert: NewStoredUserPreferences = (&preferences).into(); - conn.raw_query(true, |conn| { + conn.raw_query_write( |conn| { diesel::insert_into(dsl::user_preferences) .values(to_insert) .execute(conn) @@ -115,7 +115,7 @@ mod tests { // check that there are two preferences stored let query = dsl::user_preferences.order(dsl::id.desc()); let result = conn - .raw_query(false, |conn| query.load::(conn)) + .raw_query_read( |conn| query.load::(conn)) .unwrap(); assert_eq!(result.len(), 1); } diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 2945f0656..05ebfbbec 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -57,12 +57,12 @@ pub mod test_util { for query in queries { let query = diesel::sql_query(query); - let _ = self.raw_query(true, |conn| query.execute(conn)).unwrap(); + let _ = self.raw_query_write(|conn| query.execute(conn)).unwrap(); } } pub fn intents_published(&self) -> i32 { - self.raw_query(false, |conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query( "SELECT intents_published FROM test_metadata WHERE rowid = 1", @@ -78,7 +78,7 @@ pub mod test_util { } pub fn intents_deleted(&self) -> i32 { - self.raw_query(false, |conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query("SELECT intents_deleted FROM test_metadata")) .unwrap(); @@ -92,7 +92,7 @@ pub mod test_util { } pub fn intents_created(&self) -> i32 { - self.raw_query(false, |conn| { + self.raw_query_read(|conn| { let mut row = conn .load(sql_query("SELECT intents_created FROM test_metadata")) .unwrap(); diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index b0a81299b..c9c2b40df 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -49,7 +49,7 @@ where &self, storage_key: &Vec, ) -> Result, diesel::result::Error> { - self.conn_ref().raw_query(false, |conn| { + self.conn_ref().raw_query_read(|conn| { sql_query(SELECT_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -62,7 +62,7 @@ where storage_key: &Vec, value: &[u8], ) -> Result { - self.conn_ref().raw_query(true, |conn| { + self.conn_ref().raw_query_write(|conn| { sql_query(REPLACE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -76,7 +76,7 @@ where storage_key: &Vec, modified_data: &Vec, ) -> Result { - self.conn_ref().raw_query(true, |conn| { + self.conn_ref().raw_query_write(|conn| { sql_query(UPDATE_QUERY) .bind::(&modified_data) .bind::(&storage_key) @@ -224,7 +224,7 @@ where ) -> Result<(), >::Error> { let storage_key = build_key_from_vec::(label, key.to_vec()); - let _ = self.conn_ref().raw_query(true, |conn| { + let _ = self.conn_ref().raw_query_write(|conn| { sql_query(DELETE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -809,7 +809,7 @@ where let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; - let data: Vec = self.conn_ref().raw_query(false, |conn| { + let data: Vec = self.conn_ref().raw_query_read(|conn| { sql_query(query) .bind::(&storage_key) .bind::(CURRENT_VERSION as i32) From 0282c311537f13b06a497a530246cdff410534c2 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 16 Jan 2025 17:14:31 -0500 Subject: [PATCH 11/38] cleanup --- xmtp_mls/src/client.rs | 2 ++ xmtp_mls/src/storage/encrypted_store/mod.rs | 24 ++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 0d43b4eb0..ec00a0d4b 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1424,6 +1424,7 @@ pub(crate) mod tests { not(target_arch = "wasm32"), tokio::test(flavor = "multi_thread", worker_threads = 1) )] + #[ignore] async fn test_add_remove_then_add_again() { let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -1445,6 +1446,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(amal_group.members().await.unwrap().len(), 1); tracing::info!("Syncing bolas welcomes"); + // See if Bola can see that they were added to the group bola.sync_welcomes(&bola.mls_provider().unwrap()) .await diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index c16ae8b16..bc6444b22 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -179,7 +179,7 @@ pub mod private { #[tracing::instrument(level = "trace", skip_all)] pub(super) fn init_db(&mut self) -> Result<(), StorageError> { self.db.validate(&self.opts)?; - self.db.conn()?.raw_query_write( |conn| { + self.db.conn()?.raw_query_write(|conn| { conn.batch_execute("PRAGMA journal_mode = WAL;")?; tracing::info!("Running DB migrations"); conn.run_pending_migrations(MIGRATIONS)?; @@ -305,7 +305,7 @@ macro_rules! impl_store { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query_write( |conn| { + into.raw_query_write(|conn| { diesel::insert_into($table::table) .values(self) .execute(conn) @@ -327,7 +327,7 @@ macro_rules! impl_store_or_ignore { &self, into: &$crate::storage::encrypted_store::db_connection::DbConnection, ) -> Result<(), $crate::StorageError> { - into.raw_query_write( |conn| { + into.raw_query_write(|conn| { diesel::insert_or_ignore_into($table::table) .values(self) .execute(conn) @@ -404,7 +404,7 @@ where match fun(self) { Ok(value) => { - conn.raw_query_write( |conn| { + conn.raw_query_write(|conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; tracing::debug!("Transaction being committed"); @@ -412,7 +412,7 @@ where } Err(err) => { tracing::debug!("Transaction being rolled back"); - match conn.raw_query_write( |conn| { + match conn.raw_query_write(|conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -444,7 +444,7 @@ where E: From + From, Db: 'a, { - tracing::debug!("Transaction async beginning"); + tracing::info!("Transaction async beginning"); let _guard = { let wrapper = self.conn_ref(); let mut connection = wrapper.write_mut_ref(); @@ -472,15 +472,15 @@ where DbConnectionPrivate::from_arc_mutex(local_read_connection, local_write_connection); match result { Ok(value) => { - local_connection.raw_query_write( |conn| { + local_connection.raw_query_write(|conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; - tracing::debug!("Transaction async being committed"); + tracing::info!("Transaction async being committed"); Ok(value) } Err(err) => { - tracing::debug!("Transaction async being rolled back"); - match local_connection.raw_query_write( |conn| { + tracing::info!("Transaction async being rolled back"); + match local_connection.raw_query_write(|conn| { ::TransactionManager::rollback_transaction(&mut *conn) }) { Ok(()) => Err(err), @@ -630,7 +630,7 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query_write( |conn| { + .raw_query_write(|conn| { for _ in 0..15 { conn.run_next_migration(MIGRATIONS)?; } @@ -676,7 +676,7 @@ pub(crate) mod tests { .db .conn() .unwrap() - .raw_query_write( |conn| { + .raw_query_write(|conn| { conn.run_pending_migrations(MIGRATIONS)?; Ok::<_, StorageError>(()) }) From 5b8cd12b75dfe240b15b8fe23098239343ce0252 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 23 Jan 2025 11:55:25 -0500 Subject: [PATCH 12/38] test --- .../storage/encrypted_store/db_connection.rs | 22 +++++++++++-------- xmtp_mls/src/storage/encrypted_store/mod.rs | 16 +++++--------- .../src/storage/encrypted_store/native.rs | 1 + 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 8fe640d76..cac1f02cc 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,6 +1,6 @@ use parking_lot::Mutex; use std::fmt; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; use crate::storage::xmtp_openmls_provider::XmtpOpenMlsProvider; @@ -22,17 +22,21 @@ pub type DbConnection = DbConnectionPrivate { read: Arc>, write: Option>>, - pub(super) in_transaction: Arc, + pub(super) transaction_count: Arc, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(read: Arc>, write: Option>>) -> Self { + pub(super) fn from_arc_mutex( + read: Arc>, + write: Option>>, + is_transaction: bool, + ) -> Self { Self { read, write, - in_transaction: Arc::new(AtomicBool::new(false)), + transaction_count: Arc::new(AtomicI32::new(is_transaction as i32)), } } } @@ -42,13 +46,13 @@ where C: diesel::Connection, { fn in_transaction(&self) -> bool { - self.in_transaction.load(Ordering::SeqCst) + self.transaction_count.load(Ordering::SeqCst) == 0 } pub(crate) fn start_transaction(&self) -> TransactionGuard { - self.in_transaction.store(true, Ordering::SeqCst); + self.transaction_count.fetch_add(1, Ordering::SeqCst); TransactionGuard { - in_transaction: self.in_transaction.clone(), + transaction_count: self.transaction_count.clone(), } } @@ -139,10 +143,10 @@ impl fmt::Debug for DbConnectionPrivate { } pub struct TransactionGuard { - in_transaction: Arc, + transaction_count: Arc, } impl Drop for TransactionGuard { fn drop(&mut self) { - self.in_transaction.store(false, Ordering::SeqCst); + self.transaction_count.fetch_add(-1, Ordering::SeqCst); } } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index bc6444b22..3f512c856 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -456,20 +456,14 @@ where let result = fun(self).await; let local_read_connection = self.conn_ref().read_ref(); let local_write_connection = self.conn_ref().write_ref(); - if Arc::strong_count(&local_read_connection) > 1 { - tracing::warn!( - "More than 1 strong connection references still exist during async transaction" - ); - } - - if Arc::weak_count(&local_read_connection) > 1 { - tracing::warn!("More than 1 weak connection references still exist during transaction"); - } // after the closure finishes, `local_provider` should have the only reference ('strong') // to `XmtpOpenMlsProvider` inner `DbConnection`.. - let local_connection = - DbConnectionPrivate::from_arc_mutex(local_read_connection, local_write_connection); + let local_connection = DbConnectionPrivate::from_arc_mutex( + local_read_connection, + local_write_connection, + true, + ); match result { Ok(value) => { local_connection.raw_query_write(|conn| { diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 8644cee09..f55a5bb16 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -169,6 +169,7 @@ impl XmtpDb for NativeDb { Ok(DbConnectionPrivate::from_arc_mutex( Arc::new(parking_lot::Mutex::new(conn)), self.write_conn.clone(), + false, )) } From d3fa5b90461e6c459da46ab4542b961472dea4d9 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 23 Jan 2025 12:01:24 -0500 Subject: [PATCH 13/38] funnel --- .../src/storage/encrypted_store/db_connection.rs | 9 ++++++++- xmtp_mls/src/storage/encrypted_store/mod.rs | 15 ++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index cac1f02cc..5e06b14d9 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -46,7 +46,7 @@ where C: diesel::Connection, { fn in_transaction(&self) -> bool { - self.transaction_count.load(Ordering::SeqCst) == 0 + self.transaction_count.load(Ordering::SeqCst) != 0 } pub(crate) fn start_transaction(&self) -> TransactionGuard { @@ -62,6 +62,13 @@ where where F: FnOnce(&mut C) -> Result, { + if self.in_transaction() { + if let Some(write) = &self.write { + let mut lock = write.lock(); + return fun(&mut lock); + }; + } + let mut lock = self.read.lock(); fun(&mut lock) } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 3f512c856..51ae82157 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -57,7 +57,7 @@ use diesel::{ sql_query, }; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; -use std::sync::Arc; +use std::sync::{atomic::Ordering, Arc}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations/"); @@ -454,21 +454,14 @@ where // ensuring we have only one strong reference let result = fun(self).await; - let local_read_connection = self.conn_ref().read_ref(); - let local_write_connection = self.conn_ref().write_ref(); - - // after the closure finishes, `local_provider` should have the only reference ('strong') - // to `XmtpOpenMlsProvider` inner `DbConnection`.. - let local_connection = DbConnectionPrivate::from_arc_mutex( - local_read_connection, - local_write_connection, - true, - ); + + let local_connection = self.conn_ref(); match result { Ok(value) => { local_connection.raw_query_write(|conn| { ::TransactionManager::commit_transaction(&mut *conn) })?; + tracing::info!("Transaction async being committed"); Ok(value) } From 139ae1b83de6c16eae7e8a5ace8d477a4459a56b Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 23 Jan 2025 12:36:57 -0500 Subject: [PATCH 14/38] undo funneling in the raw query read --- .../src/storage/encrypted_store/db_connection.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 5e06b14d9..377659920 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -62,12 +62,12 @@ where where F: FnOnce(&mut C) -> Result, { - if self.in_transaction() { - if let Some(write) = &self.write { - let mut lock = write.lock(); - return fun(&mut lock); - }; - } + // if self.in_transaction() { + // if let Some(write) = &self.write { + // let mut lock = write.lock(); + // return fun(&mut lock); + // }; + // } let mut lock = self.read.lock(); fun(&mut lock) From a4a521c3b599044be7b8983434a15695eb92f742 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 23 Jan 2025 12:46:23 -0500 Subject: [PATCH 15/38] cleanup --- .../src/storage/encrypted_store/db_connection.rs | 12 ++++++------ xmtp_mls/src/storage/encrypted_store/mod.rs | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 377659920..5e06b14d9 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -62,12 +62,12 @@ where where F: FnOnce(&mut C) -> Result, { - // if self.in_transaction() { - // if let Some(write) = &self.write { - // let mut lock = write.lock(); - // return fun(&mut lock); - // }; - // } + if self.in_transaction() { + if let Some(write) = &self.write { + let mut lock = write.lock(); + return fun(&mut lock); + }; + } let mut lock = self.read.lock(); fun(&mut lock) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 51ae82157..4b772bf57 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -57,7 +57,6 @@ use diesel::{ sql_query, }; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; -use std::sync::{atomic::Ordering, Arc}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations/"); @@ -486,6 +485,7 @@ pub(crate) mod tests { use diesel::sql_types::{BigInt, Blob, Integer, Text}; use group::ConversationType; use schema::groups; + use std::sync::Arc; use wasm_bindgen_test::wasm_bindgen_test; use super::*; From 980c66123daa2ac51ac8229aa503cc8575dc0aca Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Thu, 23 Jan 2025 15:31:34 -0500 Subject: [PATCH 16/38] clone --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 5e06b14d9..88130aa96 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -25,6 +25,16 @@ pub struct DbConnectionPrivate { pub(super) transaction_count: Arc, } +impl Clone for DbConnectionPrivate { + fn clone(&self) -> Self { + Self { + read: self.read.clone(), + write: self.write.clone(), + transaction_count: self.transaction_count.clone(), + } + } +} + /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> From 58af8685d1c7e8d591fd47ec565160c80a6b586a Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 11:53:44 -0500 Subject: [PATCH 17/38] lint --- xmtp_mls/src/storage/encrypted_store/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 3d6e75a2d..45854fa53 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -423,7 +423,6 @@ pub(crate) mod tests { use diesel::sql_types::{BigInt, Blob, Integer, Text}; use group::ConversationType; use schema::groups; - use std::sync::Arc; use wasm_bindgen_test::wasm_bindgen_test; use super::*; From dc76452fcb06302a47b7e267f30968ec1f2cfb6d Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 12:41:15 -0500 Subject: [PATCH 18/38] cleanup --- .../storage/encrypted_store/db_connection.rs | 81 +++++++++---------- xmtp_mls/src/storage/encrypted_store/mod.rs | 8 +- .../src/storage/encrypted_store/native.rs | 1 - 3 files changed, 37 insertions(+), 53 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 64a3b061a..038c92d0d 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,9 +1,15 @@ +use crate::storage::{xmtp_openmls_provider::XmtpOpenMlsProvider, StorageError}; +use diesel::connection::TransactionManager; use parking_lot::Mutex; -use std::fmt; -use std::sync::atomic::{AtomicI32, Ordering}; -use std::sync::Arc; +use std::{ + fmt, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; -use crate::storage::xmtp_openmls_provider::XmtpOpenMlsProvider; +use super::XmtpDb; #[cfg(not(target_arch = "wasm32"))] pub type DbConnection = DbConnectionPrivate; @@ -22,7 +28,8 @@ pub type DbConnection = DbConnectionPrivate { read: Arc>, write: Option>>, - pub(super) transaction_count: Arc, + // This field will funnel all reads / writes to the write connection if true. + pub(super) in_transaction: Arc, } impl Clone for DbConnectionPrivate { @@ -30,7 +37,7 @@ impl Clone for DbConnectionPrivate { Self { read: self.read.clone(), write: self.write.clone(), - transaction_count: self.transaction_count.clone(), + in_transaction: self.in_transaction.clone(), } } } @@ -38,15 +45,11 @@ impl Clone for DbConnectionPrivate { /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex( - read: Arc>, - write: Option>>, - is_transaction: bool, - ) -> Self { + pub(super) fn from_arc_mutex(read: Arc>, write: Option>>) -> Self { Self { read, write, - transaction_count: Arc::new(AtomicI32::new(is_transaction as i32)), + in_transaction: Arc::new(AtomicBool::new(false)), } } } @@ -55,15 +58,27 @@ impl DbConnectionPrivate where C: diesel::Connection, { - fn in_transaction(&self) -> bool { - self.transaction_count.load(Ordering::SeqCst) != 0 + pub(crate) fn start_transaction>( + &self, + ) -> Result { + let mut write = self + .write + .as_ref() + .expect("Tried to open transaction on read-only connection") + .lock(); + ::TransactionManager::begin_transaction(&mut *write)?; + + if self.in_transaction.swap(true, Ordering::SeqCst) { + panic!("Already in transaction."); + } + + Ok(TransactionGuard { + in_transaction: self.in_transaction.clone(), + }) } - pub(crate) fn start_transaction(&self) -> TransactionGuard { - self.transaction_count.fetch_add(1, Ordering::SeqCst); - TransactionGuard { - transaction_count: self.transaction_count.clone(), - } + fn in_transaction(&self) -> bool { + self.in_transaction.load(Ordering::SeqCst) } /// Do a scoped query with a mutable [`diesel::Connection`] @@ -97,30 +112,6 @@ where let mut lock = self.read.lock(); fun(&mut lock) } - - /// Internal-only API to get the underlying `diesel::Connection` reference - /// without a scope - /// Must be used with care. holding this reference while calling `raw_query` - /// will cause a deadlock. - pub(super) fn read_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { - if self.in_transaction() { - if let Some(write) = &self.write { - return write.lock(); - } - } - self.read.lock() - } - - /// Internal-only API to get the underlying `diesel::Connection` reference - /// without a scope - /// Must be used with care. holding this reference while calling `raw_query` - /// will cause a deadlock. - pub(super) fn write_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> { - let Some(write) = &self.write else { - return self.read_mut_ref(); - }; - write.lock() - } } // Forces a move for conn @@ -143,10 +134,10 @@ impl fmt::Debug for DbConnectionPrivate { } pub struct TransactionGuard { - transaction_count: Arc, + in_transaction: Arc, } impl Drop for TransactionGuard { fn drop(&mut self) { - self.transaction_count.fetch_add(-1, Ordering::SeqCst); + self.in_transaction.store(false, Ordering::SeqCst); } } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 45854fa53..f138cca75 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -385,14 +385,8 @@ where { tracing::debug!("Transaction beginning"); - let _guard = { - let wrapper = self.conn_ref(); - let mut connection = wrapper.write_mut_ref(); - ::TransactionManager::begin_transaction(&mut *connection)?; - wrapper.start_transaction() - }; - let conn = self.conn_ref(); + let _guard = conn.start_transaction::()?; match fun(self) { Ok(value) => { diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index f55a5bb16..8644cee09 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -169,7 +169,6 @@ impl XmtpDb for NativeDb { Ok(DbConnectionPrivate::from_arc_mutex( Arc::new(parking_lot::Mutex::new(conn)), self.write_conn.clone(), - false, )) } From 7971689ff3e0759e162f6d6aaab4db3882d781a2 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 12:47:34 -0500 Subject: [PATCH 19/38] fix wasm --- xmtp_mls/src/groups/mod.rs | 4 ++-- xmtp_mls/src/storage/encrypted_store/wasm.rs | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index c51857a85..6f4c1ba8d 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -2250,7 +2250,7 @@ pub(crate) mod tests { // The dm shows up let alix_groups = alix_conn - .raw_query_read( |conn| groups::table.load::(conn)) + .raw_query_read(|conn| groups::table.load::(conn)) .unwrap(); assert_eq!(alix_groups.len(), 2); // They should have the same ID @@ -3865,7 +3865,7 @@ pub(crate) mod tests { let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into(); let conn_2 = bo.store().conn().unwrap(); conn_2 - .raw_query_read( |c| { + .raw_query_write(|c| { c.batch_execute("BEGIN EXCLUSIVE").unwrap(); Ok::<_, diesel::result::Error>(()) }) diff --git a/xmtp_mls/src/storage/encrypted_store/wasm.rs b/xmtp_mls/src/storage/encrypted_store/wasm.rs index 3cc984fde..ab73410db 100644 --- a/xmtp_mls/src/storage/encrypted_store/wasm.rs +++ b/xmtp_mls/src/storage/encrypted_store/wasm.rs @@ -42,7 +42,10 @@ impl XmtpDb for WasmDb { type TransactionManager = AnsiTransactionManager; fn conn(&self) -> Result, StorageError> { - Ok(DbConnectionPrivate::from_arc_mutex(self.conn.clone())) + Ok(DbConnectionPrivate::from_arc_mutex( + self.conn.clone(), + Some(self.conn.clone()), + )) } fn validate(&self, _opts: &StorageOption) -> Result<(), StorageError> { From 1eabf9522678093472c841c3bd15fbaf79b3ab87 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 12:52:44 -0500 Subject: [PATCH 20/38] lint --- xmtp_mls/src/groups/mod.rs | 2 +- .../encrypted_store/association_state.rs | 5 ++--- .../storage/encrypted_store/consent_record.rs | 6 ++--- .../storage/encrypted_store/group_intent.rs | 22 +++++++++---------- .../encrypted_store/identity_update.rs | 8 +++---- .../storage/encrypted_store/refresh_state.rs | 4 ++-- .../encrypted_store/user_preferences.rs | 8 +++---- 7 files changed, 27 insertions(+), 28 deletions(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 6f4c1ba8d..a1963d8f0 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -3873,7 +3873,7 @@ pub(crate) mod tests { let process_result = bo_group.process_messages(bo_messages, &conn_1).await; if let Some(GroupError::ReceiveErrors(errors)) = process_result.err() { - assert_eq!(errors.len(), 2); + assert_eq!(errors.len(), 1); assert!(errors .iter() .any(|err| err.to_string().contains("database is locked"))); diff --git a/xmtp_mls/src/storage/encrypted_store/association_state.rs b/xmtp_mls/src/storage/encrypted_store/association_state.rs index 363c2f970..9f7583b0f 100644 --- a/xmtp_mls/src/storage/encrypted_store/association_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/association_state.rs @@ -108,9 +108,8 @@ impl StoredAssociationState { .and(dsl::sequence_id.eq_any(sequence_ids)), ); - let association_states = conn.raw_query_read( |query_conn| { - query.load::(query_conn) - })?; + let association_states = + conn.raw_query_read(|query_conn| query.load::(query_conn))?; association_states .into_iter() diff --git a/xmtp_mls/src/storage/encrypted_store/consent_record.rs b/xmtp_mls/src/storage/encrypted_store/consent_record.rs index b70552a3b..c70359473 100644 --- a/xmtp_mls/src/storage/encrypted_store/consent_record.rs +++ b/xmtp_mls/src/storage/encrypted_store/consent_record.rs @@ -48,7 +48,7 @@ impl DbConnection { entity: String, entity_type: ConsentType, ) -> Result, StorageError> { - Ok(self.raw_query_read( |conn| -> diesel::QueryResult<_> { + Ok(self.raw_query_read(|conn| -> diesel::QueryResult<_> { dsl::consent_records .filter(dsl::entity.eq(entity)) .filter(dsl::entity_type.eq(entity_type)) @@ -77,7 +77,7 @@ impl DbConnection { ); } - let changed = self.raw_query_write( |conn| -> diesel::QueryResult<_> { + let changed = self.raw_query_write(|conn| -> diesel::QueryResult<_> { let existing: Vec = query.load(conn)?; let changed: Vec<_> = records .iter() @@ -107,7 +107,7 @@ impl DbConnection { &self, record: &StoredConsentRecord, ) -> Result, StorageError> { - self.raw_query_write( |conn| { + self.raw_query_write(|conn| { let maybe_inserted_consent_record: Option = diesel::insert_into(dsl::consent_records) .values(record) diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 464f7b9f6..a02d99324 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -119,7 +119,7 @@ impl_fetch!(StoredGroupIntent, group_intents, ID); impl Delete for DbConnection { type Key = ID; fn delete(&self, key: ID) -> Result { - Ok(self.raw_query_write( |raw_conn| { + Ok(self.raw_query_write(|raw_conn| { diesel::delete(dsl::group_intents.find(key)).execute(raw_conn) })?) } @@ -156,7 +156,7 @@ impl DbConnection { &self, to_save: NewGroupIntent, ) -> Result { - Ok(self.raw_query_write( |conn| { + Ok(self.raw_query_write(|conn| { diesel::insert_into(dsl::group_intents) .values(to_save) .get_result(conn) @@ -185,7 +185,7 @@ impl DbConnection { query = query.order(dsl::id.asc()); - Ok(self.raw_query_read( |conn| query.load::(conn))?) + Ok(self.raw_query_read(|conn| query.load::(conn))?) } // Set the intent with the given ID to `Published` and set the payload hash. Optionally add @@ -198,7 +198,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let rows_changed = self.raw_query_write( |conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -215,7 +215,7 @@ impl DbConnection { })?; if rows_changed == 0 { - let already_published = self.raw_query_read( |conn| { + let already_published = self.raw_query_read(|conn| { dsl::group_intents .filter(dsl::id.eq(intent_id)) .first::(conn) @@ -232,7 +232,7 @@ impl DbConnection { // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query_write( |conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -253,7 +253,7 @@ impl DbConnection { // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query_write( |conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -279,7 +279,7 @@ impl DbConnection { /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let rows_changed = self.raw_query_write( |conn| { + let rows_changed = self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) @@ -299,7 +299,7 @@ impl DbConnection { &self, payload_hash: Vec, ) -> Result, StorageError> { - let result = self.raw_query_read( |conn| { + let result = self.raw_query_read(|conn| { dsl::group_intents .filter(dsl::payload_hash.eq(payload_hash)) .first::(conn) @@ -313,7 +313,7 @@ impl DbConnection { &self, intent_id: ID, ) -> Result<(), StorageError> { - self.raw_query_write( |conn| { + self.raw_query_write(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::publish_attempts.eq(dsl::publish_attempts + 1)) @@ -432,7 +432,7 @@ pub(crate) mod tests { } fn find_first_intent(conn: &DbConnection, group_id: group::ID) -> StoredGroupIntent { - conn.raw_query_read( |raw_conn| { + conn.raw_query_read(|raw_conn| { dsl::group_intents .filter(dsl::group_id.eq(group_id)) .first(raw_conn) diff --git a/xmtp_mls/src/storage/encrypted_store/identity_update.rs b/xmtp_mls/src/storage/encrypted_store/identity_update.rs index bb37dd774..2c1e8749a 100644 --- a/xmtp_mls/src/storage/encrypted_store/identity_update.rs +++ b/xmtp_mls/src/storage/encrypted_store/identity_update.rs @@ -72,7 +72,7 @@ impl DbConnection { query = query.filter(dsl::sequence_id.le(sequence_id)); } - Ok(self.raw_query_read( |conn| query.load::(conn))?) + Ok(self.raw_query_read(|conn| query.load::(conn))?) } /// Batch insert identity updates, ignoring duplicates. @@ -81,7 +81,7 @@ impl DbConnection { &self, updates: &[StoredIdentityUpdate], ) -> Result<(), StorageError> { - self.raw_query_write( |conn| { + self.raw_query_write(|conn| { diesel::insert_or_ignore_into(dsl::identity_updates) .values(updates) .execute(conn)?; @@ -98,7 +98,7 @@ impl DbConnection { .filter(dsl::inbox_id.eq(inbox_id)) .into_boxed(); - Ok(self.raw_query_read( |conn| query.first::(conn))?) + Ok(self.raw_query_read(|conn| query.first::(conn))?) } /// Given a list of inbox_ids return a HashMap of each inbox ID -> highest known sequence ID @@ -115,7 +115,7 @@ impl DbConnection { // Get the results as a Vec of (inbox_id, sequence_id) tuples let result_tuples: Vec<(String, i64)> = self - .raw_query_read( |conn| query.load::<(String, Option)>(conn))? + .raw_query_read(|conn| query.load::<(String, Option)>(conn))? .into_iter() // Diesel needs an Option type for aggregations like max(sequence_id), so we // unwrap the option here diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index e3acefa3f..248248c8d 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -74,7 +74,7 @@ impl DbConnection { entity_kind: EntityKind, ) -> Result, StorageError> { use super::schema::refresh_state::dsl; - let res = self.raw_query_read( |conn| { + let res = self.raw_query_read(|conn| { dsl::refresh_state .find((entity_id.as_ref(), entity_kind)) .first(conn) @@ -115,7 +115,7 @@ impl DbConnection { NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), )?; - let num_updated = self.raw_query_write( |conn| { + let num_updated = self.raw_query_write(|conn| { diesel::update(&state) .filter(dsl::cursor.lt(cursor)) .set(dsl::cursor.eq(cursor)) diff --git a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs index 94aba271d..4a9be10c9 100644 --- a/xmtp_mls/src/storage/encrypted_store/user_preferences.rs +++ b/xmtp_mls/src/storage/encrypted_store/user_preferences.rs @@ -37,7 +37,7 @@ impl<'a> From<&'a StoredUserPreferences> for NewStoredUserPreferences<'a> { impl Store for StoredUserPreferences { fn store(&self, conn: &DbConnection) -> Result<(), StorageError> { - conn.raw_query_write( |conn| { + conn.raw_query_write(|conn| { diesel::update(dsl::user_preferences) .set(self) .execute(conn) @@ -50,7 +50,7 @@ impl Store for StoredUserPreferences { impl StoredUserPreferences { pub fn load(conn: &DbConnection) -> Result { let query = dsl::user_preferences.order(dsl::id.desc()).limit(1); - let mut result = conn.raw_query_read( |conn| query.load::(conn))?; + let mut result = conn.raw_query_read(|conn| query.load::(conn))?; Ok(result.pop().unwrap_or_default()) } @@ -73,7 +73,7 @@ impl StoredUserPreferences { ])); let to_insert: NewStoredUserPreferences = (&preferences).into(); - conn.raw_query_write( |conn| { + conn.raw_query_write(|conn| { diesel::insert_into(dsl::user_preferences) .values(to_insert) .execute(conn) @@ -115,7 +115,7 @@ mod tests { // check that there are two preferences stored let query = dsl::user_preferences.order(dsl::id.desc()); let result = conn - .raw_query_read( |conn| query.load::(conn)) + .raw_query_read(|conn| query.load::(conn)) .unwrap(); assert_eq!(result.len(), 1); } From 853db39b8a6ce872d17fdbd0584402ee33db2a2c Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:06:55 -0500 Subject: [PATCH 21/38] test cleanup --- xmtp_mls/src/groups/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index a1963d8f0..b7575796b 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -3874,9 +3874,9 @@ pub(crate) mod tests { let process_result = bo_group.process_messages(bo_messages, &conn_1).await; if let Some(GroupError::ReceiveErrors(errors)) = process_result.err() { assert_eq!(errors.len(), 1); - assert!(errors - .iter() - .any(|err| err.to_string().contains("database is locked"))); + assert!(errors.iter().any(|err| err + .to_string() + .contains("cannot start a transaction within a transaction"))); } else { panic!("Expected error") } From 446354ad1fc75521f043891de0ca1172273a9651 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:11:53 -0500 Subject: [PATCH 22/38] add comment --- xmtp_mls/src/storage/encrypted_store/native.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 8644cee09..81dc7b9f9 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -127,6 +127,7 @@ impl NativeDb { }; let write_conn = if matches!(opts, StorageOption::Persistent(_)) { + // Take one of the connections and use it as the only writer. let mut write_conn = pool.get()?; write_conn.batch_execute("PRAGMA query_only = OFF;")?; Some(Arc::new(Mutex::new(write_conn))) From 6afe56e47f312d699bc8c95677791aef476b5799 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:14:04 -0500 Subject: [PATCH 23/38] cleanup --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 038c92d0d..765849092 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -25,6 +25,7 @@ pub type DbConnection = DbConnectionPrivate { read: Arc>, write: Option>>, @@ -32,16 +33,6 @@ pub struct DbConnectionPrivate { pub(super) in_transaction: Arc, } -impl Clone for DbConnectionPrivate { - fn clone(&self) -> Self { - Self { - read: self.read.clone(), - write: self.write.clone(), - in_transaction: self.in_transaction.clone(), - } - } -} - /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> From 1f0a2f81ba4628d76e3e01af3eb6cb72ac697675 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:17:48 -0500 Subject: [PATCH 24/38] undo ignore --- xmtp_mls/src/client.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 37a062c21..90a920355 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1403,7 +1403,6 @@ pub(crate) mod tests { not(target_arch = "wasm32"), tokio::test(flavor = "multi_thread", worker_threads = 1) )] - #[ignore] async fn test_add_remove_then_add_again() { let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; From b75faedb017ce4d25fa09b016041ef7be04acd28 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:19:55 -0500 Subject: [PATCH 25/38] should never happen, but lets handle gracefully anyway --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 2 +- xmtp_mls/src/storage/errors.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 765849092..d4621a25a 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -60,7 +60,7 @@ where ::TransactionManager::begin_transaction(&mut *write)?; if self.in_transaction.swap(true, Ordering::SeqCst) { - panic!("Already in transaction."); + return Err(StorageError::AlreadyInTransaction); } Ok(TransactionGuard { diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index 35d433af3..04f17e94b 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -50,6 +50,8 @@ pub enum StorageError { Duplicate(DuplicateItem), #[error(transparent)] OpenMlsStorage(#[from] SqlKeyStoreError), + #[error("Connection is already marked as in transaction")] + AlreadyInTransaction, #[error("Transaction was intentionally rolled back")] IntentionalRollback, } From de2fa0223813f1e557e1b636cef5173cf3e26a17 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:38:56 -0500 Subject: [PATCH 26/38] bump the ephemeral connection count to 2 --- .../storage/encrypted_store/db_connection.rs | 23 +++++-------------- .../src/storage/encrypted_store/native.rs | 17 +++++--------- 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index d4621a25a..59b8e541f 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -28,7 +28,7 @@ pub type DbConnection = DbConnectionPrivate { read: Arc>, - write: Option>>, + write: Arc>, // This field will funnel all reads / writes to the write connection if true. pub(super) in_transaction: Arc, } @@ -36,7 +36,7 @@ pub struct DbConnectionPrivate { /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(read: Arc>, write: Option>>) -> Self { + pub(super) fn from_arc_mutex(read: Arc>, write: Arc>) -> Self { Self { read, write, @@ -52,11 +52,7 @@ where pub(crate) fn start_transaction>( &self, ) -> Result { - let mut write = self - .write - .as_ref() - .expect("Tried to open transaction on read-only connection") - .lock(); + let mut write = self.write.lock(); ::TransactionManager::begin_transaction(&mut *write)?; if self.in_transaction.swap(true, Ordering::SeqCst) { @@ -79,10 +75,8 @@ where F: FnOnce(&mut C) -> Result, { if self.in_transaction() { - if let Some(write) = &self.write { - let mut lock = write.lock(); - return fun(&mut lock); - }; + let mut lock = self.write.lock(); + return fun(&mut lock); } let mut lock = self.read.lock(); @@ -95,12 +89,7 @@ where where F: FnOnce(&mut C) -> Result, { - if let Some(write_conn) = &self.write { - let mut lock = write_conn.lock(); - return fun(&mut lock); - } - - let mut lock = self.read.lock(); + let mut lock = self.write.lock(); fun(&mut lock) } } diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 81dc7b9f9..f9ccac8bb 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -92,7 +92,7 @@ impl StorageOption { #[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { - pub(super) write_conn: Option>>, + pub(super) write_conn: Arc>, pub(super) pool: Arc>>, customizer: Option>, opts: StorageOption, @@ -119,24 +119,19 @@ impl NativeDb { let pool = match opts { StorageOption::Ephemeral => builder - .max_size(1) + .max_size(2) .build(ConnectionManager::new(":memory:"))?, StorageOption::Persistent(ref path) => builder .max_size(crate::configuration::MAX_DB_POOL_SIZE) .build(ConnectionManager::new(path))?, }; - let write_conn = if matches!(opts, StorageOption::Persistent(_)) { - // Take one of the connections and use it as the only writer. - let mut write_conn = pool.get()?; - write_conn.batch_execute("PRAGMA query_only = OFF;")?; - Some(Arc::new(Mutex::new(write_conn))) - } else { - None - }; + // Take one of the connections and use it as the only writer. + let mut write_conn = pool.get()?; + write_conn.batch_execute("PRAGMA query_only = OFF;")?; Ok(Self { - write_conn, + write_conn: Arc::new(Mutex::new(write_conn)), pool: Arc::new(Some(pool).into()), customizer, opts: opts.clone(), From 17a8d804ff8fee4464cb7db71911cdbf7e4142c5 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:39:56 -0500 Subject: [PATCH 27/38] cleanup --- xmtp_mls/src/groups/mod.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 50c34ac7c..1cd041861 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -933,8 +933,6 @@ impl MlsGroup { intent_data.into(), )?; - tracing::warn!("This makes it here?"); - self.sync_until_intent_resolved(provider, intent.id).await } From 484dafd210f42cbcca05ba0748cc6acfb99cdd98 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 13:43:01 -0500 Subject: [PATCH 28/38] cleanup wasm --- xmtp_mls/src/storage/encrypted_store/wasm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/wasm.rs b/xmtp_mls/src/storage/encrypted_store/wasm.rs index ab73410db..9cf63e3d0 100644 --- a/xmtp_mls/src/storage/encrypted_store/wasm.rs +++ b/xmtp_mls/src/storage/encrypted_store/wasm.rs @@ -44,7 +44,7 @@ impl XmtpDb for WasmDb { fn conn(&self) -> Result, StorageError> { Ok(DbConnectionPrivate::from_arc_mutex( self.conn.clone(), - Some(self.conn.clone()), + self.conn.clone(), )) } From 9741ef873694932ddfe100d5bc062e257bbc2702 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 14:28:14 -0500 Subject: [PATCH 29/38] test fix --- .../storage/encrypted_store/db_connection.rs | 20 ++++++++++--------- .../src/storage/encrypted_store/native.rs | 14 ++++++++----- .../encrypted_store/sqlcipher_connection.rs | 4 +--- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 59b8e541f..cf7f804a9 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -27,7 +27,7 @@ pub type DbConnection = DbConnectionPrivate { - read: Arc>, + read: Option>>, write: Arc>, // This field will funnel all reads / writes to the write connection if true. pub(super) in_transaction: Arc, @@ -36,7 +36,7 @@ pub struct DbConnectionPrivate { /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(read: Arc>, write: Arc>) -> Self { + pub(super) fn from_arc_mutex(write: Arc>, read: Option>>) -> Self { Self { read, write, @@ -74,12 +74,15 @@ where where F: FnOnce(&mut C) -> Result, { - if self.in_transaction() { - let mut lock = self.write.lock(); - return fun(&mut lock); - } + let mut lock = if self.in_transaction() { + tracing::debug!("Funneling read to write connection due to being in a transaction."); + self.write.lock() + } else if let Some(read) = &self.read { + read.lock() + } else { + self.write.lock() + }; - let mut lock = self.read.lock(); fun(&mut lock) } @@ -89,8 +92,7 @@ where where F: FnOnce(&mut C) -> Result, { - let mut lock = self.write.lock(); - fun(&mut lock) + fun(&mut self.write.lock()) } } diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index f9ccac8bb..57b833590 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -92,8 +92,8 @@ impl StorageOption { #[derive(Clone)] /// Database used in `native` (everywhere but web) pub struct NativeDb { - pub(super) write_conn: Arc>, pub(super) pool: Arc>>, + pub(super) write_conn: Arc>, customizer: Option>, opts: StorageOption, } @@ -119,7 +119,7 @@ impl NativeDb { let pool = match opts { StorageOption::Ephemeral => builder - .max_size(2) + .max_size(1) .build(ConnectionManager::new(":memory:"))?, StorageOption::Persistent(ref path) => builder .max_size(crate::configuration::MAX_DB_POOL_SIZE) @@ -131,8 +131,8 @@ impl NativeDb { write_conn.batch_execute("PRAGMA query_only = OFF;")?; Ok(Self { - write_conn: Arc::new(Mutex::new(write_conn)), pool: Arc::new(Some(pool).into()), + write_conn: Arc::new(Mutex::new(write_conn)), customizer, opts: opts.clone(), }) @@ -161,10 +161,14 @@ impl XmtpDb for NativeDb { /// Returns the Wrapped [`super::db_connection::DbConnection`] Connection implementation for this Database fn conn(&self) -> Result, StorageError> { - let conn = self.raw_conn()?; + let conn = match self.opts { + StorageOption::Ephemeral => None, + StorageOption::Persistent(_) => Some(self.raw_conn()?), + }; + Ok(DbConnectionPrivate::from_arc_mutex( - Arc::new(parking_lot::Mutex::new(conn)), self.write_conn.clone(), + conn.map(|conn| Arc::new(parking_lot::Mutex::new(conn))), )) } diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index b5cae60f3..184a0c4f9 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -204,9 +204,7 @@ impl EncryptedConnection { /// Output the corect order of PRAGMAS to instantiate a connection fn pragmas(&self) -> impl Display { - let Self { - ref key, ref salt, .. - } = self; + let Self { ref key, ref salt } = self; if let Some(s) = salt { format!( From 4fb6475f37c5879bc7fe479c3c0d1f9e0978c12a Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 14:31:51 -0500 Subject: [PATCH 30/38] wasm fix --- xmtp_mls/src/storage/encrypted_store/wasm.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/wasm.rs b/xmtp_mls/src/storage/encrypted_store/wasm.rs index 9cf63e3d0..680dd83eb 100644 --- a/xmtp_mls/src/storage/encrypted_store/wasm.rs +++ b/xmtp_mls/src/storage/encrypted_store/wasm.rs @@ -42,10 +42,7 @@ impl XmtpDb for WasmDb { type TransactionManager = AnsiTransactionManager; fn conn(&self) -> Result, StorageError> { - Ok(DbConnectionPrivate::from_arc_mutex( - self.conn.clone(), - self.conn.clone(), - )) + Ok(DbConnectionPrivate::from_arc_mutex(self.conn.clone(), None)) } fn validate(&self, _opts: &StorageOption) -> Result<(), StorageError> { From bd90ca6d1d82154e8f388b75b185da0a425ffd5a Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 15:43:48 -0500 Subject: [PATCH 31/38] global transaction lock --- .../storage/encrypted_store/db_connection.rs | 32 ++++---- xmtp_mls/src/storage/encrypted_store/mod.rs | 74 ------------------- .../src/storage/encrypted_store/native.rs | 3 + 3 files changed, 19 insertions(+), 90 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index cf7f804a9..d39aa97e7 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -29,17 +29,22 @@ pub type DbConnection = DbConnectionPrivate { read: Option>>, write: Arc>, - // This field will funnel all reads / writes to the write connection if true. - pub(super) in_transaction: Arc, + transaction_lock: Arc>, + in_transaction: Arc, } /// Owned DBConnection Methods impl DbConnectionPrivate { /// Create a new [`DbConnectionPrivate`] from an existing Arc> - pub(super) fn from_arc_mutex(write: Arc>, read: Option>>) -> Self { + pub(super) fn from_arc_mutex( + write: Arc>, + read: Option>>, + transaction_lock: Arc>, + ) -> Self { Self { read, write, + transaction_lock, in_transaction: Arc::new(AtomicBool::new(false)), } } @@ -51,31 +56,25 @@ where { pub(crate) fn start_transaction>( &self, - ) -> Result { + ) -> Result, StorageError> { + let guard = self.transaction_lock.lock(); let mut write = self.write.lock(); ::TransactionManager::begin_transaction(&mut *write)?; - - if self.in_transaction.swap(true, Ordering::SeqCst) { - return Err(StorageError::AlreadyInTransaction); - } + self.in_transaction.store(true, Ordering::SeqCst); Ok(TransactionGuard { + _mutex_guard: guard, in_transaction: self.in_transaction.clone(), }) } - fn in_transaction(&self) -> bool { - self.in_transaction.load(Ordering::SeqCst) - } - /// Do a scoped query with a mutable [`diesel::Connection`] /// reference pub(crate) fn raw_query_read(&self, fun: F) -> Result where F: FnOnce(&mut C) -> Result, { - let mut lock = if self.in_transaction() { - tracing::debug!("Funneling read to write connection due to being in a transaction."); + let mut lock = if self.in_transaction.load(Ordering::SeqCst) { self.write.lock() } else if let Some(read) = &self.read { read.lock() @@ -115,10 +114,11 @@ impl fmt::Debug for DbConnectionPrivate { } } -pub struct TransactionGuard { +pub struct TransactionGuard<'a> { in_transaction: Arc, + _mutex_guard: parking_lot::MutexGuard<'a, ()>, } -impl Drop for TransactionGuard { +impl<'a> Drop for TransactionGuard<'a> { fn drop(&mut self) { self.in_transaction.store(false, Ordering::SeqCst); } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index f138cca75..a615ed102 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -664,78 +664,4 @@ pub(crate) mod tests { } EncryptedMessageStore::remove_db_files(db_path) } - - // get two connections - // start a transaction - // try to write with second connection - // write should fail & rollback - // first thread succeeds - // wasm does not have threads - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - #[cfg(not(target_arch = "wasm32"))] - async fn test_transaction_rollback() { - use crate::XmtpOpenMlsProvider; - use std::sync::{Arc, Barrier}; - - let db_path = tmp_path(); - let store = EncryptedMessageStore::new( - StorageOption::Persistent(db_path.clone()), - EncryptedMessageStore::generate_enc_key(), - ) - .await - .unwrap(); - - let barrier = Arc::new(Barrier::new(2)); - let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); - let barrier_pointer = barrier.clone(); - let handle = std::thread::spawn(move || { - provider.transaction(|provider| { - let conn1 = provider.conn_ref(); - StoredIdentity::new("correct".to_string(), rand_vec::<24>(), rand_vec::<24>()) - .store(conn1) - .unwrap(); - // wait for second transaction to start - barrier_pointer.wait(); - // wait for second transaction to finish - barrier_pointer.wait(); - Ok::<_, StorageError>(()) - }) - }); - - let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); - let handle2 = std::thread::spawn(move || { - barrier.wait(); - let result = provider.transaction(|provider| -> Result<(), anyhow::Error> { - let connection = provider.conn_ref(); - let group = StoredGroup::new( - b"should not exist".to_vec(), - 0, - GroupMembershipState::Allowed, - "goodbye".to_string(), - None, - ); - group.store(connection)?; - Ok(()) - }); - barrier.wait(); - result - }); - - let result = handle.join().unwrap(); - assert!(result.is_ok()); - - let result = handle2.join().unwrap(); - - // handle 2 errored because the first transaction has precedence - assert_eq!( - result.unwrap_err().to_string(), - "Diesel result error: database is locked" - ); - let groups = store - .conn() - .unwrap() - .find_group(b"should not exist") - .unwrap(); - assert_eq!(groups, None); - } } diff --git a/xmtp_mls/src/storage/encrypted_store/native.rs b/xmtp_mls/src/storage/encrypted_store/native.rs index 57b833590..bf3aab9cc 100644 --- a/xmtp_mls/src/storage/encrypted_store/native.rs +++ b/xmtp_mls/src/storage/encrypted_store/native.rs @@ -94,6 +94,7 @@ impl StorageOption { pub struct NativeDb { pub(super) pool: Arc>>, pub(super) write_conn: Arc>, + transaction_lock: Arc>, customizer: Option>, opts: StorageOption, } @@ -133,6 +134,7 @@ impl NativeDb { Ok(Self { pool: Arc::new(Some(pool).into()), write_conn: Arc::new(Mutex::new(write_conn)), + transaction_lock: Arc::new(Mutex::new(())), customizer, opts: opts.clone(), }) @@ -169,6 +171,7 @@ impl XmtpDb for NativeDb { Ok(DbConnectionPrivate::from_arc_mutex( self.write_conn.clone(), conn.map(|conn| Arc::new(parking_lot::Mutex::new(conn))), + self.transaction_lock.clone(), )) } From 3687bbfb15e6533e81cda07780328f883053e6d6 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 15:46:04 -0500 Subject: [PATCH 32/38] comments --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index d39aa97e7..239c5aa62 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -27,9 +27,13 @@ pub type DbConnection = DbConnectionPrivate { + // Connection with read-only privileges read: Option>>, + // Connection with write privileges write: Arc>, + // Is any connection (possibly this one) currently in a transaction? transaction_lock: Arc>, + // Is this particular connection in a transaction? in_transaction: Arc, } @@ -118,7 +122,7 @@ pub struct TransactionGuard<'a> { in_transaction: Arc, _mutex_guard: parking_lot::MutexGuard<'a, ()>, } -impl<'a> Drop for TransactionGuard<'a> { +impl<'_> Drop for TransactionGuard<'_> { fn drop(&mut self) { self.in_transaction.store(false, Ordering::SeqCst); } From fe0c8ebb91e9f69e2908fa7d30aa4ed192d901c3 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 15:48:20 -0500 Subject: [PATCH 33/38] lint --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 239c5aa62..96458e5c2 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -122,7 +122,7 @@ pub struct TransactionGuard<'a> { in_transaction: Arc, _mutex_guard: parking_lot::MutexGuard<'a, ()>, } -impl<'_> Drop for TransactionGuard<'_> { +impl Drop for TransactionGuard<'_> { fn drop(&mut self) { self.in_transaction.store(false, Ordering::SeqCst); } From 39230c76dea888acf1e4d26aec7f500b817ab5af Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 15:52:28 -0500 Subject: [PATCH 34/38] fix wasm --- xmtp_mls/src/storage/encrypted_store/wasm.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/wasm.rs b/xmtp_mls/src/storage/encrypted_store/wasm.rs index 680dd83eb..558dc88f6 100644 --- a/xmtp_mls/src/storage/encrypted_store/wasm.rs +++ b/xmtp_mls/src/storage/encrypted_store/wasm.rs @@ -11,6 +11,7 @@ use super::{db_connection::DbConnectionPrivate, StorageError, StorageOption, Xmt pub struct WasmDb { conn: Arc>, opts: StorageOption, + transaction_lock: Arc>, } impl std::fmt::Debug for WasmDb { @@ -33,6 +34,7 @@ impl WasmDb { Ok(Self { conn: Arc::new(Mutex::new(conn)), opts: opts.clone(), + transaction_lock: Arc::new(Mutex::new(())), }) } } @@ -42,7 +44,11 @@ impl XmtpDb for WasmDb { type TransactionManager = AnsiTransactionManager; fn conn(&self) -> Result, StorageError> { - Ok(DbConnectionPrivate::from_arc_mutex(self.conn.clone(), None)) + Ok(DbConnectionPrivate::from_arc_mutex( + self.conn.clone(), + None, + self.transaction_lock.clone(), + )) } fn validate(&self, _opts: &StorageOption) -> Result<(), StorageError> { From 0cd2375f946b3331cf0039a5c9e5cbf9fbc6ab23 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 16:06:58 -0500 Subject: [PATCH 35/38] cleanup clone --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 96458e5c2..3ca2ee7cc 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -25,7 +25,6 @@ pub type DbConnection = DbConnectionPrivate { // Connection with read-only privileges read: Option>>, From 6ae423f57a87e5bb34ecccb013c0357dd0e3ec7f Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 16:08:45 -0500 Subject: [PATCH 36/38] cleanup unused error --- xmtp_mls/src/storage/errors.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index 04f17e94b..35d433af3 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -50,8 +50,6 @@ pub enum StorageError { Duplicate(DuplicateItem), #[error(transparent)] OpenMlsStorage(#[from] SqlKeyStoreError), - #[error("Connection is already marked as in transaction")] - AlreadyInTransaction, #[error("Transaction was intentionally rolled back")] IntentionalRollback, } From 3ce7d1d8eeb789d6c0e7af630236cab351990f93 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 16:20:23 -0500 Subject: [PATCH 37/38] a little more sync work --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 3ca2ee7cc..e66a3d123 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -94,6 +94,12 @@ where where F: FnOnce(&mut C) -> Result, { + let _guard; + // If this connection is not in a transaction + if !self.in_transaction.load(Ordering::SeqCst) { + // Make sure another connection isn't + _guard = self.transaction_lock.lock(); + } fun(&mut self.write.lock()) } } From de95bab8d7d9760deb567e57fcfa544bc6c94466 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Mon, 3 Feb 2025 16:51:33 -0500 Subject: [PATCH 38/38] naming --- xmtp_mls/src/storage/encrypted_store/db_connection.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index e66a3d123..d42f656d9 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -31,7 +31,7 @@ pub struct DbConnectionPrivate { // Connection with write privileges write: Arc>, // Is any connection (possibly this one) currently in a transaction? - transaction_lock: Arc>, + global_transaction_lock: Arc>, // Is this particular connection in a transaction? in_transaction: Arc, } @@ -47,7 +47,7 @@ impl DbConnectionPrivate { Self { read, write, - transaction_lock, + global_transaction_lock: transaction_lock, in_transaction: Arc::new(AtomicBool::new(false)), } } @@ -60,7 +60,7 @@ where pub(crate) fn start_transaction>( &self, ) -> Result, StorageError> { - let guard = self.transaction_lock.lock(); + let guard = self.global_transaction_lock.lock(); let mut write = self.write.lock(); ::TransactionManager::begin_transaction(&mut *write)?; self.in_transaction.store(true, Ordering::SeqCst); @@ -98,7 +98,7 @@ where // If this connection is not in a transaction if !self.in_transaction.load(Ordering::SeqCst) { // Make sure another connection isn't - _guard = self.transaction_lock.lock(); + _guard = self.global_transaction_lock.lock(); } fun(&mut self.write.lock()) }