Skip to content

Commit

Permalink
dont execute callbacks when dm group welcomes are streamed
Browse files Browse the repository at this point in the history
  • Loading branch information
cameronvoell committed Sep 19, 2024
1 parent f4fd05e commit b81ebf0
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
9 changes: 6 additions & 3 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,14 +855,17 @@ impl FfiConversations {

pub async fn stream(&self, callback: Box<dyn FfiConversationCallback>) -> FfiStreamCloser {
let client = self.inner_client.clone();
let handle =
RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| {
let handle = RustXmtpClient::stream_conversations_with_callback(
client.clone(),
move |convo| {
callback.on_conversation(Arc::new(FfiGroup {
inner_client: client.clone(),
group_id: convo.group_id,
created_at_ns: convo.created_at_ns,
}))
});
},
false,
);

FfiStreamCloser::new(handle)
}
Expand Down
9 changes: 6 additions & 3 deletions bindings_node/src/conversations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ impl NapiConversations {
let tsfn: ThreadsafeFunction<NapiGroup, ErrorStrategy::CalleeHandled> =
callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?;
let client = self.inner_client.clone();
let stream_closer =
RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| {
let stream_closer = RustXmtpClient::stream_conversations_with_callback(
client.clone(),
move |convo| {
tsfn.call(
Ok(NapiGroup::new(
client.clone(),
Expand All @@ -207,7 +208,9 @@ impl NapiConversations {
)),
ThreadsafeFunctionCallMode::Blocking,
);
});
},
false,
);

Ok(NapiStreamCloser::new(stream_closer))
}
Expand Down
82 changes: 77 additions & 5 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage;
use crate::{
api::GroupFilter,
client::{extract_welcome_message, ClientError},
groups::{extract_group_id, GroupError, MlsGroup},
groups::{extract_group_id, group_metadata::ConversationType, GroupError, MlsGroup},
retry::Retry,
retry_async,
storage::{group::StoredGroup, group_message::StoredGroupMessage},
Expand Down Expand Up @@ -234,6 +234,7 @@ where
pub fn stream_conversations_with_callback(
client: Arc<Client<ApiClient>>,
mut convo_callback: impl FnMut(MlsGroup) + Send + 'static,
include_dm: bool,
) -> StreamHandle<Result<(), ClientError>> {
let (tx, rx) = oneshot::channel();

Expand All @@ -242,7 +243,12 @@ where
futures::pin_mut!(stream);
let _ = tx.send(());
while let Some(convo) = stream.next().await {
convo_callback(convo)
let provider = client.context.mls_provider()?;
// Don't execute callback for dms unless include_dm is true
if include_dm || convo.metadata(provider)?.conversation_type != ConversationType::Dm
{
convo_callback(convo)
}
}
log::debug!("`stream_conversations` stream ended, dropping stream");
Ok(())
Expand Down Expand Up @@ -726,12 +732,15 @@ mod tests {
let notify = Delivery::new(None);
let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone());

let closer =
Client::<TestClient>::stream_conversations_with_callback(alix.clone(), move |g| {
let closer = Client::<TestClient>::stream_conversations_with_callback(
alix.clone(),
move |g| {
let mut groups = groups_pointer.lock();
groups.push(g);
notify_pointer.notify_one();
});
},
false,
);

alix.create_group(None, GroupMetadataOptions::default())
.unwrap();
Expand Down Expand Up @@ -763,4 +772,67 @@ mod tests {

closer.handle.abort();
}

#[tokio::test(flavor = "multi_thread")]
async fn test_dm_creation() {
let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);

let groups = Arc::new(Mutex::new(Vec::new()));
// Wait for 2 seconds for the group creation to be streamed
let notify = Delivery::new(Some(std::time::Duration::from_secs(1)));
let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone());

// Start a stream with enableDm set to false
let closer = Client::<TestClient>::stream_conversations_with_callback(
alix.clone(),
move |g| {
let mut groups = groups_pointer.lock();
groups.push(g);
notify_pointer.notify_one();
},
false,
);

alix.create_dm_by_inbox_id(bo.inbox_id()).await.unwrap();

let result = notify.wait_for_delivery().await;
assert!(result.is_err(), "Stream unexpectedly received a DM group");

closer.handle.abort();

// Start a stream with enableDm set to true
let groups = Arc::new(Mutex::new(Vec::new()));
// Wait for 2 seconds for the group creation to be streamed
let notify = Delivery::new(Some(std::time::Duration::from_secs(1)));
let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone());
let closer = Client::<TestClient>::stream_conversations_with_callback(
alix.clone(),
move |g| {
let mut groups = groups_pointer.lock();
groups.push(g);
notify_pointer.notify_one();
},
true,
);

alix.create_dm_by_inbox_id(bo.inbox_id()).await.unwrap();
notify.wait_for_delivery().await.unwrap();
{
let grps = groups.lock();
assert_eq!(grps.len(), 1);
}

let dm = bo.create_dm_by_inbox_id(alix.inbox_id()).await.unwrap();
dm.add_members_by_inbox_id(&bo, vec![alix.inbox_id()])
.await
.unwrap();
notify.wait_for_delivery().await.unwrap();
{
let grps = groups.lock();
assert_eq!(grps.len(), 2);
}

closer.handle.abort();
}
}

0 comments on commit b81ebf0

Please sign in to comment.