diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index f3b006082..931b64b1f 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -79,7 +79,7 @@ use crate::{ SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS, }, hpke::{decrypt_welcome, HpkeError}, - identity::{parse_credential, Identity, IdentityError}, + identity::{parse_credential, IdentityError}, identity_updates::{load_identity_updates, InstallationDiffError}, retry::RetryableError, storage::{ @@ -295,9 +295,10 @@ impl MlsGroup { ) -> Result { let conn = context.store.conn()?; let provider = XmtpOpenMlsProvider::new(conn); + let creator_inbox_id = context.inbox_id(); let protected_metadata = - build_protected_metadata_extension(&context.identity, Purpose::Conversation)?; - let mutable_metadata = build_mutable_metadata_extension_default(&context.identity, opts)?; + build_protected_metadata_extension(creator_inbox_id.clone(), Purpose::Conversation)?; + let mutable_metadata = build_mutable_metadata_extension_default(creator_inbox_id, opts)?; let group_membership = build_starting_group_membership_extension(context.inbox_id(), 0); let mutable_permissions = build_mutable_permissions_extension(permissions_policy_set)?; let group_config = build_group_config( @@ -342,7 +343,7 @@ impl MlsGroup { let conn = context.store.conn()?; let provider = XmtpOpenMlsProvider::new(conn); let protected_metadata = - build_dm_protected_metadata_extension(&context.identity, dm_target_inbox_id.clone())?; + build_dm_protected_metadata_extension(context.inbox_id(), dm_target_inbox_id.clone())?; let mutable_metadata = build_dm_mutable_metadata_extension_default(context.inbox_id(), dm_target_inbox_id)?; let group_membership = build_starting_group_membership_extension(context.inbox_id(), 0); @@ -400,7 +401,7 @@ impl MlsGroup { let group_type = metadata.conversation_type; let to_store = match group_type { - ConversationType::Group | ConversationType::Dm => StoredGroup::new_from_welcome( + ConversationType::Group => StoredGroup::new_from_welcome( group_id.clone(), now_ns(), GroupMembershipState::Pending, @@ -408,6 +409,17 @@ impl MlsGroup { welcome_id, Purpose::Conversation, ), + ConversationType::Dm => { + validate_dm_group(client, &mls_group, &added_by_inbox)?; + StoredGroup::new_from_welcome( + group_id.clone(), + now_ns(), + GroupMembershipState::Pending, + added_by_inbox, + welcome_id, + Purpose::Conversation, + ) + } ConversationType::Sync => StoredGroup::new_from_welcome( group_id.clone(), now_ns(), @@ -465,11 +477,12 @@ impl MlsGroup { ) -> Result { let conn = context.store.conn()?; // let my_sequence_id = context.inbox_sequence_id(&conn)?; + let creator_inbox_id = context.inbox_id().to_string(); let provider = XmtpOpenMlsProvider::new(conn); let protected_metadata = - build_protected_metadata_extension(&context.identity, Purpose::Sync)?; + build_protected_metadata_extension(creator_inbox_id.clone(), Purpose::Sync)?; let mutable_metadata = build_mutable_metadata_extension_default( - &context.identity, + creator_inbox_id, GroupMetadataOptions::default(), )?; let group_membership = build_starting_group_membership_extension(context.inbox_id(), 0); @@ -1063,6 +1076,70 @@ impl MlsGroup { Ok(extract_group_permissions(&mls_group)?) } + /// Used for testing that dm group validation works as expected. + /// + /// See the `test_validate_dm_group` test function for more details. + #[cfg(test)] + pub fn create_test_dm_group( + context: Arc, + dm_target_inbox_id: InboxId, + custom_protected_metadata: Option, + custom_mutable_metadata: Option, + custom_group_membership: Option, + custom_mutable_permissions: Option, + ) -> Result { + let conn = context.store.conn()?; + let provider = XmtpOpenMlsProvider::new(conn); + + let protected_metadata = custom_protected_metadata.unwrap_or_else(|| { + build_dm_protected_metadata_extension(context.inbox_id(), dm_target_inbox_id.clone()) + .unwrap() + }); + let mutable_metadata = custom_mutable_metadata.unwrap_or_else(|| { + build_dm_mutable_metadata_extension_default( + context.inbox_id(), + dm_target_inbox_id.clone(), + ) + .unwrap() + }); + let group_membership = custom_group_membership + .unwrap_or_else(|| build_starting_group_membership_extension(context.inbox_id(), 0)); + let mutable_permissions = custom_mutable_permissions.unwrap_or_else(PolicySet::new_dm); + let mutable_permission_extension = + build_mutable_permissions_extension(mutable_permissions)?; + + let group_config = build_group_config( + protected_metadata, + mutable_metadata, + group_membership, + mutable_permission_extension, + )?; + + let mls_group = OpenMlsGroup::new( + &provider, + &context.identity.installation_keys, + &group_config, + CredentialWithKey { + credential: context.identity.credential(), + signature_key: context.identity.installation_keys.to_public_vec().into(), + }, + )?; + + let group_id = mls_group.group_id().to_vec(); + let stored_group = StoredGroup::new( + group_id.clone(), + now_ns(), + GroupMembershipState::Allowed, // Use Allowed as default for tests + context.inbox_id(), + ); + + stored_group.store(provider.conn_ref())?; + Ok(Self::new( + context.clone(), + group_id, + stored_group.created_at_ns, + )) + } } fn extract_message_v1(message: GroupMessage) -> Result { @@ -1080,7 +1157,7 @@ pub fn extract_group_id(message: &GroupMessage) -> Result, MessageProces } fn build_protected_metadata_extension( - identity: &Identity, + creator_inbox_id: String, group_purpose: Purpose, ) -> Result { let group_type = match group_purpose { @@ -1088,26 +1165,22 @@ fn build_protected_metadata_extension( Purpose::Sync => ConversationType::Sync, }; - let metadata = GroupMetadata::new(group_type, identity.inbox_id().clone(), None); + let metadata = GroupMetadata::new(group_type, creator_inbox_id, None); let protected_metadata = Metadata::new(metadata.try_into()?); Ok(Extension::ImmutableMetadata(protected_metadata)) } fn build_dm_protected_metadata_extension( - identity: &Identity, + creator_inbox_id: String, dm_inbox_id: InboxId, ) -> Result { let dm_members = Some(DmMembers { - member_one_inbox_id: identity.inbox_id().clone(), + member_one_inbox_id: creator_inbox_id.clone(), member_two_inbox_id: dm_inbox_id, }); - let metadata = GroupMetadata::new( - ConversationType::Dm, - identity.inbox_id().clone(), - dm_members, - ); + let metadata = GroupMetadata::new(ConversationType::Dm, creator_inbox_id, dm_members); let protected_metadata = Metadata::new(metadata.try_into()?); Ok(Extension::ImmutableMetadata(protected_metadata)) @@ -1124,11 +1197,11 @@ fn build_mutable_permissions_extension(policies: PolicySet) -> Result Result { let mutable_metadata: Vec = - GroupMutableMetadata::new_default(identity.inbox_id.clone(), opts).try_into()?; + GroupMutableMetadata::new_default(creator_inbox_id, opts).try_into()?; let unknown_gc_extension = UnknownExtension(mutable_metadata); Ok(Extension::Unknown( @@ -1377,6 +1450,59 @@ async fn validate_initial_group_membership( Ok(()) } +fn validate_dm_group( + client: &Client, + mls_group: &OpenMlsGroup, + added_by_inbox: &str, +) -> Result<(), GroupError> { + let metadata = extract_group_metadata(mls_group)?; + + // Check if the conversation type is DM + if metadata.conversation_type != ConversationType::Dm { + return Err(GroupError::Generic( + "Invalid conversation type for DM group".to_string(), + )); + } + + // Check if DmMembers are set and validate their contents + if let Some(dm_members) = metadata.dm_members { + let our_inbox_id = client.context.identity.inbox_id().clone(); + if !((dm_members.member_one_inbox_id == added_by_inbox + && dm_members.member_two_inbox_id == our_inbox_id) + || (dm_members.member_one_inbox_id == our_inbox_id + && dm_members.member_two_inbox_id == added_by_inbox)) + { + return Err(GroupError::Generic( + "DM members do not match expected inboxes".to_string(), + )); + } + } else { + return Err(GroupError::Generic( + "DM group must have DmMembers set".to_string(), + )); + } + + // Validate mutable metadata + let mutable_metadata: GroupMutableMetadata = mls_group.try_into()?; + + // Check if the admin list and super admin list are empty + if !mutable_metadata.admin_list.is_empty() || !mutable_metadata.super_admin_list.is_empty() { + return Err(GroupError::Generic( + "DM group must have empty admin and super admin lists".to_string(), + )); + } + + // Validate permissions + let permissions = extract_group_permissions(mls_group)?; + if permissions != GroupMutablePermissions::new(PolicySet::new_dm()) { + return Err(GroupError::Generic( + "Invalid permissions for DM group".to_string(), + )); + } + + Ok(()) +} + fn build_group_join_config() -> MlsGroupJoinConfig { MlsGroupJoinConfig::builder() .wire_format_policy(WireFormatPolicy::default()) @@ -1402,16 +1528,19 @@ mod tests { client::MessageProcessingError, codecs::{group_updated::GroupUpdatedCodec, ContentCodec}, groups::{ - build_group_membership_extension, + build_dm_protected_metadata_extension, build_group_membership_extension, + build_mutable_metadata_extension_default, build_protected_metadata_extension, group_membership::GroupMembership, group_metadata::{ConversationType, GroupMetadata}, group_mutable_metadata::MetadataField, intents::{PermissionPolicyOption, PermissionUpdateType}, members::{GroupMember, PermissionLevel}, - DeliveryStatus, GroupMetadataOptions, PreconfiguredPolicies, UpdateAdminListType, + validate_dm_group, DeliveryStatus, GroupMetadataOptions, PreconfiguredPolicies, + UpdateAdminListType, }, storage::{ consent_record::ConsentState, + group::Purpose, group_intent::{IntentKind, IntentState, NewGroupIntent}, group_message::{GroupMessageKind, StoredGroupMessage}, }, @@ -1420,6 +1549,7 @@ mod tests { }; use super::{ + group_permissions::PolicySet, intents::{Installation, SendWelcomesAction}, GroupError, MlsGroup, }; @@ -3419,4 +3549,115 @@ mod tests { assert_eq!(consent, ConsentState::Denied); } + + #[tokio::test] + async fn test_validate_dm_group() { + let client = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let added_by_inbox = "added_by_inbox_id"; + let creator_inbox_id = client.context.identity.inbox_id().clone(); + let dm_target_inbox_id = added_by_inbox.to_string(); + + // Test case 1: Valid DM group + let valid_dm_group = MlsGroup::create_test_dm_group( + client.context.clone(), + dm_target_inbox_id.clone(), + None, + None, + None, + None, + ) + .unwrap(); + assert!(validate_dm_group( + &client, + &valid_dm_group + .load_mls_group(client.mls_provider().unwrap()) + .unwrap(), + added_by_inbox + ) + .is_ok()); + + // Test case 2: Invalid conversation type + let invalid_protected_metadata = + build_protected_metadata_extension(creator_inbox_id.clone(), Purpose::Conversation) + .unwrap(); + let invalid_type_group = MlsGroup::create_test_dm_group( + client.context.clone(), + dm_target_inbox_id.clone(), + Some(invalid_protected_metadata), + None, + None, + None, + ) + .unwrap(); + assert!(matches!( + validate_dm_group(&client, &invalid_type_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") + )); + + // Test case 3: Missing DmMembers + // This case is not easily testable with the current structure, as DmMembers are set in the protected metadata + + // Test case 4: Mismatched DM members + let mismatched_dm_members = build_dm_protected_metadata_extension( + creator_inbox_id.clone(), + "wrong_inbox_id".to_string(), + ) + .unwrap(); + let mismatched_dm_members_group = MlsGroup::create_test_dm_group( + client.context.clone(), + dm_target_inbox_id.clone(), + Some(mismatched_dm_members), + None, + None, + None, + ) + .unwrap(); + assert!(matches!( + validate_dm_group(&client, &mismatched_dm_members_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") + )); + + // Test case 5: Non-empty admin list + let non_empty_admin_list = build_mutable_metadata_extension_default( + creator_inbox_id.clone(), + GroupMetadataOptions::default(), + ) + .unwrap(); + let non_empty_admin_list_group = MlsGroup::create_test_dm_group( + client.context.clone(), + dm_target_inbox_id.clone(), + None, + Some(non_empty_admin_list), + None, + None, + ) + .unwrap(); + assert!(matches!( + validate_dm_group(&client, &non_empty_admin_list_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") + )); + + // Test case 6: Non-empty super admin list + // Similar to test case 5, but with super_admin_list + + // Test case 7: Invalid permissions + let invalid_permissions = PolicySet::default(); + let invalid_permissions_group = MlsGroup::create_test_dm_group( + client.context.clone(), + dm_target_inbox_id.clone(), + None, + None, + None, + Some(invalid_permissions), + ) + .unwrap(); + assert!(matches!( + validate_dm_group( + &client, + &invalid_permissions_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), + added_by_inbox + ), + Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") + )); + } }