From a2ed158c3cbc473e99c45f29a0422c2f2a6bb98f Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 8 Sep 2023 10:57:50 +0200 Subject: [PATCH 1/2] Clean up library imports --- src/client.rs | 10 ++++++++-- src/common/handshake.rs | 6 ++++-- src/common/mod.rs | 9 +++++---- src/common/test_stream.rs | 10 ++++++---- src/lib.rs | 30 +++++++++++++++--------------- src/server.rs | 9 +++++++-- 6 files changed, 45 insertions(+), 29 deletions(-) diff --git a/src/client.rs b/src/client.rs index f8d8d07d..f03448fe 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,15 @@ -use super::*; -use crate::common::IoSession; +use std::io; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use rustls::ClientConnection; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::common::{IoSession, Stream, TlsState}; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. diff --git a/src/common/handshake.rs b/src/common/handshake.rs index 67252778..ac78165a 100644 --- a/src/common/handshake.rs +++ b/src/common/handshake.rs @@ -1,12 +1,14 @@ -use crate::common::{Stream, TlsState}; -use rustls::{ConnectionCommon, SideData}; use std::future::Future; use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; + +use rustls::{ConnectionCommon, SideData}; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::common::{Stream, TlsState}; + pub(crate) trait IoSession { type Io; type Session; diff --git a/src/common/mod.rs b/src/common/mod.rs index fde34c0d..664f8f93 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,13 +1,14 @@ -mod handshake; - -pub(crate) use handshake::{IoSession, MidHandshake}; -use rustls::{ConnectionCommon, SideData}; use std::io::{self, IoSlice, Read, Write}; use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; + +use rustls::{ConnectionCommon, SideData}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +mod handshake; +pub(crate) use handshake::{IoSession, MidHandshake}; + #[derive(Debug)] pub enum TlsState { #[cfg(feature = "early-data")] diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index b643ed24..d58eabc0 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -1,12 +1,14 @@ -use super::Stream; -use futures_util::future::poll_fn; -use futures_util::task::noop_waker_ref; -use rustls::{ClientConnection, Connection, ServerConnection}; use std::io::{self, Cursor, Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; + +use futures_util::future::poll_fn; +use futures_util::task::noop_waker_ref; +use rustls::{ClientConnection, Connection, ServerConnection}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use super::Stream; + struct Good<'a>(&'a mut Connection); impl<'a> AsyncRead for Good<'a> { diff --git a/src/lib.rs b/src/lib.rs index 000245ce..e39ee353 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,20 @@ //! //! see +use std::future::Future; +use std::io; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub use rustls; +use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + macro_rules! ready { ( $e:expr ) => { match $e { @@ -47,23 +61,9 @@ macro_rules! ready { pub mod client; mod common; +use common::{MidHandshake, TlsState}; pub mod server; -use common::{MidHandshake, Stream, TlsState}; -use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; -use std::future::Future; -use std::io; -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, RawSocket}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -pub use rustls; - /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { diff --git a/src/server.rs b/src/server.rs index f39f80f7..9444a625 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,15 @@ +use std::io; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; +use std::pin::Pin; +use std::task::{Context, Poll}; -use super::*; -use crate::common::IoSession; +use rustls::ServerConnection; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::common::{IoSession, Stream, TlsState}; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. From 3e97391da87111031f65defec8621e8324f8f56c Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 8 Sep 2023 11:12:48 +0200 Subject: [PATCH 2/2] Update to rustls 0.22 alpha --- Cargo.toml | 8 ++++---- examples/client.rs | 22 ++++------------------ examples/server.rs | 19 +++++++++---------- src/lib.rs | 20 +++++++++++--------- tests/badssl.rs | 21 +++++---------------- tests/early-data.rs | 25 ++++++++----------------- tests/test.rs | 45 +++++++++++++++++++-------------------------- tests/utils.rs | 29 ++++++++++++----------------- 8 files changed, 72 insertions(+), 117 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ca5b7170..ebc3ec63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ exclude = ["/.github", "/examples", "/scripts"] [dependencies] tokio = "1.0" -rustls = { version = "0.21.6", default-features = false } +rustls = { version = "=0.22.0-alpha.2", default-features = false } [features] default = ["logging", "tls12"] @@ -29,6 +29,6 @@ argh = "0.1" tokio = { version = "1.0", features = ["full"] } futures-util = "0.3.1" lazy_static = "1" -webpki-roots = "0.25" -rustls-pemfile = "1" -webpki = { package = "rustls-webpki", version = "0.101.2", features = ["alloc", "std"] } +webpki-roots = "=0.26.0-alpha.1" +rustls-pemfile = "=2.0.0-alpha.1" +webpki = { package = "rustls-webpki", version = "=0.102.0-alpha.2", features = ["alloc", "std"] } diff --git a/examples/client.rs b/examples/client.rs index 3b6eb0f6..193f0133 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,7 +8,6 @@ use std::sync::Arc; use argh::FromArgs; use tokio::io::{copy, split, stdin as tokio_stdin, stdout as tokio_stdout, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_rustls::rustls::{self, OwnedTrustAnchor}; use tokio_rustls::TlsConnector; /// Tokio Rustls client example @@ -45,24 +44,11 @@ async fn main() -> io::Result<()> { let mut root_cert_store = rustls::RootCertStore::empty(); if let Some(cafile) = &options.cafile { let mut pem = BufReader::new(File::open(cafile)?); - let certs = rustls_pemfile::certs(&mut pem)?; - let trust_anchors = certs.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }); - root_cert_store.add_trust_anchors(trust_anchors); + for cert in rustls_pemfile::certs(&mut pem) { + root_cert_store.add(cert?).unwrap(); + } } else { - root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); } let config = rustls::ClientConfig::builder() diff --git a/examples/server.rs b/examples/server.rs index 94dcc405..b81b76c3 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -8,8 +8,8 @@ use argh::FromArgs; use rustls_pemfile::{certs, rsa_private_keys}; use tokio::io::{copy, sink, split, AsyncWriteExt}; use tokio::net::TcpListener; -use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; +use webpki::types::{CertificateDer, PrivateKeyDer}; /// Tokio Rustls server example #[derive(FromArgs)] @@ -31,16 +31,15 @@ struct Options { echo_mode: bool, } -fn load_certs(path: &Path) -> io::Result> { - certs(&mut BufReader::new(File::open(path)?)) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) - .map(|mut certs| certs.drain(..).map(Certificate).collect()) +fn load_certs(path: &Path) -> io::Result>> { + certs(&mut BufReader::new(File::open(path)?)).collect() } -fn load_keys(path: &Path) -> io::Result> { +fn load_keys(path: &Path) -> io::Result> { rsa_private_keys(&mut BufReader::new(File::open(path)?)) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) - .map(|mut keys| keys.drain(..).map(PrivateKey).collect()) + .next() + .unwrap() + .map(Into::into) } #[tokio::main] @@ -53,13 +52,13 @@ async fn main() -> io::Result<()> { .next() .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?; let certs = load_certs(&options.cert)?; - let mut keys = load_keys(&options.key)?; + let key = load_keys(&options.key)?; let flag_echo = options.echo_mode; let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, keys.remove(0)) + .with_single_cert(certs, key) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; let acceptor = TlsAcceptor::from(Arc::new(config)); diff --git a/src/lib.rs b/src/lib.rs index e39ee353..3045d927 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; pub use rustls; +use rustls::crypto::ring::Ring; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -67,7 +68,7 @@ pub mod server; /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { - inner: Arc, + inner: Arc>, #[cfg(feature = "early-data")] early_data: bool, } @@ -75,11 +76,11 @@ pub struct TlsConnector { /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { - inner: Arc, + inner: Arc>, } -impl From> for TlsConnector { - fn from(inner: Arc) -> TlsConnector { +impl From>> for TlsConnector { + fn from(inner: Arc>) -> TlsConnector { TlsConnector { inner, #[cfg(feature = "early-data")] @@ -88,8 +89,8 @@ impl From> for TlsConnector { } } -impl From> for TlsAcceptor { - fn from(inner: Arc) -> TlsAcceptor { +impl From>> for TlsAcceptor { + fn from(inner: Arc>) -> TlsAcceptor { TlsAcceptor { inner } } } @@ -210,9 +211,10 @@ where /// # Example /// /// ```no_run + /// # use rustls::crypto::ring::Ring; /// # fn choose_server_config( /// # _: rustls::server::ClientHello, - /// # ) -> std::sync::Arc { + /// # ) -> std::sync::Arc> { /// # unimplemented!(); /// # } /// # #[allow(unused_variables)] @@ -304,11 +306,11 @@ where self.accepted.client_hello() } - pub fn into_stream(self, config: Arc) -> Accept { + pub fn into_stream(self, config: Arc>) -> Accept { self.into_stream_with(config, |_| ()) } - pub fn into_stream_with(self, config: Arc, f: F) -> Accept + pub fn into_stream_with(self, config: Arc>, f: F) -> Accept where F: FnOnce(&mut ServerConnection), { diff --git a/tests/badssl.rs b/tests/badssl.rs index a130b6c8..21afdefa 100644 --- a/tests/badssl.rs +++ b/tests/badssl.rs @@ -2,16 +2,17 @@ use std::io; use std::net::ToSocketAddrs; use std::sync::Arc; +use rustls::crypto::ring::Ring; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio_rustls::{ client::TlsStream, - rustls::{self, ClientConfig, OwnedTrustAnchor}, + rustls::{self, ClientConfig}, TlsConnector, }; async fn get( - config: Arc, + config: Arc>, domain: &str, port: u16, ) -> io::Result<(TlsStream, String)> { @@ -34,13 +35,7 @@ async fn get( #[tokio::test] async fn test_tls12() -> io::Result<()> { let mut root_store = rustls::RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = rustls::ClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() @@ -72,13 +67,7 @@ fn test_tls13() { #[tokio::test] async fn test_modern() -> io::Result<()> { let mut root_store = rustls::RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) diff --git a/tests/early-data.rs b/tests/early-data.rs index 9d83bcca..dc44e29d 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -10,16 +10,13 @@ use std::thread; use std::time::Duration; use futures_util::{future, future::Future, ready}; -use rustls::RootCertStore; +use rustls::crypto::ring::Ring; +use rustls::{self, ClientConfig, RootCertStore}; use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; use tokio::sync::oneshot; use tokio::time::sleep; -use tokio_rustls::{ - client::TlsStream, - rustls::{self, ClientConfig, OwnedTrustAnchor}, - TlsConnector, -}; +use tokio_rustls::{client::TlsStream, TlsConnector}; struct Read1(T); @@ -42,7 +39,7 @@ impl Future for Read1 { } async fn send( - config: Arc, + config: Arc>, addr: SocketAddr, data: &[u8], ) -> io::Result> { @@ -132,17 +129,11 @@ async fn test_0rtt() -> io::Result<()> { wait_for_server(format!("127.0.0.1:{}", server_port).as_str()).await; let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); - let certs = rustls_pemfile::certs(&mut chain).unwrap(); - let trust_anchors = certs.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }); let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(trust_anchors); + for cert in rustls_pemfile::certs(&mut chain) { + root_store.add(cert.unwrap()).unwrap(); + } + let mut config = rustls::ClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() diff --git a/tests/test.rs b/tests/test.rs index 9f83db8d..8f9699b8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -7,7 +7,8 @@ use std::{io, thread}; use futures_util::future::TryFutureExt; use lazy_static::lazy_static; -use rustls::{ClientConfig, OwnedTrustAnchor}; +use rustls::crypto::ring::Ring; +use rustls::ClientConfig; use rustls_pemfile::{certs, rsa_private_keys}; use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; @@ -22,17 +23,17 @@ const RSA: &str = include_str!("end.rsa"); lazy_static! { static ref TEST_SERVER: (SocketAddr, &'static str, &'static [u8]) = { let cert = certs(&mut BufReader::new(Cursor::new(CERT))) - .unwrap() - .drain(..) - .map(rustls::Certificate) + .map(|result| result.unwrap()) .collect(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut keys = keys.drain(..).map(rustls::PrivateKey); + let key = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))) + .next() + .unwrap() + .unwrap(); let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(cert, keys.next().unwrap()) + .with_single_cert(cert, key.into()) .unwrap(); let acceptor = TlsAcceptor::from(Arc::new(config)); @@ -83,7 +84,11 @@ fn start_server() -> &'static (SocketAddr, &'static str, &'static [u8]) { &TEST_SERVER } -async fn start_client(addr: SocketAddr, domain: &str, config: Arc) -> io::Result<()> { +async fn start_client( + addr: SocketAddr, + domain: &str, + config: Arc>, +) -> io::Result<()> { const FILE: &[u8] = include_bytes!("../README.md"); let domain = rustls::ServerName::try_from(domain).unwrap(); @@ -111,16 +116,10 @@ async fn pass() -> io::Result<()> { use std::time::*; tokio::time::sleep(Duration::from_secs(1)).await; - let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap(); let mut root_store = rustls::RootCertStore::empty(); - root_store.add_trust_anchors(chain.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + for cert in certs(&mut std::io::Cursor::new(*chain)) { + root_store.add(cert.unwrap()).unwrap(); + } let config = rustls::ClientConfig::builder() .with_safe_defaults() @@ -137,16 +136,10 @@ async fn pass() -> io::Result<()> { async fn fail() -> io::Result<()> { let (addr, domain, chain) = start_server(); - let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap(); let mut root_store = rustls::RootCertStore::empty(); - root_store.add_trust_anchors(chain.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + for cert in certs(&mut std::io::Cursor::new(*chain)) { + root_store.add(cert.unwrap()).unwrap(); + } let config = rustls::ClientConfig::builder() .with_safe_defaults() diff --git a/tests/utils.rs b/tests/utils.rs index 8e8237f9..ae916f03 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -2,39 +2,34 @@ mod utils { use std::io::{BufReader, Cursor}; use std::sync::Arc; - use rustls::{ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerConfig}; + use rustls::crypto::ring::Ring; + use rustls::{ClientConfig, RootCertStore, ServerConfig}; use rustls_pemfile::{certs, rsa_private_keys}; #[allow(dead_code)] - pub fn make_configs() -> (Arc, Arc) { + pub fn make_configs() -> (Arc>, Arc>) { const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); let cert = certs(&mut BufReader::new(Cursor::new(CERT))) - .unwrap() - .drain(..) - .map(rustls::Certificate) + .map(|result| result.unwrap()) .collect(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut keys = keys.drain(..).map(PrivateKey); + let key = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))) + .next() + .unwrap() + .unwrap(); let sconfig = ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(cert, keys.next().unwrap()) + .with_single_cert(cert, key.into()) .unwrap(); let mut client_root_cert_store = RootCertStore::empty(); let mut chain = BufReader::new(Cursor::new(CHAIN)); - let certs = certs(&mut chain).unwrap(); - client_root_cert_store.add_trust_anchors(certs.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + for cert in certs(&mut chain) { + client_root_cert_store.add(cert.unwrap()).unwrap(); + } let cconfig = ClientConfig::builder() .with_safe_defaults()