From 1fe8cb0d9749a8b45aae65743a78fb90465a7b5b Mon Sep 17 00:00:00 2001 From: tottoto Date: Fri, 9 Feb 2024 07:20:49 +0900 Subject: [PATCH] refactor(transport): Move channel feature to channel module --- tonic/src/transport/channel/endpoint.rs | 15 +-- tonic/src/transport/channel/mod.rs | 3 +- .../{ => channel}/service/add_origin.rs | 0 .../{ => channel}/service/connection.rs | 3 +- .../{ => channel}/service/connector.rs | 9 +- .../{ => channel}/service/discover.rs | 0 .../{ => channel}/service/executor.rs | 0 tonic/src/transport/channel/service/io.rs | 69 ++++++++++++++ tonic/src/transport/channel/service/mod.rs | 26 ++++++ .../{ => channel}/service/reconnect.rs | 0 tonic/src/transport/channel/service/tls.rs | 92 +++++++++++++++++++ .../{ => channel}/service/user_agent.rs | 0 tonic/src/transport/channel/tls.rs | 2 +- tonic/src/transport/mod.rs | 2 +- tonic/src/transport/service/io.rs | 81 ---------------- tonic/src/transport/service/mod.rs | 25 +---- tonic/src/transport/service/tls.rs | 88 +----------------- 17 files changed, 213 insertions(+), 202 deletions(-) rename tonic/src/transport/{ => channel}/service/add_origin.rs (100%) rename tonic/src/transport/{ => channel}/service/connection.rs (97%) rename tonic/src/transport/{ => channel}/service/connector.rs (98%) rename tonic/src/transport/{ => channel}/service/discover.rs (100%) rename tonic/src/transport/{ => channel}/service/executor.rs (100%) create mode 100644 tonic/src/transport/channel/service/io.rs create mode 100644 tonic/src/transport/channel/service/mod.rs rename tonic/src/transport/{ => channel}/service/reconnect.rs (100%) create mode 100644 tonic/src/transport/channel/service/tls.rs rename tonic/src/transport/{ => channel}/service/user_agent.rs (100%) diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 598e89b70..a4358d27b 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -1,10 +1,11 @@ -use super::super::service; +use super::service::Connector; +#[cfg(feature = "tls")] +use super::service::TlsConnector; +use super::service::{Executor, SharedExec}; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; -#[cfg(feature = "tls")] -use crate::transport::service::TlsConnector; -use crate::transport::{service::SharedExec, Error, Executor}; +use crate::transport::Error; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; @@ -301,12 +302,12 @@ impl Endpoint { self } - pub(crate) fn connector(&self, c: C) -> service::Connector { + pub(crate) fn connector(&self, c: C) -> Connector { #[cfg(feature = "tls")] - let connector = service::Connector::new(c, self.tls.clone()); + let connector = Connector::new(c, self.tls.clone()); #[cfg(not(feature = "tls"))] - let connector = service::Connector::new(c); + let connector = Connector::new(c); connector } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..8a9f1b8c4 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -1,6 +1,7 @@ //! Client implementation and builder. mod endpoint; +pub(crate) mod service; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; @@ -9,7 +10,7 @@ pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; -use super::service::{Connection, DynamicServiceStream, SharedExec}; +use self::service::{Connection, DynamicServiceStream, SharedExec}; use crate::body::BoxBody; use crate::transport::Executor; use bytes::Bytes; diff --git a/tonic/src/transport/service/add_origin.rs b/tonic/src/transport/channel/service/add_origin.rs similarity index 100% rename from tonic/src/transport/service/add_origin.rs rename to tonic/src/transport/channel/service/add_origin.rs diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/channel/service/connection.rs similarity index 97% rename from tonic/src/transport/service/connection.rs rename to tonic/src/transport/channel/service/connection.rs index 46a88dda5..ece496489 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -1,4 +1,5 @@ -use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; +use super::{reconnect::Reconnect, AddOrigin, UserAgent}; +use crate::transport::service::GrpcTimeout; use crate::{ body::BoxBody, transport::{BoxFuture, Endpoint}, diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/channel/service/connector.rs similarity index 98% rename from tonic/src/transport/service/connector.rs rename to tonic/src/transport/channel/service/connector.rs index 6645696ea..df6aba3cc 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/channel/service/connector.rs @@ -1,13 +1,14 @@ -use super::super::BoxFuture; -use super::io::BoxedIo; -#[cfg(feature = "tls")] -use super::tls::TlsConnector; use http::Uri; use std::fmt; use std::task::{Context, Poll}; use tower::make::MakeConnection; use tower_service::Service; +use super::io::BoxedIo; +#[cfg(feature = "tls")] +use super::tls::TlsConnector; +use crate::transport::BoxFuture; + pub(crate) struct Connector { inner: C, #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/channel/service/discover.rs similarity index 100% rename from tonic/src/transport/service/discover.rs rename to tonic/src/transport/channel/service/discover.rs diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/channel/service/executor.rs similarity index 100% rename from tonic/src/transport/service/executor.rs rename to tonic/src/transport/channel/service/executor.rs diff --git a/tonic/src/transport/channel/service/io.rs b/tonic/src/transport/channel/service/io.rs new file mode 100644 index 000000000..f419388a6 --- /dev/null +++ b/tonic/src/transport/channel/service/io.rs @@ -0,0 +1,69 @@ +use std::io::{self, IoSlice}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use hyper::client::connect::{Connected as HyperConnected, Connection}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(in crate::transport) trait Io: + AsyncRead + AsyncWrite + Send + 'static +{ +} + +impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} + +pub(crate) struct BoxedIo(Pin>); + +impl BoxedIo { + pub(in crate::transport) fn new(io: I) -> Self { + BoxedIo(Box::pin(io)) + } +} + +impl Connection for BoxedIo { + fn connected(&self) -> HyperConnected { + HyperConnected::new() + } +} + +#[cfg(feature = "channel")] +impl AsyncRead for BoxedIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +#[cfg(feature = "channel")] +impl AsyncWrite for BoxedIo { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} diff --git a/tonic/src/transport/channel/service/mod.rs b/tonic/src/transport/channel/service/mod.rs new file mode 100644 index 000000000..bb40ae592 --- /dev/null +++ b/tonic/src/transport/channel/service/mod.rs @@ -0,0 +1,26 @@ +mod add_origin; +pub(crate) use self::add_origin::AddOrigin; + +mod connector; +pub(crate) use self::connector::Connector; + +mod connection; +pub(crate) use self::connection::Connection; + +mod discover; +pub(crate) use self::discover::DynamicServiceStream; + +pub(crate) mod executor; +pub(crate) use self::executor::{Executor, SharedExec}; + +pub(crate) mod io; + +mod reconnect; + +mod user_agent; +pub(crate) use self::user_agent::UserAgent; + +#[cfg(feature = "tls")] +mod tls; +#[cfg(feature = "tls")] +pub(crate) use self::tls::TlsConnector; diff --git a/tonic/src/transport/service/reconnect.rs b/tonic/src/transport/channel/service/reconnect.rs similarity index 100% rename from tonic/src/transport/service/reconnect.rs rename to tonic/src/transport/channel/service/reconnect.rs diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs new file mode 100644 index 000000000..dcf894595 --- /dev/null +++ b/tonic/src/transport/channel/service/tls.rs @@ -0,0 +1,92 @@ +use std::fmt; +use std::io::Cursor; +use std::sync::Arc; + +use rustls_pki_types::ServerName; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::rustls::RootCertStore; +use tokio_rustls::{rustls::ClientConfig, TlsConnector as RustlsConnector}; + +use super::io::BoxedIo; +use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2}; +use crate::transport::tls::{Certificate, Identity}; + +#[derive(Debug)] +enum TlsError { + H2NotNegotiated, +} + +impl fmt::Display for TlsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."), + } + } +} + +impl std::error::Error for TlsError {} + +#[derive(Clone)] +pub(crate) struct TlsConnector { + config: Arc, + domain: Arc>, +} + +impl TlsConnector { + pub(crate) fn new( + ca_cert: Option, + identity: Option, + domain: &str, + ) -> Result { + let builder = ClientConfig::builder(); + let mut roots = RootCertStore::empty(); + + #[cfg(feature = "tls-roots")] + roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); + + #[cfg(feature = "tls-webpki-roots")] + roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + if let Some(cert) = ca_cert { + add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + } + + let builder = builder.with_root_certificates(roots); + let mut config = match identity { + Some(identity) => { + let (client_cert, client_key) = load_identity(identity)?; + builder.with_client_auth_cert(client_cert, client_key)? + } + None => builder.with_no_client_auth(), + }; + + config.alpn_protocols.push(ALPN_H2.into()); + Ok(Self { + config: Arc::new(config), + domain: Arc::new(ServerName::try_from(domain)?.to_owned()), + }) + } + + pub(crate) async fn connect(&self, io: I) -> Result + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let io = RustlsConnector::from(self.config.clone()) + .connect(self.domain.as_ref().to_owned(), io) + .await?; + + let (_, session) = io.get_ref(); + if session.alpn_protocol() != Some(ALPN_H2) { + return Err(TlsError::H2NotNegotiated)?; + } + + Ok(BoxedIo::new(io)) + } +} + +#[cfg(feature = "channel")] +impl fmt::Debug for TlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConnector").finish() + } +} diff --git a/tonic/src/transport/service/user_agent.rs b/tonic/src/transport/channel/service/user_agent.rs similarity index 100% rename from tonic/src/transport/service/user_agent.rs rename to tonic/src/transport/channel/service/user_agent.rs diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index cead6867a..b9e04f544 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -1,5 +1,5 @@ +use super::service::TlsConnector; use crate::transport::{ - service::TlsConnector, tls::{Certificate, Identity}, Error, }; diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 9fdaab6ad..5a41e9524 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -112,7 +112,7 @@ pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; pub use hyper::{Body, Uri}; #[cfg(feature = "channel")] -pub(crate) use self::service::executor::Executor; +pub(crate) use self::channel::service::executor::Executor; #[cfg(all(feature = "channel", feature = "tls"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))] diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 865ea814f..7821f691c 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,6 +1,4 @@ use crate::transport::server::Connected; -#[cfg(feature = "channel")] -use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -9,85 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "tls")] use tokio_rustls::server::TlsStream; -pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static -{ -} - -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} - -#[cfg(feature = "channel")] -pub(crate) struct BoxedIo(Pin>); - -#[cfg(feature = "channel")] -impl BoxedIo { - pub(in crate::transport) fn new(io: I) -> Self { - BoxedIo(Box::pin(io)) - } -} - -#[cfg(feature = "channel")] -impl Connection for BoxedIo { - fn connected(&self) -> HyperConnected { - HyperConnected::new() - } -} - -#[cfg(feature = "channel")] -impl Connected for BoxedIo { - type ConnectInfo = NoneConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - NoneConnectInfo - } -} - -#[cfg(feature = "channel")] -#[derive(Copy, Clone)] -pub(crate) struct NoneConnectInfo; - -#[cfg(feature = "channel")] -impl AsyncRead for BoxedIo { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -#[cfg(feature = "channel")] -impl AsyncWrite for BoxedIo { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } -} - pub(crate) enum ServerIo { Io(IO), #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 3a51d6a90..98129790f 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,34 +1,13 @@ -#[cfg(feature = "channel")] -mod add_origin; -#[cfg(feature = "channel")] -mod connection; -#[cfg(feature = "channel")] -mod connector; -#[cfg(feature = "channel")] -mod discover; -#[cfg(feature = "channel")] -pub(crate) mod executor; pub(crate) mod grpc_timeout; -mod io; -#[cfg(feature = "channel")] -mod reconnect; +pub(crate) mod io; mod router; #[cfg(feature = "tls")] -mod tls; -#[cfg(feature = "channel")] -mod user_agent; +pub(super) mod tls; pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; #[cfg(feature = "tls")] pub(crate) use self::tls::TlsAcceptor; -#[cfg(all(feature = "channel", feature = "tls"))] -pub(crate) use self::tls::TlsConnector; -#[cfg(feature = "channel")] -pub(crate) use self::{ - add_origin::AddOrigin, connection::Connection, connector::Connector, - discover::DynamicServiceStream, executor::SharedExec, user_agent::UserAgent, -}; pub use self::router::Routes; pub use self::router::RoutesBuilder; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index ed10102b4..62f763e2e 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -1,103 +1,27 @@ use std::io::Cursor; use std::{fmt, sync::Arc}; -#[cfg(feature = "channel")] -use rustls_pki_types::ServerName; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg(feature = "channel")] -use tokio_rustls::{rustls::ClientConfig, TlsConnector as RustlsConnector}; use tokio_rustls::{ rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}, TlsAcceptor as RustlsAcceptor, }; -#[cfg(feature = "channel")] -use super::io::BoxedIo; use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; /// h2 alpn in plain format for rustls. -const ALPN_H2: &[u8] = b"h2"; +pub(crate) const ALPN_H2: &[u8] = b"h2"; #[derive(Debug)] enum TlsError { - #[cfg(feature = "channel")] - H2NotNegotiated, CertificateParseError, PrivateKeyParseError, } -#[cfg(feature = "channel")] -#[derive(Clone)] -pub(crate) struct TlsConnector { - config: Arc, - domain: Arc>, -} - -#[cfg(feature = "channel")] -impl TlsConnector { - pub(crate) fn new( - ca_cert: Option, - identity: Option, - domain: &str, - ) -> Result { - let builder = ClientConfig::builder(); - let mut roots = RootCertStore::empty(); - - #[cfg(feature = "tls-roots")] - roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); - - #[cfg(feature = "tls-webpki-roots")] - roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - - if let Some(cert) = ca_cert { - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; - } - - let builder = builder.with_root_certificates(roots); - let mut config = match identity { - Some(identity) => { - let (client_cert, client_key) = load_identity(identity)?; - builder.with_client_auth_cert(client_cert, client_key)? - } - None => builder.with_no_client_auth(), - }; - - config.alpn_protocols.push(ALPN_H2.into()); - Ok(Self { - config: Arc::new(config), - domain: Arc::new(ServerName::try_from(domain)?.to_owned()), - }) - } - - #[cfg(feature = "channel")] - pub(crate) async fn connect(&self, io: I) -> Result - where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - let io = RustlsConnector::from(self.config.clone()) - .connect(self.domain.as_ref().to_owned(), io) - .await?; - - let (_, session) = io.get_ref(); - if session.alpn_protocol() != Some(ALPN_H2) { - return Err(TlsError::H2NotNegotiated)?; - } - - Ok(BoxedIo::new(io)) - } -} - -#[cfg(feature = "channel")] -impl fmt::Debug for TlsConnector { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsConnector").finish() - } -} - #[derive(Clone)] pub(crate) struct TlsAcceptor { inner: Arc, @@ -153,8 +77,6 @@ impl fmt::Debug for TlsAcceptor { impl fmt::Display for TlsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - #[cfg(feature = "channel")] - TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."), TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."), TlsError::PrivateKeyParseError => write!( f, @@ -166,21 +88,21 @@ impl fmt::Display for TlsError { impl std::error::Error for TlsError {} -fn load_identity( +pub(crate) fn load_identity( identity: Identity, -) -> Result<(Vec>, PrivateKeyDer<'static>), TlsError> { +) -> Result<(Vec>, PrivateKeyDer<'static>), crate::Error> { let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert)) .collect::, _>>() .map_err(|_| TlsError::CertificateParseError)?; let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else { - return Err(TlsError::PrivateKeyParseError); + return Err(TlsError::PrivateKeyParseError.into()); }; Ok((cert, key)) } -fn add_certs_from_pem( +pub(crate) fn add_certs_from_pem( mut certs: &mut dyn std::io::BufRead, roots: &mut RootCertStore, ) -> Result<(), crate::Error> {