Skip to content

Commit

Permalink
Refactor TLS code to be a bit easier to read
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Oct 20, 2020
1 parent 2eb7c06 commit 70a22e2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
7 changes: 4 additions & 3 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ decimal = [ "rust_decimal", "num-bigint" ]
json = [ "serde", "serde_json" ]

# runtimes
runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_rt-actix" ]
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_rt-async-std" ]
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_rt-tokio" ]
runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_tls-native-tls", "_rt-actix" ]
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_tls-native-tls", "_rt-async-std" ]
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_tls-native-tls", "_rt-tokio" ]

# for conditional compilation
_rt-actix = []
_rt-async-std = []
_rt-tokio = []
_tls-native-tls = []

# support offline/decoupled building (enables serialization of `Describe`)
offline = [ "serde", "either/serde" ]
Expand Down
14 changes: 8 additions & 6 deletions sqlx-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ impl Error {
pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self {
Error::Configuration(err.into())
}

#[allow(dead_code)]
#[inline]
pub(crate) fn tls(err: impl StdError + Send + Sync + 'static) -> Self {
Error::Tls(err.into())
}
}

pub(crate) fn mismatched_types<DB: Database, T: Type<DB>>(ty: &DB::TypeInfo) -> BoxDynError {
Expand Down Expand Up @@ -240,6 +234,14 @@ impl From<crate::migrate::MigrateError> for Error {
}
}

#[cfg(feature = "_tls-native-tls")]
impl From<sqlx_rt::native_tls::Error> for Error {
#[inline]
fn from(error: sqlx_rt::native_tls::Error) -> Self {
Error::Tls(Box::new(error))
}
}

// Format an error message as a `Protocol` error
macro_rules! err_protocol {
($expr:expr) => {
Expand Down
13 changes: 4 additions & 9 deletions sqlx-core/src/net/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ where
if !accept_invalid_certs {
if let Some(ca) = root_cert_path {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
let cert = Certificate::from_pem(&data)?;

builder.add_root_certificate(cert);
}
}

#[cfg(not(feature = "_rt-async-std"))]
let connector = builder.build().map_err(Error::tls)?;
let connector = sqlx_rt::TlsConnector::from(builder.build()?);

#[cfg(feature = "_rt-async-std")]
let connector = builder;
let connector = sqlx_rt::TlsConnector::from(builder);

let stream = match replace(self, MaybeTlsStream::Upgrading) {
MaybeTlsStream::Raw(stream) => stream,
Expand All @@ -75,12 +75,7 @@ where
}
};

*self = MaybeTlsStream::Tls(
sqlx_rt::TlsConnector::from(connector)
.connect(host, stream)
.await
.map_err(|err| Error::Tls(err.into()))?,
);
*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);

Ok(())
}
Expand Down

0 comments on commit 70a22e2

Please sign in to comment.