Skip to content

Commit

Permalink
Add find_groups method (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas authored Nov 13, 2023
1 parent b24dcff commit 3294846
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 8 deletions.
31 changes: 31 additions & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ where
Ok(group)
}

pub fn find_groups(
&self,
allowed_states: Option<Vec<GroupMembershipState>>,
created_at_ns_gt: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
Ok(self
.store
.find_groups(
&mut self.store.conn()?,
allowed_states,
created_at_ns_gt,
limit,
)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
}

pub async fn register_identity(&self) -> Result<(), ClientError> {
// TODO: Mark key package as last_resort in creation
let last_resort_kp = self.identity.new_key_package(&self.mls_provider())?;
Expand Down Expand Up @@ -253,4 +272,16 @@ mod tests {
assert_eq!(key_package_2.wallet_address, wallet_address);
assert!(!(key_package_2.eq(key_package)));
}

#[tokio::test]
async fn test_find_groups() {
let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let group_1 = client.create_group().unwrap();
let group_2 = client.create_group().unwrap();

let groups = client.find_groups(None, None, None).unwrap();
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].group_id, group_1.group_id);
assert_eq!(groups[1].group_id, group_2.group_id);
}
}
11 changes: 8 additions & 3 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub enum GroupError {

pub struct MlsGroup<'c, ApiClient> {
pub group_id: Vec<u8>,
pub created_at_ns: i64,
client: &'c Client<ApiClient>,
}

Expand All @@ -66,8 +67,12 @@ where
ApiClient: XmtpApiClient + XmtpMlsClient,
{
// Creates a new group instance. Does not validate that the group exists in the DB
pub fn new(group_id: Vec<u8>, client: &'c Client<ApiClient>) -> Self {
Self { client, group_id }
pub fn new(client: &'c Client<ApiClient>, group_id: Vec<u8>, created_at_ns: i64) -> Self {
Self {
client,
group_id,
created_at_ns,
}
}

pub fn load_mls_group(
Expand Down Expand Up @@ -103,7 +108,7 @@ where
let stored_group = StoredGroup::new(group_id.clone(), now_ns(), membership_state);
stored_group.store(&mut conn)?;

Ok(Self::new(group_id, client))
Ok(Self::new(client, group_id, stored_group.created_at_ns))
}

pub async fn send_message(&self, message: &[u8]) -> Result<(), GroupError> {
Expand Down
70 changes: 65 additions & 5 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use diesel::{
sqlite::Sqlite,
};

use super::{schema::groups, DbConnection, EncryptedMessageStore};
use super::{
schema::{groups, groups::dsl},
DbConnection, EncryptedMessageStore,
};
use crate::{impl_fetch, impl_store, StorageError};

/// The Group ID type.
Expand Down Expand Up @@ -43,15 +46,36 @@ impl StoredGroup {
}

impl EncryptedMessageStore {
pub fn find_groups(
&self,
conn: &mut DbConnection,
allowed_states: Option<Vec<GroupMembershipState>>,
created_at_ns_gt: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<StoredGroup>, StorageError> {
let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed();

if let Some(allowed_states) = allowed_states {
query = query.filter(dsl::membership_state.eq_any(allowed_states));
}

if let Some(created_at_ns_gt) = created_at_ns_gt {
query = query.filter(dsl::created_at_ns.gt(created_at_ns_gt));
}

if let Some(limit) = limit {
query = query.limit(limit);
}

Ok(query.load(conn)?)
}
/// Updates group membership state
pub fn update_group_membership<GroupId: AsRef<[u8]>>(
&self,
conn: &mut DbConnection,
id: GroupId,
state: GroupMembershipState,
) -> Result<(), StorageError> {
use super::schema::groups::dsl;

diesel::update(dsl::groups.find(id.as_ref()))
.set(dsl::membership_state.eq(state))
.execute(conn)?;
Expand Down Expand Up @@ -99,19 +123,20 @@ where

#[cfg(test)]
pub(crate) mod tests {

use super::*;
use crate::{
assert_ok,
storage::encrypted_store::{schema::groups::dsl::groups, tests::with_store},
utils::test::{rand_time, rand_vec},
utils::{test::rand_vec, time::now_ns},
Fetch, Store,
};

/// Generate a test group
pub fn generate_group(state: Option<GroupMembershipState>) -> StoredGroup {
StoredGroup {
id: rand_vec(),
created_at_ns: rand_time(),
created_at_ns: now_ns(),
membership_state: state.unwrap_or(GroupMembershipState::Allowed),
}
}
Expand Down Expand Up @@ -161,4 +186,39 @@ pub(crate) mod tests {
);
})
}

#[test]
fn test_find_groups() {
with_store(|store, mut conn| {
let test_group_1 = generate_group(Some(GroupMembershipState::Pending));
test_group_1.store(&mut conn).unwrap();
let test_group_2 = generate_group(Some(GroupMembershipState::Allowed));
test_group_2.store(&mut conn).unwrap();

let all_results = store.find_groups(&mut conn, None, None, None).unwrap();
assert_eq!(all_results.len(), 2);

let pending_results = store
.find_groups(
&mut conn,
Some(vec![GroupMembershipState::Pending]),
None,
None,
)
.unwrap();
assert_eq!(pending_results[0].id, test_group_1.id);
assert_eq!(pending_results.len(), 1);

// Offset and limit
let results_with_limit = store.find_groups(&mut conn, None, None, Some(1)).unwrap();
assert_eq!(results_with_limit.len(), 1);
assert_eq!(results_with_limit[0].id, test_group_1.id);

let results_with_created_at_ns_after = store
.find_groups(&mut conn, None, Some(test_group_1.created_at_ns), Some(1))
.unwrap();
assert_eq!(results_with_created_at_ns_after.len(), 1);
assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id);
})
}
}

0 comments on commit 3294846

Please sign in to comment.