From 191b8f374c41e9f8d507096488e2f7218fe7c772 Mon Sep 17 00:00:00 2001 From: Naomi Plasterer Date: Thu, 16 Jan 2025 12:35:10 -0800 Subject: [PATCH] Consent filtering by array (#1512) * allow syncing an array of consent states * allow passing an array of consent for list convos * fix up the logic a bit * cargo clippy --- bindings_ffi/src/mls.rs | 13 +++-- xmtp_mls/src/client.rs | 14 +++-- .../encrypted_store/conversation_list.rs | 37 ++++++++++--- xmtp_mls/src/storage/encrypted_store/group.rs | 55 ++++++++++++++----- 4 files changed, 88 insertions(+), 31 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 983b81963..405bcf454 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -613,7 +613,7 @@ pub struct FfiListConversationsOptions { pub created_after_ns: Option, pub created_before_ns: Option, pub limit: Option, - pub consent_state: Option, + pub consent_states: Option>, pub include_duplicate_dms: bool, } @@ -623,7 +623,9 @@ impl From for GroupQueryArgs { created_before_ns: opts.created_before_ns, created_after_ns: opts.created_after_ns, limit: opts.limit, - consent_state: opts.consent_state.map(Into::into), + consent_states: opts + .consent_states + .map(|vec| vec.into_iter().map(Into::into).collect()), include_duplicate_dms: opts.include_duplicate_dms, ..Default::default() } @@ -1027,13 +1029,14 @@ impl FfiConversations { pub async fn sync_all_conversations( &self, - consent_state: Option, + consent_states: Option>, ) -> Result { let inner = self.inner_client.as_ref(); let provider = inner.mls_provider()?; - let consent: Option = consent_state.map(|state| state.into()); + let consents: Option> = + consent_states.map(|states| states.into_iter().map(|state| state.into()).collect()); let num_groups_synced: usize = inner - .sync_all_welcomes_and_groups(&provider, consent) + .sync_all_welcomes_and_groups(&provider, consents) .await?; // Convert usize to u32 for compatibility with Uniffi let num_groups_synced: u32 = num_groups_synced diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 09a22d239..e2b9dae25 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -958,11 +958,11 @@ where pub async fn sync_all_welcomes_and_groups( &self, provider: &XmtpOpenMlsProvider, - consent_state: Option, + consent_states: Option>, ) -> Result { self.sync_welcomes(provider).await?; let query_args = GroupQueryArgs { - consent_state, + consent_states, include_sync_groups: true, include_duplicate_dms: true, ..GroupQueryArgs::default() @@ -1346,7 +1346,10 @@ pub(crate) mod tests { // Sync with `Unknown`: Bob should not fetch new messages let bob_received_groups_unknown = bo - .sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), Some(ConsentState::Allowed)) + .sync_all_welcomes_and_groups( + &bo.mls_provider().unwrap(), + Some([ConsentState::Allowed].to_vec()), + ) .await .unwrap(); assert_eq!(bob_received_groups_unknown, 0); @@ -1379,7 +1382,10 @@ pub(crate) mod tests { // Sync with `None`: Bob should fetch all messages let bob_received_groups_all = bo - .sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), Some(ConsentState::Unknown)) + .sync_all_welcomes_and_groups( + &bo.mls_provider().unwrap(), + Some([ConsentState::Unknown].to_vec()), + ) .await .unwrap(); assert_eq!(bob_received_groups_all, 2); diff --git a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs index f71dcf5ad..de1b7bd32 100644 --- a/xmtp_mls/src/storage/encrypted_store/conversation_list.rs +++ b/xmtp_mls/src/storage/encrypted_store/conversation_list.rs @@ -70,7 +70,7 @@ impl DbConnection { created_before_ns, limit, conversation_type, - consent_state, + consent_states, include_sync_groups, include_duplicate_dms, } = args.as_ref(); @@ -111,8 +111,12 @@ impl DbConnection { query = query.filter(conversation_list_dsl::conversation_type.eq(conversation_type)); } - let mut conversations = if let Some(consent_state) = consent_state { - if *consent_state == ConsentState::Unknown { + let mut conversations = if let Some(consent_states) = consent_states { + if consent_states + .iter() + .any(|state| *state == ConsentState::Unknown) + { + // Include both `Unknown`, `null`, and other specified states let query = query .left_join( consent_dsl::consent_records.on(sql::( @@ -123,13 +127,21 @@ impl DbConnection { .filter( consent_dsl::state .is_null() - .or(consent_dsl::state.eq(ConsentState::Unknown)), + .or(consent_dsl::state.eq(ConsentState::Unknown)) + .or(consent_dsl::state.eq_any( + consent_states + .iter() + .filter(|state| **state != ConsentState::Unknown) + .cloned() + .collect::>(), + )), ) .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); self.raw_query(|conn| query.load::(conn))? } else { + // Only include the specified states let query = query .inner_join( consent_dsl::consent_records.on(sql::( @@ -137,13 +149,14 @@ impl DbConnection { ) .eq(consent_dsl::entity)), ) - .filter(consent_dsl::state.eq(*consent_state)) + .filter(consent_dsl::state.eq_any(consent_states.clone())) .select(conversation_list::all_columns()) .order(conversation_list_dsl::created_at_ns.asc()); self.raw_query(|conn| query.load::(conn))? } } else { + // Handle the case where `consent_states` is `None` self.raw_query(|conn| query.load::(conn))? }; @@ -338,14 +351,22 @@ pub(crate) mod tests { let allowed_results = conn .fetch_conversation_list( - GroupQueryArgs::default().consent_state(ConsentState::Allowed), + GroupQueryArgs::default().consent_states([ConsentState::Allowed].to_vec()), ) .unwrap(); assert_eq!(allowed_results.len(), 2); + let allowed_unknown_results = conn + .fetch_conversation_list( + GroupQueryArgs::default() + .consent_states([ConsentState::Allowed, ConsentState::Unknown].to_vec()), + ) + .unwrap(); + assert_eq!(allowed_unknown_results.len(), 3); + let denied_results = conn .fetch_conversation_list( - GroupQueryArgs::default().consent_state(ConsentState::Denied), + GroupQueryArgs::default().consent_states([ConsentState::Denied].to_vec()), ) .unwrap(); assert_eq!(denied_results.len(), 1); @@ -353,7 +374,7 @@ pub(crate) mod tests { let unknown_results = conn .fetch_conversation_list( - GroupQueryArgs::default().consent_state(ConsentState::Unknown), + GroupQueryArgs::default().consent_states([ConsentState::Unknown].to_vec()), ) .unwrap(); assert_eq!(unknown_results.len(), 1); diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index e597ad062..89cde7088 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -135,7 +135,7 @@ pub struct GroupQueryArgs { pub created_before_ns: Option, pub limit: Option, pub conversation_type: Option, - pub consent_state: Option, + pub consent_states: Option>, pub include_sync_groups: bool, pub include_duplicate_dms: bool, } @@ -195,11 +195,11 @@ impl GroupQueryArgs { self } - pub fn consent_state(self, consent_state: ConsentState) -> Self { - self.maybe_consent_state(Some(consent_state)) + pub fn consent_states(self, consent_states: Vec) -> Self { + self.maybe_consent_states(Some(consent_states)) } - pub fn maybe_consent_state(mut self, consent_state: Option) -> Self { - self.consent_state = consent_state; + pub fn maybe_consent_states(mut self, consent_states: Option>) -> Self { + self.consent_states = consent_states; self } @@ -223,7 +223,7 @@ impl DbConnection { created_before_ns, limit, conversation_type, - consent_state, + consent_states, include_sync_groups, include_duplicate_dms, } = args.as_ref(); @@ -265,8 +265,12 @@ impl DbConnection { query = query.filter(groups_dsl::conversation_type.eq(conversation_type)); } - let mut groups = if let Some(consent_state) = consent_state { - if *consent_state == ConsentState::Unknown { + let mut groups = if let Some(consent_states) = consent_states { + if consent_states + .iter() + .any(|state| *state == ConsentState::Unknown) + { + // Include both `Unknown`, `null`, and other specified states let query = query .left_join( consent_dsl::consent_records @@ -276,26 +280,35 @@ impl DbConnection { .filter( consent_dsl::state .is_null() - .or(consent_dsl::state.eq(ConsentState::Unknown)), + .or(consent_dsl::state.eq(ConsentState::Unknown)) + .or(consent_dsl::state.eq_any( + consent_states + .iter() + .filter(|state| **state != ConsentState::Unknown) + .cloned() + .collect::>(), + )), ) .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); self.raw_query(|conn| query.load::(conn))? } else { + // Only include the specified states let query = query .inner_join( consent_dsl::consent_records .on(sql::("lower(hex(groups.id))") .eq(consent_dsl::entity)), ) - .filter(consent_dsl::state.eq(*consent_state)) + .filter(consent_dsl::state.eq_any(consent_states.clone())) .select(groups_dsl::groups::all_columns()) .order(groups_dsl::created_at_ns.asc()); self.raw_query(|conn| query.load::(conn))? } } else { + // Handle the case where `consent_states` is `None` self.raw_query(|conn| query.load::(conn))? }; @@ -886,7 +899,7 @@ pub(crate) mod tests { // Load the sync group with a consent filter let allowed_groups = conn .find_groups(&GroupQueryArgs { - consent_state: Some(ConsentState::Allowed), + consent_states: Some([ConsentState::Allowed].to_vec()), include_sync_groups: true, ..Default::default() }) @@ -934,18 +947,32 @@ pub(crate) mod tests { assert_eq!(all_results.len(), 4); let allowed_results = conn - .find_groups(GroupQueryArgs::default().consent_state(ConsentState::Allowed)) + .find_groups( + GroupQueryArgs::default().consent_states([ConsentState::Allowed].to_vec()), + ) .unwrap(); assert_eq!(allowed_results.len(), 2); + let allowed_unknown_results = conn + .find_groups( + GroupQueryArgs::default() + .consent_states([ConsentState::Allowed, ConsentState::Unknown].to_vec()), + ) + .unwrap(); + assert_eq!(allowed_unknown_results.len(), 3); + let denied_results = conn - .find_groups(GroupQueryArgs::default().consent_state(ConsentState::Denied)) + .find_groups( + GroupQueryArgs::default().consent_states([ConsentState::Denied].to_vec()), + ) .unwrap(); assert_eq!(denied_results.len(), 1); assert_eq!(denied_results[0].id, test_group_2.id); let unknown_results = conn - .find_groups(GroupQueryArgs::default().consent_state(ConsentState::Unknown)) + .find_groups( + GroupQueryArgs::default().consent_states([ConsentState::Unknown].to_vec()), + ) .unwrap(); assert_eq!(unknown_results.len(), 1); assert_eq!(unknown_results[0].id, test_group_4.id);