diff --git a/Cargo.lock b/Cargo.lock index 2955621342..d9c95c54c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2179,6 +2179,7 @@ dependencies = [ "url", "uuid", "webpki", + "webpki-roots", "whoami", ] @@ -2857,6 +2858,15 @@ dependencies = [ "untrusted", ] +[[package]] +name = "webpki-roots" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f20dea7535251981a9670857150d571846545088359b28e4951d350bdaf179f" +dependencies = [ + "webpki", +] + [[package]] name = "wepoll-sys" version = "3.0.1" diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index dab47b00a8..d135c2c565 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -47,7 +47,7 @@ _rt-actix = [] _rt-async-std = [] _rt-tokio = [] _tls-native-tls = [] -_tls-rustls = [ "rustls", "webpki" ] +_tls-rustls = [ "rustls", "webpki", "webpki-roots" ] # support offline/decoupled building (enables serialization of `Describe`) offline = [ "serde", "either/serde" ] @@ -91,7 +91,7 @@ parking_lot = "0.11.0" rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] } regex = { version = "1.3.9", optional = true } rsa = { version = "0.3.0", optional = true } -rustls = { version = "0.18.1", optional = true } +rustls = { version = "0.18.1", features = [ "dangerous_configuration" ], optional = true } serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true } serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true } sha-1 = { version = "0.9.0", default-features = false, optional = true } @@ -103,6 +103,7 @@ smallvec = "1.4.0" url = { version = "2.1.1", default-features = false } uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] } webpki = { version = "0.21.3", optional = true } +webpki-roots = { version = "0.20.0", optional = true } whoami = "0.9.0" stringprep = "0.1.2" lru-cache = "0.1.2" diff --git a/sqlx-core/src/net/tls.rs b/sqlx-core/src/net/tls/mod.rs similarity index 89% rename from sqlx-core/src/net/tls.rs rename to sqlx-core/src/net/tls/mod.rs index f3e0806b41..4fb1cfb8a7 100644 --- a/sqlx-core/src/net/tls.rs +++ b/sqlx-core/src/net/tls/mod.rs @@ -6,11 +6,14 @@ use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; -use sqlx_rt::{fs, AsyncRead, AsyncWrite, TlsStream}; +use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream}; use crate::error::Error; use std::mem::replace; +#[cfg(feature = "_tls-rustls")] +mod rustls; + pub enum MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin, @@ -73,7 +76,10 @@ async fn configure_tls_connector( accept_invalid_hostnames: bool, root_cert_path: Option<&Path>, ) -> Result { - use sqlx_rt::native_tls::{Certificate, TlsConnector}; + use sqlx_rt::{ + fs, + native_tls::{Certificate, TlsConnector}, + }; let mut builder = TlsConnector::builder(); builder @@ -99,29 +105,7 @@ async fn configure_tls_connector( } #[cfg(feature = "_tls-rustls")] -async fn configure_tls_connector( - _accept_invalid_certs: bool, - _accept_invalid_hostnames: bool, - root_cert_path: Option<&Path>, -) -> Result { - // FIXME: Support accept_invalid_certs / accept_invalid_hostnames - - use rustls::ClientConfig; - use std::io::Cursor; - use std::sync::Arc; - - let mut config = ClientConfig::new(); - - if let Some(ca) = root_cert_path { - let data = fs::read(ca).await?; - let mut cursor = Cursor::new(data); - config.root_store.add_pem_file(&mut cursor).map_err(|_| { - Error::Tls(format!("Invalid certificate file: {}", ca.display()).into()) - })?; - } - - Ok(Arc::new(config).into()) -} +use self::rustls::configure_tls_connector; impl AsyncRead for MaybeTlsStream where diff --git a/sqlx-core/src/net/tls/rustls.rs b/sqlx-core/src/net/tls/rustls.rs new file mode 100644 index 0000000000..fcedff754f --- /dev/null +++ b/sqlx-core/src/net/tls/rustls.rs @@ -0,0 +1,78 @@ +use rustls::{ + Certificate, ClientConfig, RootCertStore, ServerCertVerified, ServerCertVerifier, TLSError, + WebPKIVerifier, +}; +use sqlx_rt::fs; +use std::sync::Arc; +use std::{io::Cursor, path::Path}; +use webpki::DNSNameRef; + +use crate::error::Error; + +pub async fn configure_tls_connector( + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + root_cert_path: Option<&Path>, +) -> Result { + let mut config = ClientConfig::new(); + + if accept_invalid_certs { + config + .dangerous() + .set_certificate_verifier(Arc::new(DummyTlsVerifier)); + } else { + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + + if let Some(ca) = root_cert_path { + let data = fs::read(ca).await?; + let mut cursor = Cursor::new(data); + config.root_store.add_pem_file(&mut cursor).map_err(|_| { + Error::Tls(format!("Invalid certificate file: {}", ca.display()).into()) + })?; + } + + if accept_invalid_hostnames { + config + .dangerous() + .set_certificate_verifier(Arc::new(NoHostnameTlsVerifier)); + } + } + + Ok(Arc::new(config).into()) +} + +struct DummyTlsVerifier; + +impl ServerCertVerifier for DummyTlsVerifier { + fn verify_server_cert( + &self, + _roots: &RootCertStore, + _presented_certs: &[Certificate], + _dns_name: DNSNameRef<'_>, + _ocsp_response: &[u8], + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} + +pub struct NoHostnameTlsVerifier; + +impl ServerCertVerifier for NoHostnameTlsVerifier { + fn verify_server_cert( + &self, + roots: &RootCertStore, + presented_certs: &[Certificate], + dns_name: DNSNameRef<'_>, + ocsp_response: &[u8], + ) -> Result { + let verifier = WebPKIVerifier::new(); + match verifier.verify_server_cert(roots, presented_certs, dns_name, ocsp_response) { + Err(TLSError::WebPKIError(webpki::Error::CertNotValidForName)) => { + Ok(ServerCertVerified::assertion()) + } + res => res, + } + } +}