Skip to content

Commit

Permalink
feat(tls): Add tls handshake timeout support
Browse files Browse the repository at this point in the history
  • Loading branch information
honsunrise committed Jan 13, 2025
1 parent 5ad89bf commit 3ad4f97
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
17 changes: 13 additions & 4 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::fmt;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time;
use tokio_rustls::{
rustls::{
crypto,
Expand All @@ -23,6 +24,7 @@ pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
assume_http2: bool,
timeout: Option<Duration>,
}

impl TlsConnector {
Expand All @@ -34,6 +36,7 @@ impl TlsConnector {
assume_http2: bool,
#[cfg(feature = "tls-native-roots")] with_native_roots: bool,
#[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool,
timeout: Option<Duration>,
) -> Result<Self, crate::BoxError> {
fn with_provider(
provider: Arc<crypto::CryptoProvider>,
Expand Down Expand Up @@ -92,16 +95,22 @@ impl TlsConnector {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
assume_http2,
timeout,
})
}

pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::BoxError>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;
let conn_fut =
RustlsConnector::from(self.config.clone()).connect(self.domain.as_ref().to_owned(), io);
let io = match self.timeout {
Some(timeout) => time::timeout(timeout, conn_fut)
.await
.map_err(|_| TlsError::HandshakeTimeout)?,
None => conn_fut.await,
}?;

// Generally we require ALPN to be negotiated, but if the user has
// explicitly set `assume_http2` to true, we'll allow it to be missing.
Expand Down
11 changes: 11 additions & 0 deletions tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::transport::{
Error,
};
use http::Uri;
use std::time::Duration;
use tokio_rustls::rustls::pki_types::TrustAnchor;

/// Configures TLS settings for endpoints.
Expand All @@ -18,6 +19,7 @@ pub struct ClientTlsConfig {
with_native_roots: bool,
#[cfg(feature = "tls-webpki-roots")]
with_webpki_roots: bool,
timeout: Option<Duration>,
}

impl ClientTlsConfig {
Expand Down Expand Up @@ -112,6 +114,14 @@ impl ClientTlsConfig {
config
}

/// Sets the timeout for the TLS handshake.
pub fn timeout(self, timeout: Duration) -> Self {
ClientTlsConfig {
timeout: Some(timeout),
..self
}
}

pub(crate) fn into_tls_connector(self, uri: &Uri) -> Result<TlsConnector, crate::BoxError> {
let domain = match &self.domain {
Some(domain) => domain,
Expand All @@ -127,6 +137,7 @@ impl ClientTlsConfig {
self.with_native_roots,
#[cfg(feature = "tls-webpki-roots")]
self.with_webpki_roots,
self.timeout,
)
}
}
19 changes: 16 additions & 3 deletions tonic/src/transport/server/service/tls.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::{fmt, sync::Arc};
use std::{fmt, sync::Arc, time::Duration};

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time;
use tokio_rustls::{
rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig},
server::TlsStream,
TlsAcceptor as RustlsAcceptor,
};

use crate::transport::{
service::tls::{convert_certificate_to_pki_types, convert_identity_to_pki_types, ALPN_H2},
service::tls::{
convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2,
},
Certificate, Identity,
};

#[derive(Clone)]
pub(crate) struct TlsAcceptor {
inner: Arc<ServerConfig>,
timeout: Option<Duration>,
}

impl TlsAcceptor {
Expand All @@ -23,6 +27,7 @@ impl TlsAcceptor {
client_ca_root: Option<&Certificate>,
client_auth_optional: bool,
ignore_client_order: bool,
timeout: Option<Duration>,
) -> Result<Self, crate::BoxError> {
let builder = ServerConfig::builder();

Expand All @@ -48,6 +53,7 @@ impl TlsAcceptor {
config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
inner: Arc::new(config),
timeout,
})
}

Expand All @@ -56,7 +62,14 @@ impl TlsAcceptor {
IO: AsyncRead + AsyncWrite + Unpin,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
acceptor.accept(io).await.map_err(Into::into)
let accept_fut = acceptor.accept(io);
match self.timeout {
Some(timeout) => time::timeout(timeout, accept_fut)
.await
.map_err(|_| TlsError::HandshakeTimeout)?,
None => accept_fut.await,
}
.map_err(Into::into)
}
}

Expand Down
15 changes: 14 additions & 1 deletion tonic/src/transport/server/tls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt;
use std::{fmt, time::Duration};

use super::service::TlsAcceptor;
use crate::transport::tls::{Certificate, Identity};
Expand All @@ -10,6 +10,7 @@ pub struct ServerTlsConfig {
client_ca_root: Option<Certificate>,
client_auth_optional: bool,
ignore_client_order: bool,
timeout: Option<Duration>,
}

impl fmt::Debug for ServerTlsConfig {
Expand Down Expand Up @@ -64,12 +65,24 @@ impl ServerTlsConfig {
}
}

/// Sets the timeout for the TLS handshake.
///
/// # Default
/// By default, this option is set to `None`.
pub fn timeout(self, timeout: Duration) -> Self {
ServerTlsConfig {
timeout: Some(timeout),
..self
}
}

pub(crate) fn tls_acceptor(&self) -> Result<TlsAcceptor, crate::BoxError> {
TlsAcceptor::new(
self.identity.as_ref().unwrap(),
self.client_ca_root.as_ref(),
self.client_auth_optional,
self.ignore_client_order,
self.timeout,
)
}
}
2 changes: 2 additions & 0 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub(crate) enum TlsError {
NativeCertsNotFound,
CertificateParseError,
PrivateKeyParseError,
HandshakeTimeout,
}

impl fmt::Display for TlsError {
Expand All @@ -29,6 +30,7 @@ impl fmt::Display for TlsError {
f,
"Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
),
TlsError::HandshakeTimeout => write!(f, "TLS handshake timeout."),
}
}
}
Expand Down

0 comments on commit 3ad4f97

Please sign in to comment.