Skip to content

Commit

Permalink
feat(channel): Make channel feature additive
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Apr 27, 2024
1 parent 068421a commit faec3d5
Show file tree
Hide file tree
Showing 20 changed files with 255 additions and 204 deletions.
4 changes: 2 additions & 2 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ hyper-warp-multiplex = ["hyper-warp"]
uds = ["tokio-stream/net", "dep:tower", "dep:hyper"]
streaming = ["tokio-stream", "dep:h2"]
mock = ["tokio-stream", "dep:tower"]
tower = ["dep:hyper", "dep:tower", "dep:http"]
tower = ["dep:hyper", "tower/timeout", "dep:http"]
json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"]
compression = ["tonic/gzip"]
tls = ["tonic/tls"]
tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"]
dynamic-load-balance = ["dep:tower"]
timeout = ["tokio/time", "dep:tower"]
timeout = ["tokio/time", "tower/timeout"]
tls-client-auth = ["tonic/tls"]
types = ["dep:tonic-types"]
h2c = ["dep:hyper", "dep:tower", "dep:http"]
Expand Down
23 changes: 14 additions & 9 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,26 @@ version = "0.11.0"
codegen = ["dep:async-trait"]
gzip = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
default = ["channel", "codegen", "prost"]
prost = ["dep:prost"]
tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
tls-roots-common = ["tls"]
tls-roots-common = ["tls", "channel"]
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
transport = [
"dep:async-stream",
"dep:axum",
"channel",
"dep:h2",
"dep:hyper",
"dep:hyper", "hyper?/server",
"dep:tokio", "tokio?/net", "tokio?/time",
"dep:tower",
"dep:tower", "tower?/util", "tower?/limit",
]
channel = [
"transport",
"dep:hyper", "hyper?/client",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
"dep:hyper-timeout",
]
channel = []

# [[bench]]
# name = "bench_main"
Expand All @@ -68,13 +71,15 @@ async-trait = {version = "0.1.13", optional = true}

# transport
h2 = {version = "0.3.24", optional = true}
hyper = {version = "0.14.26", features = ["full"], optional = true}
hyper-timeout = {version = "0.4", optional = true}
hyper = {version = "0.14.26", features = ["http1", "http2", "runtime", "stream"], optional = true}
tokio = {version = "1.0.1", optional = true}
tokio-stream = "0.1"
tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true}
tower = {version = "0.4.7", default-features = false, optional = true}
axum = {version = "0.6.9", default_features = false, optional = true}

# channel
hyper-timeout = {version = "0.4", optional = true}

# rustls
async-stream = { version = "0.3", optional = true }
rustls-pki-types = { version = "1.0", optional = true }
Expand Down
17 changes: 9 additions & 8 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::super::service;
use super::service::Connector;
#[cfg(feature = "tls")]
use super::service::TlsConnector;
use super::service::{Executor, SharedExec};
use super::Channel;
#[cfg(feature = "tls")]
use super::ClientTlsConfig;
#[cfg(feature = "tls")]
use crate::transport::service::TlsConnector;
use crate::transport::{service::SharedExec, Error, Executor};
use crate::transport::Error;
use bytes::Bytes;
use http::{uri::Uri, HeaderValue};
use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration};
Expand Down Expand Up @@ -318,15 +319,15 @@ impl Endpoint {
self
}

pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
pub(crate) fn connector<C>(&self, c: C) -> Connector<C> {
#[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
let connector = service::Connector::new(c, self.tls.clone());
let connector = Connector::new(c, self.tls.clone());

#[cfg(all(feature = "tls", feature = "tls-roots-common"))]
let connector = service::Connector::new(c, self.tls.clone(), self.tls_assume_http2);
let connector = Connector::new(c, self.tls.clone(), self.tls_assume_http2);

#[cfg(not(feature = "tls"))]
let connector = service::Connector::new(c);
let connector = Connector::new(c);

connector
}
Expand Down
3 changes: 2 additions & 1 deletion tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Client implementation and builder.
mod endpoint;
pub(crate) mod service;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -9,7 +10,7 @@ pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;

use super::service::{Connection, DynamicServiceStream, SharedExec};
use self::service::{Connection, DynamicServiceStream, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use super::{reconnect::Reconnect, AddOrigin, UserAgent};
use crate::transport::service::GrpcTimeout;
use crate::{
body::BoxBody,
transport::{BoxFuture, Endpoint},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use super::super::BoxFuture;
use super::io::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use http::Uri;
use std::fmt;
use std::task::{Context, Poll};
use tower::make::MakeConnection;
use tower_service::Service;

use super::io::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use crate::transport::BoxFuture;

pub(crate) struct Connector<C> {
inner: C,
#[cfg(feature = "tls")]
Expand Down
File renamed without changes.
File renamed without changes.
69 changes: 69 additions & 0 deletions tonic/src/transport/channel/service/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};

use hyper::client::connect::{Connected as HyperConnected, Connection};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pub(in crate::transport) trait Io:
AsyncRead + AsyncWrite + Send + 'static
{
}

impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}

pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

impl BoxedIo {
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
BoxedIo(Box::pin(io))
}
}

impl Connection for BoxedIo {
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

#[cfg(feature = "channel")]
impl AsyncRead for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

#[cfg(feature = "channel")]
impl AsyncWrite for BoxedIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.0.is_write_vectored()
}
}
26 changes: 26 additions & 0 deletions tonic/src/transport/channel/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
mod add_origin;
pub(crate) use self::add_origin::AddOrigin;

mod connector;
pub(crate) use self::connector::Connector;

mod connection;
pub(crate) use self::connection::Connection;

mod discover;
pub(crate) use self::discover::DynamicServiceStream;

pub(crate) mod executor;
pub(crate) use self::executor::{Executor, SharedExec};

pub(crate) mod io;

mod reconnect;

mod user_agent;
pub(crate) use self::user_agent::UserAgent;

#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
pub(crate) use self::tls::TlsConnector;
File renamed without changes.
100 changes: 100 additions & 0 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use std::fmt;
use std::io::Cursor;
use std::sync::Arc;

use rustls_pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::{rustls::ClientConfig, TlsConnector as RustlsConnector};

use super::io::BoxedIo;
use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2};
use crate::transport::tls::{Certificate, Identity};

#[derive(Debug)]
enum TlsError {
H2NotNegotiated,
}

impl fmt::Display for TlsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
}
}
}

impl std::error::Error for TlsError {}

#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
assume_http2: bool,
}

impl TlsConnector {
pub(crate) fn new(
ca_cert: Option<Certificate>,
identity: Option<Identity>,
domain: &str,
assume_http2: bool,
) -> Result<Self, crate::Error> {
let builder = ClientConfig::builder();
let mut roots = RootCertStore::empty();

#[cfg(feature = "tls-roots")]
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);

#[cfg(feature = "tls-webpki-roots")]
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());

if let Some(cert) = ca_cert {
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
}

let builder = builder.with_root_certificates(roots);
let mut config = match identity {
Some(identity) => {
let (client_cert, client_key) = load_identity(identity)?;
builder.with_client_auth_cert(client_cert, client_key)?
}
None => builder.with_no_client_auth(),
};

config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
assume_http2,
})
}

pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;

// Generally we require ALPN to be negotiated, but if the user has
// explicitly set `assume_http2` to true, we'll allow it to be missing.
let (_, session) = io.get_ref();
let alpn_protocol = session.alpn_protocol();
if alpn_protocol != Some(ALPN_H2) {
if alpn_protocol.is_some() || !self.assume_http2 {
return Err(TlsError::H2NotNegotiated.into());
}
}

Ok(BoxedIo::new(io))
}
}

#[cfg(feature = "channel")]
impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
}
}
File renamed without changes.
2 changes: 1 addition & 1 deletion tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::service::TlsConnector;
use crate::transport::{
service::TlsConnector,
tls::{Certificate, Identity},
Error,
};
Expand Down
6 changes: 6 additions & 0 deletions tonic/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ struct ErrorImpl {
#[derive(Debug)]
pub(crate) enum Kind {
Transport,
#[cfg(feature = "channel")]
InvalidUri,
#[cfg(feature = "channel")]
InvalidUserAgent,
}

Expand All @@ -35,18 +37,22 @@ impl Error {
Error::new(Kind::Transport).with(source)
}

#[cfg(feature = "channel")]
pub(crate) fn new_invalid_uri() -> Self {
Error::new(Kind::InvalidUri)
}

#[cfg(feature = "channel")]
pub(crate) fn new_invalid_user_agent() -> Self {
Error::new(Kind::InvalidUserAgent)
}

fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
#[cfg(feature = "channel")]
Kind::InvalidUri => "invalid URI",
#[cfg(feature = "channel")]
Kind::InvalidUserAgent => "user agent is not a valid header value",
}
}
Expand Down
Loading

0 comments on commit faec3d5

Please sign in to comment.