From 20e0442f89f9b7697451a4d6705d6f62cf639d34 Mon Sep 17 00:00:00 2001 From: Chrislearn Young Date: Mon, 30 Oct 2023 21:35:14 +0800 Subject: [PATCH] Make Tower compat body more generic --- crates/core/src/http/body/req.rs | 11 ++++++ crates/core/src/http/errors/parse_error.rs | 4 +- crates/core/src/http/request.rs | 13 +++++- crates/core/src/tower_compat.rs | 46 ++++++++++++---------- 4 files changed, 50 insertions(+), 24 deletions(-) diff --git a/crates/core/src/http/body/req.rs b/crates/core/src/http/body/req.rs index 4d6cf1f27..4d102ce33 100644 --- a/crates/core/src/http/body/req.rs +++ b/crates/core/src/http/body/req.rs @@ -129,6 +129,17 @@ impl From for ReqBody { Self::Once(value.into()) } } +impl TryFrom for Incoming { + type Error = crate::Error; + fn try_from(body: ReqBody) -> Result { + match body { + ReqBody::None => Err(crate::Error::other("ReqBody::None cannot convert to Incoming")), + ReqBody::Once(_) => Err(crate::Error::other("ReqBody::Bytes cannot convert to Incoming")), + ReqBody::Hyper(body) => Ok(body), + ReqBody::Boxed(_) => Err(crate::Error::other("ReqBody::Boxed cannot convert to Incoming")), + } + } +} impl From<&'static [u8]> for ReqBody { fn from(value: &'static [u8]) -> Self { diff --git a/crates/core/src/http/errors/parse_error.rs b/crates/core/src/http/errors/parse_error.rs index 73d9fc175..3888914cd 100644 --- a/crates/core/src/http/errors/parse_error.rs +++ b/crates/core/src/http/errors/parse_error.rs @@ -15,11 +15,11 @@ pub type ParseResult = Result; #[non_exhaustive] pub enum ParseError { /// The Hyper request did not have a valid Content-Type header. - #[error("The Hyper request did not have a valid Content-Type header.")] + #[error("The request did not have a valid Content-Type header.")] InvalidContentType, /// The Hyper request's body is empty. - #[error("The Hyper request's body is empty.")] + #[error("The request's body is empty.")] EmptyBody, /// Parse error when parse from str. diff --git a/crates/core/src/http/request.rs b/crates/core/src/http/request.rs index fee7246d2..c666f4494 100644 --- a/crates/core/src/http/request.rs +++ b/crates/core/src/http/request.rs @@ -1,6 +1,7 @@ //! Http request. use std::fmt::{self, Formatter}; +use std::error::Error as StdError; use bytes::Bytes; #[cfg(feature = "cookie")] @@ -175,7 +176,11 @@ impl Request { /// Strip the request to [`hyper::Request`]. #[doc(hidden)] #[inline] - pub fn strip_to_hyper(&mut self) -> Result, http::Error> { + pub fn strip_to_hyper(&mut self) -> Result, crate::Error> + where + QB: TryFrom, + >::Error: StdError + Send + Sync + 'static, + { let mut builder = http::request::Builder::new() .method(self.method.clone()) .uri(self.uri.clone()) @@ -186,7 +191,11 @@ impl Request { if let Some(extensions) = builder.extensions_mut() { *extensions = std::mem::take(&mut self.extensions); } - builder.body(std::mem::take(&mut self.body)) + + std::mem::take(&mut self.body) + .try_into() + .map_err(crate::Error::other) + .and_then(|body| builder.body(body).map_err(crate::Error::other)) } /// Merge data from [`hyper::Request`]. diff --git a/crates/core/src/tower_compat.rs b/crates/core/src/tower_compat.rs index 91d7bda45..f60e143c9 100644 --- a/crates/core/src/tower_compat.rs +++ b/crates/core/src/tower_compat.rs @@ -3,6 +3,7 @@ use std::error::Error as StdError; use std::fmt; use std::future::Future; use std::io::{Error as IoError, ErrorKind}; +use std::marker::PhantomData; use std::task::{Context, Poll}; use futures_util::future::{BoxFuture, FutureExt}; @@ -17,33 +18,34 @@ use crate::{async_trait, Depot, FlowCtrl, Handler, Request, Response}; /// Trait for tower service compat. pub trait TowerServiceCompat { /// Converts a tower service to a salvo handler. - fn compat(self) -> TowerServiceHandler + fn compat(self) -> TowerServiceHandler where Self: Sized, { - TowerServiceHandler(self) + TowerServiceHandler(self, PhantomData) } } impl TowerServiceCompat for T where - QB: Into + Send + Sync + 'static, + QB: From + Send + Sync + 'static, SB: Body + Send + Sync + 'static, SB::Data: Into + Send + fmt::Debug + 'static, SB::Error: StdError + Send + Sync + 'static, E: StdError + Send + Sync + 'static, - T: Service, Response = hyper::Response, Future = Fut> + Clone + Send + Sync + 'static, + T: Service, Response = hyper::Response, Future = Fut> + Clone + Send + Sync + 'static, Fut: Future, E>> + Send + 'static, { } /// Tower service compat handler. -pub struct TowerServiceHandler(Svc); +pub struct TowerServiceHandler(Svc, PhantomData); #[async_trait] -impl Handler for TowerServiceHandler +impl Handler for TowerServiceHandler where - QB: Into + Send + Sync + 'static, + QB: TryFrom + Body + Send + Sync + 'static, + >::Error: StdError + Send + Sync + 'static, SB: Body + Send + Sync + 'static, SB::Data: Into + Send + fmt::Debug + 'static, SB::Error: StdError + Send + Sync + 'static, @@ -58,7 +60,7 @@ where res.render(StatusError::internal_server_error().cause("tower service not ready.")); return; } - let hyper_req = match req.strip_to_hyper() { + let hyper_req = match req.strip_to_hyper::() { Ok(hyper_req) => hyper_req, Err(_) => { tracing::error!("strip request to hyper failed."); @@ -157,12 +159,14 @@ impl Service> for FlowCtrlService { /// Trait for tower layer compat. pub trait TowerLayerCompat { /// Converts a tower layer to a salvo handler. - fn compat(self) -> TowerLayerHandler + fn compat(self) -> TowerLayerHandler where + QB: TryFrom + Body + Send + Sync + 'static, + >::Error: StdError + Send + Sync + 'static, Self: Layer + Sized, - Self::Service: tower::Service> + Sync + Send + 'static, - >>::Future: Send, - >>::Error: StdError + Send + Sync, + Self::Service: tower::Service> + Sync + Send + 'static, + >>::Future: Send, + >>::Error: StdError + Send + Sync, { TowerLayerHandler(Buffer::new(self.layer(FlowCtrlService), 32)) } @@ -171,17 +175,19 @@ pub trait TowerLayerCompat { impl TowerLayerCompat for T where T: Layer + Send + Sync + Sized + 'static {} /// Tower service compat handler. -pub struct TowerLayerHandler>>(Buffer>); +pub struct TowerLayerHandler>, QB>(Buffer>); #[async_trait] -impl Handler for TowerLayerHandler +impl Handler for TowerLayerHandler where - B: Body + Send + Sync + 'static, - B::Data: Into + Send + fmt::Debug + 'static, - B::Error: StdError + Send + Sync + 'static, + QB: TryFrom + Body + Send + Sync + 'static, + >::Error: StdError + Send + Sync + 'static, + SB: Body + Send + Sync + 'static, + SB::Data: Into + Send + fmt::Debug + 'static, + SB::Error: StdError + Send + Sync + 'static, E: StdError + Send + Sync + 'static, - Svc: Service, Response = hyper::Response> + Send + 'static, - Svc::Future: Future, E>> + Send + 'static, + Svc: Service, Response = hyper::Response> + Send + 'static, + Svc::Future: Future, E>> + Send + 'static, Svc::Error: StdError + Send + Sync, { async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { @@ -192,7 +198,7 @@ where return; } - let mut hyper_req = match req.strip_to_hyper() { + let mut hyper_req = match req.strip_to_hyper::() { Ok(hyper_req) => hyper_req, Err(_) => { tracing::error!("strip request to hyper failed.");