From 1873e99c72954038cd0b89e52553e3d395e59962 Mon Sep 17 00:00:00 2001 From: Dakota Brink Date: Wed, 18 Dec 2024 17:08:31 -0500 Subject: [PATCH] use the concrete type --- xmtp_mls/src/client.rs | 9 ++- xmtp_mls/src/groups/group_metadata.rs | 50 ++++++++++++---- xmtp_mls/src/groups/mod.rs | 23 ++++---- xmtp_mls/src/groups/validated_commit.rs | 2 +- xmtp_mls/src/storage/encrypted_store/group.rs | 57 ++++++++----------- 5 files changed, 84 insertions(+), 57 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6b84be9b0..d68a2c9ca 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -36,8 +36,8 @@ use crate::groups::device_sync::WorkerHandle; use crate::{ api::ApiClientWrapper, groups::{ - device_sync::preference_sync::UserPreferenceUpdate, group_permissions::PolicySet, - GroupError, GroupMetadataOptions, MlsGroup, + device_sync::preference_sync::UserPreferenceUpdate, group_metadata::DmMembers, + group_permissions::PolicySet, GroupError, GroupMetadataOptions, MlsGroup, }, identity::{parse_credential, Identity, IdentityError}, identity_updates::{load_identity_updates, IdentityUpdateError}, @@ -638,7 +638,10 @@ where target_inbox_id: String, ) -> Result, ClientError> { let conn = self.store().conn()?; - match conn.find_dm_group(self.inbox_id(), &target_inbox_id)? { + match conn.find_dm_group(DmMembers { + member_one_inbox_id: self.inbox_id(), + member_two_inbox_id: &target_inbox_id, + })? { Some(dm_group) => Ok(MlsGroup::new( self.clone(), dm_group.id, diff --git a/xmtp_mls/src/groups/group_metadata.rs b/xmtp_mls/src/groups/group_metadata.rs index e06c7e6a1..162475ee6 100644 --- a/xmtp_mls/src/groups/group_metadata.rs +++ b/xmtp_mls/src/groups/group_metadata.rs @@ -2,6 +2,7 @@ use openmls::{extensions::Extensions, group::MlsGroup as OpenMlsGroup}; use prost::Message; use thiserror::Error; +use xmtp_id::InboxId; use xmtp_proto::xmtp::mls::message_contents::{ ConversationType as ConversationTypeProto, DmMembers as DmMembersProto, GroupMetadataV1 as GroupMetadataProto, Inbox as InboxProto, @@ -31,14 +32,14 @@ pub struct GroupMetadata { pub conversation_type: ConversationType, // TODO: Remove this once transition is completed pub creator_inbox_id: String, - pub dm_members: Option, + pub dm_members: Option>, } impl GroupMetadata { pub fn new( conversation_type: ConversationType, creator_inbox_id: String, - dm_members: Option, + dm_members: Option>, ) -> Self { Self { conversation_type, @@ -130,25 +131,54 @@ impl TryFrom for ConversationType { } #[derive(Debug, Clone, PartialEq)] -pub struct DmMembers { - pub member_one_inbox_id: String, - pub member_two_inbox_id: String, +pub struct DmMembers> { + pub member_one_inbox_id: Id, + pub member_two_inbox_id: Id, } -impl From for DmMembersProto { - fn from(value: DmMembers) -> Self { +impl<'a> DmMembers { + pub fn as_ref(&'a self) -> DmMembers<&'a str> { + DmMembers { + member_one_inbox_id: &*self.member_one_inbox_id, + member_two_inbox_id: &*self.member_two_inbox_id, + } + } +} + +impl From> for DmMembersProto +where + Id: AsRef, +{ + fn from(value: DmMembers) -> Self { DmMembersProto { dm_member_one: Some(InboxProto { - inbox_id: value.member_one_inbox_id.clone(), + inbox_id: value.member_one_inbox_id.as_ref().to_string(), }), dm_member_two: Some(InboxProto { - inbox_id: value.member_two_inbox_id.clone(), + inbox_id: value.member_two_inbox_id.as_ref().to_string(), }), } } } -impl TryFrom for DmMembers { +impl From> for String +where + Id: AsRef, +{ + fn from(members: DmMembers) -> Self { + let inbox_ids = [ + members.member_one_inbox_id.as_ref(), + members.member_two_inbox_id.as_ref(), + ] + .into_iter() + .map(str::to_lowercase) + .collect::>(); + + format!("dm:{}", inbox_ids.join(":")) + } +} + +impl TryFrom for DmMembers { type Error = GroupMetadataError; fn try_from(value: DmMembersProto) -> Result { diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 78c0253f2..b21eebad4 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -58,10 +58,7 @@ use self::{ intents::IntentError, validated_commit::CommitValidationError, }; -use crate::storage::{ - group::{DmId, DmIdExt}, - StorageError, -}; +use crate::storage::StorageError; use xmtp_common::time::now_ns; use xmtp_proto::xmtp::mls::{ api::v1::{ @@ -484,7 +481,10 @@ impl MlsGroup { now_ns(), membership_state, context.inbox_id().to_string(), - Some(DmId::from_ids([&dm_target_inbox_id, client.inbox_id()])), + Some(DmMembers { + member_one_inbox_id: dm_target_inbox_id, + member_two_inbox_id: client.inbox_id().to_string(), + }), ); stored_group.store(provider.conn_ref())?; @@ -511,8 +511,6 @@ impl MlsGroup { let group_id = mls_group.group_id().to_vec(); let metadata = extract_group_metadata(&mls_group)?; let dm_members = metadata.dm_members; - let dm_id = - dm_members.map(|m| DmId::from_ids([&m.member_one_inbox_id, &m.member_two_inbox_id])); let conversation_type = metadata.conversation_type; @@ -524,7 +522,7 @@ impl MlsGroup { added_by_inbox, welcome_id, conversation_type, - dm_id, + dm_members, ), ConversationType::Dm => { validate_dm_group(client.as_ref(), &mls_group, &added_by_inbox)?; @@ -535,7 +533,7 @@ impl MlsGroup { added_by_inbox, welcome_id, conversation_type, - dm_id, + dm_members, ) } ConversationType::Sync => StoredGroup::new_from_welcome( @@ -545,7 +543,7 @@ impl MlsGroup { added_by_inbox, welcome_id, conversation_type, - dm_id, + dm_members, ), }; @@ -1274,7 +1272,10 @@ impl MlsGroup { now_ns(), GroupMembershipState::Allowed, // Use Allowed as default for tests context.inbox_id().to_string(), - Some(dm_target_inbox_id), + Some(DmMembers { + member_one_inbox_id: client.inbox_id().to_string(), + member_two_inbox_id: dm_target_inbox_id, + }), ); stored_group.store(provider.conn_ref())?; diff --git a/xmtp_mls/src/groups/validated_commit.rs b/xmtp_mls/src/groups/validated_commit.rs index f13557236..c9dcd8c18 100644 --- a/xmtp_mls/src/groups/validated_commit.rs +++ b/xmtp_mls/src/groups/validated_commit.rs @@ -210,7 +210,7 @@ pub struct ValidatedCommit { pub removed_inboxes: Vec, pub metadata_changes: MutableMetadataChanges, pub permissions_changed: bool, - pub dm_members: Option, + pub dm_members: Option>, } impl ValidatedCommit { diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index a2253b5a6..efc346332 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -5,7 +5,9 @@ use super::{ schema::groups::{self, dsl}, Sqlite, }; -use crate::{impl_fetch, impl_store, DuplicateItem, StorageError}; +use crate::{ + groups::group_metadata::DmMembers, impl_fetch, impl_store, DuplicateItem, StorageError, +}; use diesel::{ backend::Backend, deserialize::{self, FromSql, FromSqlRow}, @@ -17,6 +19,7 @@ use diesel::{ }; use serde::{Deserialize, Serialize}; use xmtp_common::time::now_ns; +use xmtp_id::InboxIdRef; pub type ID = Vec; @@ -42,7 +45,7 @@ pub struct StoredGroup { /// Enum, [`ConversationType`] signifies the group conversation type which extends to who can access it. pub conversation_type: ConversationType, /// The inbox_id of the DM target - pub dm_id: Option, + pub dm_id: Option, /// Timestamp of when the last message was sent for this group (updated automatically in a trigger) pub last_message_ns: i64, } @@ -59,7 +62,7 @@ impl StoredGroup { added_by_inbox_id: String, welcome_id: i64, conversation_type: ConversationType, - dm_id: Option, + dm_members: Option>, ) -> Self { Self { id, @@ -70,7 +73,7 @@ impl StoredGroup { added_by_inbox_id, welcome_id: Some(welcome_id), rotated_at_ns: 0, - dm_id, + dm_id: dm_members.map(String::from), last_message_ns: now_ns(), } } @@ -81,21 +84,21 @@ impl StoredGroup { created_at_ns: i64, membership_state: GroupMembershipState, added_by_inbox_id: String, - dm_id: Option, + dm_members: Option>, ) -> Self { Self { id, created_at_ns, membership_state, installations_last_checked: 0, - conversation_type: match dm_id { + conversation_type: match dm_members { Some(_) => ConversationType::Dm, None => ConversationType::Group, }, added_by_inbox_id, welcome_id: None, rotated_at_ns: 0, - dm_id, + dm_id: dm_members.map(String::from), last_message_ns: now_ns(), } } @@ -343,10 +346,9 @@ impl DbConnection { pub fn find_dm_group( &self, - inbox_id: &str, - target_inbox_id: &str, + members: DmMembers<&str>, ) -> Result, StorageError> { - let dm_id = DmId::from_ids([inbox_id, target_inbox_id]); + let dm_id = String::from(members.clone()); let query = dsl::groups .order(dsl::created_at_ns.asc()) @@ -354,10 +356,7 @@ impl DbConnection { let groups: Vec = self.raw_query(|conn| query.load(conn))?; if groups.len() > 1 { - tracing::info!( - "More than one group found for dm_inbox_id {}", - target_inbox_id - ); + tracing::info!("More than one group found for dm_inbox_id {members:?}"); } Ok(groups.into_iter().next()) @@ -550,23 +549,11 @@ impl std::fmt::Display for ConversationType { } } -pub type DmId = String; - pub trait DmIdExt { - fn from_ids(inbox_ids: [&str; 2]) -> Self; - fn other_id(&self, other: &str) -> String; + fn other_inbox_id(&self, other: &str) -> String; } -impl DmIdExt for DmId { - fn from_ids(inbox_ids: [&str; 2]) -> Self { - let inbox_ids = inbox_ids - .into_iter() - .map(str::to_lowercase) - .collect::>(); - - format!("dm:{}", inbox_ids.join(":")) - } - - fn other_id(&self, id: &str) -> String { +impl DmIdExt for String { + fn other_inbox_id(&self, id: &str) -> String { // drop the "dm:" let dm_id = &self[3..]; @@ -629,13 +616,16 @@ pub(crate) mod tests { let id = rand_vec::<24>(); let created_at_ns = now_ns(); let membership_state = state.unwrap_or(GroupMembershipState::Allowed); - let dm_inbox_id = Some(DmId::from_ids(["placeholder_inbox_id"; 2])); + let members = DmMembers { + member_one_inbox_id: "placeholder_inbox_id_1".to_string(), + member_two_inbox_id: "placeholder_inbox_id_2".to_string(), + }; StoredGroup::new( id, created_at_ns, membership_state, "placeholder_address".to_string(), - dm_inbox_id, + Some(members), ) } @@ -755,7 +745,10 @@ pub(crate) mod tests { // test find_dm_group let dm_result = conn - .find_dm_group("placeholder_inbox_id", "placeholder_inbox_id") + .find_dm_group(DmMembers { + member_one_inbox_id: "placeholder_inbox_id_1", + member_two_inbox_id: "placeholder_inbox_id_2", + }) .unwrap(); assert!(dm_result.is_some());