diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index e8b1068e9..339c9c23e 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1317,13 +1317,13 @@ impl FfiConversation { Ok(()) } - pub fn find_messages( + pub async fn find_messages( &self, opts: FfiListMessagesOptions, ) -> Result, GenericError> { let delivery_status = opts.delivery_status.map(|status| status.into()); let direction = opts.direction.map(|dir| dir.into()); - let kind = match self.conversation_type()? { + let kind = match self.conversation_type().await? { FfiConversationType::Group => None, FfiConversationType::Dm => Some(GroupMessageKind::Application), FfiConversationType::Sync => None, @@ -1445,7 +1445,7 @@ impl FfiConversation { pub fn group_image_url_square(&self) -> Result { let provider = self.inner.mls_provider()?; - Ok(self.inner.group_image_url_square(provider)?) + Ok(self.inner.group_image_url_square(&provider)?) } pub async fn update_group_description( @@ -1461,7 +1461,7 @@ impl FfiConversation { pub fn group_description(&self) -> Result { let provider = self.inner.mls_provider()?; - Ok(self.inner.group_description(provider)?) + Ok(self.inner.group_description(&provider)?) } pub async fn update_group_pinned_frame_url( @@ -1593,9 +1593,9 @@ impl FfiConversation { self.inner.added_by_inbox_id().map_err(Into::into) } - pub fn group_metadata(&self) -> Result, GenericError> { + pub async fn group_metadata(&self) -> Result, GenericError> { let provider = self.inner.mls_provider()?; - let metadata = self.inner.metadata(provider)?; + let metadata = self.inner.metadata(&provider).await?; Ok(Arc::new(FfiConversationMetadata { inner: Arc::new(metadata), })) @@ -1605,9 +1605,9 @@ impl FfiConversation { self.inner.dm_inbox_id().map_err(Into::into) } - pub fn conversation_type(&self) -> Result { + pub async fn conversation_type(&self) -> Result { let provider = self.inner.mls_provider()?; - let conversation_type = self.inner.conversation_type(&provider)?; + let conversation_type = self.inner.conversation_type(&provider).await?; Ok(conversation_type.into()) } } @@ -2104,6 +2104,9 @@ mod tests { .await .unwrap(); + let conn = client.inner_client.context().store().conn().unwrap(); + conn.register_triggers(); + register_client(&ffi_inbox_owner, &client).await; client } @@ -2595,6 +2598,8 @@ mod tests { async fn test_can_stream_group_messages_for_updates() { let alix = new_test_client().await; let bo = new_test_client().await; + let alix_provider = alix.inner_client.mls_provider().unwrap(); + let bo_provider = bo.inner_client.mls_provider().unwrap(); // Stream all group messages let message_callbacks = Arc::new(RustStreamCallback::default()); @@ -2627,14 +2632,21 @@ mod tests { .unwrap(); let bo_group = &bo_groups[0]; bo_group.sync().await.unwrap(); + + // alix published + processed group creation and name update + assert_eq!(alix_provider.conn_ref().intents_published(), 2); + assert_eq!(alix_provider.conn_ref().intents_deleted(), 2); + bo_group .update_group_name("Old Name2".to_string()) .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + assert_eq!(bo_provider.conn_ref().intents_published(), 1); alix_group.send(b"Hello there".to_vec()).await.unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + assert_eq!(alix_provider.conn_ref().intents_published(), 3); let dm = bo .conversations() @@ -2642,6 +2654,7 @@ mod tests { .await .unwrap(); dm.send(b"Hello again".to_vec()).await.unwrap(); + assert_eq!(bo_provider.conn_ref().intents_published(), 3); message_callbacks.wait_for_delivery(None).await.unwrap(); // Uncomment the following lines to add more group name updates @@ -2650,6 +2663,8 @@ mod tests { .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + message_callbacks.wait_for_delivery(None).await.unwrap(); + assert_eq!(bo_provider.conn_ref().intents_published(), 4); assert_eq!(message_callbacks.message_count(), 6); @@ -2693,9 +2708,11 @@ mod tests { let bo_messages1 = bo_group1 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages5 = bo_group5 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages1.len(), 0); assert_eq!(bo_messages5.len(), 0); @@ -2707,9 +2724,11 @@ mod tests { let bo_messages1 = bo_group1 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages5 = bo_group5 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages1.len(), 1); assert_eq!(bo_messages5.len(), 1); @@ -2828,11 +2847,13 @@ mod tests { alix_group.sync().await.unwrap(); let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); bo_group.sync().await.unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages.len(), 9); assert_eq!(alix_messages.len(), 10); @@ -3016,15 +3037,19 @@ mod tests { // Get the message count for all the clients let caro_messages = caro_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo2_messages = bo2_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(caro_messages.len(), 5); @@ -3080,9 +3105,11 @@ mod tests { let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let alix_can_see_bo_message = alix_messages @@ -3189,6 +3216,7 @@ mod tests { let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages.len(), 0); @@ -3204,8 +3232,12 @@ mod tests { let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); - assert!(bo_messages.first().unwrap().kind == FfiConversationMessageKind::MembershipChange); + assert_eq!( + bo_messages.first().unwrap().kind, + FfiConversationMessageKind::MembershipChange + ); assert_eq!(bo_messages.len(), 1); let bo_members = bo_group.list_members().await.unwrap(); @@ -3263,6 +3295,7 @@ mod tests { let bo_messages1 = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages1.len(), first_msg_check); @@ -3275,6 +3308,7 @@ mod tests { let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(alix_messages.len(), second_msg_check); @@ -3284,6 +3318,7 @@ mod tests { let bo_messages2 = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages2.len(), second_msg_check); assert_eq!(message_callbacks.message_count(), second_msg_check as u32); @@ -4529,15 +4564,19 @@ mod tests { // Get messages for both participants in both conversations let alix_dm_messages = alix_dm .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_dm_messages = bo_dm .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let alix_group_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_group_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); // Verify DM messages @@ -4658,6 +4697,7 @@ mod tests { .await .unwrap()[0] .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_dm_messages = client_b .conversations() @@ -4665,6 +4705,7 @@ mod tests { .await .unwrap()[0] .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(alix_dm_messages[0].content, "Hello in DM".as_bytes()); assert_eq!(bo_dm_messages[0].content, "Hello in DM".as_bytes()); diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index 3cd1328be..b520af793 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -161,7 +161,7 @@ impl Conversation { } #[napi] - pub fn find_messages(&self, opts: Option) -> Result> { + pub async fn find_messages(&self, opts: Option) -> Result> { let opts = opts.unwrap_or_default(); let group = MlsGroup::new( self.inner_client.clone(), @@ -171,6 +171,7 @@ impl Conversation { let provider = group.mls_provider().map_err(ErrorWrapper::from)?; let conversation_type = group .conversation_type(&provider) + .await .map_err(ErrorWrapper::from)?; let kind = match conversation_type { ConversationType::Group => None, @@ -250,7 +251,7 @@ impl Conversation { ); let admin_list = group - .admin_list(group.mls_provider().map_err(ErrorWrapper::from)?) + .admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(admin_list) @@ -265,7 +266,7 @@ impl Conversation { ); let super_admin_list = group - .super_admin_list(group.mls_provider().map_err(ErrorWrapper::from)?) + .super_admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(super_admin_list) @@ -451,7 +452,7 @@ impl Conversation { ); let group_name = group - .group_name(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_name(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_name) @@ -482,7 +483,7 @@ impl Conversation { ); let group_image_url_square = group - .group_image_url_square(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_image_url_square(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_image_url_square) @@ -513,7 +514,7 @@ impl Conversation { ); let group_description = group - .group_description(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_description(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_description) @@ -544,7 +545,7 @@ impl Conversation { ); let group_pinned_frame_url = group - .group_pinned_frame_url(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_pinned_frame_url(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_pinned_frame_url) @@ -587,7 +588,7 @@ impl Conversation { Ok( group - .is_active(group.mls_provider().map_err(ErrorWrapper::from)?) + .is_active(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?, ) } @@ -604,7 +605,7 @@ impl Conversation { } #[napi] - pub fn group_metadata(&self) -> Result { + pub async fn group_metadata(&self) -> Result { let group = MlsGroup::new( self.inner_client.clone(), self.group_id.clone(), @@ -612,7 +613,8 @@ impl Conversation { ); let metadata = group - .metadata(group.mls_provider().map_err(ErrorWrapper::from)?) + .metadata(&group.mls_provider().map_err(ErrorWrapper::from)?) + .await .map_err(ErrorWrapper::from)?; Ok(GroupMetadata { inner: metadata }) diff --git a/bindings_node/test/Conversations.test.ts b/bindings_node/test/Conversations.test.ts index ee40b431d..c6123bc6d 100644 --- a/bindings_node/test/Conversations.test.ts +++ b/bindings_node/test/Conversations.test.ts @@ -54,14 +54,16 @@ describe('Conversations', () => { updateGroupPinnedFrameUrlPolicy: 0, }) expect(group.addedByInboxId()).toBe(client1.inboxId()) - expect(group.findMessages().length).toBe(1) + expect((await group.findMessages()).length).toBe(1) const members = await group.listMembers() expect(members.length).toBe(2) const memberInboxIds = members.map((member) => member.inboxId) expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) - expect(group.groupMetadata().conversationType()).toBe('group') - expect(group.groupMetadata().creatorInboxId()).toBe(client1.inboxId()) + expect((await group.groupMetadata()).conversationType()).toBe('group') + expect((await group.groupMetadata()).creatorInboxId()).toBe( + client1.inboxId() + ) expect(group.consentState()).toBe(ConsentState.Allowed) @@ -198,14 +200,16 @@ describe('Conversations', () => { updateGroupPinnedFrameUrlPolicy: 0, }) expect(group.addedByInboxId()).toBe(client1.inboxId()) - expect(group.findMessages().length).toBe(0) + expect((await group.findMessages()).length).toBe(0) const members = await group.listMembers() expect(members.length).toBe(2) const memberInboxIds = members.map((member) => member.inboxId) expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) - expect(group.groupMetadata().conversationType()).toBe('dm') - expect(group.groupMetadata().creatorInboxId()).toBe(client1.inboxId()) + expect((await group.groupMetadata()).conversationType()).toBe('dm') + expect((await group.groupMetadata()).creatorInboxId()).toBe( + client1.inboxId() + ) expect(group.consentState()).toBe(ConsentState.Allowed) diff --git a/bindings_wasm/src/conversation.rs b/bindings_wasm/src/conversation.rs index af09d83b2..42ca2f06e 100644 --- a/bindings_wasm/src/conversation.rs +++ b/bindings_wasm/src/conversation.rs @@ -183,7 +183,10 @@ impl Conversation { } #[wasm_bindgen(js_name = findMessages)] - pub fn find_messages(&self, opts: Option) -> Result, JsError> { + pub async fn find_messages( + &self, + opts: Option, + ) -> Result, JsError> { let opts = opts.unwrap_or_default(); let group = self.to_mls_group(); let provider = group @@ -191,6 +194,7 @@ impl Conversation { .map_err(|e| JsError::new(&format!("{e}")))?; let conversation_type = group .conversation_type(&provider) + .await .map_err(|e| JsError::new(&format!("{e}")))?; let kind = match conversation_type { ConversationType::Group => None, @@ -241,7 +245,7 @@ impl Conversation { let group = self.to_mls_group(); let admin_list = group .admin_list( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -255,7 +259,7 @@ impl Conversation { let group = self.to_mls_group(); let super_admin_list = group .super_admin_list( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -401,7 +405,7 @@ impl Conversation { let group_name = group .group_name( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -431,7 +435,7 @@ impl Conversation { let group_image_url_square = group .group_image_url_square( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -458,7 +462,7 @@ impl Conversation { let group_description = group .group_description( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -488,7 +492,7 @@ impl Conversation { let group_pinned_frame_url = group .group_pinned_frame_url( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -508,7 +512,7 @@ impl Conversation { group .is_active( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -525,14 +529,15 @@ impl Conversation { } #[wasm_bindgen(js_name = groupMetadata)] - pub fn group_metadata(&self) -> Result { + pub async fn group_metadata(&self) -> Result { let group = self.to_mls_group(); let metadata = group .metadata( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) + .await .map_err(|e| JsError::new(&format!("{e}")))?; Ok(GroupMetadata { inner: metadata }) diff --git a/common/src/test.rs b/common/src/test.rs index c11c692cb..4cfb2442d 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -38,7 +38,7 @@ pub fn logger() { .from_env_lossy() }; - tracing_subscriber::registry() + let _ = tracing_subscriber::registry() // structured JSON logger only if STRUCTURED=true .with(is_structured.then(|| { tracing_subscriber::fmt::layer() @@ -61,7 +61,7 @@ pub fn logger() { }) .with_filter(filter()) })) - .init(); + .try_init(); }); } diff --git a/examples/cli/serializable.rs b/examples/cli/serializable.rs index c6ee793ce..545081638 100644 --- a/examples/cli/serializable.rs +++ b/examples/cli/serializable.rs @@ -31,11 +31,12 @@ impl SerializableGroup { let metadata = group .metadata( - group + &group .mls_provider() .expect("MLS Provider could not be created"), ) - .expect("could not load metadata"); + .await + .unwrap(); let permissions = group.permissions().expect("could not load permissions"); Self { diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6db9747da..02327f11b 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -873,7 +873,6 @@ where .map(|group| { let active_group_count = Arc::clone(&active_group_count); async move { - let mls_group = group.load_mls_group(provider)?; tracing::info!( inbox_id = self.inbox_id(), "[{}] syncing group", @@ -881,12 +880,15 @@ where ); tracing::info!( inbox_id = self.inbox_id(), - group_epoch = mls_group.epoch().as_u64(), - "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", - self.inbox_id(), - mls_group.epoch() + "[{}] syncing group", + self.inbox_id() ); - if mls_group.is_active() { + let is_active = group + .load_mls_group_with_lock_async(provider, |mls_group| async move { + Ok::(mls_group.is_active()) + }) + .await?; + if is_active { group.maybe_update_installations(provider, None).await?; group.sync_with_conn(provider).await?; diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index d756f204a..5c9863496 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -865,9 +865,10 @@ pub(crate) mod tests { }; let provider = group.client.mls_provider().unwrap(); - let mut openmls_group = group.load_mls_group(&provider).unwrap(); - let decrypted_message = openmls_group - .process_message(&provider, mls_message) + let decrypted_message = group + .load_mls_group_with_lock(&provider, |mut mls_group| { + Ok(mls_group.process_message(&provider, mls_message).unwrap()) + }) .unwrap(); let staged_commit = match decrypted_message.into_content() { diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 5ca53a40a..cfdf56e28 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -40,9 +40,9 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { - let openmls_group = self.load_mls_group(provider)?; - // TODO: Replace with try_into from extensions - let group_membership = extract_group_membership(openmls_group.extensions())?; + let group_membership = self.load_mls_group_with_lock(provider, |mls_group| { + Ok(extract_group_membership(mls_group.extensions())?) + })?; let requests = group_membership .members .into_iter() diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 8b243002c..87109e81a 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -2,11 +2,11 @@ use super::{ build_extensions_for_admin_lists_update, build_extensions_for_metadata_update, build_extensions_for_permissions_update, build_group_membership_extension, intents::{ - Installation, PostCommitAction, SendMessageIntentData, SendWelcomesAction, + Installation, IntentError, PostCommitAction, SendMessageIntentData, SendWelcomesAction, UpdateAdminListIntentData, UpdateGroupMembershipIntentData, UpdatePermissionIntentData, }, validated_commit::{extract_group_membership, CommitValidationError}, - GroupError, HmacKey, IntentError, MlsGroup, ScopedGroupClient, + GroupError, HmacKey, MlsGroup, ScopedGroupClient, }; use crate::{ configuration::{ @@ -183,7 +183,6 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "[{}] syncing group", self.client.inbox_id() ); @@ -191,10 +190,8 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), - "current epoch for [{}] in sync() is Epoch: [{}]", + "current epoch for [{}] in sync()", self.client.inbox_id(), - self.load_mls_group(&mls_provider)?.epoch() ); self.maybe_update_installations(&mls_provider, None).await?; @@ -358,265 +355,265 @@ where async fn process_own_message( &self, intent: StoredGroupIntent, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: ProtocolMessage, envelope: &GroupMessageV1, ) -> Result { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; - - if intent.state == IntentState::Committed { - return Ok(IntentState::Committed); - } - let message_epoch = message.epoch(); - let group_epoch = openmls_group.epoch(); - debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id, - intent.id, - intent.kind = %intent.kind, - "[{}]-[{}] processing own message for intent {} / {:?}, group epoch: {}, message_epoch: {}", - self.context().inbox_id(), - hex::encode(self.group_id.clone()), - intent.id, - intent.kind, - group_epoch, - message_epoch - ); + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; + + if intent.state == IntentState::Committed { + return Ok(IntentState::Committed); + } + let message_epoch = message.epoch(); + let group_epoch = mls_group.epoch(); + debug!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + group_id = hex::encode(&self.group_id), + msg_id, + intent.id, + intent.kind = %intent.kind, + "[{}]-[{}] processing own message for intent {} / {:?}, message_epoch: {}", + self.context().inbox_id(), + hex::encode(self.group_id.clone()), + intent.id, + intent.kind, + message_epoch + ); - let conn = provider.conn_ref(); - match intent.kind { - IntentKind::KeyUpdate - | IntentKind::UpdateGroupMembership - | IntentKind::UpdateAdminList - | IntentKind::MetadataUpdate - | IntentKind::UpdatePermission => { - if let Some(published_in_epoch) = intent.published_in_epoch { - let published_in_epoch_u64 = published_in_epoch as u64; - let group_epoch_u64 = group_epoch.as_u64(); - - if published_in_epoch_u64 != group_epoch_u64 { - tracing::warn!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), + let conn = provider.conn_ref(); + match intent.kind { + IntentKind::KeyUpdate + | IntentKind::UpdateGroupMembership + | IntentKind::UpdateAdminList + | IntentKind::MetadataUpdate + | IntentKind::UpdatePermission => { + if let Some(published_in_epoch) = intent.published_in_epoch { + let published_in_epoch_u64 = published_in_epoch as u64; + let group_epoch_u64 = group_epoch.as_u64(); + + if published_in_epoch_u64 != group_epoch_u64 { + tracing::warn!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id, - intent.id, - intent.kind = %intent.kind, - "Intent was published in epoch {} but group is currently in epoch {}", - published_in_epoch_u64, - group_epoch_u64 - ); - return Ok(IntentState::ToPublish); + msg_id, + intent.id, + intent.kind = %intent.kind, + "Intent was published in epoch {} but group is currently", + published_in_epoch_u64, + ); + return Ok(IntentState::ToPublish); + } } - } - - let pending_commit = if let Some(staged_commit) = intent.staged_commit { - decode_staged_commit(staged_commit)? - } else { - return Err(GroupMessageProcessingError::IntentMissingStagedCommit); - }; - tracing::info!( - "[{}] Validating commit for intent {}. Message timestamp: {}", - self.context().inbox_id(), - intent.id, - envelope_timestamp_ns - ); - - let maybe_validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - conn, - &pending_commit, - openmls_group, - ) - .await; + let pending_commit = if let Some(staged_commit) = intent.staged_commit { + decode_staged_commit(staged_commit)? + } else { + return Err(GroupMessageProcessingError::IntentMissingStagedCommit); + }; - if let Err(err) = maybe_validated_commit { - tracing::error!( - "Error validating commit for own message. Intent ID [{}]: {:?}", + tracing::info!( + "[{}] Validating commit for intent {}. Message timestamp: {}", + self.context().inbox_id(), intent.id, - err + envelope_timestamp_ns ); - // Return before merging commit since it does not pass validation - // Return OK so that the group intent update is still written to the DB - return Ok(IntentState::Error); - } - let validated_commit = maybe_validated_commit.expect("Checked for error"); + let maybe_validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + conn, + &pending_commit, + &mls_group, + ) + .await; - tracing::info!( - "[{}] merging pending commit for intent {}", - self.context().inbox_id(), - intent.id - ); - if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) { - tracing::error!("error merging commit: {}", err); - return Ok(IntentState::ToPublish); - } else { - // If no error committing the change, write a transcript message - self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; - } - } - IntentKind::SendMessage => { - if !Self::is_valid_epoch( - self.context().inbox_id(), - intent.id, - group_epoch, - message_epoch, - MAX_PAST_EPOCHS, - ) { - return Ok(IntentState::ToPublish); + if let Err(err) = maybe_validated_commit { + tracing::error!( + "Error validating commit for own message. Intent ID [{}]: {:?}", + intent.id, + err + ); + // Return before merging commit since it does not pass validation + // Return OK so that the group intent update is still written to the DB + return Ok(IntentState::Error); + } + + let validated_commit = maybe_validated_commit.expect("Checked for error"); + + tracing::info!( + "[{}] merging pending commit for intent {}", + self.context().inbox_id(), + intent.id + ); + if let Err(err) = mls_group.merge_staged_commit(&provider, pending_commit) { + tracing::error!("error merging commit: {}", err); + return Ok(IntentState::ToPublish); + } else { + // If no error committing the change, write a transcript message + self.save_transcript_message( + conn, + validated_commit, + envelope_timestamp_ns, + )?; + } } - if let Some(id) = intent.message_id()? { - conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + IntentKind::SendMessage => { + if !Self::is_valid_epoch( + self.context().inbox_id(), + intent.id, + group_epoch, + message_epoch, + MAX_PAST_EPOCHS, + ) { + return Ok(IntentState::ToPublish); + } + if let Some(id) = intent.message_id()? { + conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + } } - } - }; + }; - Ok(IntentState::Committed) + Ok(IntentState::Committed) + }) + .await } #[tracing::instrument(level = "trace", skip_all)] async fn process_external_message( &self, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, envelope: &GroupMessageV1, ) -> Result<(), GroupMessageProcessingError> { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; - let decrypted_message = openmls_group.process_message(provider, message)?; - let (sender_inbox_id, sender_installation_id) = - extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; + let decrypted_message = mls_group.process_message(provider, message)?; + let (sender_inbox_id, sender_installation_id) = + extract_message_sender(&mut mls_group, &decrypted_message, envelope_timestamp_ns)?; - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch = decrypted_message.epoch().as_u64(), - msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), - msg_id, - "[{}] extracted sender inbox id: {}", - self.client.inbox_id(), - sender_inbox_id - ); + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(),sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch = decrypted_message.epoch().as_u64(), + msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), + msg_id, + "[{}] extracted sender inbox id: {}", + self.client.inbox_id(), + sender_inbox_id + ); - let (msg_epoch, msg_group_id) = ( - decrypted_message.epoch().as_u64(), - hex::encode(decrypted_message.group_id().as_slice()), - ); - match decrypted_message.into_content() { - ProcessedMessageContent::ApplicationMessage(application_message) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] decoding application message", - self.context().inbox_id() - ); - let message_bytes = application_message.into_bytes(); - - let mut bytes = Bytes::from(message_bytes.clone()); - let envelope = PlaintextEnvelope::decode(&mut bytes)?; - - match envelope.content { - Some(Content::V1(V1 { - idempotency_key, - content, - })) => { - let message_id = - calculate_message_id(&self.group_id, &content, &idempotency_key); - StoredGroupMessage { - id: message_id, - group_id: self.group_id.clone(), - decrypted_message_bytes: content, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, + let (msg_epoch, msg_group_id) = ( + decrypted_message.epoch().as_u64(), + hex::encode(decrypted_message.group_id().as_slice()), + ); + match decrypted_message.into_content() { + ProcessedMessageContent::ApplicationMessage(application_message) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + installation_id = %self.client.installation_id(),group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] decoding application message", + self.context().inbox_id() + ); + let message_bytes = application_message.into_bytes(); + + let mut bytes = Bytes::from(message_bytes.clone()); + let envelope = PlaintextEnvelope::decode(&mut bytes)?; + + match envelope.content { + Some(Content::V1(V1 { + idempotency_key, + content, + })) => { + let message_id = + calculate_message_id(&self.group_id, &content, &idempotency_key); + StoredGroupMessage { + id: message_id, + group_id: self.group_id.clone(), + decrypted_message_bytes: content, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())? } - .store_or_ignore(provider.conn_ref())? - } - Some(Content::V2(V2 { - idempotency_key, - message_type, - })) => { - match message_type { - Some(MessageType::DeviceSyncRequest(history_request)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Request(history_request); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the request message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id: sender_inbox_id.clone(), - delivery_status: DeliveryStatus::Published, + Some(Content::V2(V2 { + idempotency_key, + message_type, + })) => { + match message_type { + Some(MessageType::DeviceSyncRequest(history_request)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Request(history_request); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the request message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id: sender_inbox_id.clone(), + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; + + tracing::info!("Received a history request."); + let _ = self.client.local_events().send(LocalEvents::SyncMessage( + SyncMessage::Request { message_id }, + )); } - .store_or_ignore(provider.conn_ref())?; - tracing::info!("Received a history request."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Request { message_id }, - )); - } - - Some(MessageType::DeviceSyncReply(history_reply)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Reply(history_reply); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the reply message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, - } - .store_or_ignore(provider.conn_ref())?; + Some(MessageType::DeviceSyncReply(history_reply)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Reply(history_reply); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the reply message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; tracing::info!("Received a history reply."); let _ = self.client.local_events().send(LocalEvents::SyncMessage( @@ -641,70 +638,68 @@ where return Err(GroupMessageProcessingError::InvalidPayload); } } + } + None => return Err(GroupMessageProcessingError::InvalidPayload), } - None => return Err(GroupMessageProcessingError::InvalidPayload), } - } - ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::StagedCommitMessage(staged_commit) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - installation_id = %self.client.installation_id(), - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] received staged commit. Merging and clearing any pending commits", - self.context().inbox_id() - ); + ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::StagedCommitMessage(staged_commit) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(),sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] received staged commit. Merging and clearing any pending commits", + self.context().inbox_id() + ); - let sc = *staged_commit; + let sc = *staged_commit; - // Validate the commit - let validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - provider.conn_ref(), - &sc, - openmls_group, - ) - .await?; - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - installation_id = %self.client.installation_id(), - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] staged commit is valid, will attempt to merge", - self.context().inbox_id() - ); - openmls_group.merge_staged_commit(provider, sc)?; - self.save_transcript_message( - provider.conn_ref(), - validated_commit, - envelope_timestamp_ns, - )?; - } - }; + // Validate the commit + let validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + provider.conn_ref(), + &sc, + &mls_group, + ) + .await?; + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(),sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] staged commit is valid, will attempt to merge", + self.context().inbox_id() + ); + mls_group.merge_staged_commit(provider, sc)?; + self.save_transcript_message( + provider.conn_ref(), + validated_commit, + envelope_timestamp_ns, + )?; + } + }; - Ok(()) + Ok(()) + }).await } #[tracing::instrument(level = "trace", skip_all)] pub(super) async fn process_message( &self, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, envelope: &GroupMessageV1, allow_epoch_increment: bool, @@ -728,7 +723,6 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, "Processing envelope with hash {:?}", hex::encode(sha256(envelope.data.as_slice())) @@ -742,7 +736,6 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, intent_id, intent.kind = %intent.kind, @@ -752,7 +745,7 @@ where intent_id ); match self - .process_own_message(intent, openmls_group, provider, message.into(), envelope) + .process_own_message(intent, provider, message.into(), envelope) .await? { IntentState::ToPublish => { @@ -777,13 +770,12 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, "client [{}] is about to process external envelope [{}]", self.client.inbox_id(), envelope.id ); - self.process_external_message(openmls_group, provider, message, envelope) + self.process_external_message(provider, message, envelope) .await } Err(err) => Err(GroupMessageProcessingError::Storage(err)), @@ -795,7 +787,6 @@ where &self, provider: &XmtpOpenMlsProvider, envelope: &GroupMessage, - openmls_group: &mut OpenMlsGroup, ) -> Result<(), GroupMessageProcessingError> { let msgv1 = match &envelope.version { Some(GroupMessageVersion::V1(value)) => value, @@ -811,7 +802,6 @@ where let last_cursor = provider .conn_ref() .get_last_cursor_for_id(&self.group_id, message_entity_kind)?; - tracing::info!("### last cursor --> [{:?}]", last_cursor); let should_skip_message = last_cursor > msgv1.id as i64; if should_skip_message { tracing::info!( @@ -837,7 +827,7 @@ where if !is_updated { return Err(ProcessIntentError::AlreadyProcessed(*cursor).into()); } - self.process_message(openmls_group, provider, msgv1, true).await?; + self.process_message(provider, msgv1, true).await?; Ok::<_, GroupMessageProcessingError>(()) }).await .inspect(|_| { @@ -865,16 +855,11 @@ where messages: Vec, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - let mut openmls_group = self.load_mls_group(provider)?; - let mut receive_errors: Vec = vec![]; for message in messages.into_iter() { let result = retry_async!( Retry::default(), - (async { - self.consume_message(provider, &message, &mut openmls_group) - .await - }) + (async { self.consume_message(provider, &message).await }) ); if let Err(e) = result { let is_retryable = e.is_retryable(); @@ -969,103 +954,104 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - let mut openmls_group = self.load_mls_group(provider)?; + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let intents = provider.conn_ref().find_group_intents( + self.group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + )?; + + for intent in intents { + let result = retry_async!( + Retry::default(), + (async { + self.get_publish_intent_data(provider, &mut mls_group, &intent) + .await + }) + ); - let intents = provider.conn_ref().find_group_intents( - self.group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - )?; + match result { + Err(err) => { + tracing::error!(error = %err, "error getting publish intent data {:?}", err); + if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { + tracing::error!( + intent.id, + intent.kind = %intent.kind, + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(),group_id = hex::encode(&self.group_id), + "intent {} has reached max publish attempts", intent.id); + // TODO: Eventually clean up errored attempts + provider + .conn_ref() + .set_group_intent_error_and_fail_msg(&intent)?; + } else { + provider + .conn_ref() + .increment_intent_publish_attempt_count(intent.id)?; + } - for intent in intents { - let result = retry_async!( - Retry::default(), - (async { - self.get_publish_intent_data(provider, &mut openmls_group, &intent) - .await - }) - ); + return Err(err); + } + Ok(Some(PublishIntentData { + payload_to_publish, + post_commit_action, + staged_commit, + })) => { + let payload_slice = payload_to_publish.as_slice(); + let has_staged_commit = staged_commit.is_some(); + provider.conn_ref().set_group_intent_published( + intent.id, + sha256(payload_slice), + post_commit_action, + staged_commit, + mls_group.epoch().as_u64() as i64, + )?; + tracing::debug!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + intent.id, + intent.kind = %intent.kind, + group_id = hex::encode(&self.group_id), + "client [{}] set stored intent [{}] to state `published`", + self.client.inbox_id(), + intent.id + ); - match result { - Err(err) => { - tracing::error!(error = %err, "error getting publish intent data {:?}", err); - if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { - tracing::error!( + let messages = self.prepare_group_messages(vec![payload_slice])?;self.client + .api() + .send_group_messages(messages) + .await?; + + tracing::info!( intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - "intent {} has reached max publish attempts", intent.id); - // TODO: Eventually clean up errored attempts - provider - .conn_ref() - .set_group_intent_error_and_fail_msg(&intent)?; - } else { - provider - .conn_ref() - .increment_intent_publish_attempt_count(intent.id)?; + "[{}] published intent [{}] of type [{}]", + self.client.inbox_id(), + intent.id, + intent.kind + ); + if has_staged_commit { + tracing::info!("Commit sent. Stopping further publishes for this round"); + return Ok(()); + } } - - return Err(err); - } - Ok(Some(PublishIntentData { - payload_to_publish, - post_commit_action, - staged_commit, - })) => { - let payload_slice = payload_to_publish.as_slice(); - let has_staged_commit = staged_commit.is_some(); - provider.conn_ref().set_group_intent_published( - intent.id, - sha256(payload_slice), - post_commit_action, - staged_commit, - openmls_group.epoch().as_u64() as i64, - )?; - tracing::debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - intent.id, - intent.kind = %intent.kind, - group_id = hex::encode(&self.group_id), - "client [{}] set stored intent [{}] to state `published`", - self.client.inbox_id(), - intent.id - ); - - let messages = self.prepare_group_messages(vec![payload_slice])?; - self.client.api().send_group_messages(messages).await?; - - tracing::info!( - intent.id, - intent.kind = %intent.kind, - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - "[{}] published intent [{}] of type [{}]", - self.client.inbox_id(), - intent.id, - intent.kind - ); - if has_staged_commit { - tracing::info!("Commit sent. Stopping further publishes for this round"); - return Ok(()); + Ok(None) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + "Skipping intent because no publish data returned" + ); + let deleter: &dyn Delete = provider.conn_ref(); + deleter.delete(intent.id)?; } } - Ok(None) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - "Skipping intent because no publish data returned" - ); - let deleter: &dyn Delete = provider.conn_ref(); - deleter.delete(intent.id)?; - } } - } - Ok(()) + Ok(()) + }).await } // Takes a StoredGroupIntent and returns the payload and post commit data as a tuple @@ -1222,10 +1208,7 @@ where update_interval_ns: Option, ) -> Result<(), GroupError> { // determine how long of an interval in time to use before updating list - let interval_ns = match update_interval_ns { - Some(val) => val, - None => SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, - }; + let interval_ns = update_interval_ns.unwrap_or(SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); let now_ns = xmtp_common::time::now_ns(); let last_ns = provider @@ -1292,58 +1275,59 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - let mls_group = self.load_mls_group(provider)?; - let existing_group_membership = extract_group_membership(mls_group.extensions())?; - - // TODO:nm prevent querying for updates on members who are being removed - let mut inbox_ids = existing_group_membership.inbox_ids(); - inbox_ids.extend_from_slice(inbox_ids_to_add); - let conn = provider.conn_ref(); - // Load any missing updates from the network - load_identity_updates(self.client.api(), conn, &inbox_ids).await?; - - let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; - - // Get a list of all inbox IDs that have increased sequence_id for the group - let changed_inbox_ids = - inbox_ids - .iter() - .try_fold(HashMap::new(), |mut updates, inbox_id| { - match ( - latest_sequence_id_map.get(inbox_id as &str), - existing_group_membership.get(inbox_id), - ) { - // This is an update. We have a new sequence ID and an existing one - (Some(latest_sequence_id), Some(current_sequence_id)) => { - let latest_sequence_id_u64 = *latest_sequence_id as u64; - if latest_sequence_id_u64.gt(current_sequence_id) { - updates.insert(inbox_id.to_string(), latest_sequence_id_u64); + self.load_mls_group_with_lock_async(provider, |mls_group| async move { + let existing_group_membership = extract_group_membership(mls_group.extensions())?; + // TODO:nm prevent querying for updates on members who are being removed + let mut inbox_ids = existing_group_membership.inbox_ids(); + inbox_ids.extend_from_slice(inbox_ids_to_add); + let conn = provider.conn_ref(); + // Load any missing updates from the network + load_identity_updates(self.client.api(), conn, &inbox_ids).await?; + + let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; + + // Get a list of all inbox IDs that have increased sequence_id for the group + let changed_inbox_ids = + inbox_ids + .iter() + .try_fold(HashMap::new(), |mut updates, inbox_id| { + match ( + latest_sequence_id_map.get(inbox_id as &str), + existing_group_membership.get(inbox_id), + ) { + // This is an update. We have a new sequence ID and an existing one + (Some(latest_sequence_id), Some(current_sequence_id)) => { + let latest_sequence_id_u64 = *latest_sequence_id as u64; + if latest_sequence_id_u64.gt(current_sequence_id) { + updates.insert(inbox_id.to_string(), latest_sequence_id_u64); + } + } + // This is for new additions to the group + (Some(latest_sequence_id), _) => { + // This is the case for net new members to the group + updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); + } + (_, _) => { + tracing::warn!( + "Could not find existing sequence ID for inbox {}", + inbox_id + ); + return Err(GroupError::MissingSequenceId); } } - // This is for new additions to the group - (Some(latest_sequence_id), _) => { - // This is the case for net new members to the group - updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); - } - (_, _) => { - tracing::warn!( - "Could not find existing sequence ID for inbox {}", - inbox_id - ); - return Err(GroupError::MissingSequenceId); - } - } - - Ok(updates) - })?; - Ok(UpdateGroupMembershipIntentData::new( - changed_inbox_ids, - inbox_ids_to_remove - .iter() - .map(|s| s.to_string()) - .collect::>(), - )) + Ok(updates) + })?; + + Ok(UpdateGroupMembershipIntentData::new( + changed_inbox_ids, + inbox_ids_to_remove + .iter() + .map(|s| s.to_string()) + .collect::>(), + )) + }) + .await } /** diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 9213c38ef..35662bc7d 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -38,7 +38,6 @@ use tokio::sync::Mutex; use self::device_sync::DeviceSyncError; pub use self::group_permissions::PreconfiguredPolicies; -pub use self::intents::{AddressesOrInstallationIds, IntentError}; use self::scoped_client::ScopedGroupClient; use self::{ group_membership::GroupMembership, @@ -56,12 +55,11 @@ use self::{ use self::{ group_metadata::{GroupMetadata, GroupMetadataError}, group_permissions::PolicySet, + intents::IntentError, validated_commit::CommitValidationError, }; -use std::{collections::HashSet, sync::Arc}; +use crate::storage::StorageError; use xmtp_common::time::now_ns; -use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; -use xmtp_id::{InboxId, InboxIdRef}; use xmtp_proto::xmtp::mls::{ api::v1::{ group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, @@ -91,13 +89,18 @@ use crate::{ group::{ConversationType, GroupMembershipState, StoredGroup}, group_intent::IntentKind, group_message::{DeliveryStatus, GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, - sql_key_store, StorageError, + sql_key_store, }, subscriptions::{LocalEventError, LocalEvents}, utils::id::calculate_message_id, xmtp_openmls_provider::XmtpOpenMlsProvider, - Store, + Store, MLS_COMMIT_LOCK, }; +use std::future::Future; +use std::{collections::HashSet, sync::Arc}; +use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; +use xmtp_id::{InboxId, InboxIdRef}; + use xmtp_common::retry::RetryableError; #[derive(Debug, Error)] @@ -203,6 +206,10 @@ pub enum GroupError { IntentNotCommitted, #[error(transparent)] ProcessIntent(#[from] ProcessIntentError), + #[error("Failed to load lock")] + LockUnavailable, + #[error("Failed to acquire semaphore lock")] + LockFailedToAcquire, } impl RetryableError for GroupError { @@ -229,6 +236,8 @@ impl RetryableError for GroupError { Self::MessageHistory(err) => err.is_retryable(), Self::ProcessIntent(err) => err.is_retryable(), Self::LocalEvent(err) => err.is_retryable(), + Self::LockUnavailable => true, + Self::LockFailedToAcquire => true, Self::SyncFailedToWait => true, Self::GroupNotFound | Self::GroupMetadata(_) @@ -332,16 +341,55 @@ impl MlsGroup { // Load the stored OpenMLS group from the OpenMLS provider's keystore #[tracing::instrument(level = "trace", skip_all)] - pub(crate) fn load_mls_group( + pub(crate) fn load_mls_group_with_lock( &self, provider: impl OpenMlsProvider, - ) -> Result { + operation: F, + ) -> Result + where + F: FnOnce(OpenMlsGroup) -> Result, + { + // Get the group ID for locking + let group_id = self.group_id.clone(); + + // Acquire the lock synchronously using blocking_lock + let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone()); + // Load the MLS group let mls_group = OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) .map_err(|_| GroupError::GroupNotFound)? .ok_or(GroupError::GroupNotFound)?; - Ok(mls_group) + // Perform the operation with the MLS group + operation(mls_group) + } + + // Load the stored OpenMLS group from the OpenMLS provider's keystore + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) async fn load_mls_group_with_lock_async( + &self, + provider: &XmtpOpenMlsProvider, + operation: F, + ) -> Result + where + F: FnOnce(OpenMlsGroup) -> Fut, + Fut: Future>, + E: From + From, + { + // Get the group ID for locking + let group_id = self.group_id.clone(); + + // Acquire the lock asynchronously + let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await; + + // Load the MLS group + let mls_group = + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) + .map_err(crate::StorageError::from)? + .ok_or(crate::StorageError::NotFound("Group Not Found".into()))?; + + // Perform the operation with the MLS group + operation(mls_group).await.map_err(Into::into) } // Create a new group and save it to the DB @@ -848,7 +896,7 @@ impl MlsGroup { /// to perform these updates. pub async fn update_group_name(&self, group_name: String) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -866,7 +914,7 @@ impl MlsGroup { metadata_field: Option, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } if permission_update_type == PermissionUpdateType::UpdateMetadata @@ -888,7 +936,7 @@ impl MlsGroup { } /// Retrieves the group name from the group's mutable metadata extension. - pub fn group_name(&self, provider: impl OpenMlsProvider) -> Result { + pub fn group_name(&self, provider: &XmtpOpenMlsProvider) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata .attributes @@ -907,7 +955,7 @@ impl MlsGroup { group_description: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -917,7 +965,7 @@ impl MlsGroup { self.sync_until_intent_resolved(&provider, intent.id).await } - pub fn group_description(&self, provider: impl OpenMlsProvider) -> Result { + pub fn group_description(&self, provider: &XmtpOpenMlsProvider) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata .attributes @@ -936,7 +984,7 @@ impl MlsGroup { group_image_url_square: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -950,7 +998,7 @@ impl MlsGroup { /// Retrieves the image URL (square) of the group from the group's mutable metadata extension. pub fn group_image_url_square( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata @@ -969,7 +1017,7 @@ impl MlsGroup { pinned_frame_url: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -981,7 +1029,7 @@ impl MlsGroup { pub fn group_pinned_frame_url( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata @@ -996,7 +1044,7 @@ impl MlsGroup { } /// Retrieves the admin list of the group from the group's mutable metadata extension. - pub fn admin_list(&self, provider: impl OpenMlsProvider) -> Result, GroupError> { + pub fn admin_list(&self, provider: &XmtpOpenMlsProvider) -> Result, GroupError> { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.admin_list) } @@ -1004,7 +1052,7 @@ impl MlsGroup { /// Retrieves the super admin list of the group from the group's mutable metadata extension. pub fn super_admin_list( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.super_admin_list) @@ -1014,7 +1062,7 @@ impl MlsGroup { pub fn is_admin( &self, inbox_id: String, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.admin_list.contains(&inbox_id)) @@ -1024,18 +1072,18 @@ impl MlsGroup { pub fn is_super_admin( &self, inbox_id: String, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.super_admin_list.contains(&inbox_id)) } /// Retrieves the conversation type of the group from the group's metadata extension. - pub fn conversation_type( + pub async fn conversation_type( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { - let metadata = self.metadata(provider)?; + let metadata = self.metadata(provider).await?; Ok(metadata.conversation_type) } @@ -1046,7 +1094,7 @@ impl MlsGroup { inbox_id: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_action_type = match action_type { @@ -1127,35 +1175,40 @@ impl MlsGroup { self.sync_until_intent_resolved(&provider, intent.id).await } - /// Checks if the the current user is active in the group. + /// Checks if the current user is active in the group. /// /// If the current user has been kicked out of the group, `is_active` will return `false` - pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - let mls_group = self.load_mls_group(provider)?; - Ok(mls_group.is_active()) + pub fn is_active(&self, provider: &XmtpOpenMlsProvider) -> Result { + self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) } /// Get the `GroupMetadata` of the group. - pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - let mls_group = self.load_mls_group(provider)?; - Ok(extract_group_metadata(&mls_group)?) + pub async fn metadata( + &self, + provider: &XmtpOpenMlsProvider, + ) -> Result { + self.load_mls_group_with_lock_async(provider, |mls_group| { + futures::future::ready(extract_group_metadata(&mls_group).map_err(Into::into)) + }) + .await } /// Get the `GroupMutableMetadata` of the group. pub fn mutable_metadata( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { - let mls_group = &self.load_mls_group(provider)?; - - Ok(mls_group.try_into()?) + self.load_mls_group_with_lock(provider, |mls_group| { + Ok(GroupMutableMetadata::try_from(&mls_group)?) + }) } pub fn permissions(&self) -> Result { let provider = self.mls_provider()?; - let mls_group = self.load_mls_group(&provider)?; - Ok(extract_group_permissions(&mls_group)?) + self.load_mls_group_with_lock(&provider, |mls_group| { + Ok(extract_group_permissions(&mls_group)?) + }) } /// Used for testing that dm group validation works as expected. @@ -1607,7 +1660,6 @@ pub(crate) mod tests { use diesel::connection::SimpleConnection; use futures::future::join_all; - use openmls::prelude::Member; use prost::Message; use std::sync::Arc; use wasm_bindgen_test::wasm_bindgen_test; @@ -1804,7 +1856,7 @@ pub(crate) mod tests { // Verify bola can see the group name let bola_group_name = bola_group - .group_name(bola_group.mls_provider().unwrap()) + .group_name(&bola_group.mls_provider().unwrap()) .unwrap(); assert_eq!(bola_group_name, ""); @@ -1876,15 +1928,19 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_mls_group = amal_group.load_mls_group(&amal_db).unwrap(); - let amal_members: Vec = amal_mls_group.members().collect(); - assert_eq!(amal_members.len(), 3); + let amal_members_len = amal_group + .load_mls_group_with_lock(&amal_db, |mls_group| Ok(mls_group.members().count())) + .unwrap(); + + assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_mls_group = bola_group.load_mls_group(&bola_db).unwrap(); - let bola_members: Vec = bola_mls_group.members().collect(); - assert_eq!(bola_members.len(), 3); + let bola_members_len = bola_group + .load_mls_group_with_lock(&bola_db, |mls_group| Ok(mls_group.members().count())) + .unwrap(); + + assert_eq!(bola_members_len, 3); let amal_uncommitted_intents = amal_db .conn_ref() @@ -1944,19 +2000,26 @@ pub(crate) mod tests { .unwrap(); let provider = alix.mls_provider().unwrap(); // Doctor the group membership - let mut mls_group = alix_group.load_mls_group(&provider).unwrap(); - let mut existing_extensions = mls_group.extensions().clone(); - let mut group_membership = GroupMembership::new(); - group_membership.add("deadbeef".to_string(), 1); - existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); - mls_group - .update_group_context_extensions( - &provider, - existing_extensions.clone(), - &alix.identity().installation_keys, - ) + let mut mls_group = alix_group + .load_mls_group_with_lock(&provider, |mut mls_group| { + let mut existing_extensions = mls_group.extensions().clone(); + let mut group_membership = GroupMembership::new(); + group_membership.add("deadbeef".to_string(), 1); + existing_extensions + .add_or_replace(build_group_membership_extension(&group_membership)); + + mls_group + .update_group_context_extensions( + &provider, + existing_extensions.clone(), + &alix.identity().installation_keys, + ) + .unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); + + Ok(mls_group) // Return the updated group if necessary + }) .unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); // Now add bo to the group force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; @@ -2078,9 +2141,13 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let mls_group = group.load_mls_group(&provider).unwrap(); - let pending_commit = mls_group.pending_commit(); - assert!(pending_commit.is_none()); + let pending_commit_is_none = group + .load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.pending_commit().is_none()) + }) + .unwrap(); + + assert!(pending_commit_is_none); group.send_message(b"hello").await.expect("send message"); @@ -2162,7 +2229,7 @@ pub(crate) mod tests { let bola_group = receive_group_invite(&bola).await; bola_group.sync().await.unwrap(); assert!(!bola_group - .is_active(bola_group.mls_provider().unwrap()) + .is_active(&bola_group.mls_provider().unwrap()) .unwrap()) } @@ -2255,8 +2322,12 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let mls_group = group.load_mls_group(&provider).unwrap(); - let num_members = mls_group.members().collect::>().len(); + let num_members = group + .load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.members().collect::>().len()) + }) + .unwrap(); + assert_eq!(num_members, 3); } @@ -2353,7 +2424,7 @@ pub(crate) mod tests { .unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2416,7 +2487,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata.attributes.len().eq(&4)); assert!(group_mutable_metadata @@ -2439,7 +2510,7 @@ pub(crate) mod tests { let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let group_mutable_metadata = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2458,7 +2529,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2469,7 +2540,7 @@ pub(crate) mod tests { // Verify bola group sees update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2486,7 +2557,7 @@ pub(crate) mod tests { // Verify bola group does not see an update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2507,7 +2578,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2524,7 +2595,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_image_url: &String = binding .attributes @@ -2545,7 +2616,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2562,7 +2633,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_pinned_frame_url: &String = binding .attributes @@ -2585,7 +2656,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2606,7 +2677,7 @@ pub(crate) mod tests { let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let group_mutable_metadata = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2623,7 +2694,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); let amal_group_name: &String = binding .attributes @@ -2634,7 +2705,7 @@ pub(crate) mod tests { // Verify bola group sees update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2651,7 +2722,7 @@ pub(crate) mod tests { // Verify amal group sees an update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2718,13 +2789,13 @@ pub(crate) mod tests { bola_group.sync().await.unwrap(); assert_eq!( bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .len(), 1 ); assert!(bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .contains(&bola.inbox_id().to_string())); @@ -2755,13 +2826,13 @@ pub(crate) mod tests { bola_group.sync().await.unwrap(); assert_eq!( bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .len(), 0 ); assert!(!bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .contains(&bola.inbox_id().to_string())); @@ -3039,13 +3110,14 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!(mutable_metadata.super_admin_list.len(), 1); assert_eq!(mutable_metadata.super_admin_list[0], amal.inbox_id()); let protected_metadata: GroupMetadata = amal_group - .metadata(amal_group.mls_provider().unwrap()) + .metadata(&amal_group.mls_provider().unwrap()) + .await .unwrap(); assert_eq!( protected_metadata.conversation_type, @@ -3085,7 +3157,7 @@ pub(crate) mod tests { .unwrap(); amal_group.sync().await.unwrap(); let name = amal_group - .group_name(amal_group.mls_provider().unwrap()) + .group_name(&amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!(name, "Name Update 1"); @@ -3106,7 +3178,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); bola_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -3114,7 +3186,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(amal_group_name, "Name Update 2"); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -3325,16 +3397,16 @@ pub(crate) mod tests { amal_dm.sync().await.unwrap(); bola_dm.sync().await.unwrap(); let is_amal_admin = amal_dm - .is_admin(amal.inbox_id().to_string(), amal.mls_provider().unwrap()) + .is_admin(amal.inbox_id().to_string(), &amal.mls_provider().unwrap()) .unwrap(); let is_bola_admin = amal_dm - .is_admin(bola.inbox_id().to_string(), bola.mls_provider().unwrap()) + .is_admin(bola.inbox_id().to_string(), &bola.mls_provider().unwrap()) .unwrap(); let is_amal_super_admin = amal_dm - .is_super_admin(amal.inbox_id().to_string(), amal.mls_provider().unwrap()) + .is_super_admin(amal.inbox_id().to_string(), &amal.mls_provider().unwrap()) .unwrap(); let is_bola_super_admin = amal_dm - .is_super_admin(bola.inbox_id().to_string(), bola.mls_provider().unwrap()) + .is_super_admin(bola.inbox_id().to_string(), &bola.mls_provider().unwrap()) .unwrap(); assert!(!is_amal_admin); assert!(!is_bola_admin); @@ -3656,9 +3728,8 @@ pub(crate) mod tests { panic!("wrong message format") }; let provider = client.mls_provider().unwrap(); - let mut openmls_group = group.load_mls_group(&provider).unwrap(); let process_result = group - .process_message(&mut openmls_group, &provider, &first_message, false) + .process_message(&provider, &first_message, false) .await; assert_err!( @@ -3802,14 +3873,11 @@ pub(crate) mod tests { None, ) .unwrap(); - assert!(validate_dm_group( - &client, - &valid_dm_group - .load_mls_group(client.mls_provider().unwrap()) - .unwrap(), - added_by_inbox - ) - .is_ok()); + assert!(valid_dm_group + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }) + .is_ok()); // Test case 2: Invalid conversation type let invalid_protected_metadata = @@ -3824,10 +3892,11 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &invalid_type_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| + validate_dm_group(&client, &mls_group, added_by_inbox) + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") )); - // Test case 3: Missing DmMembers // This case is not easily testable with the current structure, as DmMembers are set in the protected metadata @@ -3845,7 +3914,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &mismatched_dm_members_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| + validate_dm_group(&client, &mls_group, added_by_inbox) + ), Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") )); @@ -3865,7 +3936,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &non_empty_admin_list_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| + validate_dm_group(&client, &mls_group, added_by_inbox) + ), Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") )); @@ -3884,11 +3957,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group( - &client, - &invalid_permissions_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), - added_by_inbox - ), + invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| + validate_dm_group(&client, &mls_group, added_by_inbox) + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") )); } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index eb354fe3c..616a0f6a3 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -48,21 +48,14 @@ impl MlsGroup { self.context() .store() .transaction_async(provider, |provider| async move { - let mut openmls_group = self.load_mls_group(provider)?; - - // Attempt processing immediately, but fail if the message is not an Application Message - // Returning an error should roll back the DB tx tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = msgv1.id, - "current epoch for [{}] in process_stream_entry() is Epoch: [{}]", + "current epoch for [{}] in process_stream_entry()", client_id, - openmls_group.epoch() ); - - self.process_message(&mut openmls_group, provider, msgv1, false) + self.process_message(provider, msgv1, false) .await // NOTE: We want to make sure we retry an error in process_message .map_err(SubscribeError::ReceiveGroup) diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index d287fd676..164e79af6 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -20,12 +20,100 @@ pub mod verified_key_package_v2; mod xmtp_openmls_provider; pub use client::{Client, Network}; +use std::collections::HashMap; +use std::sync::{Arc, LazyLock, Mutex}; use storage::{DuplicateItem, StorageError}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; pub use xmtp_id::InboxOwner; pub use xmtp_proto::api_client::trait_impls::*; +/// A manager for group-specific semaphores +#[derive(Debug)] +pub struct GroupCommitLock { + // Storage for group-specific semaphores + locks: Mutex, Arc>>, +} + +impl Default for GroupCommitLock { + fn default() -> Self { + Self::new() + } +} +impl GroupCommitLock { + /// Create a new `GroupCommitLock` + pub fn new() -> Self { + Self { + locks: Mutex::new(HashMap::new()), + } + } + + /// Get or create a semaphore for a specific group and acquire it, returning a guard + pub async fn get_lock_async(&self, group_id: Vec) -> Result { + let semaphore = { + match self.locks.lock() { + Ok(mut locks) => locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone(), + Err(err) => { + eprintln!("Failed to lock the mutex: {}", err); + return Err(GroupError::LockUnavailable); + } + } + }; + + let semaphore_clone = semaphore.clone(); + let permit = match semaphore.acquire_owned().await { + Ok(permit) => permit, + Err(err) => { + eprintln!("Failed to acquire semaphore permit: {}", err); + return Err(GroupError::LockFailedToAcquire); + } + }; + Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore_clone, + }) + } + + /// Get or create a semaphore for a specific group and acquire it synchronously + pub fn get_lock_sync(&self, group_id: Vec) -> Result { + let semaphore = { + match self.locks.lock() { + Ok(mut locks) => locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone(), + Err(err) => { + eprintln!("Failed to lock the mutex: {}", err); + return Err(GroupError::LockUnavailable); + } + } + }; + + // Synchronously acquire the permit + let permit = semaphore + .clone() + .try_acquire_owned() + .map_err(|_| GroupError::LockUnavailable)?; + Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore, // semaphore is now valid because we cloned it earlier + }) + } +} + +/// A guard that releases the semaphore when dropped +pub struct SemaphoreGuard { + _permit: OwnedSemaphorePermit, + _semaphore: Arc, +} + +// Static instance of `GroupCommitLock` +pub static MLS_COMMIT_LOCK: LazyLock = LazyLock::new(GroupCommitLock::new); + /// Inserts a model to the underlying data store, erroring if it already exists pub trait Store { fn store(&self, into: &StorageConnection) -> Result<(), StorageError>; @@ -64,6 +152,7 @@ pub trait Delete { fn delete(&self, key: Self::Key) -> Result; } +use crate::groups::GroupError; pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index bfb3c1399..70edb0956 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -15,7 +15,7 @@ use super::{ Sqlite, }; use crate::{ - groups::{intents::SendMessageIntentData, IntentError}, + groups::intents::{IntentError, SendMessageIntentData}, impl_fetch, impl_store, storage::StorageError, utils::id::calculate_message_id, diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index de25850ab..3cc3df2a8 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -3,7 +3,7 @@ use std::sync::PoisonError; use diesel::result::DatabaseErrorKind; use thiserror::Error; -use super::sql_key_store; +use super::sql_key_store::{self, SqlKeyStoreError}; use crate::groups::intents::IntentError; use xmtp_common::{retryable, RetryableError}; @@ -27,6 +27,7 @@ pub enum StorageError { Serialization(String), #[error("deserialization error")] Deserialization(String), + // TODO:insipx Make NotFound into an enum of possible items that may not be found #[error("{0} not found")] NotFound(String), #[error("lock")] @@ -45,6 +46,8 @@ pub enum StorageError { FromHex(#[from] hex::FromHexError), #[error(transparent)] Duplicate(DuplicateItem), + #[error(transparent)] + OpenMlsStorage(#[from] SqlKeyStoreError), } #[derive(Error, Debug)] diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index f9675d500..97f538504 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -385,7 +385,7 @@ where } WelcomeOrGroup::Group(group) => group?, }; - let metadata = group.metadata(provider)?; + let metadata = group.metadata(&provider).await?; Ok((metadata, group)) } } @@ -658,11 +658,11 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bob.inbox_id()]) .await .unwrap(); - let bob_group = bob + let bob_groups = bob .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); - let bob_group = bob_group.first().unwrap(); + let bob_group = bob_groups.first().unwrap(); let notify = Delivery::new(None); let notify_ptr = notify.clone(); @@ -968,6 +968,7 @@ pub(crate) mod tests { } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread"))] + #[cfg_attr(target_family = "wasm", ignore)] async fn test_dm_streaming() { let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);