Skip to content

Commit

Permalink
Find DM conversation from target's inbox_id (#1137)
Browse files Browse the repository at this point in the history
* find dm group wip

* better

* minor cleanup

* cleaner approach

* lint

* bump

* bump

* napi too

* test

* feedback
  • Loading branch information
codabrink authored Oct 15, 2024
1 parent 6edeed8 commit f4d22fb
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 24 deletions.
16 changes: 16 additions & 0 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ impl FfiXmtpClient {

pub fn conversation(&self, conversation_id: Vec<u8>) -> Result<FfiConversation, GenericError> {
let convo = self.inner_client.group(conversation_id)?;

Ok(FfiConversation {
inner_client: self.inner_client.clone(),
conversation_id: convo.group_id,
created_at_ns: convo.created_at_ns,
})
}

pub fn dm_conversation(
&self,
target_inbox_id: String,
) -> Result<FfiConversation, GenericError> {
let convo = self
.inner_client
.dm_group_from_target_inbox(target_inbox_id)?;

Ok(FfiConversation {
inner_client: self.inner_client.clone(),
conversation_id: convo.group_id,
Expand Down
14 changes: 14 additions & 0 deletions bindings_node/src/conversations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ impl NapiConversations {
))
}

#[napi]
pub fn find_dm_by_target_inbox_id(&self, target_inbox_id: String) -> Result<NapiGroup> {
let convo = self
.inner_client
.dm_group_from_target_inbox(target_inbox_id)
.map_err(ErrorWrapper::from)?;

Ok(NapiGroup::new(
self.inner_client.clone(),
convo.group_id,
convo.created_at_ns,
))
}

#[napi]
pub fn find_message_by_id(&self, message_id: String) -> Result<NapiMessage> {
let message_id = hex::decode(message_id).map_err(ErrorWrapper::from)?;
Expand Down
11 changes: 4 additions & 7 deletions xmtp_id/src/associations/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,8 @@ impl SignatureRequestBuilder {
new_member_identifier: new_member_identifier.clone(),
}),
pending_signatures: HashMap::from([
(
SignatureField::ExistingMember,
existing_member_identifier.clone(),
),
(SignatureField::NewMember, new_member_identifier.clone()),
(SignatureField::ExistingMember, existing_member_identifier),
(SignatureField::NewMember, new_member_identifier),
]),
});

Expand All @@ -106,7 +103,7 @@ impl SignatureRequestBuilder {
self.actions.push(PendingIdentityAction {
pending_signatures: HashMap::from([(
SignatureField::RecoveryAddress,
recovery_address_identifier.clone(),
recovery_address_identifier,
)]),
unsigned_action: UnsignedAction::RevokeAssociation(UnsignedRevokeAssociation {
revoked_member,
Expand All @@ -124,7 +121,7 @@ impl SignatureRequestBuilder {
self.actions.push(PendingIdentityAction {
pending_signatures: HashMap::from([(
SignatureField::RecoveryAddress,
recovery_address_identifier.clone(),
recovery_address_identifier,
)]),
unsigned_action: UnsignedAction::ChangeRecoveryAddress(UnsignedChangeRecoveryAddress {
new_recovery_address,
Expand Down
2 changes: 1 addition & 1 deletion xmtp_id/src/associations/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub async fn sign_with_legacy_key(
))?;
let legacy_private_key = secp256k1.bytes;
let wallet: LocalWallet = hex::encode(legacy_private_key).parse::<LocalWallet>()?;
let signature = wallet.sign_message(signature_text.clone()).await?;
let signature = wallet.sign_message(signature_text).await?;

let legacy_signed_public_key_proto =
legacy_signed_private_key_proto
Expand Down
34 changes: 29 additions & 5 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,7 @@ where
&self,
address: String,
) -> Result<Option<String>, ClientError> {
let results = self
.find_inbox_ids_from_addresses(vec![address.clone()])
.await?;
let results = self.find_inbox_ids_from_addresses(vec![address]).await?;
if let Some(first_result) = results.into_iter().next() {
Ok(first_result)
} else {
Expand Down Expand Up @@ -589,8 +587,11 @@ where
Ok(sync_group)
}

/// Look up a group by its ID
/// Returns a [`MlsGroup`] if the group exists, or an error if it does not
/**
* Look up a group by its ID
*
* Returns a [`MlsGroup`] if the group exists, or an error if it does not
*/
pub fn group(&self, group_id: Vec<u8>) -> Result<MlsGroup, ClientError> {
let conn = &mut self.store().conn()?;
let stored_group: Option<StoredGroup> = conn.fetch(&group_id)?;
Expand All @@ -607,6 +608,29 @@ where
}
}

/**
* Look up a DM group by the target's inbox_id.
*
* Returns a [`MlsGroup`] if the group exists, or an error if it does not
*/
pub fn dm_group_from_target_inbox(
&self,
target_inbox_id: String,
) -> Result<MlsGroup, ClientError> {
let conn = self.store().conn()?;
match conn.find_dm_group(&target_inbox_id)? {
Some(dm_group) => Ok(MlsGroup::new(
self.context.clone(),
dm_group.id,
dm_group.created_at_ns,
)),
None => Err(ClientError::Storage(StorageError::NotFound(format!(
"dm_target_inbox_id {}",
hex::encode(target_inbox_id)
)))),
}
}

/// Look up a message by its ID
/// Returns a [`StoredGroupMessage`] if the message exists, or an error if it does not
pub fn message(&self, message_id: Vec<u8>) -> Result<StoredGroupMessage, ClientError> {
Expand Down
12 changes: 4 additions & 8 deletions xmtp_mls/src/identity_updates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,25 +486,21 @@ pub async fn load_identity_updates<ApiClient: XmtpApi>(
let filters: Vec<GetIdentityUpdatesV2Filter> = inbox_ids
.into_iter()
.map(|inbox_id| GetIdentityUpdatesV2Filter {
sequence_id: existing_sequence_ids
.get(&inbox_id)
.copied()
.map(|i| i as u64),
sequence_id: existing_sequence_ids.get(&inbox_id).map(|i| *i as u64),
inbox_id,
})
.collect();

let updates = api_client.get_identity_updates_v2(filters).await?;

let to_store = updates
.clone()
.into_iter()
.iter()
.flat_map(|(inbox_id, updates)| {
updates.into_iter().map(move |update| StoredIdentityUpdate {
updates.iter().map(|update| StoredIdentityUpdate {
inbox_id: inbox_id.clone(),
sequence_id: update.sequence_id as i64,
server_timestamp_ns: update.server_timestamp_ns as i64,
payload: update.update.into(),
payload: update.update.clone().into(),
})
})
.collect::<Vec<StoredIdentityUpdate>>();
Expand Down
31 changes: 28 additions & 3 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,34 @@ impl DbConnection {
&self,
welcome_id: i64,
) -> Result<Option<StoredGroup>, StorageError> {
let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed();
query = query.filter(dsl::welcome_id.eq(welcome_id));
let query = dsl::groups
.order(dsl::created_at_ns.asc())
.filter(dsl::welcome_id.eq(welcome_id));

let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
if groups.len() > 1 {
tracing::error!("More than one group found for welcome_id {}", welcome_id);
}
// Manually extract the first element

Ok(groups.into_iter().next())
}

pub fn find_dm_group(
&self,
dm_target_inbox_id: &str,
) -> Result<Option<StoredGroup>, StorageError> {
let query = dsl::groups
.order(dsl::created_at_ns.asc())
.filter(dsl::dm_inbox_id.eq(Some(&dm_target_inbox_id)));

let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
if groups.len() > 1 {
tracing::info!(
"More than one group found for dm_inbox_id {}",
dm_target_inbox_id
);
}

Ok(groups.into_iter().next())
}

Expand Down Expand Up @@ -515,6 +536,10 @@ pub(crate) mod tests {
assert_eq!(dm_results.len(), 3);
assert_eq!(dm_results[2].id, test_group_3.id);

// test find_dm_group
let dm_result = conn.find_dm_group("placeholder_inbox_id").unwrap();
assert!(dm_result.is_some());

// test only dms are returned
let dm_results = conn
.find_groups(None, None, None, None, Some(ConversationType::Dm))
Expand Down

0 comments on commit f4d22fb

Please sign in to comment.