From faec3d5b103c217d9d2f7209c267c90b867a8aa4 Mon Sep 17 00:00:00 2001 From: tottoto Date: Sun, 19 Nov 2023 09:08:10 +0900 Subject: [PATCH] feat(channel): Make channel feature additive --- examples/Cargo.toml | 4 +- tonic/Cargo.toml | 23 ++-- tonic/src/transport/channel/endpoint.rs | 17 +-- tonic/src/transport/channel/mod.rs | 3 +- .../{ => channel}/service/add_origin.rs | 0 .../{ => channel}/service/connection.rs | 3 +- .../{ => channel}/service/connector.rs | 9 +- .../{ => channel}/service/discover.rs | 0 .../{ => channel}/service/executor.rs | 0 tonic/src/transport/channel/service/io.rs | 69 ++++++++++++ tonic/src/transport/channel/service/mod.rs | 26 +++++ .../{ => channel}/service/reconnect.rs | 0 tonic/src/transport/channel/service/tls.rs | 100 ++++++++++++++++++ .../{ => channel}/service/user_agent.rs | 0 tonic/src/transport/channel/tls.rs | 2 +- tonic/src/transport/error.rs | 6 ++ tonic/src/transport/mod.rs | 9 +- tonic/src/transport/service/io.rs | 73 ------------- tonic/src/transport/service/mod.rs | 19 +--- tonic/src/transport/service/tls.rs | 96 ++--------------- 20 files changed, 255 insertions(+), 204 deletions(-) rename tonic/src/transport/{ => channel}/service/add_origin.rs (100%) rename tonic/src/transport/{ => channel}/service/connection.rs (97%) rename tonic/src/transport/{ => channel}/service/connector.rs (98%) rename tonic/src/transport/{ => channel}/service/discover.rs (100%) rename tonic/src/transport/{ => channel}/service/executor.rs (100%) create mode 100644 tonic/src/transport/channel/service/io.rs create mode 100644 tonic/src/transport/channel/service/mod.rs rename tonic/src/transport/{ => channel}/service/reconnect.rs (100%) create mode 100644 tonic/src/transport/channel/service/tls.rs rename tonic/src/transport/{ => channel}/service/user_agent.rs (100%) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 2239336e1..b485033e3 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -298,13 +298,13 @@ hyper-warp-multiplex = ["hyper-warp"] uds = ["tokio-stream/net", "dep:tower", "dep:hyper"] streaming = ["tokio-stream", "dep:h2"] mock = ["tokio-stream", "dep:tower"] -tower = ["dep:hyper", "dep:tower", "dep:http"] +tower = ["dep:hyper", "tower/timeout", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] dynamic-load-balance = ["dep:tower"] -timeout = ["tokio/time", "dep:tower"] +timeout = ["tokio/time", "tower/timeout"] tls-client-auth = ["tonic/tls"] types = ["dep:tonic-types"] h2c = ["dep:hyper", "dep:tower", "dep:http"] diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 99388e357..a9f452300 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -26,23 +26,26 @@ version = "0.11.0" codegen = ["dep:async-trait"] gzip = ["dep:flate2"] zstd = ["dep:zstd"] -default = ["transport", "codegen", "prost"] +default = ["channel", "codegen", "prost"] prost = ["dep:prost"] tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"] tls-roots = ["tls-roots-common", "dep:rustls-native-certs"] -tls-roots-common = ["tls"] +tls-roots-common = ["tls", "channel"] tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"] transport = [ "dep:async-stream", "dep:axum", - "channel", "dep:h2", - "dep:hyper", + "dep:hyper", "hyper?/server", "dep:tokio", "tokio?/net", "tokio?/time", - "dep:tower", + "dep:tower", "tower?/util", "tower?/limit", +] +channel = [ + "transport", + "dep:hyper", "hyper?/client", + "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make", "dep:hyper-timeout", ] -channel = [] # [[bench]] # name = "bench_main" @@ -68,13 +71,15 @@ async-trait = {version = "0.1.13", optional = true} # transport h2 = {version = "0.3.24", optional = true} -hyper = {version = "0.14.26", features = ["full"], optional = true} -hyper-timeout = {version = "0.4", optional = true} +hyper = {version = "0.14.26", features = ["http1", "http2", "runtime", "stream"], optional = true} tokio = {version = "1.0.1", optional = true} tokio-stream = "0.1" -tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} +tower = {version = "0.4.7", default-features = false, optional = true} axum = {version = "0.6.9", default_features = false, optional = true} +# channel +hyper-timeout = {version = "0.4", optional = true} + # rustls async-stream = { version = "0.3", optional = true } rustls-pki-types = { version = "1.0", optional = true } diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 597b4bc4d..481082c85 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -1,10 +1,11 @@ -use super::super::service; +use super::service::Connector; +#[cfg(feature = "tls")] +use super::service::TlsConnector; +use super::service::{Executor, SharedExec}; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; -#[cfg(feature = "tls")] -use crate::transport::service::TlsConnector; -use crate::transport::{service::SharedExec, Error, Executor}; +use crate::transport::Error; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; @@ -318,15 +319,15 @@ impl Endpoint { self } - pub(crate) fn connector(&self, c: C) -> service::Connector { + pub(crate) fn connector(&self, c: C) -> Connector { #[cfg(all(feature = "tls", not(feature = "tls-roots-common")))] - let connector = service::Connector::new(c, self.tls.clone()); + let connector = Connector::new(c, self.tls.clone()); #[cfg(all(feature = "tls", feature = "tls-roots-common"))] - let connector = service::Connector::new(c, self.tls.clone(), self.tls_assume_http2); + let connector = Connector::new(c, self.tls.clone(), self.tls_assume_http2); #[cfg(not(feature = "tls"))] - let connector = service::Connector::new(c); + let connector = Connector::new(c); connector } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..8a9f1b8c4 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -1,6 +1,7 @@ //! Client implementation and builder. mod endpoint; +pub(crate) mod service; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; @@ -9,7 +10,7 @@ pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; -use super::service::{Connection, DynamicServiceStream, SharedExec}; +use self::service::{Connection, DynamicServiceStream, SharedExec}; use crate::body::BoxBody; use crate::transport::Executor; use bytes::Bytes; diff --git a/tonic/src/transport/service/add_origin.rs b/tonic/src/transport/channel/service/add_origin.rs similarity index 100% rename from tonic/src/transport/service/add_origin.rs rename to tonic/src/transport/channel/service/add_origin.rs diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/channel/service/connection.rs similarity index 97% rename from tonic/src/transport/service/connection.rs rename to tonic/src/transport/channel/service/connection.rs index 46a88dda5..ece496489 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -1,4 +1,5 @@ -use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; +use super::{reconnect::Reconnect, AddOrigin, UserAgent}; +use crate::transport::service::GrpcTimeout; use crate::{ body::BoxBody, transport::{BoxFuture, Endpoint}, diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/channel/service/connector.rs similarity index 98% rename from tonic/src/transport/service/connector.rs rename to tonic/src/transport/channel/service/connector.rs index 8b7e7f63a..d5dd7bb46 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/channel/service/connector.rs @@ -1,13 +1,14 @@ -use super::super::BoxFuture; -use super::io::BoxedIo; -#[cfg(feature = "tls")] -use super::tls::TlsConnector; use http::Uri; use std::fmt; use std::task::{Context, Poll}; use tower::make::MakeConnection; use tower_service::Service; +use super::io::BoxedIo; +#[cfg(feature = "tls")] +use super::tls::TlsConnector; +use crate::transport::BoxFuture; + pub(crate) struct Connector { inner: C, #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/channel/service/discover.rs similarity index 100% rename from tonic/src/transport/service/discover.rs rename to tonic/src/transport/channel/service/discover.rs diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/channel/service/executor.rs similarity index 100% rename from tonic/src/transport/service/executor.rs rename to tonic/src/transport/channel/service/executor.rs diff --git a/tonic/src/transport/channel/service/io.rs b/tonic/src/transport/channel/service/io.rs new file mode 100644 index 000000000..f419388a6 --- /dev/null +++ b/tonic/src/transport/channel/service/io.rs @@ -0,0 +1,69 @@ +use std::io::{self, IoSlice}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use hyper::client::connect::{Connected as HyperConnected, Connection}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(in crate::transport) trait Io: + AsyncRead + AsyncWrite + Send + 'static +{ +} + +impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} + +pub(crate) struct BoxedIo(Pin>); + +impl BoxedIo { + pub(in crate::transport) fn new(io: I) -> Self { + BoxedIo(Box::pin(io)) + } +} + +impl Connection for BoxedIo { + fn connected(&self) -> HyperConnected { + HyperConnected::new() + } +} + +#[cfg(feature = "channel")] +impl AsyncRead for BoxedIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +#[cfg(feature = "channel")] +impl AsyncWrite for BoxedIo { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} diff --git a/tonic/src/transport/channel/service/mod.rs b/tonic/src/transport/channel/service/mod.rs new file mode 100644 index 000000000..bb40ae592 --- /dev/null +++ b/tonic/src/transport/channel/service/mod.rs @@ -0,0 +1,26 @@ +mod add_origin; +pub(crate) use self::add_origin::AddOrigin; + +mod connector; +pub(crate) use self::connector::Connector; + +mod connection; +pub(crate) use self::connection::Connection; + +mod discover; +pub(crate) use self::discover::DynamicServiceStream; + +pub(crate) mod executor; +pub(crate) use self::executor::{Executor, SharedExec}; + +pub(crate) mod io; + +mod reconnect; + +mod user_agent; +pub(crate) use self::user_agent::UserAgent; + +#[cfg(feature = "tls")] +mod tls; +#[cfg(feature = "tls")] +pub(crate) use self::tls::TlsConnector; diff --git a/tonic/src/transport/service/reconnect.rs b/tonic/src/transport/channel/service/reconnect.rs similarity index 100% rename from tonic/src/transport/service/reconnect.rs rename to tonic/src/transport/channel/service/reconnect.rs diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs new file mode 100644 index 000000000..1a2b108d0 --- /dev/null +++ b/tonic/src/transport/channel/service/tls.rs @@ -0,0 +1,100 @@ +use std::fmt; +use std::io::Cursor; +use std::sync::Arc; + +use rustls_pki_types::ServerName; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::rustls::RootCertStore; +use tokio_rustls::{rustls::ClientConfig, TlsConnector as RustlsConnector}; + +use super::io::BoxedIo; +use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2}; +use crate::transport::tls::{Certificate, Identity}; + +#[derive(Debug)] +enum TlsError { + H2NotNegotiated, +} + +impl fmt::Display for TlsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."), + } + } +} + +impl std::error::Error for TlsError {} + +#[derive(Clone)] +pub(crate) struct TlsConnector { + config: Arc, + domain: Arc>, + assume_http2: bool, +} + +impl TlsConnector { + pub(crate) fn new( + ca_cert: Option, + identity: Option, + domain: &str, + assume_http2: bool, + ) -> Result { + let builder = ClientConfig::builder(); + let mut roots = RootCertStore::empty(); + + #[cfg(feature = "tls-roots")] + roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); + + #[cfg(feature = "tls-webpki-roots")] + roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + if let Some(cert) = ca_cert { + add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + } + + let builder = builder.with_root_certificates(roots); + let mut config = match identity { + Some(identity) => { + let (client_cert, client_key) = load_identity(identity)?; + builder.with_client_auth_cert(client_cert, client_key)? + } + None => builder.with_no_client_auth(), + }; + + config.alpn_protocols.push(ALPN_H2.into()); + Ok(Self { + config: Arc::new(config), + domain: Arc::new(ServerName::try_from(domain)?.to_owned()), + assume_http2, + }) + } + + pub(crate) async fn connect(&self, io: I) -> Result + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let io = RustlsConnector::from(self.config.clone()) + .connect(self.domain.as_ref().to_owned(), io) + .await?; + + // Generally we require ALPN to be negotiated, but if the user has + // explicitly set `assume_http2` to true, we'll allow it to be missing. + let (_, session) = io.get_ref(); + let alpn_protocol = session.alpn_protocol(); + if alpn_protocol != Some(ALPN_H2) { + if alpn_protocol.is_some() || !self.assume_http2 { + return Err(TlsError::H2NotNegotiated.into()); + } + } + + Ok(BoxedIo::new(io)) + } +} + +#[cfg(feature = "channel")] +impl fmt::Debug for TlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConnector").finish() + } +} diff --git a/tonic/src/transport/service/user_agent.rs b/tonic/src/transport/channel/service/user_agent.rs similarity index 100% rename from tonic/src/transport/service/user_agent.rs rename to tonic/src/transport/channel/service/user_agent.rs diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index 346071fad..039630915 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -1,5 +1,5 @@ +use super::service::TlsConnector; use crate::transport::{ - service::TlsConnector, tls::{Certificate, Identity}, Error, }; diff --git a/tonic/src/transport/error.rs b/tonic/src/transport/error.rs index d2a1c7bb2..92a910498 100644 --- a/tonic/src/transport/error.rs +++ b/tonic/src/transport/error.rs @@ -15,7 +15,9 @@ struct ErrorImpl { #[derive(Debug)] pub(crate) enum Kind { Transport, + #[cfg(feature = "channel")] InvalidUri, + #[cfg(feature = "channel")] InvalidUserAgent, } @@ -35,10 +37,12 @@ impl Error { Error::new(Kind::Transport).with(source) } + #[cfg(feature = "channel")] pub(crate) fn new_invalid_uri() -> Self { Error::new(Kind::InvalidUri) } + #[cfg(feature = "channel")] pub(crate) fn new_invalid_user_agent() -> Self { Error::new(Kind::InvalidUserAgent) } @@ -46,7 +50,9 @@ impl Error { fn description(&self) -> &str { match &self.inner.kind { Kind::Transport => "transport error", + #[cfg(feature = "channel")] Kind::InvalidUri => "invalid URI", + #[cfg(feature = "channel")] Kind::InvalidUserAgent => "user agent is not a valid header value", } } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index a0435c797..5a41e9524 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -87,6 +87,7 @@ //! //! [rustls]: https://docs.rs/rustls/0.16.0/rustls/ +#[cfg(feature = "channel")] pub mod channel; pub mod server; @@ -110,10 +111,11 @@ pub use self::tls::Certificate; pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; pub use hyper::{Body, Uri}; -pub(crate) use self::service::executor::Executor; +#[cfg(feature = "channel")] +pub(crate) use self::channel::service::executor::Executor; -#[cfg(feature = "tls")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls")))] +#[cfg(all(feature = "channel", feature = "tls"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))] pub use self::channel::ClientTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] @@ -122,4 +124,5 @@ pub use self::server::ServerTlsConfig; #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Identity; +#[cfg(feature = "channel")] type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2230b9b2e..7821f691c 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,5 +1,4 @@ use crate::transport::server::Connected; -use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -8,78 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "tls")] use tokio_rustls::server::TlsStream; -pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static -{ -} - -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} - -pub(crate) struct BoxedIo(Pin>); - -impl BoxedIo { - pub(in crate::transport) fn new(io: I) -> Self { - BoxedIo(Box::pin(io)) - } -} - -impl Connection for BoxedIo { - fn connected(&self) -> HyperConnected { - HyperConnected::new() - } -} - -impl Connected for BoxedIo { - type ConnectInfo = NoneConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - NoneConnectInfo - } -} - -#[derive(Copy, Clone)] -pub(crate) struct NoneConnectInfo; - -impl AsyncRead for BoxedIo { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -impl AsyncWrite for BoxedIo { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } -} - pub(crate) enum ServerIo { Io(IO), #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 69d850f10..98129790f 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,26 +1,13 @@ -mod add_origin; -mod connection; -mod connector; -mod discover; -pub(crate) mod executor; pub(crate) mod grpc_timeout; -mod io; -mod reconnect; +pub(crate) mod io; mod router; #[cfg(feature = "tls")] -mod tls; -mod user_agent; +pub(super) mod tls; -pub(crate) use self::add_origin::AddOrigin; -pub(crate) use self::connection::Connection; -pub(crate) use self::connector::Connector; -pub(crate) use self::discover::DynamicServiceStream; -pub(crate) use self::executor::SharedExec; pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; #[cfg(feature = "tls")] -pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; -pub(crate) use self::user_agent::UserAgent; +pub(crate) use self::tls::TlsAcceptor; pub use self::router::Routes; pub use self::router::RoutesBuilder; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 96e1fe652..62f763e2e 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -1,102 +1,27 @@ -use std::{ - io::Cursor, - {fmt, sync::Arc}, -}; +use std::io::Cursor; +use std::{fmt, sync::Arc}; -use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::{ - rustls::{server::WebPkiClientVerifier, ClientConfig, RootCertStore, ServerConfig}, - TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector, + rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}, + TlsAcceptor as RustlsAcceptor, }; -use super::io::BoxedIo; use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; /// h2 alpn in plain format for rustls. -const ALPN_H2: &[u8] = b"h2"; +pub(crate) const ALPN_H2: &[u8] = b"h2"; #[derive(Debug)] enum TlsError { - H2NotNegotiated, CertificateParseError, PrivateKeyParseError, } -#[derive(Clone)] -pub(crate) struct TlsConnector { - config: Arc, - domain: Arc>, - assume_http2: bool, -} - -impl TlsConnector { - pub(crate) fn new( - ca_cert: Option, - identity: Option, - domain: &str, - assume_http2: bool, - ) -> Result { - let builder = ClientConfig::builder(); - let mut roots = RootCertStore::empty(); - - #[cfg(feature = "tls-roots")] - roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); - - #[cfg(feature = "tls-webpki-roots")] - roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - - if let Some(cert) = ca_cert { - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; - } - - let builder = builder.with_root_certificates(roots); - let mut config = match identity { - Some(identity) => { - let (client_cert, client_key) = load_identity(identity)?; - builder.with_client_auth_cert(client_cert, client_key)? - } - None => builder.with_no_client_auth(), - }; - - config.alpn_protocols.push(ALPN_H2.into()); - Ok(Self { - config: Arc::new(config), - domain: Arc::new(ServerName::try_from(domain)?.to_owned()), - assume_http2, - }) - } - - pub(crate) async fn connect(&self, io: I) -> Result - where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - let io = RustlsConnector::from(self.config.clone()) - .connect(self.domain.as_ref().to_owned(), io) - .await?; - - // Generally we require ALPN to be negotiated, but if the user has - // explicitly set `assume_http2` to true, we'll allow it to be missing. - let (_, session) = io.get_ref(); - let alpn_protocol = session.alpn_protocol(); - if alpn_protocol != Some(ALPN_H2) { - if alpn_protocol.is_some() || !self.assume_http2 { - return Err(TlsError::H2NotNegotiated.into()); - } - } - Ok(BoxedIo::new(io)) - } -} - -impl fmt::Debug for TlsConnector { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsConnector").finish() - } -} - #[derive(Clone)] pub(crate) struct TlsAcceptor { inner: Arc, @@ -152,7 +77,6 @@ impl fmt::Debug for TlsAcceptor { impl fmt::Display for TlsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."), TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."), TlsError::PrivateKeyParseError => write!( f, @@ -164,21 +88,21 @@ impl fmt::Display for TlsError { impl std::error::Error for TlsError {} -fn load_identity( +pub(crate) fn load_identity( identity: Identity, -) -> Result<(Vec>, PrivateKeyDer<'static>), TlsError> { +) -> Result<(Vec>, PrivateKeyDer<'static>), crate::Error> { let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert)) .collect::, _>>() .map_err(|_| TlsError::CertificateParseError)?; let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else { - return Err(TlsError::PrivateKeyParseError); + return Err(TlsError::PrivateKeyParseError.into()); }; Ok((cert, key)) } -fn add_certs_from_pem( +pub(crate) fn add_certs_from_pem( mut certs: &mut dyn std::io::BufRead, roots: &mut RootCertStore, ) -> Result<(), crate::Error> {