Skip to content

Commit

Permalink
fix http stream with stream all messages
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Jan 16, 2025
1 parent 569e297 commit 0e723b2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 53 deletions.
47 changes: 2 additions & 45 deletions xmtp_mls/src/subscriptions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use futures::{FutureExt, Stream, StreamExt};
use futures::{Stream, StreamExt};
use prost::Message;
use std::{collections::HashSet, future::Future, pin::Pin, sync::Arc, task::Poll};
use std::{collections::HashSet,sync::Arc};
use tokio::{
sync::{broadcast, oneshot},
task::JoinHandle,
Expand Down Expand Up @@ -45,49 +45,6 @@ impl RetryableError for LocalEventError {
}
}

// Wrappers to deal with Send Bounds
#[cfg(not(target_arch = "wasm32"))]
pub struct FutureWrapper<'a, O> {
inner: Pin<Box<dyn Future<Output = O> + Send + 'a>>,
}

#[cfg(target_arch = "wasm32")]
pub struct FutureWrapper<'a, O> {
inner: Pin<Box<dyn Future<Output = O> + 'a>>,
}

impl<'a, O> Future for FutureWrapper<'a, O> {
type Output = O;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let inner = &mut self.inner;
futures::pin_mut!(inner);
inner.as_mut().poll(cx)
}
}

impl<'a, O> FutureWrapper<'a, O> {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = O> + Send + 'a,
{
Self {
inner: future.boxed(),
}
}

#[cfg(target_arch = "wasm32")]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = O> + 'a,
{
Self {
inner: future.boxed_local(),
}
}
}

#[derive(Debug)]
/// Wrapper around a [`tokio::task::JoinHandle`] but with a oneshot receiver
/// which allows waiting for a `with_callback` stream fn to be ready for stream items.
Expand Down
71 changes: 65 additions & 6 deletions xmtp_mls/src/subscriptions/stream_all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams};
use super::{
stream_conversations::{StreamConversations, WelcomesApiSubscription},
stream_messages::StreamGroupMessages,
FutureWrapper, Result, SubscribeError,
Result, SubscribeError,
};
use xmtp_common::FutureWrapper;
use pin_project_lite::pin_project;

pin_project! {
Expand Down Expand Up @@ -187,6 +188,12 @@ where

let mut conversations = self.messages.group_list().clone();
conversations.insert(new_group.group_id, 1.into());
tracing::info!("------------- switching stream -------------");
for (conversation, cursor_id) in conversations.iter() {
tracing::info!("conversation_id={},cursor_id={}", hex::encode(&conversation), cursor_id);
}
tracing::info!("--------------------------------------------");


let future = StreamGroupMessages::new(self.client, conversations);
let mut this = self.as_mut().project();
Expand Down Expand Up @@ -383,7 +390,7 @@ mod tests {
let alix_group_pointer = alix_group.clone();
crate::spawn(None, async move {
let mut sent = 0;
for _ in 0..50 {
for _ in 0..25 {
alix_group_pointer.send_message(b"spam").await.unwrap();
sent += 1;
xmtp_common::time::sleep(core::time::Duration::from_micros(100)).await;
Expand All @@ -397,7 +404,59 @@ mod tests {
let caro_id = caro.inbox_id().to_string();
crate::spawn(None, async move {
let caro = &caro_id;
for i in 0..50 {
for i in 0..5 {
let new_group = eve
.create_group(None, GroupMetadataOptions::default())
.unwrap();
new_group.add_members_by_inbox_id(&[caro]).await.unwrap();
tracing::info!("\n\n EVE SENDING {i} \n\n");
new_group
.send_message(b"spam from new group")
.await
.unwrap();
}
});

let mut messages = Vec::new();
let _ = tokio::time::timeout(core::time::Duration::from_secs(5), async {
futures::pin_mut!(stream);
loop {
if messages.len() < 30 {
if let Some(Ok(msg)) = stream.next().await {
tracing::info!(
message_id = hex::encode(&msg.id),
sender_inbox_id = msg.sender_inbox_id,
sender_installation_id = hex::encode(&msg.sender_installation_id),
group_id = hex::encode(&msg.group_id),
"GOT MESSAGE {}, text={}",
messages.len(),
String::from_utf8_lossy(msg.decrypted_message_bytes.as_slice())
);
messages.push(msg)
}
} else {
break;
}
}
})
.await;

tracing::info!("Total Messages: {}", messages.len());
assert_eq!(messages.len(), 30);
}


#[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]
async fn test_stream_all_messages_detached_group_changes() {
let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let eve = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
tracing::info!(inbox_id = eve.inbox_id(), "EVE");
let stream = caro.stream_all_messages(None).await.unwrap();

let caro_id = caro.inbox_id().to_string();
crate::spawn(None, async move {
let caro = &caro_id;
for i in 0..5 {
let new_group = eve
.create_group(None, GroupMetadataOptions::default())
.unwrap();
Expand All @@ -411,10 +470,10 @@ mod tests {
});

let mut messages = Vec::new();
let _ = tokio::time::timeout(core::time::Duration::from_secs(60), async {
let _ = tokio::time::timeout(core::time::Duration::from_secs(5), async {
futures::pin_mut!(stream);
loop {
if messages.len() < 100 {
if messages.len() < 5 {
if let Some(Ok(msg)) = stream.next().await {
tracing::info!(
message_id = hex::encode(&msg.id),
Expand All @@ -435,6 +494,6 @@ mod tests {
.await;

tracing::info!("Total Messages: {}", messages.len());
assert_eq!(messages.len(), 100);
assert_eq!(messages.len(), 5);
}
}
3 changes: 2 additions & 1 deletion xmtp_mls/src/subscriptions/stream_conversations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use xmtp_proto::{
xmtp::mls::api::v1::{welcome_message, WelcomeMessage},
};

use super::{FutureWrapper, LocalEvents, Result, SubscribeError};
use super::{LocalEvents, Result, SubscribeError};
use xmtp_common::FutureWrapper;

#[derive(thiserror::Error, Debug)]
pub enum ConversationStreamError {
Expand Down
3 changes: 2 additions & 1 deletion xmtp_mls/src/subscriptions/stream_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
task::{Context, Poll},
};

use super::{FutureWrapper, Result, SubscribeError};
use super::{Result, SubscribeError};
use crate::{
api::GroupFilter,
groups::{scoped_client::ScopedGroupClient, MlsGroup},
Expand All @@ -15,6 +15,7 @@ use crate::{
},
XmtpOpenMlsProvider,
};
use xmtp_common::FutureWrapper;
use futures::Stream;
use pin_project_lite::pin_project;
use xmtp_common::{retry_async, Retry};
Expand Down

0 comments on commit 0e723b2

Please sign in to comment.