diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 05b905a5d..a235fa67c 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -7,7 +7,7 @@ use std::{ use futures::{Stream, StreamExt}; use prost::Message; use tokio::{ - sync::mpsc::{self, UnboundedSender}, + sync::mpsc::self, task::JoinHandle, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -19,7 +19,7 @@ use crate::{ groups::{extract_group_id, GroupError, MlsGroup}, retry::Retry, retry_async, - storage::group_message::StoredGroupMessage, + storage::{group_message::StoredGroupMessage, group::StoredGroup}, Client, XmtpApi, }; @@ -29,6 +29,17 @@ pub(crate) struct MessagesStreamInfo { pub cursor: u64, } +impl From for (Vec, MessagesStreamInfo) { + fn from(group: StoredGroup) -> (Vec, MessagesStreamInfo) { + ( + group.id, + MessagesStreamInfo { + convo_created_at_ns: group.created_at_ns, + cursor: 0 + }) + } +} + impl Client where ApiClient: XmtpApi, @@ -215,21 +226,7 @@ where client.sync_welcomes().await?; - let current_groups = client.store().conn()?.find_groups(None, None, None, None)?; - - let mut group_id_to_info: HashMap, MessagesStreamInfo> = current_groups - - .into_iter() - .map(|group| { - ( - group.id.clone(), - MessagesStreamInfo { - convo_created_at_ns: group.created_at_ns, - cursor: 0, - }, - ) - }) - .collect(); + let mut group_id_to_info = client.store().conn()?.find_groups(None, None, None, None)?.into_iter().map(Into::into).collect::, MessagesStreamInfo>>(); tokio::spawn(async move { let client = client.clone(); @@ -405,15 +402,16 @@ mod tests { .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); let messages_clone = messages.clone(); + let notify = Arc::new(tokio::sync::Notify::new()); + let notify_pointer = notify.clone(); let handle = Client::::stream_all_messages_with_callback(caro.clone(), move |message| { let text = String::from_utf8(message.decrypted_message_bytes.clone()) .unwrap_or("".to_string()); println!("Received: {}", text); + notify_pointer.notify_one(); (*messages_clone.lock().unwrap()).push(message); }); @@ -421,7 +419,8 @@ mod tests { .send_message("first".as_bytes(), &alix) .await .unwrap(); - + notify.notified().await; + let bo_group = bo .create_group(None, GroupMetadataOptions::default()) .unwrap(); @@ -434,11 +433,13 @@ mod tests { .send_message("second".as_bytes(), &bo) .await .unwrap(); + notify.notified().await; alix_group .send_message("third".as_bytes(), &alix) .await .unwrap(); + notify.notified().await; let alix_group_2 = alix .create_group(None, GroupMetadataOptions::default()) @@ -447,19 +448,18 @@ mod tests { .add_members_by_inbox_id(&alix, vec![caro.inbox_id()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(300)).await; alix_group .send_message("fourth".as_bytes(), &alix) .await .unwrap(); + notify.notified().await; alix_group_2 .send_message("fifth".as_bytes(), &alix) .await .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + notify.notified().await; { let messages = messages.lock().unwrap(); @@ -468,7 +468,7 @@ mod tests { let a = handle.abort_handle(); a.abort(); - handle.await.unwrap(); + let _ = handle.await; assert!(a.is_finished()); alix_group