From 36d67fcb609d730dbbd641cd8df801ff310f3ed4 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Tue, 17 Dec 2024 16:45:22 -0500 Subject: [PATCH] perfect new http stream primitives --- xmtp_api_http/src/http_stream.rs | 149 ++++++++++++++++++++----------- xmtp_mls/src/subscriptions.rs | 42 ++------- 2 files changed, 101 insertions(+), 90 deletions(-) diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs index 0a5f83014..cdfe80bd2 100644 --- a/xmtp_api_http/src/http_stream.rs +++ b/xmtp_api_http/src/http_stream.rs @@ -3,12 +3,12 @@ use crate::util::GrpcResponse; use futures::{ stream::{self, Stream, StreamExt}, - Future, + Future, FutureExt, }; use reqwest::Response; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Deserializer; -use std::pin::Pin; +use std::{marker::PhantomData, pin::Pin, task::Poll}; use xmtp_proto::{Error, ErrorKind}; #[derive(Deserialize, Serialize, Debug)] @@ -38,26 +38,103 @@ where use futures::task::Poll::*; use HttpPostStream::*; match self.as_mut().get_mut() { - NotStarted(ref mut f) => { - tracing::info!("Polling"); - let f = std::pin::pin!(f); - match f.poll(cx) { - Ready(response) => { - let s = response.unwrap().bytes_stream(); - self.set(Self::Started(Box::pin(s.boxed()))); - self.poll_next(cx) - } - Pending => { - // cx.waker().wake_by_ref(); - Pending - } + NotStarted(ref mut f) => match f.poll_unpin(cx) { + Ready(response) => { + let s = response.unwrap().bytes_stream(); + self.set(Self::Started(Box::pin(s.boxed()))); + self.poll_next(cx) } - } + Pending => Pending, + }, Started(s) => s.poll_next_unpin(cx), } } } +struct GrpcHttpStream +where + F: Future>, +{ + http: HttpPostStream, + remaining: Vec, + _marker: PhantomData, +} + +impl GrpcHttpStream +where + F: Future> + Unpin, + R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, +{ + fn new(request: F) -> Self + where + F: Future>, + { + let mut http = HttpPostStream::NotStarted(request); + // we need to poll the future once to establish the initial POST request + // it will almost always be pending + let _ = http.next().now_or_never(); + Self { + http, + remaining: vec![], + _marker: PhantomData::, + } + } +} + +impl Stream for GrpcHttpStream +where + F: Future> + Unpin, + R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, +{ + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use futures::task::Poll::*; + let this = self.get_mut(); + match this.http.poll_next_unpin(cx) { + Ready(Some(bytes)) => { + let bytes = bytes.map_err(|e| { + Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()) + })?; + let bytes = &[this.remaining.as_ref(), bytes.as_ref()].concat(); + let de = Deserializer::from_slice(bytes); + let mut stream = de.into_iter::>(); + 'messages: loop { + tracing::debug!("Waiting on next response ..."); + let response = stream.next(); + let res = match response { + Some(Ok(GrpcResponse::Ok(response))) => Ok(response), + Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), + Some(Ok(GrpcResponse::Err(e))) => { + Err(Error::new(ErrorKind::MlsError).with(e.message)) + } + Some(Err(e)) => { + if e.is_eof() { + this.remaining = (&**bytes)[stream.byte_offset()..].to_vec(); + tracing::info!("PENDING"); + return Pending; + } else { + Err(Error::new(ErrorKind::MlsError).with(e.to_string())) + } + } + Some(Ok(GrpcResponse::Empty {})) => continue 'messages, + None => return Ready(None), + }; + return Ready(Some(res)); + } + } + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + #[cfg(target_arch = "wasm32")] pub fn create_grpc_stream< T: Serialize + Send + 'static, @@ -73,7 +150,7 @@ pub fn create_grpc_stream< #[cfg(not(target_arch = "wasm32"))] pub fn create_grpc_stream< T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, + R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, >( request: T, endpoint: String, @@ -84,46 +161,12 @@ pub fn create_grpc_stream< pub fn create_grpc_stream_inner< T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, + R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, >( request: T, endpoint: String, http_client: reqwest::Client, ) -> impl Stream> { let request = http_client.post(endpoint).json(&request).send(); - let http_stream = HttpPostStream::NotStarted(request); - - async_stream::stream! { - tracing::info!("spawning grpc http stream"); - let mut remaining = vec![]; - for await bytes in http_stream { - let bytes = bytes - .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; - let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); - let de = Deserializer::from_slice(bytes); - let mut stream = de.into_iter::>(); - 'messages: loop { - tracing::debug!("Waiting on next response ..."); - let response = stream.next(); - let res = match response { - Some(Ok(GrpcResponse::Ok(response))) => Ok(response), - Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), - Some(Ok(GrpcResponse::Err(e))) => { - Err(Error::new(ErrorKind::MlsError).with(e.message)) - } - Some(Err(e)) => { - if e.is_eof() { - remaining = (&**bytes)[stream.byte_offset()..].to_vec(); - break 'messages; - } else { - Err(Error::new(ErrorKind::MlsError).with(e.to_string())) - } - } - Some(Ok(GrpcResponse::Empty {})) => continue 'messages, - None => break 'messages, - }; - yield res; - } - } - } + GrpcHttpStream::new(request) } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 22386896e..37b0f7177 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -617,23 +617,8 @@ pub(crate) mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - // FIXME:insipx we run into an issue where the reqwest::post().send() request - // blocks the executor and we cannot progress the runtime if we dont `tokio::spawn` this. - // A solution might be to use `hyper` instead, and implement a custom connection pool with - // `deadpool`. This is a bit more work but shouldn't be too complicated since - // we're only using `post` requests. It would be nice for all streams to work - // w/o spawning a separate task. - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - let bob_ptr = bob.clone(); - crate::spawn(None, async move { - let bob_stream = bob_ptr.stream_conversations(None).await.unwrap(); - futures::pin_mut!(bob_stream); - while let Some(item) = bob_stream.next().await { - let _ = tx.send(item); - } - }); - + let stream = bob.stream_conversations(None).await.unwrap(); + futures::pin_mut!(stream); let group_id = alice_bob_group.group_id.clone(); alice_bob_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -644,7 +629,7 @@ pub(crate) mod tests { assert_eq!(bob_received_groups.group_id, group_id); } - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_messages() { xmtp_common::logger(); let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -665,33 +650,16 @@ pub(crate) mod tests { .unwrap(); let bob_group = bob_group.first().unwrap(); - let notify = Delivery::new(None); - let notify_ptr = notify.clone(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - crate::spawn(None, async move { - let stream = alice_group.stream().await.unwrap(); - futures::pin_mut!(stream); - while let Some(item) = stream.next().await { - let _ = tx.send(item); - notify_ptr.notify_one(); - } - }); - - let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - // let stream = alice_group.stream().await.unwrap(); + let stream = alice_group.stream().await.unwrap(); futures::pin_mut!(stream); bob_group.send_message(b"hello").await.unwrap(); - tracing::debug!("Bob Sent Message!, waiting for delivery"); - // notify.wait_for_delivery().await.unwrap(); + let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello"); bob_group.send_message(b"hello2").await.unwrap(); - // notify.wait_for_delivery().await.unwrap(); let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello2"); - - // assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]