From e34471a49604b64af986b0f91f627c05e02ecc5e Mon Sep 17 00:00:00 2001 From: Benno Zeeman Date: Fri, 13 Jan 2023 12:08:41 +0100 Subject: [PATCH] feat: implement builder for Endpoint --- Cargo.toml | 1 + src/endpoint.rs | 12 ++- src/endpoint_builder.rs | 201 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 4 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 src/endpoint_builder.rs diff --git a/Cargo.toml b/Cargo.toml index c6cc462b..6e5cb2f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ tokio = { version = "1.12.0", features = ["sync"] } tracing = "~0.1.26" rustls = { version = "0.20.2", default-features = false, features = ["quic", "dangerous_configuration"] } structopt = {version = "0.3.25", optional = true} +webpki = "0.22.0" [dev-dependencies] color-eyre = "0.5.11" diff --git a/src/endpoint.rs b/src/endpoint.rs index 086c05ee..4c61d9cc 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -8,6 +8,7 @@ // Software. use crate::connection::ConnectionIncoming; +use crate::EndpointBuilder; use super::wire_msg::WireMsg; use super::{ @@ -47,9 +48,9 @@ impl IncomingConnections { /// Endpoint instance which can be used to communicate with peers. #[derive(Clone)] pub struct Endpoint { - inner: quinn::Endpoint, - local_addr: SocketAddr, - public_addr: Option, + pub(crate) inner: quinn::Endpoint, + pub(crate) local_addr: SocketAddr, + pub(crate) public_addr: Option, } impl std::fmt::Debug for Endpoint { @@ -392,6 +393,11 @@ impl Endpoint { }), } } + + /// Builder to create an `Endpoint`. + pub fn builder() -> EndpointBuilder { + EndpointBuilder::default() + } } pub(super) fn listen_for_incoming_connections( diff --git a/src/endpoint_builder.rs b/src/endpoint_builder.rs new file mode 100644 index 00000000..e1562e71 --- /dev/null +++ b/src/endpoint_builder.rs @@ -0,0 +1,201 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use quinn::{IdleTimeout, TransportConfig, VarInt}; +use tokio::sync::mpsc; + +use crate::{endpoint::listen_for_incoming_connections, Endpoint, IncomingConnections}; + +/// Standard size of our channel bounds +const STANDARD_CHANNEL_SIZE: usize = 10_000; + +/// Build a [`crate::Endpoint`] +#[allow(missing_debug_implementations)] +pub struct EndpointBuilder { + addr: SocketAddr, + max_idle_timeout: Option, + max_concurrent_bidi_streams: VarInt, + max_concurrent_uni_streams: VarInt, + keep_alive_interval: Option, +} + +impl Default for EndpointBuilder { + fn default() -> Self { + Self { + addr: SocketAddr::from(([0, 0, 0, 0], 0)), + max_idle_timeout: Some(IdleTimeout::from(VarInt::from_u32(10_000))), // 10s + max_concurrent_bidi_streams: 100u32.into(), + max_concurrent_uni_streams: 100u32.into(), + keep_alive_interval: None, + } + } +} + +impl EndpointBuilder { + /// Instantiate a builder with default parameters. + /// See source of [`Self::default`] for default values. + pub fn new() -> Self { + Self::default() + } + + /// Local address passed to [`quinn::Endpoint::client`]. + pub fn addr(&mut self, addr: SocketAddr) -> &mut Self { + self.addr = addr; + self + } + + /// Takes time in milliseconds. + /// + /// Maps to [`quinn::TransportConfig::max_idle_timeout`]. + pub fn max_idle_timeout(&mut self, to: Option) -> &mut Self { + self.max_idle_timeout = to.map(|v| IdleTimeout::from(VarInt::from_u32(v))); + self + } + + /// Takes time in milliseconds. + /// + /// Maps to [`quinn::TransportConfig::max_concurrent_bidi_streams`]. + pub fn max_concurrent_bidi_streams(&mut self, max: u32) -> &mut Self { + self.max_concurrent_bidi_streams = VarInt::from_u32(max); + self + } + + /// Takes time in milliseconds. + /// + /// Maps to [`quinn::TransportConfig::max_concurrent_uni_streams`]. + pub fn max_concurrent_uni_streams(&mut self, max: u32) -> &mut Self { + self.max_concurrent_uni_streams = VarInt::from_u32(max); + self + } + + /// Maps to [`quinn::TransportConfig::keep_alive_interval`]. + pub fn keep_alive_interval(&mut self, interval: Option) -> &mut Self { + self.keep_alive_interval = interval; + self + } + + /// Instantiate a server (peer) [`Endpoint`] using the parameters passed to this builder. + pub async fn server(self) -> Result<(Endpoint, IncomingConnections), EndpointError> { + let (cfg_srv, cfg_cli) = self.config()?; + + let mut endpoint = quinn::Endpoint::server(cfg_srv, self.addr)?; + endpoint.set_default_client_config(cfg_cli); + + let (connection_tx, connection_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE); + listen_for_incoming_connections(endpoint.clone(), connection_tx); + + Ok(( + Endpoint { + local_addr: self.addr, + public_addr: None, + inner: endpoint, + }, + IncomingConnections(connection_rx), + )) + } + + /// Instantiate a client (unreachable) [`Endpoint`] using the parameters passed to this builder. + pub async fn client(self) -> Result { + let (_, cfg_cli) = self.config()?; + + let mut endpoint = quinn::Endpoint::client(self.addr)?; + endpoint.set_default_client_config(cfg_cli); + + Ok(Endpoint { + local_addr: self.addr, + public_addr: None, + inner: endpoint, + }) + } + + /// Helper to construct a [`TransportConfig`] from our parameters. + fn transport_config(&self) -> TransportConfig { + let mut config = TransportConfig::default(); + let _ = config.max_idle_timeout(self.max_idle_timeout); + let _ = config.keep_alive_interval(self.keep_alive_interval); + let _ = config.max_concurrent_bidi_streams(self.max_concurrent_bidi_streams); + let _ = config.max_concurrent_uni_streams(self.max_concurrent_uni_streams); + + config + } + + fn config(&self) -> Result<(quinn::ServerConfig, quinn::ClientConfig), EndpointError> { + let transport = Arc::new(self.transport_config()); + + let (mut server, mut client) = config().map_err(EndpointError::Certificate)?; + let _ = server.transport_config(Arc::clone(&transport)); + let _ = client.transport_config(Arc::clone(&transport)); + + Ok((server, client)) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum EndpointError { + #[error("Certificate could not be generated for config")] + Certificate(CertificateError), + + #[error("Failed to bind UDP socket")] + Socket(#[from] std::io::Error), +} + +#[derive(thiserror::Error, Debug)] +pub enum CertificateError { + #[error("Rcgen internal error generating certificate")] + Rcgen(#[from] rcgen::RcgenError), + #[error("Certificate or name validation error")] + WebPki(#[from] webpki::Error), + #[error("Rustls internal error")] + Rustls(#[from] rustls::Error), +} + +// We use a hard-coded server name for self-signed certificates. +pub(crate) const SERVER_NAME: &str = "maidsafe.net"; + +fn config() -> Result<(quinn::ServerConfig, quinn::ClientConfig), CertificateError> { + let mut roots = rustls::RootCertStore::empty(); + let (cert, key) = generate_cert()?; + roots.add(&cert)?; + + let mut client_crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth(); + + // allow client to connect to unknown certificates, eg those generated above + client_crypto + .dangerous() + .set_certificate_verifier(Arc::new(SkipServerVerification)); + + let server = quinn::ServerConfig::with_single_cert(vec![cert], key)?; + let client = quinn::ClientConfig::new(Arc::new(client_crypto)); + + Ok((server, client)) +} + +fn generate_cert() -> Result<(rustls::Certificate, rustls::PrivateKey), rcgen::RcgenError> { + let cert = rcgen::generate_simple_self_signed(vec![SERVER_NAME.to_string()])?; + + let key = cert.serialize_private_key_der(); + let cert = cert.serialize_der().unwrap(); + + let key = rustls::PrivateKey(key); + let cert = rustls::Certificate(cert); + Ok((cert, key)) +} + +// Implementation of `ServerCertVerifier` that verifies everything as trustworthy. +struct SkipServerVerification; + +impl rustls::client::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result { + Ok(rustls::client::ServerCertVerified::assertion()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 7337f280..e334572f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,7 @@ pub mod config; mod connection; mod endpoint; +mod endpoint_builder; mod error; mod utils; mod wire_msg; @@ -43,6 +44,7 @@ mod wire_msg; pub use config::{Config, ConfigError}; pub use connection::{Connection, ConnectionIncoming, RecvStream, SendStream}; pub use endpoint::{Endpoint, IncomingConnections}; +pub use endpoint_builder::EndpointBuilder; pub use error::{ ClientEndpointError, Close, ConnectionError, EndpointError, InternalConfigError, RecvError, RpcError, SendError, StreamError, TransportErrorCode, UnsupportedStreamOperation,