Skip to content

Commit

Permalink
feat(tls): AWS Libcrypto Support (#2008)
Browse files Browse the repository at this point in the history
* add tls crypto provider feature

* fix features

* fix

* fix

* add tls-aws-lc to conditional compilation

* fix conditional compilation

* fix conditional compilation

* revert formatting

* revert formatting

* add tls-any and deprecate tls

* formatting

* revert formatting

* revert formatting

* clean up #[cfg(..)]

* tests pass

* update workflow for new features

* internal feature flag

* revert formatting

* update docs

* specify rustls version in tests

* tls only depends on tls-ring

* update CI + deps

* minor change for force push

* fmt

* fix docs

* fix ring docs link

* Update Cargo.toml

Co-authored-by: Lucio Franco <[email protected]>

---------

Co-authored-by: Lucio Franco <[email protected]>
  • Loading branch information
jenr24-architect and LucioFranco authored Nov 6, 2024
1 parent 8def0bb commit b28b59b
Show file tree
Hide file tree
Showing 17 changed files with 97 additions and 83 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ jobs:
- uses: taiki-e/install-action@cargo-udeps
- uses: taiki-e/install-action@protoc
- uses: Swatinem/rust-cache@v2
- run: cargo hack udeps --workspace --exclude-features tls --each-feature
- run: cargo udeps --package tonic --features tls,transport
- run: cargo udeps --package tonic --features tls,server
- run: cargo udeps --package tonic --features tls,channel
- run: cargo hack udeps --workspace --exclude-features=_tls-any,tls,tls-aws-lc,tls-ring --each-feature
- run: cargo udeps --package tonic --features tls-ring,transport
- run: cargo udeps --package tonic --features tls-ring,server
- run: cargo udeps --package tonic --features tls-ring,channel
- run: cargo udeps --package tonic --features tls-aws-lc,transport
- run: cargo udeps --package tonic --features tls-aws-lc,server
- run: cargo udeps --package tonic --features tls-aws-lc,channel

check:
runs-on: ${{ matrix.os }}
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ async-stream = "0.3"
http = "1"
http-body = "1"
hyper-util = "0.1"
rustls = {version = "0.23", features = ["ring"]}
tokio-stream = {version = "0.1.5", features = ["net"]}
tower = "0.5"
tower-http = { version = "0.6", features = ["set-header", "trace"] }
Expand Down
3 changes: 3 additions & 0 deletions tests/integration_tests/tests/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ async fn connect_returns_err() {

#[tokio::test]
async fn connect_handles_tls() {
rustls::crypto::ring::default_provider()
.install_default()
.unwrap();
TestClient::connect("https://example.com").await.unwrap();
}

Expand Down
11 changes: 7 additions & 4 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ gzip = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
prost = ["dep:prost"]
tls = ["dep:rustls-pemfile", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
_tls-any = ["dep:rustls-pemfile", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"] # Internal. Please choose one of `tls-ring` or `tls-aws-lc`
tls = ["tls-ring"] # Deprecated. Please use `tls-ring` or `tls-aws-lc` instead.
tls-ring = ["_tls-any", "tokio-rustls/ring"]
tls-aws-lc = ["_tls-any", "tokio-rustls/aws-lc-rs"]
tls-roots = ["tls-native-roots"] # Deprecated. Please use `tls-native-roots` instead.
tls-native-roots = ["tls", "channel", "dep:rustls-native-certs"]
tls-webpki-roots = ["tls", "channel", "dep:webpki-roots"]
tls-native-roots = ["_tls-any", "channel", "dep:rustls-native-certs"]
tls-webpki-roots = ["_tls-any","channel", "dep:webpki-roots"]
router = ["dep:axum", "dep:tower", "tower?/util"]
server = [
"router",
Expand Down Expand Up @@ -90,7 +93,7 @@ axum = {version = "0.7", default-features = false, optional = true}
# rustls
rustls-pemfile = { version = "2.0", optional = true }
rustls-native-certs = { version = "0.8", optional = true }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"], optional = true }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12"], optional = true }
webpki-roots = { version = "0.26", optional = true }

# compression
Expand Down
8 changes: 6 additions & 2 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
//! - `router`: Enables the [`axum`] based service router. Enabled by default.
//! - `codegen`: Enables all the required exports and optional dependencies required
//! for [`tonic-build`]. Enabled by default.
//! - `tls`: Enables the [`rustls`] based TLS options for the `transport` feature. Not
//! enabled by default.
//! - `tls`: Deprecated. An alias to `tls-ring`
//! - `tls-ring`: Enables the [`rustls`] based TLS options for the `transport` feature using
//! the [`ring`]` libcrypto provider. Not enabled by default.
//! - `tls-aws-lc`: Enables the [`rustls`] based TLS options for the `transport` feature using
//! the [`aws-lc-rs`] libcrypto provider. Not enabled by default.
//! - `tls-roots`: Deprecated. An alias to `tls-native-roots` feature.
//! - `tls-native-roots`: Adds system trust roots to [`rustls`]-based gRPC clients using the
//! [`rustls-native-certs`] crate. Not enabled by default.
Expand Down Expand Up @@ -71,6 +74,7 @@
//! [`hyper`]: https://docs.rs/hyper
//! [`tower`]: https://docs.rs/tower
//! [`tonic-build`]: https://docs.rs/tonic-build
//! [`ring`]: https://docs.rs/ring
//! [`tonic-examples`]: https://github.com/hyperium/tonic/tree/master/examples
//! [`Codec`]: codec/trait.Codec.html
//! [`Channel`]: transport/struct.Channel.html
Expand Down
12 changes: 6 additions & 6 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(feature = "server")]
use crate::transport::server::TcpConnectInfo;
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg(all(feature = "server", feature = "_tls-any"))]
use crate::transport::server::TlsConnectInfo;
use http::Extensions;
#[cfg(feature = "server")]
use std::net::SocketAddr;
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg(all(feature = "server", feature = "_tls-any"))]
use std::sync::Arc;
use std::time::Duration;
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg(all(feature = "server", feature = "_tls-any"))]
use tokio_rustls::rustls::pki_types::CertificateDer;
use tokio_stream::Stream;

Expand Down Expand Up @@ -218,7 +218,7 @@ impl<T> Request<T> {
.get::<TcpConnectInfo>()
.and_then(|i| i.local_addr());

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
let addr = addr.or_else(|| {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
Expand All @@ -240,7 +240,7 @@ impl<T> Request<T> {
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr());

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
let addr = addr.or_else(|| {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
Expand All @@ -256,7 +256,7 @@ impl<T> Request<T> {
/// and is mostly used for mTLS. This currently only returns
/// `Some` on the server side of the `transport` server with
/// TLS enabled connections.
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg(all(feature = "server", feature = "_tls-any"))]
pub fn peer_certs(&self) -> Option<Arc<Vec<CertificateDer<'static>>>> {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
Expand Down
14 changes: 7 additions & 7 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use super::service::TlsConnector;
use super::service::{self, Executor, SharedExec};
use super::Channel;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use super::ClientTlsConfig;
use crate::transport::Error;
use bytes::Bytes;
Expand All @@ -23,7 +23,7 @@ pub struct Endpoint {
pub(crate) timeout: Option<Duration>,
pub(crate) concurrency_limit: Option<usize>,
pub(crate) rate_limit: Option<(u64, Duration)>,
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub(crate) tls: Option<TlsConnector>,
pub(crate) buffer_size: Option<usize>,
pub(crate) init_stream_window_size: Option<u32>,
Expand All @@ -49,7 +49,7 @@ impl Endpoint {
D::Error: Into<crate::BoxError>,
{
let me = dst.try_into().map_err(|e| Error::from_source(e.into()))?;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
if me.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
return me.tls_config(ClientTlsConfig::new().with_enabled_roots());
}
Expand Down Expand Up @@ -244,7 +244,7 @@ impl Endpoint {
}

/// Configures TLS for the endpoint.
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, Error> {
Ok(Endpoint {
tls: Some(
Expand Down Expand Up @@ -320,7 +320,7 @@ impl Endpoint {
pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
service::Connector::new(
c,
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
self.tls.clone(),
)
}
Expand Down Expand Up @@ -445,7 +445,7 @@ impl From<Uri> for Endpoint {
concurrency_limit: None,
rate_limit: None,
timeout: None,
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
tls: None,
buffer_size: None,
init_stream_window_size: None,
Expand Down
4 changes: 2 additions & 2 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
mod endpoint;
pub(crate) mod service;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
mod tls;

pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub use tls::ClientTlsConfig;

use self::service::{Connection, DynamicServiceStream, Executor, SharedExec};
Expand Down
24 changes: 12 additions & 12 deletions tonic/src/transport/channel/service/connector.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
use super::BoxedIo;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use super::TlsConnector;
use crate::transport::channel::BoxFuture;
use crate::ConnectError;
use http::Uri;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use std::fmt;
use std::task::{Context, Poll};

use hyper::rt;

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use hyper_util::rt::TokioIo;
use tower_service::Service;

pub(crate) struct Connector<C> {
inner: C,
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
tls: Option<TlsConnector>,
}

impl<C> Connector<C> {
pub(crate) fn new(inner: C, #[cfg(feature = "tls")] tls: Option<TlsConnector>) -> Self {
pub(crate) fn new(inner: C, #[cfg(feature = "_tls-any")] tls: Option<TlsConnector>) -> Self {
Self {
inner,
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
tls,
}
}
Expand All @@ -48,18 +48,18 @@ where
}

fn call(&mut self, uri: Uri) -> Self::Future {
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
let tls = self.tls.clone();

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
let is_https = uri.scheme_str() == Some("https");
let connect = self.inner.call(uri);

Box::pin(async move {
async {
let io = connect.await?;

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
if is_https {
return if let Some(tls) = tls {
let io = tls.connect(TokioIo::new(io)).await?;
Expand All @@ -78,17 +78,17 @@ where
}

/// Error returned when trying to connect to an HTTPS endpoint without TLS enabled.
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
#[derive(Debug)]
pub(crate) struct HttpsUriWithoutTlsSupport(());

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
impl fmt::Display for HttpsUriWithoutTlsSupport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Connecting to HTTPS without TLS enabled")
}
}

// std::error::Error only requires a type to impl Debug and Display
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
impl std::error::Error for HttpsUriWithoutTlsSupport {}
4 changes: 2 additions & 2 deletions tonic/src/transport/channel/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub(crate) use self::connector::Connector;
mod executor;
pub(super) use self::executor::{Executor, SharedExec};

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
mod tls;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub(super) use self::tls::TlsConnector;
12 changes: 6 additions & 6 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub mod server;

mod error;
mod service;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
mod tls;

#[doc(inline)]
Expand All @@ -109,15 +109,15 @@ pub use self::server::Server;
/// Deprecated. Please use [`crate::status::TimeoutExpired`] instead.
pub use crate::status::TimeoutExpired;

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub use self::tls::Certificate;
pub use hyper::{body::Body, Uri};
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub use tokio_rustls::rustls::pki_types::CertificateDer;

#[cfg(all(feature = "channel", feature = "tls"))]
#[cfg(all(feature = "channel", feature = "_tls-any"))]
pub use self::channel::ClientTlsConfig;
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg(all(feature = "server", feature = "_tls-any"))]
pub use self::server::ServerTlsConfig;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
pub use self::tls::Identity;
12 changes: 6 additions & 6 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::net::SocketAddr;
use tokio::net::TcpStream;

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use std::sync::Arc;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use tokio_rustls::rustls::pki_types::CertificateDer;
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
use tokio_rustls::server::TlsStream;

/// Trait that connected IO resources implement and use to produce info about the connection.
Expand Down Expand Up @@ -102,7 +102,7 @@ impl Connected for tokio::io::DuplexStream {
fn connect_info(&self) -> Self::ConnectInfo {}
}

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
impl<T> Connected for TlsStream<T>
where
T: Connected,
Expand All @@ -128,14 +128,14 @@ where
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
#[derive(Debug, Clone)]
pub struct TlsConnectInfo<T> {
inner: T,
certs: Option<Arc<Vec<CertificateDer<'static>>>>,
}

#[cfg(feature = "tls")]
#[cfg(feature = "_tls-any")]
impl<T> TlsConnectInfo<T> {
/// Get a reference to the underlying connection info.
pub fn get_ref(&self) -> &T {
Expand Down
Loading

0 comments on commit b28b59b

Please sign in to comment.