diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 22906211f..c959f7471 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -301,28 +301,12 @@ mod tests { let bob = ClientBuilder::new_test_client(generate_local_wallet().into()).await; bob.register_identity().await.unwrap(); - let conn = &mut alice.store.conn().unwrap(); let alice_bob_group = alice.create_group().unwrap(); alice_bob_group .add_members_by_installation_id(vec![bob.installation_public_key()]) .await .unwrap(); - // Manually mark as committed - // TODO: Replace with working synchronization once we can add members end to end - let intents = alice - .store - .find_group_intents(conn, alice_bob_group.group_id.clone(), None, None) - .unwrap(); - let intent = intents.first().unwrap(); - // Set the intent to committed manually - alice - .store - .set_group_intent_committed(conn, intent.id) - .unwrap(); - - alice_bob_group.post_commit(conn).await.unwrap(); - let bob_received_groups = bob.sync_welcomes().await.unwrap(); assert_eq!(bob_received_groups.len(), 1); assert_eq!( diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 7d4bea6b9..c987dc82e 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -7,6 +7,8 @@ use intents::SendMessageIntentData; #[cfg(not(test))] use log::debug; use openmls::{ + framing::ProtocolMessage, + group::{GroupEpoch, MergePendingCommitError}, prelude::{ CredentialWithKey, CryptoConfig, GroupId, LeafNodeIndex, MlsGroup as OpenMlsGroup, MlsGroupConfig, MlsMessageIn, MlsMessageInBody, PrivateMessageIn, ProcessedMessage, @@ -78,6 +80,19 @@ pub enum MessageProcessingError { }, #[error("openmls process message error: {0}")] OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError), + #[error("merge pending commit: {0}")] + MergePendingCommit(#[from] openmls::group::MergePendingCommitError), + #[error("merge staged commit: {0}")] + MergeStagedCommit(#[from] openmls::group::MergeCommitError), + #[error( + "no pending commit to merge. group epoch is {group_epoch:?} and got {message_epoch:?}" + )] + NoPendingCommit { + message_epoch: GroupEpoch, + group_epoch: GroupEpoch, + }, + #[error("intent error: {0}")] + Intent(#[from] IntentError), #[error("storage error: {0}")] Storage(#[from] crate::storage::StorageError), #[error("tls deserialization: {0}")] @@ -211,13 +226,94 @@ where )) } + fn process_own_message( + &self, + conn: &mut DbConnection, + intent: StoredGroupIntent, + openmls_group: &mut OpenMlsGroup, + provider: &XmtpOpenMlsProvider, + message: ProtocolMessage, + envelope_timestamp_ns: u64, + ) -> Result<(), MessageProcessingError> { + if intent.state == IntentState::Committed { + return Ok(()); + } + debug!( + "[{}]processing own message for intent {} / {:?}", + self.client.account_address(), + intent.id, + intent.kind + ); + + match intent.kind { + IntentKind::AddMembers | IntentKind::RemoveMembers | IntentKind::KeyUpdate => { + // We don't get errors with merge_pending_commit when there are no commits to merge + if openmls_group.pending_commit().is_none() { + let message_epoch = message.epoch(); + let group_epoch = openmls_group.epoch(); + debug!( + "no pending commit to merge. Group epoch: {}. Message epoch: {}", + group_epoch, message_epoch + ); + self.client + .store + .set_group_intent_to_publish(conn, intent.id)?; + + return Err(MessageProcessingError::NoPendingCommit { + message_epoch, + group_epoch, + }); + } + debug!("[{}] merging pending commit", self.client.account_address()); + match openmls_group.merge_pending_commit(provider) { + Err(MergePendingCommitError::MlsGroupStateError(err)) => { + debug!("error merging commit: {}", err); + openmls_group.clear_pending_commit(); + self.client + .store + .set_group_intent_to_publish(conn, intent.id)?; + } + _ => (), + }; + // TOOD: Handle writing transcript messages for adding/removing members + } + IntentKind::SendMessage => { + let intent_data = SendMessageIntentData::from_bytes(intent.data.as_slice())?; + let group_id = openmls_group.group_id().as_slice(); + let decrypted_message_data = intent_data.message.as_slice(); + + StoredGroupMessage { + id: get_message_id(decrypted_message_data, group_id, envelope_timestamp_ns), + group_id: group_id.to_vec(), + decrypted_message_bytes: intent_data.message, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id: self.client.installation_public_key(), + sender_wallet_address: self.client.account_address(), + } + .store(conn)?; + } + }; + + self.client + .store + .set_group_intent_committed(conn, intent.id)?; + + Ok(()) + } + fn process_private_message( &self, + conn: &mut DbConnection, openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, envelope_timestamp_ns: u64, ) -> Result<(), MessageProcessingError> { + debug!( + "[{}] processing private message", + self.client.account_address() + ); let decrypted_message = openmls_group.process_message(provider, message)?; let (sender_account_address, sender_installation_id) = self.validate_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; @@ -236,7 +332,7 @@ where sender_installation_id, sender_wallet_address: sender_account_address, }; - message.store(&mut self.client.store.conn()?)?; + message.store(conn)?; } ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { // intentionally left blank. @@ -244,32 +340,67 @@ where ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { // intentionally left blank. } - ProcessedMessageContent::StagedCommitMessage(_commit_ptr) => { - // intentionally left blank. + ProcessedMessageContent::StagedCommitMessage(staged_commit) => { + debug!( + "[{}] received staged commit. Merging and clearing any pending commits", + self.client.account_address() + ); + openmls_group.merge_staged_commit(provider, *staged_commit)?; } } Ok(()) } + fn process_message( + &self, + conn: &mut DbConnection, + openmls_group: &mut OpenMlsGroup, + provider: &XmtpOpenMlsProvider, + envelope: &Envelope, + ) -> Result<(), MessageProcessingError> { + let mls_message_in = MlsMessageIn::tls_deserialize_exact(&envelope.message)?; + + let message = match mls_message_in.extract() { + MlsMessageInBody::PrivateMessage(message) => Ok(message), + other => Err(MessageProcessingError::UnsupportedMessageType( + discriminant(&other), + )), + }?; + + match self + .client + .store + .find_group_intent_by_payload_hash(conn, sha256(envelope.message.as_slice())) + { + // Intent with the payload hash matches + Ok(Some(intent)) => self.process_own_message( + conn, + intent, + openmls_group, + provider, + message.into(), + envelope.timestamp_ns, + ), + Err(err) => Err(MessageProcessingError::Storage(err)), + // No matching intent found + Ok(None) => self.process_private_message( + conn, + openmls_group, + provider, + message, + envelope.timestamp_ns, + ), + } + } + pub fn process_messages(&self, envelopes: Vec) -> Result<(), GroupError> { let provider = self.client.mls_provider(); let mut openmls_group = self.load_mls_group(&provider)?; + let conn = &mut self.client.store.conn()?; let receive_errors: Vec = envelopes .into_iter() .map(|envelope| -> Result<(), MessageProcessingError> { - let mls_message_in = MlsMessageIn::tls_deserialize_exact(&envelope.message)?; - - match mls_message_in.extract() { - MlsMessageInBody::PrivateMessage(message) => self.process_private_message( - &mut openmls_group, - &provider, - message, - envelope.timestamp_ns, - ), - other => Err(MessageProcessingError::UnsupportedMessageType( - discriminant(&other), - )), - } + self.process_message(conn, &mut openmls_group, &provider, &envelope) }) .filter(|result| result.is_err()) .map(|result| result.unwrap_err()) @@ -297,13 +428,14 @@ where } pub async fn send_message(&self, message: &[u8]) -> Result<(), GroupError> { - let mut conn = self.client.store.conn()?; + let conn = &mut self.client.store.conn()?; let intent_data: Vec = SendMessageIntentData::new(message.to_vec()).into(); let intent = NewGroupIntent::new(IntentKind::SendMessage, self.group_id.clone(), intent_data); - intent.store(&mut conn)?; + intent.store(conn)?; - self.publish_intents(&mut conn).await?; + // Skipping a full sync here and instead just firing and forgetting + self.publish_intents(conn).await?; Ok(()) } @@ -311,7 +443,7 @@ where &self, installation_ids: Vec>, ) -> Result<(), GroupError> { - let mut conn = self.client.store.conn()?; + let conn = &mut self.client.store.conn()?; let key_packages = self .client .get_key_packages_for_installation_ids(installation_ids) @@ -319,39 +451,41 @@ where let intent_data: Vec = AddMembersIntentData::new(key_packages).try_into()?; let intent = NewGroupIntent::new(IntentKind::AddMembers, self.group_id.clone(), intent_data); - intent.store(&mut conn)?; + intent.store(conn)?; - self.publish_intents(&mut conn).await?; - // ... sync state with the network - self.post_commit(&mut conn).await?; - - Ok(()) + self.sync(conn).await } pub async fn remove_members_by_installation_id( &self, installation_ids: Vec>, ) -> Result<(), GroupError> { - let mut conn = self.client.store.conn()?; + let conn = &mut self.client.store.conn()?; let intent_data: Vec = RemoveMembersIntentData::new(installation_ids).into(); let intent = NewGroupIntent::new( IntentKind::RemoveMembers, self.group_id.clone(), intent_data, ); - intent.store(&mut conn)?; - - self.publish_intents(&mut conn).await?; + intent.store(conn)?; - Ok(()) + self.sync(conn).await } pub async fn key_update(&self) -> Result<(), GroupError> { - let mut conn = self.client.store.conn()?; + let conn = &mut self.client.store.conn()?; let intent = NewGroupIntent::new(IntentKind::KeyUpdate, self.group_id.clone(), vec![]); - intent.store(&mut conn)?; + intent.store(conn)?; - self.publish_intents(&mut conn).await?; + self.sync(conn).await + } + + pub async fn sync(&self, conn: &mut DbConnection) -> Result<(), GroupError> { + self.publish_intents(conn).await?; + if let Err(e) = self.receive().await { + log::warn!("receive error {:?}", e); + } + self.post_commit(conn).await?; Ok(()) } @@ -538,10 +672,12 @@ fn build_group_config() -> MlsGroupConfig { #[cfg(test)] mod tests { - use openmls_traits::OpenMlsProvider; + use openmls::prelude::Member; use xmtp_cryptography::utils::generate_local_wallet; - use crate::{builder::ClientBuilder, groups::GroupError, utils::topic::get_welcome_topic}; + use crate::{ + builder::ClientBuilder, storage::group_intent::IntentState, utils::topic::get_welcome_topic, + }; #[tokio::test] async fn test_send_message() { @@ -566,14 +702,86 @@ mod tests { let wallet = generate_local_wallet(); let client = ClientBuilder::new_test_client(wallet.into()).await; let group = client.create_group().expect("create group"); - group.send_message(b"hello").await.expect("send message"); + let msg = b"hello"; + group.send_message(msg).await.expect("send message"); - let result = group.receive().await; - if let GroupError::ReceiveError(errors) = result.err().unwrap() { - assert_eq!(errors.len(), 1); - } else { - panic!("expected GroupError::ReceiveError") - } + group.receive().await.unwrap(); + // Check for messages + let messages = group.find_messages(None, None, None, None).unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages.first().unwrap().decrypted_message_bytes, msg); + } + + // Amal and Bola will both try and add Charlie from the same epoch. + // The group should resolve to a consistent state + #[tokio::test] + async fn test_add_member_conflict() { + let amal = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + let bola = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + let charlie = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + futures::future::join_all(vec![ + amal.register_identity(), + bola.register_identity(), + charlie.register_identity(), + ]) + .await; + + let amal_group = amal.create_group().unwrap(); + // Add bola + amal_group + .add_members_by_installation_id(vec![bola.installation_public_key()]) + .await + .unwrap(); + + // Get bola's version of the same group + let bola_groups = bola.sync_welcomes().await.unwrap(); + let bola_group = bola_groups.first().unwrap(); + + // Have amal and bola both invite charlie. + amal_group + .add_members_by_installation_id(vec![charlie.installation_public_key()]) + .await + .expect("failed to add charlie"); + bola_group + .add_members_by_installation_id(vec![charlie.installation_public_key()]) + .await + .unwrap(); + + amal_group.receive().await.expect_err("expected error"); + bola_group.receive().await.expect_err("expected error"); + + // Check Amal's MLS group state. + let amal_mls_group = amal_group.load_mls_group(&amal.mls_provider()).unwrap(); + let amal_members: Vec = amal_mls_group.members().collect(); + assert_eq!(amal_members.len(), 3); + + // Check Bola's MLS group state. + let bola_mls_group = bola_group.load_mls_group(&bola.mls_provider()).unwrap(); + let bola_members: Vec = bola_mls_group.members().collect(); + assert_eq!(bola_members.len(), 3); + + let amal_uncommitted_intents = amal + .store + .find_group_intents( + &mut amal.store.conn().unwrap(), + amal_group.group_id.clone(), + Some(vec![IntentState::ToPublish, IntentState::Published]), + None, + ) + .unwrap(); + assert_eq!(amal_uncommitted_intents.len(), 0); + + let bola_uncommitted_intents = bola + .store + .find_group_intents( + &mut bola.store.conn().unwrap(), + bola_group.group_id.clone(), + Some(vec![IntentState::ToPublish, IntentState::Published]), + None, + ) + .unwrap(); + // Bola should have one uncommitted intent for the failed attempt at adding Charlie, who is already in the group + assert_eq!(bola_uncommitted_intents.len(), 1); } #[tokio::test] @@ -621,7 +829,6 @@ mod tests { let client_2 = ClientBuilder::new_test_client(generate_local_wallet().into()).await; client_2.register_identity().await.unwrap(); - let provider = client_1.mls_provider(); let group = client_1.create_group().expect("create group"); group .add_members_by_installation_id(vec![client_2 @@ -649,24 +856,7 @@ mod tests { .await .expect("read topic"); - assert_eq!(messages.len(), 1); - // Now merge the commit and try again - let mut mls_group = group.load_mls_group(&provider).unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); - mls_group.save(provider.key_store()).unwrap(); - - group - .publish_intents(&mut client_1.store.conn().unwrap()) - .await - .unwrap(); - - let messages_after_second_try = client_1 - .api_client - .read_topic(topic.as_str(), 0) - .await - .expect("read topic"); - - assert_eq!(messages_after_second_try.len(), 2) + assert_eq!(messages.len(), 2); } #[tokio::test] @@ -685,7 +875,7 @@ mod tests { let mls_group = group.load_mls_group(&client.mls_provider()).unwrap(); let pending_commit = mls_group.pending_commit(); - assert!(pending_commit.is_some()); + assert!(pending_commit.is_none()); } #[tokio::test] @@ -694,7 +884,6 @@ mod tests { let client_2 = ClientBuilder::new_test_client(generate_local_wallet().into()).await; client_2.register_identity().await.unwrap(); let group = client.create_group().expect("create group"); - let conn = &mut client.store.conn().unwrap(); group .add_members_by_installation_id(vec![client_2 @@ -704,19 +893,6 @@ mod tests { .await .unwrap(); - let intents = client - .store - .find_group_intents(conn, group.group_id.clone(), None, None) - .unwrap(); - let intent = intents.first().unwrap(); - // Set the intent to committed manually - // TODO: Replace with working synchronization once we can add members end to end - client - .store - .set_group_intent_committed(conn, intent.id) - .unwrap(); - group.post_commit(conn).await.unwrap(); - // Check if the welcome was actually sent let welcome_topic = get_welcome_topic(&client_2.identity.installation_keys.to_public_vec()); let welcome_messages = client