From f66b5515910541dbfbcc86b564c5c1890c7c1ebf Mon Sep 17 00:00:00 2001 From: Honsun Zhu Date: Thu, 28 Nov 2024 11:23:05 +0800 Subject: [PATCH] feat(tls): Add tls handshake timeout support --- tonic/src/transport/channel/service/tls.rs | 17 +++++++++++++---- tonic/src/transport/channel/tls.rs | 11 +++++++++++ tonic/src/transport/service/tls.rs | 2 ++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs index 5dd227f81..cce5bd1fd 100644 --- a/tonic/src/transport/channel/service/tls.rs +++ b/tonic/src/transport/channel/service/tls.rs @@ -1,8 +1,9 @@ use std::fmt; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time; use tokio_rustls::{ rustls::{ crypto, @@ -23,6 +24,7 @@ pub(crate) struct TlsConnector { config: Arc, domain: Arc>, assume_http2: bool, + timeout: Option, } impl TlsConnector { @@ -34,6 +36,7 @@ impl TlsConnector { assume_http2: bool, #[cfg(feature = "tls-native-roots")] with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, + timeout: Option, ) -> Result { fn with_provider( provider: Arc, @@ -92,6 +95,7 @@ impl TlsConnector { config: Arc::new(config), domain: Arc::new(ServerName::try_from(domain)?.to_owned()), assume_http2, + timeout, }) } @@ -99,9 +103,14 @@ impl TlsConnector { where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let io = RustlsConnector::from(self.config.clone()) - .connect(self.domain.as_ref().to_owned(), io) - .await?; + let conn_fut = + RustlsConnector::from(self.config.clone()).connect(self.domain.as_ref().to_owned(), io); + let io = match self.timeout { + Some(timeout) => time::timeout(timeout, conn_fut) + .await + .map_err(|_| TlsError::HandshakeTimeout)??, + None => conn_fut.await?, + }; // Generally we require ALPN to be negotiated, but if the user has // explicitly set `assume_http2` to true, we'll allow it to be missing. diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index 0c2eb37e0..2b893d50b 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -4,6 +4,7 @@ use crate::transport::{ Error, }; use http::Uri; +use std::time::Duration; use tokio_rustls::rustls::pki_types::TrustAnchor; /// Configures TLS settings for endpoints. @@ -18,6 +19,7 @@ pub struct ClientTlsConfig { with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, + timeout: Option, } impl ClientTlsConfig { @@ -112,6 +114,14 @@ impl ClientTlsConfig { config } + /// Sets the timeout for the TLS handshake. + pub fn timeout(self, timeout: Duration) -> Self { + ClientTlsConfig { + timeout: Some(timeout), + ..self + } + } + pub(crate) fn into_tls_connector(self, uri: &Uri) -> Result { let domain = match &self.domain { Some(domain) => domain, @@ -127,6 +137,7 @@ impl ClientTlsConfig { self.with_native_roots, #[cfg(feature = "tls-webpki-roots")] self.with_webpki_roots, + self.timeout, ) } } diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 8cb30c73c..0d6e9bc87 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -15,6 +15,7 @@ pub(crate) enum TlsError { NativeCertsNotFound, CertificateParseError, PrivateKeyParseError, + HandshakeTimeout, } impl fmt::Display for TlsError { @@ -29,6 +30,7 @@ impl fmt::Display for TlsError { f, "Error parsing TLS private key - no RSA or PKCS8-encoded keys found." ), + TlsError::HandshakeTimeout => write!(f, "TLS handshake timeout."), } } }