Skip to content

Commit

Permalink
Consent filtering by array (#1512)
Browse files Browse the repository at this point in the history
* allow syncing an array of consent states

* allow passing an array of consent for list convos

* fix up the logic a bit

* cargo clippy
  • Loading branch information
nplasterer authored Jan 16, 2025
1 parent dad64d5 commit 191b8f3
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 31 deletions.
13 changes: 8 additions & 5 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ pub struct FfiListConversationsOptions {
pub created_after_ns: Option<i64>,
pub created_before_ns: Option<i64>,
pub limit: Option<i64>,
pub consent_state: Option<FfiConsentState>,
pub consent_states: Option<Vec<FfiConsentState>>,
pub include_duplicate_dms: bool,
}

Expand All @@ -623,7 +623,9 @@ impl From<FfiListConversationsOptions> for GroupQueryArgs {
created_before_ns: opts.created_before_ns,
created_after_ns: opts.created_after_ns,
limit: opts.limit,
consent_state: opts.consent_state.map(Into::into),
consent_states: opts
.consent_states
.map(|vec| vec.into_iter().map(Into::into).collect()),
include_duplicate_dms: opts.include_duplicate_dms,
..Default::default()
}
Expand Down Expand Up @@ -1027,13 +1029,14 @@ impl FfiConversations {

pub async fn sync_all_conversations(
&self,
consent_state: Option<FfiConsentState>,
consent_states: Option<Vec<FfiConsentState>>,
) -> Result<u32, GenericError> {
let inner = self.inner_client.as_ref();
let provider = inner.mls_provider()?;
let consent: Option<ConsentState> = consent_state.map(|state| state.into());
let consents: Option<Vec<ConsentState>> =
consent_states.map(|states| states.into_iter().map(|state| state.into()).collect());
let num_groups_synced: usize = inner
.sync_all_welcomes_and_groups(&provider, consent)
.sync_all_welcomes_and_groups(&provider, consents)
.await?;
// Convert usize to u32 for compatibility with Uniffi
let num_groups_synced: u32 = num_groups_synced
Expand Down
14 changes: 10 additions & 4 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,11 +958,11 @@ where
pub async fn sync_all_welcomes_and_groups(
&self,
provider: &XmtpOpenMlsProvider,
consent_state: Option<ConsentState>,
consent_states: Option<Vec<ConsentState>>,
) -> Result<usize, ClientError> {
self.sync_welcomes(provider).await?;
let query_args = GroupQueryArgs {
consent_state,
consent_states,
include_sync_groups: true,
include_duplicate_dms: true,
..GroupQueryArgs::default()
Expand Down Expand Up @@ -1346,7 +1346,10 @@ pub(crate) mod tests {

// Sync with `Unknown`: Bob should not fetch new messages
let bob_received_groups_unknown = bo
.sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), Some(ConsentState::Allowed))
.sync_all_welcomes_and_groups(
&bo.mls_provider().unwrap(),
Some([ConsentState::Allowed].to_vec()),
)
.await
.unwrap();
assert_eq!(bob_received_groups_unknown, 0);
Expand Down Expand Up @@ -1379,7 +1382,10 @@ pub(crate) mod tests {

// Sync with `None`: Bob should fetch all messages
let bob_received_groups_all = bo
.sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), Some(ConsentState::Unknown))
.sync_all_welcomes_and_groups(
&bo.mls_provider().unwrap(),
Some([ConsentState::Unknown].to_vec()),
)
.await
.unwrap();
assert_eq!(bob_received_groups_all, 2);
Expand Down
37 changes: 29 additions & 8 deletions xmtp_mls/src/storage/encrypted_store/conversation_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl DbConnection {
created_before_ns,
limit,
conversation_type,
consent_state,
consent_states,
include_sync_groups,
include_duplicate_dms,
} = args.as_ref();
Expand Down Expand Up @@ -111,8 +111,12 @@ impl DbConnection {
query = query.filter(conversation_list_dsl::conversation_type.eq(conversation_type));
}

let mut conversations = if let Some(consent_state) = consent_state {
if *consent_state == ConsentState::Unknown {
let mut conversations = if let Some(consent_states) = consent_states {
if consent_states
.iter()
.any(|state| *state == ConsentState::Unknown)
{
// Include both `Unknown`, `null`, and other specified states
let query = query
.left_join(
consent_dsl::consent_records.on(sql::<diesel::sql_types::Text>(
Expand All @@ -123,27 +127,36 @@ impl DbConnection {
.filter(
consent_dsl::state
.is_null()
.or(consent_dsl::state.eq(ConsentState::Unknown)),
.or(consent_dsl::state.eq(ConsentState::Unknown))
.or(consent_dsl::state.eq_any(
consent_states
.iter()
.filter(|state| **state != ConsentState::Unknown)
.cloned()
.collect::<Vec<_>>(),
)),
)
.select(conversation_list::all_columns())
.order(conversation_list_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
} else {
// Only include the specified states
let query = query
.inner_join(
consent_dsl::consent_records.on(sql::<diesel::sql_types::Text>(
"lower(hex(conversation_list.id))",
)
.eq(consent_dsl::entity)),
)
.filter(consent_dsl::state.eq(*consent_state))
.filter(consent_dsl::state.eq_any(consent_states.clone()))
.select(conversation_list::all_columns())
.order(conversation_list_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
}
} else {
// Handle the case where `consent_states` is `None`
self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
};

Expand Down Expand Up @@ -338,22 +351,30 @@ pub(crate) mod tests {

let allowed_results = conn
.fetch_conversation_list(
GroupQueryArgs::default().consent_state(ConsentState::Allowed),
GroupQueryArgs::default().consent_states([ConsentState::Allowed].to_vec()),
)
.unwrap();
assert_eq!(allowed_results.len(), 2);

let allowed_unknown_results = conn
.fetch_conversation_list(
GroupQueryArgs::default()
.consent_states([ConsentState::Allowed, ConsentState::Unknown].to_vec()),
)
.unwrap();
assert_eq!(allowed_unknown_results.len(), 3);

let denied_results = conn
.fetch_conversation_list(
GroupQueryArgs::default().consent_state(ConsentState::Denied),
GroupQueryArgs::default().consent_states([ConsentState::Denied].to_vec()),
)
.unwrap();
assert_eq!(denied_results.len(), 1);
assert_eq!(denied_results[0].id, test_group_2.id);

let unknown_results = conn
.fetch_conversation_list(
GroupQueryArgs::default().consent_state(ConsentState::Unknown),
GroupQueryArgs::default().consent_states([ConsentState::Unknown].to_vec()),
)
.unwrap();
assert_eq!(unknown_results.len(), 1);
Expand Down
55 changes: 41 additions & 14 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ pub struct GroupQueryArgs {
pub created_before_ns: Option<i64>,
pub limit: Option<i64>,
pub conversation_type: Option<ConversationType>,
pub consent_state: Option<ConsentState>,
pub consent_states: Option<Vec<ConsentState>>,
pub include_sync_groups: bool,
pub include_duplicate_dms: bool,
}
Expand Down Expand Up @@ -195,11 +195,11 @@ impl GroupQueryArgs {
self
}

pub fn consent_state(self, consent_state: ConsentState) -> Self {
self.maybe_consent_state(Some(consent_state))
pub fn consent_states(self, consent_states: Vec<ConsentState>) -> Self {
self.maybe_consent_states(Some(consent_states))
}
pub fn maybe_consent_state(mut self, consent_state: Option<ConsentState>) -> Self {
self.consent_state = consent_state;
pub fn maybe_consent_states(mut self, consent_states: Option<Vec<ConsentState>>) -> Self {
self.consent_states = consent_states;
self
}

Expand All @@ -223,7 +223,7 @@ impl DbConnection {
created_before_ns,
limit,
conversation_type,
consent_state,
consent_states,
include_sync_groups,
include_duplicate_dms,
} = args.as_ref();
Expand Down Expand Up @@ -265,8 +265,12 @@ impl DbConnection {
query = query.filter(groups_dsl::conversation_type.eq(conversation_type));
}

let mut groups = if let Some(consent_state) = consent_state {
if *consent_state == ConsentState::Unknown {
let mut groups = if let Some(consent_states) = consent_states {
if consent_states
.iter()
.any(|state| *state == ConsentState::Unknown)
{
// Include both `Unknown`, `null`, and other specified states
let query = query
.left_join(
consent_dsl::consent_records
Expand All @@ -276,26 +280,35 @@ impl DbConnection {
.filter(
consent_dsl::state
.is_null()
.or(consent_dsl::state.eq(ConsentState::Unknown)),
.or(consent_dsl::state.eq(ConsentState::Unknown))
.or(consent_dsl::state.eq_any(
consent_states
.iter()
.filter(|state| **state != ConsentState::Unknown)
.cloned()
.collect::<Vec<_>>(),
)),
)
.select(groups_dsl::groups::all_columns())
.order(groups_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<StoredGroup>(conn))?
} else {
// Only include the specified states
let query = query
.inner_join(
consent_dsl::consent_records
.on(sql::<diesel::sql_types::Text>("lower(hex(groups.id))")
.eq(consent_dsl::entity)),
)
.filter(consent_dsl::state.eq(*consent_state))
.filter(consent_dsl::state.eq_any(consent_states.clone()))
.select(groups_dsl::groups::all_columns())
.order(groups_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<StoredGroup>(conn))?
}
} else {
// Handle the case where `consent_states` is `None`
self.raw_query(|conn| query.load::<StoredGroup>(conn))?
};

Expand Down Expand Up @@ -886,7 +899,7 @@ pub(crate) mod tests {
// Load the sync group with a consent filter
let allowed_groups = conn
.find_groups(&GroupQueryArgs {
consent_state: Some(ConsentState::Allowed),
consent_states: Some([ConsentState::Allowed].to_vec()),
include_sync_groups: true,
..Default::default()
})
Expand Down Expand Up @@ -934,18 +947,32 @@ pub(crate) mod tests {
assert_eq!(all_results.len(), 4);

let allowed_results = conn
.find_groups(GroupQueryArgs::default().consent_state(ConsentState::Allowed))
.find_groups(
GroupQueryArgs::default().consent_states([ConsentState::Allowed].to_vec()),
)
.unwrap();
assert_eq!(allowed_results.len(), 2);

let allowed_unknown_results = conn
.find_groups(
GroupQueryArgs::default()
.consent_states([ConsentState::Allowed, ConsentState::Unknown].to_vec()),
)
.unwrap();
assert_eq!(allowed_unknown_results.len(), 3);

let denied_results = conn
.find_groups(GroupQueryArgs::default().consent_state(ConsentState::Denied))
.find_groups(
GroupQueryArgs::default().consent_states([ConsentState::Denied].to_vec()),
)
.unwrap();
assert_eq!(denied_results.len(), 1);
assert_eq!(denied_results[0].id, test_group_2.id);

let unknown_results = conn
.find_groups(GroupQueryArgs::default().consent_state(ConsentState::Unknown))
.find_groups(
GroupQueryArgs::default().consent_states([ConsentState::Unknown].to_vec()),
)
.unwrap();
assert_eq!(unknown_results.len(), 1);
assert_eq!(unknown_results[0].id, test_group_4.id);
Expand Down

0 comments on commit 191b8f3

Please sign in to comment.