Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Tower compat body more generic #473

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/core/src/http/body/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ impl From<String> for ReqBody {
Self::Once(value.into())
}
}
impl TryFrom<ReqBody> for Incoming {
type Error = crate::Error;
fn try_from(body: ReqBody) -> Result<Self, Self::Error> {
match body {
ReqBody::None => Err(crate::Error::other("ReqBody::None cannot convert to Incoming")),
ReqBody::Once(_) => Err(crate::Error::other("ReqBody::Bytes cannot convert to Incoming")),
ReqBody::Hyper(body) => Ok(body),
ReqBody::Boxed(_) => Err(crate::Error::other("ReqBody::Boxed cannot convert to Incoming")),
}
}
}

impl From<&'static [u8]> for ReqBody {
fn from(value: &'static [u8]) -> Self {
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/http/errors/parse_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ pub type ParseResult<T> = Result<T, ParseError>;
#[non_exhaustive]
pub enum ParseError {
/// The Hyper request did not have a valid Content-Type header.
#[error("The Hyper request did not have a valid Content-Type header.")]
#[error("The request did not have a valid Content-Type header.")]
InvalidContentType,

/// The Hyper request's body is empty.
#[error("The Hyper request's body is empty.")]
#[error("The request's body is empty.")]
EmptyBody,

/// Parse error when parse from str.
Expand Down
13 changes: 11 additions & 2 deletions crates/core/src/http/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Http request.

use std::error::Error as StdError;
use std::fmt::{self, Formatter};

use bytes::Bytes;
Expand Down Expand Up @@ -175,7 +176,11 @@ impl Request {
/// Strip the request to [`hyper::Request`].
#[doc(hidden)]
#[inline]
pub fn strip_to_hyper(&mut self) -> Result<hyper::Request<ReqBody>, http::Error> {
pub fn strip_to_hyper<QB>(&mut self) -> Result<hyper::Request<QB>, crate::Error>
where
QB: TryFrom<ReqBody>,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
{
let mut builder = http::request::Builder::new()
.method(self.method.clone())
.uri(self.uri.clone())
Expand All @@ -186,7 +191,11 @@ impl Request {
if let Some(extensions) = builder.extensions_mut() {
*extensions = std::mem::take(&mut self.extensions);
}
builder.body(std::mem::take(&mut self.body))

std::mem::take(&mut self.body)
.try_into()
.map_err(crate::Error::other)
.and_then(|body| builder.body(body).map_err(crate::Error::other))
}

/// Merge data from [`hyper::Request`].
Expand Down
66 changes: 37 additions & 29 deletions crates/core/src/tower_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::io::{Error as IoError, ErrorKind};
use std::marker::PhantomData;
use std::task::{Context, Poll};

use futures_util::future::{BoxFuture, FutureExt};
Expand All @@ -15,39 +16,42 @@ use crate::http::{ReqBody, ResBody, StatusError};
use crate::{async_trait, Depot, FlowCtrl, Handler, Request, Response};

/// Trait for tower service compat.
pub trait TowerServiceCompat<B, E, Fut> {
pub trait TowerServiceCompat<QB, SB, E, Fut> {
/// Converts a tower service to a salvo handler.
fn compat(self) -> TowerServiceHandler<Self>
fn compat(self) -> TowerServiceHandler<Self, QB>
where
Self: Sized,
{
TowerServiceHandler(self)
TowerServiceHandler(self, PhantomData)
}
}

impl<T, B, E, Fut> TowerServiceCompat<B, E, Fut> for T
impl<T, QB, SB, E, Fut> TowerServiceCompat<QB, SB, E, Fut> for T
where
B: Body + Send + Sync + 'static,
B::Data: Into<Bytes> + Send + fmt::Debug + 'static,
B::Error: StdError + Send + Sync + 'static,
QB: From<ReqBody> + Send + Sync + 'static,
SB: Body + Send + Sync + 'static,
SB::Data: Into<Bytes> + Send + fmt::Debug + 'static,
SB::Error: StdError + Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
T: Service<hyper::Request<ReqBody>, Response = hyper::Response<B>, Future = Fut> + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<hyper::Response<B>, E>> + Send + 'static,
T: Service<hyper::Request<ReqBody>, Response = hyper::Response<SB>, Future = Fut> + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<hyper::Response<SB>, E>> + Send + 'static,
{
}

/// Tower service compat handler.
pub struct TowerServiceHandler<Svc>(Svc);
pub struct TowerServiceHandler<Svc, QB>(Svc, PhantomData<QB>);

#[async_trait]
impl<Svc, B, E, Fut> Handler for TowerServiceHandler<Svc>
impl<Svc, QB, SB, E, Fut> Handler for TowerServiceHandler<Svc, QB>
where
B: Body + Send + Sync + 'static,
B::Data: Into<Bytes> + Send + fmt::Debug + 'static,
B::Error: StdError + Send + Sync + 'static,
QB: TryFrom<ReqBody> + Body + Send + Sync + 'static,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
SB: Body + Send + Sync + 'static,
SB::Data: Into<Bytes> + Send + fmt::Debug + 'static,
SB::Error: StdError + Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
Svc: Service<hyper::Request<ReqBody>, Response = hyper::Response<B>, Future = Fut> + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<hyper::Response<B>, E>> + Send + 'static,
Svc: Service<hyper::Request<QB>, Response = hyper::Response<SB>, Future = Fut> + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<hyper::Response<SB>, E>> + Send + 'static,
{
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) {
let mut svc = self.0.clone();
Expand All @@ -56,7 +60,7 @@ where
res.render(StatusError::internal_server_error().cause("tower service not ready."));
return;
}
let hyper_req = match req.strip_to_hyper() {
let hyper_req = match req.strip_to_hyper::<QB>() {
Ok(hyper_req) => hyper_req,
Err(_) => {
tracing::error!("strip request to hyper failed.");
Expand Down Expand Up @@ -155,12 +159,14 @@ impl Service<hyper::Request<ReqBody>> for FlowCtrlService {
/// Trait for tower layer compat.
pub trait TowerLayerCompat {
/// Converts a tower layer to a salvo handler.
fn compat(self) -> TowerLayerHandler<Self::Service>
fn compat<QB>(self) -> TowerLayerHandler<Self::Service, QB>
where
QB: TryFrom<ReqBody> + Body + Send + Sync + 'static,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
Self: Layer<FlowCtrlService> + Sized,
Self::Service: tower::Service<hyper::Request<ReqBody>> + Sync + Send + 'static,
<Self::Service as Service<hyper::Request<ReqBody>>>::Future: Send,
<Self::Service as Service<hyper::Request<ReqBody>>>::Error: StdError + Send + Sync,
Self::Service: tower::Service<hyper::Request<QB>> + Sync + Send + 'static,
<Self::Service as Service<hyper::Request<QB>>>::Future: Send,
<Self::Service as Service<hyper::Request<QB>>>::Error: StdError + Send + Sync,
{
TowerLayerHandler(Buffer::new(self.layer(FlowCtrlService), 32))
}
Expand All @@ -169,17 +175,19 @@ pub trait TowerLayerCompat {
impl<T> TowerLayerCompat for T where T: Layer<FlowCtrlService> + Send + Sync + Sized + 'static {}

/// Tower service compat handler.
pub struct TowerLayerHandler<Svc: Service<hyper::Request<ReqBody>>>(Buffer<Svc, hyper::Request<ReqBody>>);
pub struct TowerLayerHandler<Svc: Service<hyper::Request<QB>>, QB>(Buffer<Svc, hyper::Request<QB>>);

#[async_trait]
impl<Svc, B, E> Handler for TowerLayerHandler<Svc>
impl<Svc, QB, SB, E> Handler for TowerLayerHandler<Svc, QB>
where
B: Body + Send + Sync + 'static,
B::Data: Into<Bytes> + Send + fmt::Debug + 'static,
B::Error: StdError + Send + Sync + 'static,
QB: TryFrom<ReqBody> + Body + Send + Sync + 'static,
<QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
SB: Body + Send + Sync + 'static,
SB::Data: Into<Bytes> + Send + fmt::Debug + 'static,
SB::Error: StdError + Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
Svc: Service<hyper::Request<ReqBody>, Response = hyper::Response<B>> + Send + 'static,
Svc::Future: Future<Output = Result<hyper::Response<B>, E>> + Send + 'static,
Svc: Service<hyper::Request<QB>, Response = hyper::Response<SB>> + Send + 'static,
Svc::Future: Future<Output = Result<hyper::Response<SB>, E>> + Send + 'static,
Svc::Error: StdError + Send + Sync,
{
async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
Expand All @@ -190,7 +198,7 @@ where
return;
}

let mut hyper_req = match req.strip_to_hyper() {
let mut hyper_req = match req.strip_to_hyper::<QB>() {
Ok(hyper_req) => hyper_req,
Err(_) => {
tracing::error!("strip request to hyper failed.");
Expand Down