diff --git a/Cargo.lock b/Cargo.lock index 650cd2cb6..cb96d59af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7246,10 +7246,10 @@ dependencies = [ name = "xmtp_api_http" version = "0.1.0" dependencies = [ - "async-stream", "async-trait", "bytes", "futures", + "pin-project-lite", "reqwest 0.12.9", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index b46b42352..26a7dfd73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,7 @@ ctor = "0.2" ed25519 = "2.2.3" ed25519-dalek = { version = "2.1.1", features = ["zeroize"] } ethers = { version = "2.0", default-features = false } -futures = "0.3.30" -futures-core = "0.3.30" +futures = { version = "0.3.30", default-features = false } getrandom = { version = "0.2", default-features = false } hex = "0.4.3" hkdf = "0.12.3" @@ -61,16 +60,7 @@ thiserror = "2.0" tls_codec = "0.4.1" tokio = { version = "1.35.1", default-features = false } uuid = "1.10" -wasm-timer = "0.2" web-time = "1.1" -# Changing this version and rustls may potentially break the android build. Use Caution. -# Test with Android and Swift first. -# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier -# Until then, always test agains iOS/Android after updating these dependencies & making a PR -# Related Issues: -# - https://github.com/seanmonstar/reqwest/issues/2159 -# - https://github.com/hyperium/tonic/pull/1974 -# - https://github.com/rustls/rustls-platform-verifier/issues/58 bincode = "1.3" console_error_panic_hook = "0.1" const_format = "0.2" @@ -87,6 +77,14 @@ openssl = { version = "0.10", features = ["vendored"] } openssl-sys = { version = "0.9", features = ["vendored"] } parking_lot = "0.12.3" sqlite-web = "0.0.1" +# Changing this version and rustls may potentially break the android build. Use Caution. +# Test with Android and Swift first. +# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier +# Until then, always test agains iOS/Android after updating these dependencies & making a PR +# Related Issues: +# - https://github.com/seanmonstar/reqwest/issues/2159 +# - https://github.com/hyperium/tonic/pull/1974 +# - https://github.com/rustls/rustls-platform-verifier/issues/58 tonic = { version = "0.12", default-features = false } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false } @@ -101,7 +99,7 @@ criterion = { version = "0.5", features = [ "html_reports", "async_tokio", ]} - once_cell = "1.2" +once_cell = "1.2" # Internal Crate Dependencies xmtp_api_grpc = { path = "xmtp_api_grpc" } diff --git a/xmtp_api_grpc/Cargo.toml b/xmtp_api_grpc/Cargo.toml index b69a0d0c4..67ea6fb16 100644 --- a/xmtp_api_grpc/Cargo.toml +++ b/xmtp_api_grpc/Cargo.toml @@ -8,7 +8,7 @@ version.workspace = true async-stream.workspace = true async-trait = "0.1" base64.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true prost = { workspace = true, features = ["prost-derive"] } tokio = { workspace = true, features = ["macros", "time"] } diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index dae2a490c..09a6a9214 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -8,17 +8,17 @@ license.workspace = true crate-type = ["cdylib", "rlib"] [dependencies] -async-stream.workspace = true futures = { workspace = true } tracing.workspace = true reqwest = { version = "0.12.5", features = ["json", "stream"] } serde = { workspace = true } serde_json = { workspace = true } -thiserror = "2.0" +thiserror.workspace = true tokio = { workspace = true, features = ["sync", "rt", "macros"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } async-trait = "0.1" bytes = "1.9" +pin-project-lite = "0.2.15" [dev-dependencies] xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] } diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs index cdfe80bd2..f15e045e1 100644 --- a/xmtp_api_http/src/http_stream.rs +++ b/xmtp_api_http/src/http_stream.rs @@ -3,7 +3,7 @@ use crate::util::GrpcResponse; use futures::{ stream::{self, Stream, StreamExt}, - Future, FutureExt, + Future, }; use reqwest::Response; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -16,90 +16,113 @@ pub(crate) struct SubscriptionItem { pub result: T, } -enum HttpPostStream -where - F: Future>, -{ - NotStarted(F), - // NotStarted(Box>>), - Started(Pin> + Unpin + Send>>), +#[cfg(target_arch = "wasm32")] +pub type BytesStream = stream::LocalBoxStream<'static, Result>; + +// #[cfg(not(target_arch = "wasm32"))] +// pub type BytesStream = Pin> + Send>>; + +#[cfg(not(target_arch = "wasm32"))] +pub type BytesStream = stream::BoxStream<'static, Result>; + +pin_project_lite::pin_project! { + #[project = PostStreamProject] + enum HttpPostStream { + NotStarted{#[pin] fut: F}, + // `Reqwest::bytes_stream` returns `impl Stream` rather than a type generic, + // so we can't use a type generic here + // this makes wasm a bit tricky. + Started { + #[pin] http: BytesStream, + remaining: Vec, + _marker: PhantomData, + }, + } } -impl Stream for HttpPostStream +impl Stream for HttpPostStream where - F: Future> + Unpin, + F: Future>, + for<'de> R: Send + Deserialize<'de>, { - type Item = Result; + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { use futures::task::Poll::*; - use HttpPostStream::*; - match self.as_mut().get_mut() { - NotStarted(ref mut f) => match f.poll_unpin(cx) { + match self.as_mut().project() { + PostStreamProject::NotStarted { fut } => match fut.poll(cx) { Ready(response) => { let s = response.unwrap().bytes_stream(); - self.set(Self::Started(Box::pin(s.boxed()))); - self.poll_next(cx) + self.set(Self::started(s)); + self.as_mut().poll_next(cx) + } + Pending => { + cx.waker().wake_by_ref(); + Pending } - Pending => Pending, }, - Started(s) => s.poll_next_unpin(cx), + PostStreamProject::Started { + ref mut http, + ref mut remaining, + .. + } => { + let mut pinned = std::pin::pin!(http); + let next = pinned.as_mut().poll_next(cx); + Self::on_bytes(next, remaining, cx) + } } } } -struct GrpcHttpStream +impl HttpPostStream where - F: Future>, + R: Send, { - http: HttpPostStream, - remaining: Vec, - _marker: PhantomData, -} + #[cfg(not(target_arch = "wasm32"))] + fn started( + http: impl Stream> + Send + 'static, + ) -> Self { + Self::Started { + http: http.boxed(), + remaining: Vec::new(), + _marker: PhantomData, + } + } -impl GrpcHttpStream -where - F: Future> + Unpin, - R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, -{ - fn new(request: F) -> Self - where - F: Future>, - { - let mut http = HttpPostStream::NotStarted(request); - // we need to poll the future once to establish the initial POST request - // it will almost always be pending - let _ = http.next().now_or_never(); - Self { - http, - remaining: vec![], - _marker: PhantomData::, + #[cfg(target_arch = "wasm32")] + fn started(http: impl Stream> + 'static) -> Self { + Self::Started { + http: http.boxed_local(), + remaining: Vec::new(), + _marker: PhantomData, } } } -impl Stream for GrpcHttpStream +impl HttpPostStream where - F: Future> + Unpin, - R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, { - type Item = Result; + fn new(request: F) -> Self { + Self::NotStarted { fut: request } + } - fn poll_next( - self: std::pin::Pin<&mut Self>, + fn on_bytes( + p: Poll>>, + remaining: &mut Vec, cx: &mut std::task::Context<'_>, - ) -> Poll> { + ) -> Poll::Item>> { use futures::task::Poll::*; - let this = self.get_mut(); - match this.http.poll_next_unpin(cx) { + match p { Ready(Some(bytes)) => { let bytes = bytes.map_err(|e| { Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()) })?; - let bytes = &[this.remaining.as_ref(), bytes.as_ref()].concat(); + let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); let de = Deserializer::from_slice(bytes); let mut stream = de.into_iter::>(); 'messages: loop { @@ -113,8 +136,7 @@ where } Some(Err(e)) => { if e.is_eof() { - this.remaining = (&**bytes)[stream.byte_offset()..].to_vec(); - tracing::info!("PENDING"); + *remaining = (&**bytes)[stream.byte_offset()..].to_vec(); return Pending; } else { Err(Error::new(ErrorKind::MlsError).with(e.to_string())) @@ -133,13 +155,45 @@ where } } } + /* + fn on_request( + self: &mut Pin<&mut Self>, + p: Poll>, + cx: &mut std::task::Context<'_>, + ) -> Poll::Item>> { + use futures::task::Poll::*; + match p { + Ready(response) => { + let s = response.unwrap().bytes_stream(); + self.set(Self::started(s)); + self.as_mut().poll_next(cx) + } + Pending => Pending, + } + } + */ +} + +impl HttpPostStream +where + F: Future> + Unpin, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + /// Establish the initial HTTP Stream connection + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + // let mut this = Pin::new(self); + let mut this = Pin::new(self); + let _ = this.poll_next_unpin(&mut cx); + } } #[cfg(target_arch = "wasm32")] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( +pub fn create_grpc_stream( request: T, endpoint: String, http_client: reqwest::Client, @@ -148,25 +202,29 @@ pub fn create_grpc_stream< } #[cfg(not(target_arch = "wasm32"))] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, ->( +pub fn create_grpc_stream( request: T, endpoint: String, http_client: reqwest::Client, -) -> stream::BoxStream<'static, Result> { +) -> stream::BoxStream<'static, Result> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ create_grpc_stream_inner(request, endpoint, http_client).boxed() } -pub fn create_grpc_stream_inner< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + Unpin + 'static, ->( +fn create_grpc_stream_inner( request: T, endpoint: String, http_client: reqwest::Client, -) -> impl Stream> { +) -> impl Stream> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ let request = http_client.post(endpoint).json(&request).send(); - GrpcHttpStream::new(request) + let mut http = HttpPostStream::new(request); + http.establish(); + http } diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 0e160e485..9ea13f2dc 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -49,7 +49,7 @@ async-stream.workspace = true async-trait.workspace = true bincode.workspace = true diesel_migrations.workspace = true -futures.workspace = true +futures = { workspace = true } hex.workspace = true hkdf.workspace = true openmls_rust_crypto = { workspace = true }