Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add find_groups method #324

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ where
Ok(group)
}

pub fn find_groups(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this directly on the Group struct, with &Client as an argument? Same with create_group(). Would prefer not to have every DB-related method replicated on the Client class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the FFI world it's going to make life easier to have all the group administration (list, create) live on the Client that we return to the native SDK. That way they don't have to deal with a bunch of different foreign object types and passing around references.

We could still do that mapping in the bindings, but it's messier. How about we leave it for now and we can always move things around later if we decide that it isn't as ergonomic?

&self,
allowed_states: Option<Vec<GroupMembershipState>>,
created_at_ns_gt: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
Ok(self
.store
.find_groups(
&mut self.store.conn()?,
allowed_states,
created_at_ns_gt,
limit,
)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
}

pub async fn register_identity(&self) -> Result<(), ClientError> {
// TODO: Mark key package as last_resort in creation
let last_resort_kp = self.identity.new_key_package(&self.mls_provider())?;
Expand Down Expand Up @@ -253,4 +272,16 @@ mod tests {
assert_eq!(key_package_2.wallet_address, wallet_address);
assert!(!(key_package_2.eq(key_package)));
}

#[tokio::test]
async fn test_find_groups() {
let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let group_1 = client.create_group().unwrap();
let group_2 = client.create_group().unwrap();

let groups = client.find_groups(None, None, None).unwrap();
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].group_id, group_1.group_id);
assert_eq!(groups[1].group_id, group_2.group_id);
}
}
11 changes: 8 additions & 3 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub enum GroupError {

pub struct MlsGroup<'c, ApiClient> {
pub group_id: Vec<u8>,
pub created_at_ns: i64,
neekolas marked this conversation as resolved.
Show resolved Hide resolved
client: &'c Client<ApiClient>,
}

Expand All @@ -65,8 +66,12 @@ where
ApiClient: XmtpApiClient + XmtpMlsClient,
{
// Creates a new group instance. Does not validate that the group exists in the DB
pub fn new(group_id: Vec<u8>, client: &'c Client<ApiClient>) -> Self {
Self { client, group_id }
pub fn new(client: &'c Client<ApiClient>, group_id: Vec<u8>, created_at_ns: i64) -> Self {
Self {
client,
group_id,
created_at_ns,
}
}

pub fn load_mls_group(
Expand Down Expand Up @@ -102,7 +107,7 @@ where
let stored_group = StoredGroup::new(group_id.clone(), now_ns(), membership_state);
stored_group.store(&mut conn)?;

Ok(Self::new(group_id, client))
Ok(Self::new(client, group_id, stored_group.created_at_ns))
}

pub async fn send_message(&self, message: &[u8]) -> Result<(), GroupError> {
Expand Down
70 changes: 65 additions & 5 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use diesel::{
sqlite::Sqlite,
};

use super::{schema::groups, DbConnection, EncryptedMessageStore};
use super::{
schema::{groups, groups::dsl},
DbConnection, EncryptedMessageStore,
};
use crate::{impl_fetch, impl_store, StorageError};

/// The Group ID type.
Expand Down Expand Up @@ -43,15 +46,36 @@ impl StoredGroup {
}

impl EncryptedMessageStore {
pub fn find_groups(
&self,
conn: &mut DbConnection,
allowed_states: Option<Vec<GroupMembershipState>>,
created_at_ns_gt: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<StoredGroup>, StorageError> {
let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed();

if let Some(allowed_states) = allowed_states {
query = query.filter(dsl::membership_state.eq_any(allowed_states));
}

if let Some(created_at_ns_gt) = created_at_ns_gt {
query = query.filter(dsl::created_at_ns.gt(created_at_ns_gt));
}

if let Some(limit) = limit {
query = query.limit(limit);
}

Ok(query.load(conn)?)
}
/// Updates group membership state
pub fn update_group_membership<GroupId: AsRef<[u8]>>(
&self,
conn: &mut DbConnection,
id: GroupId,
state: GroupMembershipState,
) -> Result<(), StorageError> {
use super::schema::groups::dsl;

diesel::update(dsl::groups.find(id.as_ref()))
.set(dsl::membership_state.eq(state))
.execute(conn)?;
Expand Down Expand Up @@ -99,19 +123,20 @@ where

#[cfg(test)]
pub(crate) mod tests {

use super::*;
use crate::{
assert_ok,
storage::encrypted_store::{schema::groups::dsl::groups, tests::with_store},
utils::test::{rand_time, rand_vec},
utils::{test::rand_vec, time::now_ns},
Fetch, Store,
};

/// Generate a test group
pub fn generate_group(state: Option<GroupMembershipState>) -> StoredGroup {
StoredGroup {
id: rand_vec(),
created_at_ns: rand_time(),
created_at_ns: now_ns(),
membership_state: state.unwrap_or(GroupMembershipState::Allowed),
}
}
Expand Down Expand Up @@ -161,4 +186,39 @@ pub(crate) mod tests {
);
})
}

#[test]
fn test_find_groups() {
with_store(|store, mut conn| {
let test_group_1 = generate_group(Some(GroupMembershipState::Pending));
test_group_1.store(&mut conn).unwrap();
let test_group_2 = generate_group(Some(GroupMembershipState::Allowed));
test_group_2.store(&mut conn).unwrap();

let all_results = store.find_groups(&mut conn, None, None, None).unwrap();
assert_eq!(all_results.len(), 2);

let pending_results = store
.find_groups(
&mut conn,
Some(vec![GroupMembershipState::Pending]),
None,
None,
)
.unwrap();
assert_eq!(pending_results[0].id, test_group_1.id);
assert_eq!(pending_results.len(), 1);

// Offset and limit
let results_with_limit = store.find_groups(&mut conn, None, None, Some(1)).unwrap();
assert_eq!(results_with_limit.len(), 1);
assert_eq!(results_with_limit[0].id, test_group_1.id);

let results_with_created_at_ns_after = store
.find_groups(&mut conn, None, Some(test_group_1.created_at_ns), Some(1))
.unwrap();
assert_eq!(results_with_created_at_ns_after.len(), 1);
assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id);
})
}
}
Loading