Skip to content

Commit

Permalink
use the concrete type
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Dec 18, 2024
1 parent 91388a7 commit 1873e99
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 57 deletions.
9 changes: 6 additions & 3 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ use crate::groups::device_sync::WorkerHandle;
use crate::{
api::ApiClientWrapper,
groups::{
device_sync::preference_sync::UserPreferenceUpdate, group_permissions::PolicySet,
GroupError, GroupMetadataOptions, MlsGroup,
device_sync::preference_sync::UserPreferenceUpdate, group_metadata::DmMembers,
group_permissions::PolicySet, GroupError, GroupMetadataOptions, MlsGroup,
},
identity::{parse_credential, Identity, IdentityError},
identity_updates::{load_identity_updates, IdentityUpdateError},
Expand Down Expand Up @@ -638,7 +638,10 @@ where
target_inbox_id: String,
) -> Result<MlsGroup<Self>, ClientError> {
let conn = self.store().conn()?;
match conn.find_dm_group(self.inbox_id(), &target_inbox_id)? {
match conn.find_dm_group(DmMembers {
member_one_inbox_id: self.inbox_id(),
member_two_inbox_id: &target_inbox_id,
})? {
Some(dm_group) => Ok(MlsGroup::new(
self.clone(),
dm_group.id,
Expand Down
50 changes: 40 additions & 10 deletions xmtp_mls/src/groups/group_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use openmls::{extensions::Extensions, group::MlsGroup as OpenMlsGroup};
use prost::Message;
use thiserror::Error;

use xmtp_id::InboxId;
use xmtp_proto::xmtp::mls::message_contents::{
ConversationType as ConversationTypeProto, DmMembers as DmMembersProto,
GroupMetadataV1 as GroupMetadataProto, Inbox as InboxProto,
Expand Down Expand Up @@ -31,14 +32,14 @@ pub struct GroupMetadata {
pub conversation_type: ConversationType,
// TODO: Remove this once transition is completed
pub creator_inbox_id: String,
pub dm_members: Option<DmMembers>,
pub dm_members: Option<DmMembers<InboxId>>,
}

impl GroupMetadata {
pub fn new(
conversation_type: ConversationType,
creator_inbox_id: String,
dm_members: Option<DmMembers>,
dm_members: Option<DmMembers<InboxId>>,
) -> Self {
Self {
conversation_type,
Expand Down Expand Up @@ -130,25 +131,54 @@ impl TryFrom<i32> for ConversationType {
}

#[derive(Debug, Clone, PartialEq)]
pub struct DmMembers {
pub member_one_inbox_id: String,
pub member_two_inbox_id: String,
pub struct DmMembers<Id: AsRef<str>> {
pub member_one_inbox_id: Id,
pub member_two_inbox_id: Id,
}

impl From<DmMembers> for DmMembersProto {
fn from(value: DmMembers) -> Self {
impl<'a> DmMembers<String> {
pub fn as_ref(&'a self) -> DmMembers<&'a str> {
DmMembers {
member_one_inbox_id: &*self.member_one_inbox_id,
member_two_inbox_id: &*self.member_two_inbox_id,
}
}
}

impl<Id> From<DmMembers<Id>> for DmMembersProto
where
Id: AsRef<str>,
{
fn from(value: DmMembers<Id>) -> Self {
DmMembersProto {
dm_member_one: Some(InboxProto {
inbox_id: value.member_one_inbox_id.clone(),
inbox_id: value.member_one_inbox_id.as_ref().to_string(),
}),
dm_member_two: Some(InboxProto {
inbox_id: value.member_two_inbox_id.clone(),
inbox_id: value.member_two_inbox_id.as_ref().to_string(),
}),
}
}
}

impl TryFrom<DmMembersProto> for DmMembers {
impl<Id> From<DmMembers<Id>> for String
where
Id: AsRef<str>,
{
fn from(members: DmMembers<Id>) -> Self {
let inbox_ids = [
members.member_one_inbox_id.as_ref(),
members.member_two_inbox_id.as_ref(),
]
.into_iter()
.map(str::to_lowercase)
.collect::<Vec<_>>();

format!("dm:{}", inbox_ids.join(":"))
}
}

impl TryFrom<DmMembersProto> for DmMembers<InboxId> {
type Error = GroupMetadataError;

fn try_from(value: DmMembersProto) -> Result<Self, Self::Error> {
Expand Down
23 changes: 12 additions & 11 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ use self::{
intents::IntentError,
validated_commit::CommitValidationError,
};
use crate::storage::{
group::{DmId, DmIdExt},
StorageError,
};
use crate::storage::StorageError;
use xmtp_common::time::now_ns;
use xmtp_proto::xmtp::mls::{
api::v1::{
Expand Down Expand Up @@ -484,7 +481,10 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
now_ns(),
membership_state,
context.inbox_id().to_string(),
Some(DmId::from_ids([&dm_target_inbox_id, client.inbox_id()])),
Some(DmMembers {
member_one_inbox_id: dm_target_inbox_id,
member_two_inbox_id: client.inbox_id().to_string(),
}),
);

stored_group.store(provider.conn_ref())?;
Expand All @@ -511,8 +511,6 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
let group_id = mls_group.group_id().to_vec();
let metadata = extract_group_metadata(&mls_group)?;
let dm_members = metadata.dm_members;
let dm_id =
dm_members.map(|m| DmId::from_ids([&m.member_one_inbox_id, &m.member_two_inbox_id]));

let conversation_type = metadata.conversation_type;

Expand All @@ -524,7 +522,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
added_by_inbox,
welcome_id,
conversation_type,
dm_id,
dm_members,
),
ConversationType::Dm => {
validate_dm_group(client.as_ref(), &mls_group, &added_by_inbox)?;
Expand All @@ -535,7 +533,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
added_by_inbox,
welcome_id,
conversation_type,
dm_id,
dm_members,
)
}
ConversationType::Sync => StoredGroup::new_from_welcome(
Expand All @@ -545,7 +543,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
added_by_inbox,
welcome_id,
conversation_type,
dm_id,
dm_members,
),
};

Expand Down Expand Up @@ -1274,7 +1272,10 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
now_ns(),
GroupMembershipState::Allowed, // Use Allowed as default for tests
context.inbox_id().to_string(),
Some(dm_target_inbox_id),
Some(DmMembers {
member_one_inbox_id: client.inbox_id().to_string(),
member_two_inbox_id: dm_target_inbox_id,
}),
);

stored_group.store(provider.conn_ref())?;
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/validated_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ pub struct ValidatedCommit {
pub removed_inboxes: Vec<Inbox>,
pub metadata_changes: MutableMetadataChanges,
pub permissions_changed: bool,
pub dm_members: Option<DmMembers>,
pub dm_members: Option<DmMembers<String>>,
}

impl ValidatedCommit {
Expand Down
57 changes: 25 additions & 32 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use super::{
schema::groups::{self, dsl},
Sqlite,
};
use crate::{impl_fetch, impl_store, DuplicateItem, StorageError};
use crate::{
groups::group_metadata::DmMembers, impl_fetch, impl_store, DuplicateItem, StorageError,
};
use diesel::{
backend::Backend,
deserialize::{self, FromSql, FromSqlRow},
Expand All @@ -17,6 +19,7 @@ use diesel::{
};
use serde::{Deserialize, Serialize};
use xmtp_common::time::now_ns;
use xmtp_id::InboxIdRef;

pub type ID = Vec<u8>;

Expand All @@ -42,7 +45,7 @@ pub struct StoredGroup {
/// Enum, [`ConversationType`] signifies the group conversation type which extends to who can access it.
pub conversation_type: ConversationType,
/// The inbox_id of the DM target
pub dm_id: Option<DmId>,
pub dm_id: Option<String>,
/// Timestamp of when the last message was sent for this group (updated automatically in a trigger)
pub last_message_ns: i64,
}
Expand All @@ -59,7 +62,7 @@ impl StoredGroup {
added_by_inbox_id: String,
welcome_id: i64,
conversation_type: ConversationType,
dm_id: Option<String>,
dm_members: Option<DmMembers<String>>,
) -> Self {
Self {
id,
Expand All @@ -70,7 +73,7 @@ impl StoredGroup {
added_by_inbox_id,
welcome_id: Some(welcome_id),
rotated_at_ns: 0,
dm_id,
dm_id: dm_members.map(String::from),
last_message_ns: now_ns(),
}
}
Expand All @@ -81,21 +84,21 @@ impl StoredGroup {
created_at_ns: i64,
membership_state: GroupMembershipState,
added_by_inbox_id: String,
dm_id: Option<String>,
dm_members: Option<DmMembers<String>>,
) -> Self {
Self {
id,
created_at_ns,
membership_state,
installations_last_checked: 0,
conversation_type: match dm_id {
conversation_type: match dm_members {
Some(_) => ConversationType::Dm,
None => ConversationType::Group,
},
added_by_inbox_id,
welcome_id: None,
rotated_at_ns: 0,
dm_id,
dm_id: dm_members.map(String::from),
last_message_ns: now_ns(),
}
}
Expand Down Expand Up @@ -343,21 +346,17 @@ impl DbConnection {

pub fn find_dm_group(
&self,
inbox_id: &str,
target_inbox_id: &str,
members: DmMembers<&str>,
) -> Result<Option<StoredGroup>, StorageError> {
let dm_id = DmId::from_ids([inbox_id, target_inbox_id]);
let dm_id = String::from(members.clone());

let query = dsl::groups
.order(dsl::created_at_ns.asc())
.filter(dsl::dm_id.eq(Some(dm_id)));

let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
if groups.len() > 1 {
tracing::info!(
"More than one group found for dm_inbox_id {}",
target_inbox_id
);
tracing::info!("More than one group found for dm_inbox_id {members:?}");
}

Ok(groups.into_iter().next())
Expand Down Expand Up @@ -550,23 +549,11 @@ impl std::fmt::Display for ConversationType {
}
}

pub type DmId = String;

pub trait DmIdExt {
fn from_ids(inbox_ids: [&str; 2]) -> Self;
fn other_id(&self, other: &str) -> String;
fn other_inbox_id(&self, other: &str) -> String;
}
impl DmIdExt for DmId {
fn from_ids(inbox_ids: [&str; 2]) -> Self {
let inbox_ids = inbox_ids
.into_iter()
.map(str::to_lowercase)
.collect::<Vec<_>>();

format!("dm:{}", inbox_ids.join(":"))
}

fn other_id(&self, id: &str) -> String {
impl DmIdExt for String {
fn other_inbox_id(&self, id: &str) -> String {
// drop the "dm:"
let dm_id = &self[3..];

Expand Down Expand Up @@ -629,13 +616,16 @@ pub(crate) mod tests {
let id = rand_vec::<24>();
let created_at_ns = now_ns();
let membership_state = state.unwrap_or(GroupMembershipState::Allowed);
let dm_inbox_id = Some(DmId::from_ids(["placeholder_inbox_id"; 2]));
let members = DmMembers {
member_one_inbox_id: "placeholder_inbox_id_1".to_string(),
member_two_inbox_id: "placeholder_inbox_id_2".to_string(),
};
StoredGroup::new(
id,
created_at_ns,
membership_state,
"placeholder_address".to_string(),
dm_inbox_id,
Some(members),
)
}

Expand Down Expand Up @@ -755,7 +745,10 @@ pub(crate) mod tests {
// test find_dm_group

let dm_result = conn
.find_dm_group("placeholder_inbox_id", "placeholder_inbox_id")
.find_dm_group(DmMembers {
member_one_inbox_id: "placeholder_inbox_id_1",
member_two_inbox_id: "placeholder_inbox_id_2",
})
.unwrap();
assert!(dm_result.is_some());

Expand Down

0 comments on commit 1873e99

Please sign in to comment.