Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLS Dms - Dual Sending Pre-requisites #1076

Merged
merged 19 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 69 additions & 10 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use xmtp_id::{
associations::{builder::SignatureRequest, generate_inbox_id as xmtp_id_generate_inbox_id},
InboxId,
};
use xmtp_mls::client::FindGroupParams;
use xmtp_mls::groups::group_mutable_metadata::MetadataField;
use xmtp_mls::groups::group_permissions::BasePolicies;
use xmtp_mls::groups::group_permissions::GroupMutablePermissionsError;
Expand Down Expand Up @@ -780,6 +781,20 @@ impl FfiConversations {
Ok(out)
}

pub async fn create_dm(&self, account_address: String) -> Result<Arc<FfiGroup>, GenericError> {
log::info!("creating dm with target address: {}", account_address);

let convo = self.inner_client.create_dm(account_address).await?;

let out = Arc::new(FfiGroup {
inner_client: self.inner_client.clone(),
group_id: convo.group_id,
created_at_ns: convo.created_at_ns,
});

Ok(out)
}

pub async fn process_streamed_welcome_message(
&self,
envelope_bytes: Vec<u8>,
Expand All @@ -804,7 +819,16 @@ impl FfiConversations {

pub async fn sync_all_groups(&self) -> Result<u32, GenericError> {
let inner = self.inner_client.as_ref();
let groups = inner.find_groups(None, None, None, None)?;
let groups = inner.find_groups(FindGroupParams {
include_dm_groups: true,
..FindGroupParams::default()
})?;

log::info!(
"groups for client inbox id {:?}: {:?}",
self.inner_client.inbox_id(),
groups.len()
);

let num_groups_synced: usize = inner.sync_all_groups(groups).await?;
// Uniffi does not work with usize, so we need to convert to u32
Expand All @@ -824,12 +848,13 @@ impl FfiConversations {
) -> Result<Vec<Arc<FfiGroup>>, GenericError> {
let inner = self.inner_client.as_ref();
let convo_list: Vec<Arc<FfiGroup>> = inner
.find_groups(
None,
opts.created_after_ns,
opts.created_before_ns,
opts.limit,
)?
.find_groups(FindGroupParams {
allowed_states: None,
created_after_ns: opts.created_after_ns,
created_before_ns: opts.created_before_ns,
limit: opts.limit,
include_dm_groups: false,
})?
.into_iter()
.map(|group| {
Arc::new(FfiGroup {
Expand All @@ -845,14 +870,17 @@ impl FfiConversations {

pub async fn stream(&self, callback: Box<dyn FfiConversationCallback>) -> FfiStreamCloser {
let client = self.inner_client.clone();
let handle =
RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| {
let handle = RustXmtpClient::stream_conversations_with_callback(
client.clone(),
move |convo| {
callback.on_conversation(Arc::new(FfiGroup {
inner_client: client.clone(),
group_id: convo.group_id,
created_at_ns: convo.created_at_ns,
}))
});
},
false,
);

FfiStreamCloser::new(handle)
}
Expand Down Expand Up @@ -3700,6 +3728,37 @@ mod tests {
);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_dms_sync_but_do_not_list() {
let alix = new_test_client().await;
let bola = new_test_client().await;

let alix_conversations = alix.conversations();
let bola_conversations = bola.conversations();

let _alix_group = alix_conversations
.create_dm(bola.account_address.clone())
.await
.unwrap();
let alix_num_sync = alix_conversations.sync_all_groups().await.unwrap();
bola_conversations.sync().await.unwrap();
let bola_num_sync = bola_conversations.sync_all_groups().await.unwrap();
assert_eq!(alix_num_sync, 1);
assert_eq!(bola_num_sync, 1);

let alix_groups = alix_conversations
.list(FfiListConversationsOptions::default())
.await
.unwrap();
assert_eq!(alix_groups.len(), 0);

let bola_groups = bola_conversations
.list(FfiListConversationsOptions::default())
.await
.unwrap();
assert_eq!(bola_groups.len(), 0);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_set_and_get_group_consent() {
let alix = new_test_client().await;
Expand Down
22 changes: 13 additions & 9 deletions bindings_node/src/conversations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use napi::bindgen_prelude::{Error, Result, Uint8Array};
use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode};
use napi::JsFunction;
use napi_derive::napi;
use xmtp_mls::client::FindGroupParams;
use xmtp_mls::groups::{GroupMetadataOptions, PreconfiguredPolicies};

use crate::messages::NapiMessage;
Expand Down Expand Up @@ -171,12 +172,12 @@ impl NapiConversations {
};
let convo_list: Vec<NapiGroup> = self
.inner_client
.find_groups(
None,
opts.created_after_ns,
opts.created_before_ns,
opts.limit,
)
.find_groups(FindGroupParams {
created_after_ns: opts.created_after_ns,
created_before_ns: opts.created_before_ns,
limit: opts.limit,
..FindGroupParams::default()
})
.map_err(|e| Error::from_reason(format!("{}", e)))?
.into_iter()
.map(|group| {
Expand All @@ -196,8 +197,9 @@ impl NapiConversations {
let tsfn: ThreadsafeFunction<NapiGroup, ErrorStrategy::CalleeHandled> =
callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?;
let client = self.inner_client.clone();
let stream_closer =
RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| {
let stream_closer = RustXmtpClient::stream_conversations_with_callback(
client.clone(),
move |convo| {
tsfn.call(
Ok(NapiGroup::new(
client.clone(),
Expand All @@ -206,7 +208,9 @@ impl NapiConversations {
)),
ThreadsafeFunctionCallMode::Blocking,
);
});
},
false,
);

Ok(NapiStreamCloser::new(stream_closer))
}
Expand Down
3 changes: 2 additions & 1 deletion examples/cli/cli-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use futures::future::join_all;
use kv_log_macro::{error, info};
use prost::Message;
use xmtp_id::associations::unverified::{UnverifiedRecoverableEcdsaSignature, UnverifiedSignature};
use xmtp_mls::client::FindGroupParams;
use xmtp_mls::groups::message_history::MessageHistoryContent;
use xmtp_mls::storage::group_message::GroupMessageKind;

Expand Down Expand Up @@ -209,7 +210,7 @@ async fn main() {

// recv(&client).await.unwrap();
let group_list = client
.find_groups(None, None, None, None)
.find_groups(FindGroupParams::default())
.expect("failed to list groups");
for group in group_list.iter() {
group.sync(&client).await.expect("error syncing group");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- This file should undo anything in `up.sql`
ALTER TABLE groups DROP COLUMN dm_inbox_id;
DROP INDEX idx_dm_target;
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- Your SQL goes here
ALTER TABLE groups ADD COLUMN dm_inbox_id text;
CREATE INDEX idx_dm_target ON groups(dm_inbox_id);
77 changes: 65 additions & 12 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ impl From<&str> for ClientError {
}
}

#[derive(Debug, Default)]
pub struct FindGroupParams {
pub allowed_states: Option<Vec<GroupMembershipState>>,
pub created_after_ns: Option<i64>,
pub created_before_ns: Option<i64>,
pub limit: Option<i64>,
pub include_dm_groups: bool,
}

/// Clients manage access to the network, identity, and data store
#[derive(Debug)]
pub struct Client<ApiClient> {
Expand Down Expand Up @@ -495,6 +504,49 @@ where
Ok(group)
}

/// Create a new Direct Message with the default settings
pub async fn create_dm(&self, account_address: String) -> Result<MlsGroup, ClientError> {
tracing::info!("creating dm with address: {}", account_address);

let inbox_id = match self
.find_inbox_id_from_address(account_address.clone())
.await?
{
Some(id) => id,
None => {
return Err(ClientError::Storage(StorageError::NotFound(format!(
"inbox id for address {} not found",
account_address
))))
}
};

self.create_dm_by_inbox_id(inbox_id).await
}

/// Create a new Direct Message with the default settings
pub async fn create_dm_by_inbox_id(
&self,
dm_target_inbox_id: InboxId,
) -> Result<MlsGroup, ClientError> {
tracing::info!("creating dm with {}", dm_target_inbox_id);

let group = MlsGroup::create_dm_and_insert(
self.context.clone(),
GroupMembershipState::Allowed,
dm_target_inbox_id.clone(),
)?;

group
.add_members_by_inbox_id(self, vec![dm_target_inbox_id])
.await?;

// notify any streams of the new group
let _ = self.local_events.send(LocalEvents::NewGroup(group.clone()));

Ok(group)
}

#[cfg(feature = "message-history")]
pub(crate) fn create_sync_group(&self) -> Result<MlsGroup, ClientError> {
tracing::info!("creating sync group");
Expand Down Expand Up @@ -542,17 +594,17 @@ where
/// - created_after_ns: only return groups created after the given timestamp (in nanoseconds)
/// - created_before_ns: only return groups created before the given timestamp (in nanoseconds)
/// - limit: only return the first `limit` groups
pub fn find_groups(
&self,
allowed_states: Option<Vec<GroupMembershipState>>,
created_after_ns: Option<i64>,
created_before_ns: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<MlsGroup>, ClientError> {
pub fn find_groups(&self, params: FindGroupParams) -> Result<Vec<MlsGroup>, ClientError> {
Ok(self
.store()
.conn()?
.find_groups(allowed_states, created_after_ns, created_before_ns, limit)?
.find_groups(
params.allowed_states,
params.created_after_ns,
params.created_before_ns,
params.limit,
params.include_dm_groups,
)?
.into_iter()
.map(|stored_group| {
MlsGroup::new(
Expand Down Expand Up @@ -857,6 +909,7 @@ mod tests {

use crate::{
builder::ClientBuilder,
client::FindGroupParams,
groups::GroupMetadataOptions,
hpke::{decrypt_welcome, encrypt_welcome},
identity::serialize_key_package_hash_ref,
Expand Down Expand Up @@ -958,7 +1011,7 @@ mod tests {
.create_group(None, GroupMetadataOptions::default())
.unwrap();

let groups = client.find_groups(None, None, None, None).unwrap();
let groups = client.find_groups(FindGroupParams::default()).unwrap();
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].group_id, group_1.group_id);
assert_eq!(groups[1].group_id, group_2.group_id);
Expand Down Expand Up @@ -1024,7 +1077,7 @@ mod tests {
let bob_received_groups = bo.sync_welcomes().await.unwrap();
assert_eq!(bob_received_groups.len(), 2);

let bo_groups = bo.find_groups(None, None, None, None).unwrap();
let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap();
let bo_group1 = bo.group(alix_bo_group1.clone().group_id).unwrap();
let bo_messages1 = bo_group1
.find_messages(None, None, None, None, None)
Expand Down Expand Up @@ -1129,7 +1182,7 @@ mod tests {
tracing::info!("Syncing bolas welcomes");
// See if Bola can see that they were added to the group
bola.sync_welcomes().await.unwrap();
let bola_groups = bola.find_groups(None, None, None, None).unwrap();
let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap();
assert_eq!(bola_groups.len(), 1);
let bola_group = bola_groups.first().unwrap();
tracing::info!("Syncing bolas messages");
Expand Down Expand Up @@ -1262,7 +1315,7 @@ mod tests {
bo.sync_welcomes().await.unwrap();

// Bo should have two groups now
let bo_groups = bo.find_groups(None, None, None, None).unwrap();
let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap();
assert_eq!(bo_groups.len(), 2);

// Bo's original key should be deleted
Expand Down
Loading