Skip to content

Commit

Permalink
Make Tower compat body more generic (#473)
Browse files Browse the repository at this point in the history
* wip

* Make Tower compat body more generic

* Format Rust code using rustfmt

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
chrislearn and github-actions[bot] authored Oct 31, 2023
1 parent 65d9fbb commit 05c907e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 33 deletions.
11 changes: 11 additions & 0 deletions crates/core/src/http/body/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ impl From<String> for ReqBody {
Self::Once(value.into())
}
}
impl TryFrom<ReqBody> for Incoming {
type Error = crate::Error;
fn try_from(body: ReqBody) -> Result<Self, Self::Error> {
match body {
ReqBody::None => Err(crate::Error::other("ReqBody::None cannot convert to Incoming")),
ReqBody::Once(_) => Err(crate::Error::other("ReqBody::Bytes cannot convert to Incoming")),
ReqBody::Hyper(body) => Ok(body),
ReqBody::Boxed(_) => Err(crate::Error::other("ReqBody::Boxed cannot convert to Incoming")),
}
}
}

impl From<&'static [u8]> for ReqBody {
fn from(value: &'static [u8]) -> Self {
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/http/errors/parse_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ pub type ParseResult<T> = Result<T, ParseError>;
#[non_exhaustive]
pub enum ParseError {
/// The Hyper request did not have a valid Content-Type header.
#[error("The Hyper request did not have a valid Content-Type header.")]
#[error("The request did not have a valid Content-Type header.")]
InvalidContentType,

/// The Hyper request's body is empty.
#[error("The Hyper request's body is empty.")]
#[error("The request's body is empty.")]
EmptyBody,

/// Parse error when parse from str.
Expand Down
13 changes: 11 additions & 2 deletions crates/core/src/http/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Http request.
use std::error::Error as StdError;
use std::fmt::{self, Formatter};

use bytes::Bytes;
Expand Down Expand Up @@ -175,7 +176,11 @@ impl Request {
/// Strip the request to [`hyper::Request`].
#[doc(hidden)]
#[inline]
pub fn strip_to_hyper(&mut self) -> Result<hyper::Request<ReqBody>, http::Error> {
pub fn strip_to_hyper<QB>(&mut self) -> Result<hyper::Request<QB>, crate::Error>
where
QB: TryFrom<ReqBody>,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
{
let mut builder = http::request::Builder::new()
.method(self.method.clone())
.uri(self.uri.clone())
Expand All @@ -186,7 +191,11 @@ impl Request {
if let Some(extensions) = builder.extensions_mut() {
*extensions = std::mem::take(&mut self.extensions);
}
builder.body(std::mem::take(&mut self.body))

std::mem::take(&mut self.body)
.try_into()
.map_err(crate::Error::other)
.and_then(|body| builder.body(body).map_err(crate::Error::other))
}

/// Merge data from [`hyper::Request`].
Expand Down
66 changes: 37 additions & 29 deletions crates/core/src/tower_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::io::{Error as IoError, ErrorKind};
use std::marker::PhantomData;
use std::task::{Context, Poll};

use futures_util::future::{BoxFuture, FutureExt};
Expand All @@ -15,39 +16,42 @@ use crate::http::{ReqBody, ResBody, StatusError};
use crate::{async_trait, Depot, FlowCtrl, Handler, Request, Response};

/// Trait for tower service compat.
pub trait TowerServiceCompat<B, E, Fut> {
pub trait TowerServiceCompat<QB, SB, E, Fut> {
/// Converts a tower service to a salvo handler.
fn compat(self) -> TowerServiceHandler<Self>
fn compat(self) -> TowerServiceHandler<Self, QB>
where
Self: Sized,
{
TowerServiceHandler(self)
TowerServiceHandler(self, PhantomData)
}
}

impl<T, B, E, Fut> TowerServiceCompat<B, E, Fut> for T
impl<T, QB, SB, E, Fut> TowerServiceCompat<QB, SB, E, Fut> for T
where
B: Body + Send + Sync + 'static,
B::Data: Into<Bytes> + Send + fmt::Debug + 'static,
B::Error: StdError + Send + Sync + 'static,
QB: From<ReqBody> + Send + Sync + 'static,
SB: Body + Send + Sync + 'static,
SB::Data: Into<Bytes> + Send + fmt::Debug + 'static,
SB::Error: StdError + Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
T: Service<hyper::Request<ReqBody>, Response = hyper::Response<B>, Future = Fut> + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<hyper::Response<B>, E>> + Send + 'static,
T: Service<hyper::Request<ReqBody>, Response = hyper::Response<SB>, Future = Fut> + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<hyper::Response<SB>, E>> + Send + 'static,
{
}

/// Tower service compat handler.
pub struct TowerServiceHandler<Svc>(Svc);
pub struct TowerServiceHandler<Svc, QB>(Svc, PhantomData<QB>);

#[async_trait]
impl<Svc, B, E, Fut> Handler for TowerServiceHandler<Svc>
impl<Svc, QB, SB, E, Fut> Handler for TowerServiceHandler<Svc, QB>
where
B: Body + Send + Sync + 'static,
B::Data: Into<Bytes> + Send + fmt::Debug + 'static,
B::Error: StdError + Send + Sync + 'static,
QB: TryFrom<ReqBody> + Body + Send + Sync + 'static,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
SB: Body + Send + Sync + 'static,
SB::Data: Into<Bytes> + Send + fmt::Debug + 'static,
SB::Error: StdError + Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
Svc: Service<hyper::Request<ReqBody>, Response = hyper::Response<B>, Future = Fut> + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<hyper::Response<B>, E>> + Send + 'static,
Svc: Service<hyper::Request<QB>, Response = hyper::Response<SB>, Future = Fut> + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<hyper::Response<SB>, E>> + Send + 'static,
{
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) {
let mut svc = self.0.clone();
Expand All @@ -56,7 +60,7 @@ where
res.render(StatusError::internal_server_error().cause("tower service not ready."));
return;
}
let hyper_req = match req.strip_to_hyper() {
let hyper_req = match req.strip_to_hyper::<QB>() {
Ok(hyper_req) => hyper_req,
Err(_) => {
tracing::error!("strip request to hyper failed.");
Expand Down Expand Up @@ -155,12 +159,14 @@ impl Service<hyper::Request<ReqBody>> for FlowCtrlService {
/// Trait for tower layer compat.
pub trait TowerLayerCompat {
/// Converts a tower layer to a salvo handler.
fn compat(self) -> TowerLayerHandler<Self::Service>
fn compat<QB>(self) -> TowerLayerHandler<Self::Service, QB>
where
QB: TryFrom<ReqBody> + Body + Send + Sync + 'static,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
Self: Layer<FlowCtrlService> + Sized,
Self::Service: tower::Service<hyper::Request<ReqBody>> + Sync + Send + 'static,
<Self::Service as Service<hyper::Request<ReqBody>>>::Future: Send,
<Self::Service as Service<hyper::Request<ReqBody>>>::Error: StdError + Send + Sync,
Self::Service: tower::Service<hyper::Request<QB>> + Sync + Send + 'static,
<Self::Service as Service<hyper::Request<QB>>>::Future: Send,
<Self::Service as Service<hyper::Request<QB>>>::Error: StdError + Send + Sync,
{
TowerLayerHandler(Buffer::new(self.layer(FlowCtrlService), 32))
}
Expand All @@ -169,17 +175,19 @@ pub trait TowerLayerCompat {
impl<T> TowerLayerCompat for T where T: Layer<FlowCtrlService> + Send + Sync + Sized + 'static {}

/// Tower service compat handler.
pub struct TowerLayerHandler<Svc: Service<hyper::Request<ReqBody>>>(Buffer<Svc, hyper::Request<ReqBody>>);
pub struct TowerLayerHandler<Svc: Service<hyper::Request<QB>>, QB>(Buffer<Svc, hyper::Request<QB>>);

#[async_trait]
impl<Svc, B, E> Handler for TowerLayerHandler<Svc>
impl<Svc, QB, SB, E> Handler for TowerLayerHandler<Svc, QB>
where
B: Body + Send + Sync + 'static,
B::Data: Into<Bytes> + Send + fmt::Debug + 'static,
B::Error: StdError + Send + Sync + 'static,
QB: TryFrom<ReqBody> + Body + Send + Sync + 'static,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
SB: Body + Send + Sync + 'static,
SB::Data: Into<Bytes> + Send + fmt::Debug + 'static,
SB::Error: StdError + Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
Svc: Service<hyper::Request<ReqBody>, Response = hyper::Response<B>> + Send + 'static,
Svc::Future: Future<Output = Result<hyper::Response<B>, E>> + Send + 'static,
Svc: Service<hyper::Request<QB>, Response = hyper::Response<SB>> + Send + 'static,
Svc::Future: Future<Output = Result<hyper::Response<SB>, E>> + Send + 'static,
Svc::Error: StdError + Send + Sync,
{
async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
Expand All @@ -190,7 +198,7 @@ where
return;
}

let mut hyper_req = match req.strip_to_hyper() {
let mut hyper_req = match req.strip_to_hyper::<QB>() {
Ok(hyper_req) => hyper_req,
Err(_) => {
tracing::error!("strip request to hyper failed.");
Expand Down

0 comments on commit 05c907e

Please sign in to comment.