From d70a5011a217f3cb98caa22cf8f24ade5515ae18 Mon Sep 17 00:00:00 2001 From: Christopher Young Date: Sat, 7 Oct 2023 22:32:38 +0800 Subject: [PATCH] Refactor `ResBody` and add channel support --- Cargo.toml | 1 + crates/compression/src/lib.rs | 8 +- crates/core/Cargo.toml | 1 + crates/core/src/conn/quinn/builder.rs | 26 ++-- crates/core/src/fs/named_file.rs | 4 +- crates/core/src/http/body/channel.rs | 82 ++++++++++++ crates/core/src/http/body/mod.rs | 4 +- crates/core/src/http/body/req.rs | 42 +++---- crates/core/src/http/body/res.rs | 174 +++++++++++++++++--------- crates/core/src/http/form.rs | 2 + crates/core/src/http/response.rs | 34 ++--- crates/core/src/test/response.rs | 20 +-- crates/core/src/writing/seek.rs | 4 +- crates/extra/src/sse.rs | 23 ++-- crates/proxy/src/lib.rs | 4 +- examples/body-channel/Cargo.toml | 11 ++ examples/body-channel/src/main.rs | 19 +++ 17 files changed, 313 insertions(+), 146 deletions(-) create mode 100644 crates/core/src/http/body/channel.rs create mode 100644 examples/body-channel/Cargo.toml create mode 100644 examples/body-channel/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 3369412d1..672f6ff67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ etag = "4" eyre = "0.6" fastrand = "2" form_urlencoded = "1" +futures-channel = "0.3" futures-util = { version = "0.3", default-features = false } headers = "0.3" http = "0.2" diff --git a/crates/compression/src/lib.rs b/crates/compression/src/lib.rs index 5de3a4f43..cdb317c4c 100644 --- a/crates/compression/src/lib.rs +++ b/crates/compression/src/lib.rs @@ -333,7 +333,7 @@ impl Handler for Compression { } match self.negotiate(req, res) { Some((algo, level)) => { - res.stream(EncodeStream::new(algo, level, Some(bytes))).ok(); + res.stream(EncodeStream::new(algo, level, Some(bytes))); res.headers_mut().append(CONTENT_ENCODING, algo.into()); } None => { @@ -352,7 +352,7 @@ impl Handler for Compression { } match self.negotiate(req, res) { Some((algo, level)) => { - res.stream(EncodeStream::new(algo, level, chunks)).ok(); + res.stream(EncodeStream::new(algo, level, chunks)); res.headers_mut().append(CONTENT_ENCODING, algo.into()); } None => { @@ -363,7 +363,7 @@ impl Handler for Compression { } ResBody::Hyper(body) => match self.negotiate(req, res) { Some((algo, level)) => { - res.stream(EncodeStream::new(algo, level, body)).ok(); + res.stream(EncodeStream::new(algo, level, body)); res.headers_mut().append(CONTENT_ENCODING, algo.into()); } None => { @@ -375,7 +375,7 @@ impl Handler for Compression { let body = body.into_inner(); match self.negotiate(req, res) { Some((algo, level)) => { - res.stream(EncodeStream::new(algo, level, body)).ok(); + res.stream(EncodeStream::new(algo, level, body)); res.headers_mut().append(CONTENT_ENCODING, algo.into()); } None => { diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index edac4eb13..4d489159c 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -43,6 +43,7 @@ encoding_rs = { workspace = true, optional = true } enumflags2 = { workspace = true } eyre = { workspace = true, optional = true } form_urlencoded = { workspace = true } +futures-channel = { workspace = true } futures-util = { workspace = true, features = ["io"] } headers = { workspace = true } http = { workspace = true } diff --git a/crates/core/src/conn/quinn/builder.rs b/crates/core/src/conn/quinn/builder.rs index 8cde9cd76..9acd62a96 100644 --- a/crates/core/src/conn/quinn/builder.rs +++ b/crates/core/src/conn/quinn/builder.rs @@ -55,7 +55,7 @@ impl Builder { &self, conn: crate::conn::quinn::H3Connection, hyper_handler: crate::service::HyperHandler, - _server_shutdown_token: CancellationToken, //TODO + _server_shutdown_token: CancellationToken, //TODO _idle_connection_timeout: Option, //TODO ) -> IoResult<()> { let mut conn = self @@ -170,9 +170,15 @@ async fn process_web_transport( let mut body = Pin::new(&mut body); while let Some(result) = poll_fn(|cx| body.as_mut().poll_next(cx)).await { match result { - Ok(bytes) => { - if let Err(e) = stream.send_data(bytes).await { - tracing::error!(error = ?e, "unable to send data to connection peer"); + Ok(frame) => { + if frame.is_data() { + if let Err(e) = stream.send_data(frame.into_data().unwrap_or_default()).await { + tracing::error!(error = ?e, "unable to send data to connection peer"); + } + } else { + if let Err(e) = stream.send_trailers(frame.into_trailers().unwrap_or_default()).await { + tracing::error!(error = ?e, "unable to send trailers to connection peer"); + } } } Err(e) => { @@ -220,9 +226,15 @@ where let mut body = Pin::new(&mut body); while let Some(result) = poll_fn(|cx| body.as_mut().poll_next(cx)).await { match result { - Ok(bytes) => { - if let Err(e) = tx.send_data(bytes).await { - tracing::error!(error = ?e, "unable to send data to connection peer"); + Ok(frame) => { + if frame.is_data() { + if let Err(e) = tx.send_data(frame.into_data().unwrap_or_default()).await { + tracing::error!(error = ?e, "unable to send data to connection peer"); + } + } else { + if let Err(e) = tx.send_trailers(frame.into_trailers().unwrap_or_default()).await { + tracing::error!(error = ?e, "unable to send trailers to connection peer"); + } } } Err(e) => { diff --git a/crates/core/src/fs/named_file.rs b/crates/core/src/fs/named_file.rs index 89689cf98..fada16d60 100644 --- a/crates/core/src/fs/named_file.rs +++ b/crates/core/src/fs/named_file.rs @@ -510,7 +510,7 @@ impl NamedFile { buffer_size: self.buffer_size, }; res.headers_mut().typed_insert(ContentLength(reader.total_size)); - res.stream(reader).ok(); + res.stream(reader); } else { res.status_code(StatusCode::OK); let reader = ChunkedFile { @@ -521,7 +521,7 @@ impl NamedFile { buffer_size: self.buffer_size, }; res.headers_mut().typed_insert(ContentLength(length)); - res.stream(reader).ok(); + res.stream(reader); } } } diff --git a/crates/core/src/http/body/channel.rs b/crates/core/src/http/body/channel.rs new file mode 100644 index 000000000..3900f803a --- /dev/null +++ b/crates/core/src/http/body/channel.rs @@ -0,0 +1,82 @@ +use std::fmt; +use std::io::{Error as IoError, ErrorKind, Result as IoResult}; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use futures_channel::mpsc; +use futures_channel::oneshot; +use hyper::HeaderMap; + +/// A sender half created through [`ResBody::channel()`]. +/// +/// Useful when wanting to stream chunks from another thread. +/// +/// ## Body Closing +/// +/// Note that the request body will always be closed normally when the sender is dropped (meaning +/// that the empty terminating chunk will be sent to the remote). If you desire to close the +/// connection with an incomplete response (e.g. in the case of an error during asynchronous +/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion. +/// +/// [`Body::channel()`]: struct.Body.html#method.channel +/// [`Sender::abort()`]: struct.Sender.html#method.abort +#[must_use = "Sender does nothing unless sent on"] +pub struct BodySender { + pub(crate) data_tx: mpsc::Sender>, + pub(crate) trailers_tx: Option>, +} +impl BodySender { + /// Check to see if this `Sender` can send more data. + pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.data_tx + .poll_ready(cx) + .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to poll ready: {}", e))) + } + + async fn ready(&mut self) -> IoResult<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on data channel when it is ready. + pub async fn send_data(&mut self, chunk: impl Into) -> IoResult<()> { + self.ready().await?; + self.data_tx + .try_send(Ok(chunk.into())) + .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to send data: {}", e))) + } + + /// Send trailers on trailers channel. + pub async fn send_trailers(&mut self, trailers: HeaderMap) -> IoResult<()> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(IoError::new(ErrorKind::Other, "failed to send railers")), + }; + tx.send(trailers) + .map_err(|_| IoError::new(ErrorKind::Other, "failed to send railers")) + } + + /// Send error on data channel. + pub fn send_error(&mut self, err: IoError) { + let _ = self + .data_tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(err)); + } +} + +impl fmt::Debug for BodySender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + + let mut builder = f.debug_tuple("BodySender"); + + builder.finish() + } +} + +/// A receiver for [`ResBody`] +pub struct BodyReceiver { + pub(crate) content_length: u64, + pub(crate) data_rx: mpsc::Receiver>, + pub(crate) trailers_rx: oneshot::Receiver, +} \ No newline at end of file diff --git a/crates/core/src/http/body/mod.rs b/crates/core/src/http/body/mod.rs index 7abb1b306..0ec8d2672 100644 --- a/crates/core/src/http/body/mod.rs +++ b/crates/core/src/http/body/mod.rs @@ -1,5 +1,5 @@ //! Http body. -pub use hyper::body::{Body, SizeHint}; +pub use hyper::body::{Body, SizeHint, Frame}; mod req; #[cfg(feature = "quinn")] @@ -8,3 +8,5 @@ pub use req::ReqBody; mod res; pub use hyper::body::Incoming as HyperBody; pub use res::ResBody; +mod channel; +pub use channel::{BodySender, BodyReceiver}; diff --git a/crates/core/src/http/body/req.rs b/crates/core/src/http/body/req.rs index ad242e282..7e1b49901 100644 --- a/crates/core/src/http/body/req.rs +++ b/crates/core/src/http/body/req.rs @@ -64,8 +64,8 @@ impl Body for ReqBody { cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { match &mut *self { - ReqBody::None => Poll::Ready(None), - ReqBody::Once(bytes) => { + Self::None => Poll::Ready(None), + Self::Once(bytes) => { if bytes.is_empty() { Poll::Ready(None) } else { @@ -73,10 +73,10 @@ impl Body for ReqBody { Poll::Ready(Some(Ok(Frame::data(bytes)))) } } - ReqBody::Hyper(body) => Pin::new(body) + Self::Hyper(body) => Pin::new(body) .poll_frame(cx) .map_err(|e| IoError::new(ErrorKind::Other, e)), - ReqBody::Boxed(inner) => Pin::new(inner) + Self::Boxed(inner) => Pin::new(inner) .poll_frame(cx) .map_err(|e| IoError::new(ErrorKind::Other, e)), } @@ -84,28 +84,28 @@ impl Body for ReqBody { fn is_end_stream(&self) -> bool { match self { - ReqBody::None => true, - ReqBody::Once(bytes) => bytes.is_empty(), - ReqBody::Hyper(body) => body.is_end_stream(), - ReqBody::Boxed(body) => body.is_end_stream(), + Self::None => true, + Self::Once(bytes) => bytes.is_empty(), + Self::Hyper(body) => body.is_end_stream(), + Self::Boxed(body) => body.is_end_stream(), } } fn size_hint(&self) -> SizeHint { match self { - ReqBody::None => SizeHint::with_exact(0), - ReqBody::Once(bytes) => SizeHint::with_exact(bytes.len() as u64), - ReqBody::Hyper(body) => body.size_hint(), - ReqBody::Boxed(body) => body.size_hint(), + Self::None => SizeHint::with_exact(0), + Self::Once(bytes) => SizeHint::with_exact(bytes.len() as u64), + Self::Hyper(body) => body.size_hint(), + Self::Boxed(body) => body.size_hint(), } } } impl Stream for ReqBody { - type Item = IoResult; + type Item = IoResult>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Body::poll_frame(self, cx) { - Poll::Ready(Some(Ok(frame))) => Poll::Ready(frame.into_data().map(Ok).ok()), + Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(IoError::new(ErrorKind::Other, e)))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, @@ -115,42 +115,42 @@ impl Stream for ReqBody { impl From for ReqBody { fn from(value: Bytes) -> ReqBody { - ReqBody::Once(value) + Self::Once(value) } } impl From for ReqBody { fn from(value: Incoming) -> ReqBody { - ReqBody::Hyper(value) + Self::Hyper(value) } } impl From for ReqBody { #[inline] fn from(value: String) -> ReqBody { - ReqBody::Once(value.into()) + Self::Once(value.into()) } } impl From<&'static [u8]> for ReqBody { fn from(value: &'static [u8]) -> ReqBody { - ReqBody::Once(value.into()) + Self::Once(value.into()) } } impl From<&'static str> for ReqBody { fn from(value: &'static str) -> ReqBody { - ReqBody::Once(value.into()) + Self::Once(value.into()) } } impl From> for ReqBody { fn from(value: Vec) -> ReqBody { - ReqBody::Once(value.into()) + Self::Once(value.into()) } } impl From> for ReqBody { fn from(value: Box<[u8]>) -> ReqBody { - ReqBody::Once(value.into()) + Self::Once(value.into()) } } diff --git a/crates/core/src/http/body/res.rs b/crates/core/src/http/body/res.rs index 01ae8273b..822d0436d 100644 --- a/crates/core/src/http/body/res.rs +++ b/crates/core/src/http/body/res.rs @@ -3,19 +3,26 @@ use std::boxed::Box; use std::collections::VecDeque; use std::fmt::Debug; +use std::future::Future; use std::io::{Error as IoError, ErrorKind, Result as IoResult}; use std::pin::Pin; -use std::task::{self, Context, Poll}; +use std::task::{self, ready, Context, Poll}; -use futures_util::stream::{BoxStream, Stream, TryStreamExt}; +use futures_channel::mpsc; +use futures_channel::oneshot; +use futures_util::stream::{BoxStream, FusedStream, Stream, TryStreamExt}; use hyper::body::{Body, Frame, Incoming, SizeHint}; use sync_wrapper::SyncWrapper; use bytes::Bytes; use crate::error::BoxedError; +use crate::http::body::{BodyReceiver, BodySender}; use crate::prelude::StatusError; +const CHUNKED_LENGTH: u64 = ::std::u64::MAX - 1; +const ZERO_LENGTH: u64 = 0; + /// Response body type. #[allow(clippy::type_complexity)] #[non_exhaustive] @@ -34,6 +41,8 @@ pub enum ResBody { Boxed(Pin + Send + Sync + 'static>>), /// Stream body. Stream(SyncWrapper>>), + /// Channel body. + Channel(BodyReceiver), /// Error body will be process in catcher. Error(StatusError), } @@ -68,6 +77,11 @@ impl ResBody { pub fn is_stream(&self) -> bool { matches!(*self, Self::Stream(_)) } + /// Check is that body is stream. + #[inline] + pub fn is_channel(&self) -> bool { + matches!(*self, Self::Channel { .. }) + } /// Check is that body is error will be process in catcher. pub fn is_error(&self) -> bool { matches!(*self, Self::Error(_)) @@ -84,6 +98,26 @@ impl ResBody { Self::Stream(SyncWrapper::new(Box::pin(mapped))) } + /// Create a `Body` stream with an associated sender half. + /// + /// Useful when wanting to stream chunks from another thread. + pub fn channel() -> (BodySender, Self) { + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); + + let tx = BodySender { + data_tx, + trailers_tx: Some(trailers_tx), + }; + let rx = ResBody::Channel(BodyReceiver { + content_length: CHUNKED_LENGTH, + data_rx, + trailers_rx, + }); + + (tx, rx) + } + /// Get body's size. #[inline] pub fn size(&self) -> Option { @@ -94,6 +128,7 @@ impl ResBody { Self::Hyper(_) => None, Self::Boxed(_) => None, Self::Stream(_) => None, + Self::Channel { .. } => None, Self::Error(_) => None, } } @@ -105,144 +140,161 @@ impl ResBody { } } -impl Stream for ResBody { - type Item = IoResult; +impl Body for ResBody { + type Data = Bytes; + type Error = IoError; - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, ::Error>>> { match self.get_mut() { - ResBody::None => Poll::Ready(None), - ResBody::Once(bytes) => { + Self::None => Poll::Ready(None), + Self::Once(bytes) => { if bytes.is_empty() { Poll::Ready(None) } else { let bytes = std::mem::replace(bytes, Bytes::new()); - Poll::Ready(Some(Ok(bytes))) + Poll::Ready(Some(Ok(Frame::data(bytes)))) } } - ResBody::Chunks(chunks) => Poll::Ready(chunks.pop_front().map(Ok)), - ResBody::Hyper(body) => match Body::poll_frame(Pin::new(body), cx) { - Poll::Ready(Some(Ok(frame))) => Poll::Ready(frame.into_data().map(Ok).ok()), + Self::Chunks(chunks) => Poll::Ready(chunks.pop_front().map(|bytes| Ok(Frame::data(bytes)))), + Self::Hyper(body) => match Body::poll_frame(Pin::new(body), cx) { + Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(IoError::new(ErrorKind::Other, e)))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, - ResBody::Boxed(body) => match Body::poll_frame(Pin::new(body), cx) { - Poll::Ready(Some(Ok(frame))) => Poll::Ready(frame.into_data().map(Ok).ok()), + Self::Boxed(body) => match Body::poll_frame(Pin::new(body), cx) { + Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(IoError::new(ErrorKind::Other, e)))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, - ResBody::Stream(stream) => stream + Self::Stream(stream) => stream .get_mut() .as_mut() .poll_next(cx) + .map_ok(Frame::data) .map_err(|e| IoError::new(ErrorKind::Other, e)), - ResBody::Error(_) => Poll::Ready(None), - } - } -} - -impl Body for ResBody { - type Data = Bytes; - type Error = IoError; + Self::Channel(rx) => { + if !rx.data_rx.is_terminated() { + if let Some(chunk) = ready!(Pin::new(&mut rx.data_rx).poll_next(cx)?) { + return Poll::Ready(Some(Ok(Frame::data(chunk)))); + } + } - fn poll_frame( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, ::Error>>> { - match self.poll_next(_cx) { - Poll::Ready(Some(Ok(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, + // check trailers after data is terminated + match ready!(Pin::new(&mut rx.trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), + Err(_) => Poll::Ready(None), + } + } + ResBody::Error(_) => Poll::Ready(None), } } fn is_end_stream(&self) -> bool { match self { - ResBody::None => true, - ResBody::Once(bytes) => bytes.is_empty(), - ResBody::Chunks(chunks) => chunks.is_empty(), - ResBody::Hyper(body) => body.is_end_stream(), - ResBody::Boxed(body) => body.is_end_stream(), - ResBody::Stream(_) => false, - ResBody::Error(_) => true, + Self::None => true, + Self::Once(bytes) => bytes.is_empty(), + Self::Chunks(chunks) => chunks.is_empty(), + Self::Hyper(body) => body.is_end_stream(), + Self::Boxed(body) => body.is_end_stream(), + Self::Stream(_) => false, + Self::Channel(rx) => rx.content_length == ZERO_LENGTH, + Self::Error(_) => true, } } fn size_hint(&self) -> SizeHint { match self { - ResBody::None => SizeHint::with_exact(0), - ResBody::Once(bytes) => SizeHint::with_exact(bytes.len() as u64), - ResBody::Chunks(chunks) => { + Self::None => SizeHint::with_exact(0), + Self::Once(bytes) => SizeHint::with_exact(bytes.len() as u64), + Self::Chunks(chunks) => { let size = chunks.iter().map(|bytes| bytes.len() as u64).sum(); SizeHint::with_exact(size) } - ResBody::Hyper(recv) => recv.size_hint(), - ResBody::Boxed(recv) => recv.size_hint(), - ResBody::Stream(_) => SizeHint::default(), - ResBody::Error(_) => SizeHint::with_exact(0), + Self::Hyper(recv) => recv.size_hint(), + Self::Boxed(recv) => recv.size_hint(), + Self::Stream(_) => SizeHint::default(), + Self::Channel { .. } => SizeHint::default(), + Self::Error(_) => SizeHint::with_exact(0), + } + } +} + +impl Stream for ResBody { + type Item = IoResult>; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match Body::poll_frame(self, cx) { + Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(IoError::new(ErrorKind::Other, e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } impl From<()> for ResBody { fn from(_value: ()) -> ResBody { - ResBody::None + Self::None } } impl From for ResBody { fn from(value: Bytes) -> ResBody { - ResBody::Once(value) + Self::Once(value) } } impl From for ResBody { fn from(value: Incoming) -> ResBody { - ResBody::Hyper(value) + Self::Hyper(value) } } impl From for ResBody { #[inline] fn from(value: String) -> ResBody { - ResBody::Once(value.into()) + Self::Once(value.into()) } } impl From<&'static [u8]> for ResBody { fn from(value: &'static [u8]) -> ResBody { - ResBody::Once(value.into()) + Self::Once(value.into()) } } impl From<&'static str> for ResBody { fn from(value: &'static str) -> ResBody { - ResBody::Once(value.into()) + Self::Once(value.into()) } } impl From> for ResBody { fn from(value: Vec) -> ResBody { - ResBody::Once(value.into()) + Self::Once(value.into()) } } impl From> for ResBody { fn from(value: Box<[u8]>) -> ResBody { - ResBody::Once(value.into()) + Self::Once(value.into()) } } impl Debug for ResBody { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ResBody::None => write!(f, "ResBody::None"), - ResBody::Once(bytes) => write!(f, "ResBody::Once({:?})", bytes), - ResBody::Chunks(chunks) => write!(f, "ResBody::Chunks({:?})", chunks), - ResBody::Hyper(_) => write!(f, "ResBody::Hyper(_)"), - ResBody::Boxed(_) => write!(f, "ResBody::Boxed(_)"), - ResBody::Stream(_) => write!(f, "ResBody::Stream(_)"), - ResBody::Error(_) => write!(f, "ResBody::Error(_)"), + Self::None => write!(f, "ResBody::None"), + Self::Once(bytes) => write!(f, "ResBody::Once({:?})", bytes), + Self::Chunks(chunks) => write!(f, "ResBody::Chunks({:?})", chunks), + Self::Hyper(_) => write!(f, "ResBody::Hyper(_)"), + Self::Boxed(_) => write!(f, "ResBody::Boxed(_)"), + Self::Stream(_) => write!(f, "ResBody::Stream(_)"), + Self::Channel { .. } => write!(f, "ResBody::Channel{{..}}"), + Self::Error(_) => write!(f, "ResBody::Error(_)"), } } } diff --git a/crates/core/src/http/form.rs b/crates/core/src/http/form.rs index d133c303b..d894050ae 100644 --- a/crates/core/src/http/form.rs +++ b/crates/core/src/http/form.rs @@ -2,6 +2,7 @@ use std::ffi::OsStr; use std::path::{Path, PathBuf}; +use futures_util::StreamExt; use http_body_util::BodyExt; use mime::Mime; use multer::{Field, Multipart}; @@ -57,6 +58,7 @@ impl FormData { .and_then(|ct| ct.to_str().ok()) .and_then(|ct| multer::parse_boundary(ct).ok()) { + let body = body.map(|f| f.map(|f| f.into_data().unwrap_or_default())); let mut multipart = Multipart::new(body, boundary); while let Some(mut field) = multipart.next_field().await? { if let Some(name) = field.name().map(|s| s.to_owned()) { diff --git a/crates/core/src/http/response.rs b/crates/core/src/http/response.rs index 7d4e7b053..18dcfea8f 100644 --- a/crates/core/src/http/response.rs +++ b/crates/core/src/http/response.rs @@ -16,7 +16,7 @@ use crate::http::{StatusCode, StatusError}; use crate::{BoxedError, Error, Scribe}; use bytes::Bytes; -pub use crate::http::body::ResBody; +pub use crate::http::body::{BodySender, ResBody}; /// Represents an HTTP response #[non_exhaustive] @@ -432,6 +432,12 @@ impl Response { "current body's kind is `ResBody::Stream`, it is not allowed to write bytes", )); } + ResBody::Channel { .. } => { + tracing::error!("current body's kind is `ResBody::Channel`, it is not allowed to write bytes"); + return Err(Error::other( + "current body's kind is `ResBody::Channel`, it is not allowed to write bytes", + )); + } ResBody::Error(_) => { self.body = ResBody::Once(data.into()); } @@ -441,26 +447,20 @@ impl Response { /// Set response's body to stream. #[inline] - pub fn stream(&mut self, stream: S) -> crate::Result<()> + pub fn stream(&mut self, stream: S) where S: Stream> + Send + 'static, O: Into + 'static, E: Into + 'static, { - match &self.body { - ResBody::Once(_) => { - return Err(Error::other("current body kind is `ResBody::Once` already")); - } - ResBody::Chunks(_) => { - return Err(Error::other("current body kind is `ResBody::Chunks` already")); - } - ResBody::Stream(_) => { - return Err(Error::other("current body kind is `ResBody::Stream` already")); - } - _ => {} - } self.body = ResBody::stream(stream); - Ok(()) + } + /// Set response's body to channel. + #[inline] + pub fn channel(&mut self) -> BodySender { + let (sender, body) = ResBody::channel(); + self.body = body; + sender } } @@ -506,7 +506,7 @@ mod test { let mut result = bytes::BytesMut::new(); while let Some(Ok(data)) = body.next().await { - result.extend_from_slice(&data) + result.extend_from_slice(&data.into_data().unwrap_or_default()) } assert_eq!("hello", &result) @@ -521,7 +521,7 @@ mod test { let mut result = bytes::BytesMut::new(); while let Some(Ok(data)) = body.next().await { - result.extend_from_slice(&data) + result.extend_from_slice(&data.into_data().unwrap_or_default()) } assert_eq!("Hello World", &result) diff --git a/crates/core/src/test/response.rs b/crates/core/src/test/response.rs index 177e3646c..31f63c804 100644 --- a/crates/core/src/test/response.rs +++ b/crates/core/src/test/response.rs @@ -1,10 +1,9 @@ use std::borrow::Cow; -use std::io::{self, Write, Result as IoResult}; +use std::io::{self, Result as IoResult, Write}; use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; use flate2::write::{GzDecoder, ZlibDecoder}; -use futures_util::stream::StreamExt; use http_body_util::BodyExt; use mime::Mime; use serde::de::DeserializeOwned; @@ -135,22 +134,6 @@ impl ResponseExt for Response { let bytes = match body { ResBody::None => Bytes::new(), ResBody::Once(bytes) => bytes, - ResBody::Chunks(chunks) => { - let mut bytes = BytesMut::new(); - for chunk in chunks { - bytes.extend(chunk); - } - bytes.freeze() - } - ResBody::Hyper(body) => body.collect().await?.to_bytes(), - ResBody::Boxed(body) => body.collect().await?.to_bytes(), - ResBody::Stream(mut stream) => { - let mut bytes = BytesMut::new(); - while let Some(chunk) = stream.get_mut().next().await { - bytes.extend(chunk?); - } - bytes.freeze() - } ResBody::Error(e) => { if let Some(content_type) = content_type { status_error_bytes(&e, content_type, None).1 @@ -158,6 +141,7 @@ impl ResponseExt for Response { status_error_bytes(&e, &"text/html".parse().unwrap(), None).1 } } + _ => BodyExt::collect(body).await?.to_bytes(), }; Ok(bytes) } diff --git a/crates/core/src/writing/seek.rs b/crates/core/src/writing/seek.rs index a090d4532..9bba2c3f2 100644 --- a/crates/core/src/writing/seek.rs +++ b/crates/core/src/writing/seek.rs @@ -137,11 +137,11 @@ where } res.headers_mut() .typed_insert(ContentLength(cmp::min(length, self.length))); - res.stream(ReaderStream::new(self.reader)).ok(); + res.stream(ReaderStream::new(self.reader)); } else { res.status_code(StatusCode::OK); res.headers_mut().typed_insert(ContentLength(self.length)); - res.stream(ReaderStream::new(self.reader)).ok(); + res.stream(ReaderStream::new(self.reader)); } } } diff --git a/crates/extra/src/sse.rs b/crates/extra/src/sse.rs index 4d3d41243..5b140ac89 100644 --- a/crates/extra/src/sse.rs +++ b/crates/extra/src/sse.rs @@ -254,7 +254,7 @@ where /// Send stream. #[inline] - pub fn stream(self, res: &mut Response) -> salvo_core::Result<()> { + pub fn stream(self, res: &mut Response) { stream(res, self) } } @@ -268,9 +268,9 @@ fn write_request_headers(res: &mut Response) { .insert(CACHE_CONTROL, HeaderValue::from_static("no-cache")); } -/// Send event stream +/// Send event stream. #[inline] -pub fn stream(res: &mut Response, event_stream: S) -> salvo_core::Result<()> +pub fn stream(res: &mut Response, event_stream: S) where S: TryStream + Send + 'static, S::Error: StdError + Send + Sync + 'static, @@ -338,7 +338,7 @@ mod tests { Ok::<_, Infallible>(SseEvent::default().text("2")), ]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains("data:1") && text.contains("data:2")); } @@ -350,8 +350,7 @@ mod tests { SseKeepAlive::new(event_stream) .comment("love you") .max_interval(Duration::from_secs(1)) - .stream(&mut res) - .unwrap(); + .stream(&mut res); let text = res.take_string().await.unwrap(); assert!(text.contains("data:1")); } @@ -367,7 +366,7 @@ mod tests { name: "jobs".to_owned(), })]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains(r#"data:{"name":"jobs"}"#)); } @@ -376,7 +375,7 @@ mod tests { async fn test_sse_comment() { let event_stream = tokio_stream::iter(vec![Ok::<_, Infallible>(SseEvent::default().comment("comment"))]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains(":comment")); } @@ -385,7 +384,7 @@ mod tests { async fn test_sse_name() { let event_stream = tokio_stream::iter(vec![Ok::<_, Infallible>(SseEvent::default().name("evt2"))]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains("event:evt2")); } @@ -396,7 +395,7 @@ mod tests { SseEvent::default().retry(std::time::Duration::from_secs_f32(1.0)), )]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains("retry:1000")); @@ -404,7 +403,7 @@ mod tests { SseEvent::default().retry(std::time::Duration::from_secs_f32(1.001)), )]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains("retry:1001")); } @@ -413,7 +412,7 @@ mod tests { async fn test_sse_id() { let event_stream = tokio_stream::iter(vec![Ok::<_, Infallible>(SseEvent::default().id("jobs"))]); let mut res = Response::new(); - super::stream(&mut res, event_stream).unwrap(); + super::stream(&mut res, event_stream); let text = res.take_string().await.unwrap(); assert!(text.contains("id:jobs")); } diff --git a/crates/proxy/src/lib.rs b/crates/proxy/src/lib.rs index 236d57d5c..0a9ddf7af 100644 --- a/crates/proxy/src/lib.rs +++ b/crates/proxy/src/lib.rs @@ -12,6 +12,7 @@ use std::convert::{Infallible, TryFrom}; +use futures_util::TryStreamExt; use hyper::upgrade::OnUpgrade; use percent_encoding::{utf8_percent_encode, CONTROLS}; use reqwest::Client; @@ -244,7 +245,8 @@ where ) -> Result { let request_upgrade_type = get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned()); - let proxied_request = proxied_request.map(reqwest::Body::wrap_stream); + let proxied_request = + proxied_request.map(|s| reqwest::Body::wrap_stream(s.map_ok(|s| s.into_data().unwrap_or_default()))); let response = self .client .execute(proxied_request.try_into().map_err(Error::other)?) diff --git a/examples/body-channel/Cargo.toml b/examples/body-channel/Cargo.toml new file mode 100644 index 000000000..1bf3724ad --- /dev/null +++ b/examples/body-channel/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-body-channel" +version.workspace = true +edition.workspace = true +publish.workspace = true + +[dependencies] +salvo = { workspace = true } +tokio = { workspace = true, features = ["macros"] } +tracing.workspace = true +tracing-subscriber.workspace = true diff --git a/examples/body-channel/src/main.rs b/examples/body-channel/src/main.rs new file mode 100644 index 000000000..ce4e94648 --- /dev/null +++ b/examples/body-channel/src/main.rs @@ -0,0 +1,19 @@ +use salvo::prelude::*; + +#[handler] +async fn hello(res: &mut Response) { + res.add_header("content-type", "text/plain", true).unwrap(); + let mut tx = res.channel(); + tokio::spawn(async move { + tx.send_data("Hello world").await.unwrap(); + }); +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().init(); + + let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; + let router = Router::new().get(hello); + Server::new(acceptor).serve(router).await; +}