Skip to content

Commit

Permalink
Refactor ResBody and add channel support (#445)
Browse files Browse the repository at this point in the history
* Refactor `ResBody` and add channel support

* Format Rust code using rustfmt

* wip

* Format Rust code using rustfmt

* fix ci

* fix ci

---------

Co-authored-by: Christopher Young <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 7, 2023
1 parent 4469241 commit e531d85
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 148 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ etag = "4"
eyre = "0.6"
fastrand = "2"
form_urlencoded = "1"
futures-channel = "0.3"
futures-util = { version = "0.3", default-features = false }
headers = "0.3"
http = "0.2"
Expand Down
8 changes: 4 additions & 4 deletions crates/compression/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl Handler for Compression {
}
match self.negotiate(req, res) {
Some((algo, level)) => {
res.stream(EncodeStream::new(algo, level, Some(bytes))).ok();
res.stream(EncodeStream::new(algo, level, Some(bytes)));
res.headers_mut().append(CONTENT_ENCODING, algo.into());
}
None => {
Expand All @@ -352,7 +352,7 @@ impl Handler for Compression {
}
match self.negotiate(req, res) {
Some((algo, level)) => {
res.stream(EncodeStream::new(algo, level, chunks)).ok();
res.stream(EncodeStream::new(algo, level, chunks));
res.headers_mut().append(CONTENT_ENCODING, algo.into());
}
None => {
Expand All @@ -363,7 +363,7 @@ impl Handler for Compression {
}
ResBody::Hyper(body) => match self.negotiate(req, res) {
Some((algo, level)) => {
res.stream(EncodeStream::new(algo, level, body)).ok();
res.stream(EncodeStream::new(algo, level, body));
res.headers_mut().append(CONTENT_ENCODING, algo.into());
}
None => {
Expand All @@ -375,7 +375,7 @@ impl Handler for Compression {
let body = body.into_inner();
match self.negotiate(req, res) {
Some((algo, level)) => {
res.stream(EncodeStream::new(algo, level, body)).ok();
res.stream(EncodeStream::new(algo, level, body));
res.headers_mut().append(CONTENT_ENCODING, algo.into());
}
None => {
Expand Down
1 change: 1 addition & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ encoding_rs = { workspace = true, optional = true }
enumflags2 = { workspace = true }
eyre = { workspace = true, optional = true }
form_urlencoded = { workspace = true }
futures-channel = { workspace = true }
futures-util = { workspace = true, features = ["io"] }
headers = { workspace = true }
http = { workspace = true }
Expand Down
26 changes: 19 additions & 7 deletions crates/core/src/conn/quinn/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Builder {
&self,
conn: crate::conn::quinn::H3Connection,
hyper_handler: crate::service::HyperHandler,
_server_shutdown_token: CancellationToken, //TODO
_server_shutdown_token: CancellationToken, //TODO
_idle_connection_timeout: Option<Duration>, //TODO
) -> IoResult<()> {
let mut conn = self
Expand Down Expand Up @@ -170,9 +170,15 @@ async fn process_web_transport(
let mut body = Pin::new(&mut body);
while let Some(result) = poll_fn(|cx| body.as_mut().poll_next(cx)).await {
match result {
Ok(bytes) => {
if let Err(e) = stream.send_data(bytes).await {
tracing::error!(error = ?e, "unable to send data to connection peer");
Ok(frame) => {
if frame.is_data() {
if let Err(e) = stream.send_data(frame.into_data().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send data to connection peer");
}
} else {
if let Err(e) = stream.send_trailers(frame.into_trailers().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send trailers to connection peer");
}
}
}
Err(e) => {
Expand Down Expand Up @@ -220,9 +226,15 @@ where
let mut body = Pin::new(&mut body);
while let Some(result) = poll_fn(|cx| body.as_mut().poll_next(cx)).await {
match result {
Ok(bytes) => {
if let Err(e) = tx.send_data(bytes).await {
tracing::error!(error = ?e, "unable to send data to connection peer");
Ok(frame) => {
if frame.is_data() {
if let Err(e) = tx.send_data(frame.into_data().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send data to connection peer");
}
} else {
if let Err(e) = tx.send_trailers(frame.into_trailers().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send trailers to connection peer");
}
}
}
Err(e) => {
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/fs/named_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ impl NamedFile {
buffer_size: self.buffer_size,
};
res.headers_mut().typed_insert(ContentLength(reader.total_size));
res.stream(reader).ok();
res.stream(reader);
} else {
res.status_code(StatusCode::OK);
let reader = ChunkedFile {
Expand All @@ -521,7 +521,7 @@ impl NamedFile {
buffer_size: self.buffer_size,
};
res.headers_mut().typed_insert(ContentLength(length));
res.stream(reader).ok();
res.stream(reader);
}
}
}
Expand Down
80 changes: 80 additions & 0 deletions crates/core/src/http/body/channel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::fmt;
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::task::{Context, Poll};

use bytes::Bytes;
use futures_channel::mpsc;
use futures_channel::oneshot;
use hyper::HeaderMap;

/// A sender half created through [`ResBody::channel()`].
///
/// Useful when wanting to stream chunks from another thread.
///
/// ## Body Closing
///
/// Note that the request body will always be closed normally when the sender is dropped (meaning
/// that the empty terminating chunk will be sent to the remote). If you desire to close the
/// connection with an incomplete response (e.g. in the case of an error during asynchronous
/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion.
///
/// [`Body::channel()`]: struct.Body.html#method.channel
/// [`Sender::abort()`]: struct.Sender.html#method.abort
#[must_use = "Sender does nothing unless sent on"]
pub struct BodySender {
pub(crate) data_tx: mpsc::Sender<Result<Bytes, IoError>>,
pub(crate) trailers_tx: Option<oneshot::Sender<HeaderMap>>,
}
impl BodySender {
/// Check to see if this `Sender` can send more data.
pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
self.data_tx
.poll_ready(cx)
.map_err(|e| IoError::new(ErrorKind::Other, format!("failed to poll ready: {}", e)))
}

async fn ready(&mut self) -> IoResult<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
}

/// Send data on data channel when it is ready.
pub async fn send_data(&mut self, chunk: impl Into<Bytes>) -> IoResult<()> {
self.ready().await?;
self.data_tx
.try_send(Ok(chunk.into()))
.map_err(|e| IoError::new(ErrorKind::Other, format!("failed to send data: {}", e)))
}

/// Send trailers on trailers channel.
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> IoResult<()> {
let tx = match self.trailers_tx.take() {
Some(tx) => tx,
None => return Err(IoError::new(ErrorKind::Other, "failed to send railers")),
};
tx.send(trailers)
.map_err(|_| IoError::new(ErrorKind::Other, "failed to send railers"))
}

/// Send error on data channel.
pub fn send_error(&mut self, err: IoError) {
let _ = self
.data_tx
// clone so the send works even if buffer is full
.clone()
.try_send(Err(err));
}
}

impl fmt::Debug for BodySender {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut builder = f.debug_tuple("BodySender");

builder.finish()
}
}

/// A receiver for [`ResBody`]
pub struct BodyReceiver {
pub(crate) data_rx: mpsc::Receiver<Result<Bytes, IoError>>,
pub(crate) trailers_rx: oneshot::Receiver<HeaderMap>,
}
4 changes: 3 additions & 1 deletion crates/core/src/http/body/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Http body.
pub use hyper::body::{Body, SizeHint};
pub use hyper::body::{Body, Frame, SizeHint};

mod req;
#[cfg(feature = "quinn")]
Expand All @@ -8,3 +8,5 @@ pub use req::ReqBody;
mod res;
pub use hyper::body::Incoming as HyperBody;
pub use res::ResBody;
mod channel;
pub use channel::{BodyReceiver, BodySender};
42 changes: 21 additions & 21 deletions crates/core/src/http/body/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,48 +64,48 @@ impl Body for ReqBody {
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match &mut *self {
ReqBody::None => Poll::Ready(None),
ReqBody::Once(bytes) => {
Self::None => Poll::Ready(None),
Self::Once(bytes) => {
if bytes.is_empty() {
Poll::Ready(None)
} else {
let bytes = std::mem::replace(bytes, Bytes::new());
Poll::Ready(Some(Ok(Frame::data(bytes))))
}
}
ReqBody::Hyper(body) => Pin::new(body)
Self::Hyper(body) => Pin::new(body)
.poll_frame(cx)
.map_err(|e| IoError::new(ErrorKind::Other, e)),
ReqBody::Boxed(inner) => Pin::new(inner)
Self::Boxed(inner) => Pin::new(inner)
.poll_frame(cx)
.map_err(|e| IoError::new(ErrorKind::Other, e)),
}
}

fn is_end_stream(&self) -> bool {
match self {
ReqBody::None => true,
ReqBody::Once(bytes) => bytes.is_empty(),
ReqBody::Hyper(body) => body.is_end_stream(),
ReqBody::Boxed(body) => body.is_end_stream(),
Self::None => true,
Self::Once(bytes) => bytes.is_empty(),
Self::Hyper(body) => body.is_end_stream(),
Self::Boxed(body) => body.is_end_stream(),
}
}

fn size_hint(&self) -> SizeHint {
match self {
ReqBody::None => SizeHint::with_exact(0),
ReqBody::Once(bytes) => SizeHint::with_exact(bytes.len() as u64),
ReqBody::Hyper(body) => body.size_hint(),
ReqBody::Boxed(body) => body.size_hint(),
Self::None => SizeHint::with_exact(0),
Self::Once(bytes) => SizeHint::with_exact(bytes.len() as u64),
Self::Hyper(body) => body.size_hint(),
Self::Boxed(body) => body.size_hint(),
}
}
}
impl Stream for ReqBody {
type Item = IoResult<Bytes>;
type Item = IoResult<Frame<Bytes>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Body::poll_frame(self, cx) {
Poll::Ready(Some(Ok(frame))) => Poll::Ready(frame.into_data().map(Ok).ok()),
Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(IoError::new(ErrorKind::Other, e)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
Expand All @@ -115,42 +115,42 @@ impl Stream for ReqBody {

impl From<Bytes> for ReqBody {
fn from(value: Bytes) -> ReqBody {
ReqBody::Once(value)
Self::Once(value)
}
}
impl From<Incoming> for ReqBody {
fn from(value: Incoming) -> ReqBody {
ReqBody::Hyper(value)
Self::Hyper(value)
}
}
impl From<String> for ReqBody {
#[inline]
fn from(value: String) -> ReqBody {
ReqBody::Once(value.into())
Self::Once(value.into())
}
}

impl From<&'static [u8]> for ReqBody {
fn from(value: &'static [u8]) -> ReqBody {
ReqBody::Once(value.into())
Self::Once(value.into())
}
}

impl From<&'static str> for ReqBody {
fn from(value: &'static str) -> ReqBody {
ReqBody::Once(value.into())
Self::Once(value.into())
}
}

impl From<Vec<u8>> for ReqBody {
fn from(value: Vec<u8>) -> ReqBody {
ReqBody::Once(value.into())
Self::Once(value.into())
}
}

impl From<Box<[u8]>> for ReqBody {
fn from(value: Box<[u8]>) -> ReqBody {
ReqBody::Once(value.into())
Self::Once(value.into())
}
}

Expand Down
Loading

0 comments on commit e531d85

Please sign in to comment.