diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 39963f27b..de7833cb7 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -94,6 +94,25 @@ where Ok(group) } + pub fn find_groups( + &self, + allowed_states: Option>, + created_at_ns_gt: Option, + limit: Option, + ) -> Result>, ClientError> { + Ok(self + .store + .find_groups( + &mut self.store.conn()?, + allowed_states, + created_at_ns_gt, + limit, + )? + .into_iter() + .map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns)) + .collect()) + } + pub async fn register_identity(&self) -> Result<(), ClientError> { // TODO: Mark key package as last_resort in creation let last_resort_kp = self.identity.new_key_package(&self.mls_provider())?; @@ -253,4 +272,16 @@ mod tests { assert_eq!(key_package_2.wallet_address, wallet_address); assert!(!(key_package_2.eq(key_package))); } + + #[tokio::test] + async fn test_find_groups() { + let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + let group_1 = client.create_group().unwrap(); + let group_2 = client.create_group().unwrap(); + + let groups = client.find_groups(None, None, None).unwrap(); + assert_eq!(groups.len(), 2); + assert_eq!(groups[0].group_id, group_1.group_id); + assert_eq!(groups[1].group_id, group_2.group_id); + } } diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 45a1fe5b3..9b3827331 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -57,6 +57,7 @@ pub enum GroupError { pub struct MlsGroup<'c, ApiClient> { pub group_id: Vec, + pub created_at_ns: i64, client: &'c Client, } @@ -65,8 +66,12 @@ where ApiClient: XmtpApiClient + XmtpMlsClient, { // Creates a new group instance. Does not validate that the group exists in the DB - pub fn new(group_id: Vec, client: &'c Client) -> Self { - Self { client, group_id } + pub fn new(client: &'c Client, group_id: Vec, created_at_ns: i64) -> Self { + Self { + client, + group_id, + created_at_ns, + } } pub fn load_mls_group( @@ -102,7 +107,7 @@ where let stored_group = StoredGroup::new(group_id.clone(), now_ns(), membership_state); stored_group.store(&mut conn)?; - Ok(Self::new(group_id, client)) + Ok(Self::new(client, group_id, stored_group.created_at_ns)) } pub async fn send_message(&self, message: &[u8]) -> Result<(), GroupError> { diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index a54a475dc..e8fa37cd0 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -10,7 +10,10 @@ use diesel::{ sqlite::Sqlite, }; -use super::{schema::groups, DbConnection, EncryptedMessageStore}; +use super::{ + schema::{groups, groups::dsl}, + DbConnection, EncryptedMessageStore, +}; use crate::{impl_fetch, impl_store, StorageError}; /// The Group ID type. @@ -43,6 +46,29 @@ impl StoredGroup { } impl EncryptedMessageStore { + pub fn find_groups( + &self, + conn: &mut DbConnection, + allowed_states: Option>, + created_at_ns_gt: Option, + limit: Option, + ) -> Result, StorageError> { + let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed(); + + if let Some(allowed_states) = allowed_states { + query = query.filter(dsl::membership_state.eq_any(allowed_states)); + } + + if let Some(created_at_ns_gt) = created_at_ns_gt { + query = query.filter(dsl::created_at_ns.gt(created_at_ns_gt)); + } + + if let Some(limit) = limit { + query = query.limit(limit); + } + + Ok(query.load(conn)?) + } /// Updates group membership state pub fn update_group_membership>( &self, @@ -50,8 +76,6 @@ impl EncryptedMessageStore { id: GroupId, state: GroupMembershipState, ) -> Result<(), StorageError> { - use super::schema::groups::dsl; - diesel::update(dsl::groups.find(id.as_ref())) .set(dsl::membership_state.eq(state)) .execute(conn)?; @@ -99,11 +123,12 @@ where #[cfg(test)] pub(crate) mod tests { + use super::*; use crate::{ assert_ok, storage::encrypted_store::{schema::groups::dsl::groups, tests::with_store}, - utils::test::{rand_time, rand_vec}, + utils::{test::rand_vec, time::now_ns}, Fetch, Store, }; @@ -111,7 +136,7 @@ pub(crate) mod tests { pub fn generate_group(state: Option) -> StoredGroup { StoredGroup { id: rand_vec(), - created_at_ns: rand_time(), + created_at_ns: now_ns(), membership_state: state.unwrap_or(GroupMembershipState::Allowed), } } @@ -161,4 +186,39 @@ pub(crate) mod tests { ); }) } + + #[test] + fn test_find_groups() { + with_store(|store, mut conn| { + let test_group_1 = generate_group(Some(GroupMembershipState::Pending)); + test_group_1.store(&mut conn).unwrap(); + let test_group_2 = generate_group(Some(GroupMembershipState::Allowed)); + test_group_2.store(&mut conn).unwrap(); + + let all_results = store.find_groups(&mut conn, None, None, None).unwrap(); + assert_eq!(all_results.len(), 2); + + let pending_results = store + .find_groups( + &mut conn, + Some(vec![GroupMembershipState::Pending]), + None, + None, + ) + .unwrap(); + assert_eq!(pending_results[0].id, test_group_1.id); + assert_eq!(pending_results.len(), 1); + + // Offset and limit + let results_with_limit = store.find_groups(&mut conn, None, None, Some(1)).unwrap(); + assert_eq!(results_with_limit.len(), 1); + assert_eq!(results_with_limit[0].id, test_group_1.id); + + let results_with_created_at_ns_after = store + .find_groups(&mut conn, None, Some(test_group_1.created_at_ns), Some(1)) + .unwrap(); + assert_eq!(results_with_created_at_ns_after.len(), 1); + assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id); + }) + } }