diff --git a/Cargo.lock b/Cargo.lock index d9d8a7187..59f9f8b3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5410,6 +5410,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] @@ -6512,6 +6513,7 @@ dependencies = [ "thiserror", "tls_codec 0.4.1", "tokio", + "tokio-stream", "toml", "tracing", "tracing-flame", diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index c37cefae1..26ac6b9ae 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -50,6 +50,7 @@ smart-default = "0.7.1" thiserror = { workspace = true } tls_codec = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio-stream = { version = "0.1", features = ["sync"] } toml = "0.8.4" xmtp_cryptography = { workspace = true } xmtp_id = { path = "../xmtp_id" } diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index ff89d4f7d..f201bf2fa 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -14,6 +14,7 @@ use openmls::{ use openmls_traits::OpenMlsProvider; use prost::EncodeError; use thiserror::Error; +use tokio::sync::broadcast; use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; use xmtp_id::{ @@ -46,6 +47,7 @@ use crate::{ refresh_state::EntityKind, sql_key_store, EncryptedMessageStore, StorageError, }, + subscriptions::LocalEvents, verified_key_package_v2::{KeyPackageVerificationError, VerifiedKeyPackageV2}, xmtp_openmls_provider::XmtpOpenMlsProvider, Fetch, XmtpApi, @@ -207,6 +209,7 @@ pub struct Client { pub(crate) api_client: ApiClientWrapper, pub(crate) context: Arc, pub(crate) history_sync_url: Option, + pub(crate) local_events: broadcast::Sender, } /// The local context a XMTP MLS needs to function: @@ -261,10 +264,12 @@ where history_sync_url: Option, ) -> Self { let context = XmtpMlsLocalContext { identity, store }; + let (tx, _) = broadcast::channel(10); Self { api_client, context: Arc::new(context), history_sync_url, + local_events: tx, } } @@ -339,6 +344,9 @@ where ) .map_err(Box::new)?; + // notify any streams of the new group + let _ = self.local_events.send(LocalEvents::NewGroup(group.clone())); + Ok(group) } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 2f338ad81..d30ff458c 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -10,6 +10,7 @@ use std::{ use futures::{Stream, StreamExt}; use prost::Message; use tokio::sync::oneshot::{self, Sender}; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; use crate::{ @@ -22,6 +23,14 @@ use crate::{ Client, XmtpApi, }; +/// Events local to this client +/// are broadcast across all senders/receivers of streams +#[derive(Clone, Debug)] +pub(crate) enum LocalEvents { + // a new group was created + NewGroup(MlsGroup), +} + // TODO simplify FfiStreamCloser + StreamCloser duplication pub struct StreamCloser { pub close_fn: Arc>>>, @@ -117,6 +126,19 @@ where pub async fn stream_conversations( &self, ) -> Result + Send + '_>>, ClientError> { + let event_queue = + tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe()); + + let event_queue = event_queue.filter_map(|event| async move { + match event { + Ok(LocalEvents::NewGroup(g)) => Some(g), + Err(BroadcastStreamRecvError::Lagged(missed)) => { + log::warn!("Missed {missed} messages due to local event queue lagging"); + None + } + } + }); + let installation_key = self.installation_public_key(); let id_cursor = 0; @@ -141,7 +163,7 @@ where } }); - Ok(Box::pin(stream)) + Ok(Box::pin(futures::stream::select(stream, event_queue))) } pub(crate) async fn stream_messages( @@ -365,6 +387,7 @@ mod tests { }; use futures::StreamExt; use std::sync::{Arc, Mutex}; + use tokio::sync::Notify; use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient; use xmtp_cryptography::utils::generate_local_wallet; @@ -546,4 +569,60 @@ mod tests { let messages = messages.lock().unwrap(); assert_eq!(messages.len(), 5); } + + #[tokio::test(flavor = "multi_thread")] + async fn test_self_group_creation() { + let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + + let groups = Arc::new(Mutex::new(Vec::new())); + let notify = Arc::new(Notify::new()); + let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone()); + + let closer = Client::::stream_conversations_with_callback( + alix.clone(), + move |g| { + let mut groups = groups_pointer.lock().unwrap(); + groups.push(g); + notify_pointer.notify_one(); + }, + || {}, + ) + .unwrap(); + + alix.create_group(None, GroupMetadataOptions::default()) + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(60), async { + notify.notified().await + }) + .await + .expect("Stream never received group"); + + { + let grps = groups.lock().unwrap(); + assert_eq!(grps.len(), 1); + } + + let group = bo + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + group + .add_members_by_inbox_id(&bo, vec![alix.inbox_id()]) + .await + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(60), async { + notify.notified().await + }) + .await + .expect("Stream never received group"); + + { + let grps = groups.lock().unwrap(); + assert_eq!(grps.len(), 2); + } + + closer.end(); + } }