From 14612082bb61217593bbd9ffe6ad6a6a7b3fbc3a Mon Sep 17 00:00:00 2001 From: tottoto Date: Sat, 28 Sep 2024 08:15:51 +0900 Subject: [PATCH] feat(web): Add CorsGrpcWebResponseFuture to make tower-http internal dependency --- tonic-web/Cargo.toml | 1 - tonic-web/src/lib.rs | 42 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index 85e5ccb9c..c8cf054d4 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -43,7 +43,6 @@ allowed_external_types = [ # not major released "futures_core::stream::Stream", "http_body_util::combinators::box_body::UnsyncBoxBody", - "tower_http::cors::Cors", "tower_layer::Layer", "tower_service::Service", ] diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index d768a7005..2648e06a5 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -108,7 +108,14 @@ mod layer; mod service; use http::header::HeaderName; -use std::time::Duration; +use pin_project::pin_project; +use std::{ + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; use tonic::{body::BoxBody, server::NamedService, Status}; use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_layer::Layer; @@ -155,18 +162,37 @@ where { type Response = S::Response; type Error = S::Error; - type Future = - > as Service>>::Future; + type Future = CorsGrpcWebResponseFuture; - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: http::Request) -> Self::Future { - self.0.call(req) + CorsGrpcWebResponseFuture(self.0.call(req)) + } +} + +/// Response Future for the [`CorsGrpcWeb`]. +#[pin_project] +pub struct CorsGrpcWebResponseFuture( + #[pin] tower_http::cors::ResponseFuture>, +); + +impl Future for CorsGrpcWebResponseFuture +where + F: Future, E>>, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().0.poll(cx) + } +} + +impl fmt::Debug for CorsGrpcWebResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("CorsGrpcWebResponseFuture").finish() } }