Skip to content

Commit

Permalink
Support accept_invalid_certs & accept_invalid_hostnames with rustls
Browse files Browse the repository at this point in the history
Co-authored-by: BlackHoleFox <[email protected]>
  • Loading branch information
jplatte and BlackHoleFox committed Oct 22, 2020
1 parent 20ca367 commit 4048d2d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 27 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ _rt-actix = []
_rt-async-std = []
_rt-tokio = []
_tls-native-tls = []
_tls-rustls = [ "rustls", "webpki" ]
_tls-rustls = [ "rustls", "webpki", "webpki-roots" ]

# support offline/decoupled building (enables serialization of `Describe`)
offline = [ "serde", "either/serde" ]
Expand Down Expand Up @@ -91,7 +91,7 @@ parking_lot = "0.11.0"
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
regex = { version = "1.3.9", optional = true }
rsa = { version = "0.3.0", optional = true }
rustls = { version = "0.18.1", optional = true }
rustls = { version = "0.18.1", features = [ "dangerous_configuration" ], optional = true }
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
sha-1 = { version = "0.9.0", default-features = false, optional = true }
Expand All @@ -103,6 +103,7 @@ smallvec = "1.4.0"
url = { version = "2.1.1", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
webpki = { version = "0.21.3", optional = true }
webpki-roots = { version = "0.20.0", optional = true }
whoami = "0.9.0"
stringprep = "0.1.2"
lru-cache = "0.1.2"
34 changes: 9 additions & 25 deletions sqlx-core/src/net/tls.rs → sqlx-core/src/net/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};

use sqlx_rt::{fs, AsyncRead, AsyncWrite, TlsStream};
use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};

use crate::error::Error;
use std::mem::replace;

#[cfg(feature = "_tls-rustls")]
mod rustls;

pub enum MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
Expand Down Expand Up @@ -73,7 +76,10 @@ async fn configure_tls_connector(
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<sqlx_rt::TlsConnector, Error> {
use sqlx_rt::native_tls::{Certificate, TlsConnector};
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
};

let mut builder = TlsConnector::builder();
builder
Expand All @@ -99,29 +105,7 @@ async fn configure_tls_connector(
}

#[cfg(feature = "_tls-rustls")]
async fn configure_tls_connector(
_accept_invalid_certs: bool,
_accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<sqlx_rt::TlsConnector, Error> {
// FIXME: Support accept_invalid_certs / accept_invalid_hostnames

use rustls::ClientConfig;
use std::io::Cursor;
use std::sync::Arc;

let mut config = ClientConfig::new();

if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let mut cursor = Cursor::new(data);
config.root_store.add_pem_file(&mut cursor).map_err(|_| {
Error::Tls(format!("Invalid certificate file: {}", ca.display()).into())
})?;
}

Ok(Arc::new(config).into())
}
use self::rustls::configure_tls_connector;

impl<S> AsyncRead for MaybeTlsStream<S>
where
Expand Down
78 changes: 78 additions & 0 deletions sqlx-core/src/net/tls/rustls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use rustls::{
Certificate, ClientConfig, RootCertStore, ServerCertVerified, ServerCertVerifier, TLSError,
WebPKIVerifier,
};
use sqlx_rt::fs;
use std::sync::Arc;
use std::{io::Cursor, path::Path};
use webpki::DNSNameRef;

use crate::error::Error;

pub async fn configure_tls_connector(
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&Path>,
) -> Result<sqlx_rt::TlsConnector, Error> {
let mut config = ClientConfig::new();

if accept_invalid_certs {
config
.dangerous()
.set_certificate_verifier(Arc::new(DummyTlsVerifier));
} else {
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);

if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let mut cursor = Cursor::new(data);
config.root_store.add_pem_file(&mut cursor).map_err(|_| {
Error::Tls(format!("Invalid certificate file: {}", ca.display()).into())
})?;
}

if accept_invalid_hostnames {
config
.dangerous()
.set_certificate_verifier(Arc::new(NoHostnameTlsVerifier));
}
}

Ok(Arc::new(config).into())
}

struct DummyTlsVerifier;

impl ServerCertVerifier for DummyTlsVerifier {
fn verify_server_cert(
&self,
_roots: &RootCertStore,
_presented_certs: &[Certificate],
_dns_name: DNSNameRef<'_>,
_ocsp_response: &[u8],
) -> Result<ServerCertVerified, TLSError> {
Ok(ServerCertVerified::assertion())
}
}

pub struct NoHostnameTlsVerifier;

impl ServerCertVerifier for NoHostnameTlsVerifier {
fn verify_server_cert(
&self,
roots: &RootCertStore,
presented_certs: &[Certificate],
dns_name: DNSNameRef<'_>,
ocsp_response: &[u8],
) -> Result<ServerCertVerified, TLSError> {
let verifier = WebPKIVerifier::new();
match verifier.verify_server_cert(roots, presented_certs, dns_name, ocsp_response) {
Err(TLSError::WebPKIError(webpki::Error::CertNotValidForName)) => {
Ok(ServerCertVerified::assertion())
}
res => res,
}
}
}

0 comments on commit 4048d2d

Please sign in to comment.