From cfcbc7b26d1cf5f1ac922a447a2843ed0117f2f3 Mon Sep 17 00:00:00 2001 From: tottoto Date: Sat, 19 Oct 2024 14:31:47 +0900 Subject: [PATCH] feat(web): Relax GrpcWebService request body type --- tonic-web/src/lib.rs | 16 +++++++++++----- tonic-web/src/service.rs | 21 ++++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index aca8f5f2f..52c73f23c 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -153,23 +153,27 @@ where #[derive(Debug, Clone)] pub struct CorsGrpcWeb(tower_http::cors::Cors>); -impl Service> for CorsGrpcWeb +impl Service> for CorsGrpcWeb where S: Service, Response = http::Response>, + B: http_body::Body + Send + 'static, + B::Error: Into + std::fmt::Display, { type Response = S::Response; type Error = S::Error; - type Future = - > as Service>>::Future; + type Future = > as Service>>::Future; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.0.poll_ready(cx) + > as Service>>::poll_ready( + &mut self.0, + cx, + ) } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { self.0.call(req) } } @@ -181,6 +185,8 @@ where const NAME: &'static str = S::NAME; } +type BoxError = Box; + pub(crate) mod util { pub(crate) mod base64 { use base64::{ diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index 199dc4103..820746ac3 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -1,3 +1,4 @@ +use core::fmt; use std::future::Future; use std::pin::Pin; use std::task::{ready, Context, Poll}; @@ -61,9 +62,11 @@ where } } -impl Service> for GrpcWebService +impl Service> for GrpcWebService where S: Service, Response = Response>, + B: http_body::Body + Send + 'static, + B::Error: Into + fmt::Display, { type Response = S::Response; type Error = S::Error; @@ -73,7 +76,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { match RequestKind::new(req.headers(), req.method(), req.version()) { // A valid grpc-web request, regardless of HTTP version. // @@ -113,7 +116,7 @@ where debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE)); ResponseFuture { case: Case::Other { - future: self.inner.call(req), + future: self.inner.call(req.map(tonic::body::boxed)), }, } } @@ -194,7 +197,11 @@ impl<'a> RequestKind<'a> { // Mutating request headers to conform to a gRPC request is not really // necessary for us at this point. We could remove most of these except // maybe for inserting `header::TE`, which tonic should check? -fn coerce_request(mut req: Request, encoding: Encoding) -> Request { +fn coerce_request(mut req: Request, encoding: Encoding) -> Request +where + B: http_body::Body + Send + 'static, + B::Error: Into + fmt::Display, +{ req.headers_mut().remove(header::CONTENT_LENGTH); req.headers_mut() @@ -211,7 +218,11 @@ fn coerce_request(mut req: Request, encoding: Encoding) -> Request, encoding: Encoding) -> Response { +fn coerce_response(res: Response, encoding: Encoding) -> Response +where + B: http_body::Body + Send + 'static, + B::Error: Into + fmt::Display, +{ let mut res = res .map(|b| GrpcWebCall::response(b, encoding)) .map(BoxBody::new);