Skip to content

Commit

Permalink
perfect new http stream primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Dec 18, 2024
1 parent f8d93ea commit 36d67fc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 90 deletions.
149 changes: 96 additions & 53 deletions xmtp_api_http/src/http_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
use crate::util::GrpcResponse;
use futures::{
stream::{self, Stream, StreamExt},
Future,
Future, FutureExt,
};
use reqwest::Response;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Deserializer;
use std::pin::Pin;
use std::{marker::PhantomData, pin::Pin, task::Poll};
use xmtp_proto::{Error, ErrorKind};

#[derive(Deserialize, Serialize, Debug)]
Expand Down Expand Up @@ -38,26 +38,103 @@ where
use futures::task::Poll::*;
use HttpPostStream::*;
match self.as_mut().get_mut() {
NotStarted(ref mut f) => {
tracing::info!("Polling");
let f = std::pin::pin!(f);
match f.poll(cx) {
Ready(response) => {
let s = response.unwrap().bytes_stream();
self.set(Self::Started(Box::pin(s.boxed())));
self.poll_next(cx)
}
Pending => {
// cx.waker().wake_by_ref();
Pending
}
NotStarted(ref mut f) => match f.poll_unpin(cx) {
Ready(response) => {
let s = response.unwrap().bytes_stream();
self.set(Self::Started(Box::pin(s.boxed())));
self.poll_next(cx)
}
}
Pending => Pending,
},
Started(s) => s.poll_next_unpin(cx),
}
}
}

struct GrpcHttpStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
{
http: HttpPostStream<F>,
remaining: Vec<u8>,
_marker: PhantomData<R>,
}

impl<F, R> GrpcHttpStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>> + Unpin,
R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static,
{
fn new(request: F) -> Self
where
F: Future<Output = Result<Response, reqwest::Error>>,
{
let mut http = HttpPostStream::NotStarted(request);
// we need to poll the future once to establish the initial POST request
// it will almost always be pending
let _ = http.next().now_or_never();
Self {
http,
remaining: vec![],
_marker: PhantomData::<R>,
}
}
}

impl<F, R> Stream for GrpcHttpStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>> + Unpin,
R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static,
{
type Item = Result<R, Error>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
use futures::task::Poll::*;
let this = self.get_mut();
match this.http.poll_next_unpin(cx) {
Ready(Some(bytes)) => {
let bytes = bytes.map_err(|e| {
Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string())
})?;
let bytes = &[this.remaining.as_ref(), bytes.as_ref()].concat();
let de = Deserializer::from_slice(bytes);
let mut stream = de.into_iter::<GrpcResponse<R>>();
'messages: loop {
tracing::debug!("Waiting on next response ...");
let response = stream.next();
let res = match response {
Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
Some(Ok(GrpcResponse::Err(e))) => {
Err(Error::new(ErrorKind::MlsError).with(e.message))
}
Some(Err(e)) => {
if e.is_eof() {
this.remaining = (&**bytes)[stream.byte_offset()..].to_vec();
tracing::info!("PENDING");
return Pending;
} else {
Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
}
}
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => return Ready(None),
};
return Ready(Some(res));
}
}
Ready(None) => Ready(None),
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
}
}

#[cfg(target_arch = "wasm32")]
pub fn create_grpc_stream<
T: Serialize + Send + 'static,
Expand All @@ -73,7 +150,7 @@ pub fn create_grpc_stream<
#[cfg(not(target_arch = "wasm32"))]
pub fn create_grpc_stream<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static,
>(
request: T,
endpoint: String,
Expand All @@ -84,46 +161,12 @@ pub fn create_grpc_stream<

pub fn create_grpc_stream_inner<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> impl Stream<Item = Result<R, Error>> {
let request = http_client.post(endpoint).json(&request).send();
let http_stream = HttpPostStream::NotStarted(request);

async_stream::stream! {
tracing::info!("spawning grpc http stream");
let mut remaining = vec![];
for await bytes in http_stream {
let bytes = bytes
.map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?;
let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
let de = Deserializer::from_slice(bytes);
let mut stream = de.into_iter::<GrpcResponse<R>>();
'messages: loop {
tracing::debug!("Waiting on next response ...");
let response = stream.next();
let res = match response {
Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
Some(Ok(GrpcResponse::Err(e))) => {
Err(Error::new(ErrorKind::MlsError).with(e.message))
}
Some(Err(e)) => {
if e.is_eof() {
remaining = (&**bytes)[stream.byte_offset()..].to_vec();
break 'messages;
} else {
Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
}
}
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => break 'messages,
};
yield res;
}
}
}
GrpcHttpStream::new(request)
}
42 changes: 5 additions & 37 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,23 +617,8 @@ pub(crate) mod tests {
.create_group(None, GroupMetadataOptions::default())
.unwrap();

// FIXME:insipx we run into an issue where the reqwest::post().send() request
// blocks the executor and we cannot progress the runtime if we dont `tokio::spawn` this.
// A solution might be to use `hyper` instead, and implement a custom connection pool with
// `deadpool`. This is a bit more work but shouldn't be too complicated since
// we're only using `post` requests. It would be nice for all streams to work
// w/o spawning a separate task.
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
let bob_ptr = bob.clone();
crate::spawn(None, async move {
let bob_stream = bob_ptr.stream_conversations(None).await.unwrap();
futures::pin_mut!(bob_stream);
while let Some(item) = bob_stream.next().await {
let _ = tx.send(item);
}
});

let stream = bob.stream_conversations(None).await.unwrap();
futures::pin_mut!(stream);
let group_id = alice_bob_group.group_id.clone();
alice_bob_group
.add_members_by_inbox_id(&[bob.inbox_id()])
Expand All @@ -644,7 +629,7 @@ pub(crate) mod tests {
assert_eq!(bob_received_groups.group_id, group_id);
}

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]
#[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))]
async fn test_stream_messages() {
xmtp_common::logger();
let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
Expand All @@ -665,33 +650,16 @@ pub(crate) mod tests {
.unwrap();
let bob_group = bob_group.first().unwrap();

let notify = Delivery::new(None);
let notify_ptr = notify.clone();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
crate::spawn(None, async move {
let stream = alice_group.stream().await.unwrap();
futures::pin_mut!(stream);
while let Some(item) = stream.next().await {
let _ = tx.send(item);
notify_ptr.notify_one();
}
});

let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
// let stream = alice_group.stream().await.unwrap();
let stream = alice_group.stream().await.unwrap();
futures::pin_mut!(stream);
bob_group.send_message(b"hello").await.unwrap();
tracing::debug!("Bob Sent Message!, waiting for delivery");
// notify.wait_for_delivery().await.unwrap();

let message = stream.next().await.unwrap().unwrap();
assert_eq!(message.decrypted_message_bytes, b"hello");

bob_group.send_message(b"hello2").await.unwrap();
// notify.wait_for_delivery().await.unwrap();
let message = stream.next().await.unwrap().unwrap();
assert_eq!(message.decrypted_message_bytes, b"hello2");

// assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id);
}

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]
Expand Down

0 comments on commit 36d67fc

Please sign in to comment.