diff --git a/xmtp_id/src/associations/mod.rs b/xmtp_id/src/associations/mod.rs index 579e86ca7..387296f45 100644 --- a/xmtp_id/src/associations/mod.rs +++ b/xmtp_id/src/associations/mod.rs @@ -14,7 +14,7 @@ pub use self::hashes::generate_inbox_id; pub use self::member::{Member, MemberIdentifier, MemberKind}; pub use self::serialization::DeserializationError; pub use self::signature::{Signature, SignatureError, SignatureKind}; -pub use self::state::AssociationState; +pub use self::state::{AssociationState, AssociationStateDiff}; // Apply a single IdentityUpdate to an existing AssociationState pub fn apply_update( diff --git a/xmtp_id/src/associations/state.rs b/xmtp_id/src/associations/state.rs index b0d824877..47b62ec14 100644 --- a/xmtp_id/src/associations/state.rs +++ b/xmtp_id/src/associations/state.rs @@ -110,6 +110,15 @@ impl AssociationState { } } + /// Converts the AssociationState to a diff that represents all members + /// of the inbox at the current state. + pub fn as_diff(&self) -> AssociationStateDiff { + AssociationStateDiff { + new_members: self.members.keys().cloned().collect(), + removed_members: vec![], + } + } + pub fn new(account_address: String, nonce: u64) -> Self { let inbox_id = generate_inbox_id(&account_address, &nonce); let identifier = MemberIdentifier::Address(account_address.clone()); diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index c34c12afa..e098a5d31 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -217,11 +217,22 @@ where permissions, self.account_address(), ) - .map_err(|e| ClientError::Generic(format!("group create error {}", e)))?; + .map_err(|e| { + ClientError::Storage(StorageError::Store(format!("group create error {}", e))) + })?; Ok(group) } + pub fn create_sync_group(&self) -> Result, ClientError> { + log::info!("creating sync group"); + let sync_group = MlsGroup::create_and_insert_sync_group(self).map_err(|e| { + ClientError::Storage(StorageError::Store(format!("group create error {}", e))) + })?; + + Ok(sync_group) + } + /// Look up a group by its ID /// Returns a [`MlsGroup`] if the group exists, or an error if it does not pub fn group(&self, group_id: Vec) -> Result, ClientError> { @@ -229,7 +240,7 @@ where let stored_group: Option = conn.fetch(&group_id)?; match stored_group { Some(group) => Ok(MlsGroup::new(self, group.id, group.created_at_ns)), - None => Err(ClientError::Generic("group not found".to_string())), + None => Err(ClientError::Storage(StorageError::NotFound)), } } diff --git a/xmtp_mls/src/groups/group_membership.rs b/xmtp_mls/src/groups/group_membership.rs index f48a83c6d..c3fd84a62 100644 --- a/xmtp_mls/src/groups/group_membership.rs +++ b/xmtp_mls/src/groups/group_membership.rs @@ -105,17 +105,17 @@ mod tests { use super::GroupMembership; #[test] - fn equality_works() { + fn test_equality_works() { let inbox_id_1 = "inbox_1".to_string(); let sequence_id_1: u64 = 1; let mut member_map_1 = GroupMembership::new(); let mut member_map_2 = GroupMembership::new(); - member_map_1.add(inbox_id_1.clone(), sequence_id_1.clone()); + member_map_1.add(inbox_id_1.clone(), sequence_id_1); assert!(member_map_1.ne(&member_map_2)); - member_map_2.add(inbox_id_1.clone(), sequence_id_1.clone()); + member_map_2.add(inbox_id_1.clone(), sequence_id_1); assert!(member_map_1.eq(&member_map_2)); // Now change the sequence ID and make sure it is not equal again @@ -124,7 +124,7 @@ mod tests { } #[test] - fn diff() { + fn test_diff() { let mut initial_members = GroupMembership::new(); initial_members.add("inbox_1".into(), 1); initial_members.add("inbox_2".into(), 1); diff --git a/xmtp_mls/src/groups/message_history.rs b/xmtp_mls/src/groups/message_history.rs index 503c2f37b..ca8544ece 100644 --- a/xmtp_mls/src/groups/message_history.rs +++ b/xmtp_mls/src/groups/message_history.rs @@ -97,7 +97,7 @@ mod tests { async fn test_send_mesage_history_request() { let wallet = generate_local_wallet(); let client = ClientBuilder::new_test_client(&wallet).await; - let group = client.create_group(None).expect("create group"); + let group = client.create_sync_group().expect("create group"); let result = group.send_message_history_request().await; assert_ok!(result); @@ -107,14 +107,28 @@ mod tests { async fn test_send_mesage_history_reply() { let wallet = generate_local_wallet(); let client = ClientBuilder::new_test_client(&wallet).await; - let group = client.create_group(None).expect("create group"); - let expiry = now_ns() + 1_000; + let group = client.create_sync_group().expect("create sync group"); let request_id = new_request_id(); let url = "https://test.com/abc-123"; let backup_hash = b"ABC123".into(); + let expiry = now_ns() + 10_000; let reply = new_message_history_reply(&request_id, url, backup_hash, expiry); let result = group.send_message_history_reply(reply).await; assert_ok!(result); } + + #[tokio::test] + async fn test_request_reply_roundtrip() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + let group = amal_a.create_group(None).expect("create group"); + let add_members_result = group + .add_members_by_installation_id(vec![amal_b.installation_public_key()]) + .await; + assert_ok!(add_members_result); + + let _ = group.send_message_history_request().await; + } } diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 3581953d3..07c698fbc 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -297,13 +297,38 @@ where Self::create_from_welcome(client, provider, welcome, account_address) } - fn into_envelope(encoded_msg: &[u8], idempotency_key: &str) -> PlaintextEnvelope { - PlaintextEnvelope { - content: Some(Content::V1(V1 { - content: encoded_msg.to_vec(), - idempotency_key: idempotency_key.into(), - })), - } + pub(crate) fn create_and_insert_sync_group( + client: &'c Client, + ) -> Result, GroupError> { + let conn = client.store.conn()?; + let provider = XmtpOpenMlsProvider::new(&conn); + let protected_metadata = build_protected_metadata_extension( + &client.identity, + PreconfiguredPolicies::default().to_policy_set(), + )?; + let mutable_metadata = build_mutable_metadata_extension(DEFAULT_GROUP_NAME.to_string())?; + let group_config = build_group_config(protected_metadata, mutable_metadata)?; + let mut mls_group = OpenMlsGroup::new( + &provider, + &client.identity.installation_keys, + &group_config, + CredentialWithKey { + credential: client.identity.credential()?, + signature_key: client.identity.installation_keys.to_public_vec().into(), + }, + )?; + mls_group.save(provider.key_store())?; + + let group_id = mls_group.group_id().to_vec(); + let stored_group = + StoredGroup::new_sync_group(group_id.clone(), now_ns(), GroupMembershipState::Allowed); + + stored_group.store(provider.conn())?; + Ok(Self::new( + client, + stored_group.id, + stored_group.created_at_ns, + )) } pub async fn send_message(&self, message: &[u8]) -> Result, GroupError> { @@ -351,6 +376,15 @@ where Ok(message_id) } + fn into_envelope(encoded_msg: &[u8], idempotency_key: &str) -> PlaintextEnvelope { + PlaintextEnvelope { + content: Some(Content::V1(V1 { + content: encoded_msg.to_vec(), + idempotency_key: idempotency_key.into(), + })), + } + } + // Query the database for stored messages. Optionally filtered by time, kind, delivery_status // and limit pub fn find_messages( diff --git a/xmtp_mls/src/identity_updates.rs b/xmtp_mls/src/identity_updates.rs index a57f33501..c31403f8e 100644 --- a/xmtp_mls/src/identity_updates.rs +++ b/xmtp_mls/src/identity_updates.rs @@ -1,5 +1,8 @@ use prost::Message; -use xmtp_id::associations::{get_state, AssociationError, AssociationState, IdentityUpdate}; +use xmtp_id::associations::{ + apply_update, get_state, AssociationError, AssociationState, AssociationStateDiff, + IdentityUpdate, +}; use xmtp_proto::api_client::{XmtpIdentityClient, XmtpMlsClient}; use crate::{ @@ -75,6 +78,33 @@ where Ok(get_state(updates)?) } + + pub fn get_association_state_diff>( + &self, + conn: &'a DbConnection<'a>, + inbox_id: String, + starting_sequence_id: Option, + ending_sequence_id: Option, + ) -> Result { + let initial_state = self.get_association_state(conn, &inbox_id, starting_sequence_id)?; + if starting_sequence_id.is_none() { + return Ok(initial_state.as_diff()); + } + + let incremental_updates = conn + .get_identity_updates(inbox_id, starting_sequence_id, ending_sequence_id)? + .into_iter() + .map(|update| update.try_into()) + .collect::, AssociationError>>()?; + + let final_state = incremental_updates + .into_iter() + .try_fold(initial_state.clone(), |state, update| { + apply_update(state, update) + })?; + + Ok(initial_state.diff(&final_state)) + } } #[cfg(test)] diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 4fced93f9..c68e9c78b 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -64,7 +64,6 @@ impl StoredGroup { id: ID, created_at_ns: i64, membership_state: GroupMembershipState, - added_by_address: String, ) -> Self { Self { id, @@ -72,7 +71,7 @@ impl StoredGroup { membership_state, installations_last_checked: 0, purpose: Purpose::Sync, - added_by_address, + added_by_address: "".into(), } } } @@ -414,12 +413,7 @@ pub(crate) mod tests { let created_at_ns = now_ns(); let membership_state = GroupMembershipState::Allowed; - let sync_group = StoredGroup::new_sync_group( - id, - created_at_ns, - membership_state, - "placeholder_address".to_string(), - ); + let sync_group = StoredGroup::new_sync_group(id, created_at_ns, membership_state); let purpose = sync_group.purpose; assert_eq!(purpose, Purpose::Sync); }) diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index 1dd8c3632..8a4f82d45 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -170,6 +170,13 @@ impl DbConnection<'_> { })?) } + pub fn get_sync_group_messages(&self) -> Result, StorageError> { + let query = dsl::group_messages + .order(dsl::sent_at_ns.asc()) + .into_boxed(); + Ok(self.raw_query(|conn| query.load::(conn))?) + } + pub fn set_delivery_status_to_published>( &self, msg_id: &MessageId,