diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index 4092ca6c6..4e53061a7 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -116,9 +116,8 @@ impl Service> for InterceptedServ where S: Service, Response = http::Response>, I: Interceptor, - ResBody: Default, { - type Response = http::Response; + type Response = http::Response>; type Error = S::Error; type Future = ResponseFuture; @@ -194,21 +193,79 @@ enum Kind { impl Future for ResponseFuture where F: Future, E>>, - B: Default, { - type Output = Result, E>; + type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { - KindProj::Future(future) => future.poll(cx), + KindProj::Future(future) => future.poll(cx).map_ok(|res| res.map(ResponseBody::wrap)), KindProj::Status(status) => { - let response = status.take().unwrap().into_http(); + let (parts, ()) = status.take().unwrap().into_http::<()>().into_parts(); + let response = http::Response::from_parts(parts, ResponseBody::::empty()); Poll::Ready(Ok(response)) } } } } +/// Response body for [`InterceptedService`]. +#[pin_project] +#[derive(Debug)] +pub struct ResponseBody { + #[pin] + kind: ResponseBodyKind, +} + +#[pin_project(project = ResponseBodyKindProj)] +#[derive(Debug)] +enum ResponseBodyKind { + Empty, + Wrap(#[pin] B), +} + +impl ResponseBody { + fn new(kind: ResponseBodyKind) -> Self { + Self { kind } + } + + fn empty() -> Self { + Self::new(ResponseBodyKind::Empty) + } + + fn wrap(body: B) -> Self { + Self::new(ResponseBodyKind::Wrap(body)) + } +} + +impl http_body::Body for ResponseBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.project().kind.project() { + ResponseBodyKindProj::Empty => Poll::Ready(None), + ResponseBodyKindProj::Wrap(body) => body.poll_frame(cx), + } + } + + fn size_hint(&self) -> http_body::SizeHint { + match &self.kind { + ResponseBodyKind::Empty => http_body::SizeHint::with_exact(0), + ResponseBodyKind::Wrap(body) => body.size_hint(), + } + } + + fn is_end_stream(&self) -> bool { + match &self.kind { + ResponseBodyKind::Empty => true, + ResponseBodyKind::Wrap(body) => body.is_end_stream(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 77bf00710..13958c978 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -395,11 +395,8 @@ impl Server { /// route around different services. pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response, Error = Infallible> - + NamedService - + Clone - + Send - + 'static, + S: Service, Error = Infallible> + NamedService + Clone + Send + 'static, + S::Response: axum::response::IntoResponse, S::Future: Send + 'static, L: Clone, { @@ -416,11 +413,8 @@ impl Server { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response, Error = Infallible> - + NamedService - + Clone - + Send - + 'static, + S: Service, Error = Infallible> + NamedService + Clone + Send + 'static, + S::Response: axum::response::IntoResponse, S::Future: Send + 'static, L: Clone, { @@ -732,11 +726,8 @@ impl Router { /// Add a new service to this router. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> - + NamedService - + Clone - + Send - + 'static, + S: Service, Error = Infallible> + NamedService + Clone + Send + 'static, + S::Response: axum::response::IntoResponse, S::Future: Send + 'static, { self.routes = self.routes.add_service(svc); @@ -750,11 +741,8 @@ impl Router { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response, Error = Infallible> - + NamedService - + Clone - + Send - + 'static, + S: Service, Error = Infallible> + NamedService + Clone + Send + 'static, + S::Response: axum::response::IntoResponse, S::Future: Send + 'static, { if let Some(svc) = svc {