Skip to content

Commit

Permalink
Support SSL in windows
Browse files Browse the repository at this point in the history
  • Loading branch information
elemount authored and Shuode Li committed Dec 27, 2018
1 parent e88ee0a commit 693ccda
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 23 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ lto = true
default = ['mysql_common']
nightly = ['mysql_common']
rustc_serialize = ['mysql_common/rustc_serialize', 'rustc-serialize']
ssl = ['mysql_common', "openssl", "security-framework"]
ssl = ['mysql_common', "openssl", "security-framework", "schannel"]

[dev-dependencies]
serde_derive = "1"
Expand Down Expand Up @@ -79,6 +79,10 @@ version = "~0.2"
optional = true
features = ["OSX_10_9"]

[target.'cfg(target_os = "windows")'.dependencies.schannel]
version = "~0.1"
optional = true

[target.'cfg(target_os = "windows")'.dependencies]
named_pipe = "~0.3"
winapi = "~0.3"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ features = ["rustc-serialize"]
```

### Windows support (since 0.18.0)
Windows is supported but currently rust-mysql-simple has no support for SSL on Windows.
Windows is supported.
2 changes: 2 additions & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ before_test:
$newText = ([System.IO.File]::ReadAllText($iniPath)).Replace("# enable-named-pipe", "enable-named-pipe")
$newText = $newText + "`nssl-ca=c:/clone/tests/ca-cert.pem`nssl-cert=c:/clone/tests/server-cert.pem`nssl-key=c:/clone/tests/server-key.pem"
[System.IO.File]::WriteAllText($iniPath, $newText)
Restart-Service MySQL57
Expand Down
22 changes: 13 additions & 9 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,12 @@ impl Conn {
}
}

#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
#[cfg(not(feature = "ssl"))]
fn switch_to_ssl(&mut self) -> MyResult<()> {
unimplemented!();
}

#[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))]
fn switch_to_ssl(&mut self) -> MyResult<()> {
match self.stream.take() {
Some(ConnStream::Plain(stream)) => {
Expand All @@ -798,11 +803,6 @@ impl Conn {
Ok(())
}

#[cfg(any(not(feature = "ssl"), target_os = "windows"))]
fn switch_to_ssl(&mut self) -> MyResult<()> {
unimplemented!();
}

fn connect_stream(&mut self) -> MyResult<()> {
let read_timeout = self.opts.get_read_timeout().cloned();
let write_timeout = self.opts.get_write_timeout().cloned();
Expand Down Expand Up @@ -2075,7 +2075,11 @@ mod test {
builder.into()
}

#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
#[cfg(all(
feature = "ssl",
not(target_os = "macos"),
any(unix, target_os = "windows")
))]
pub fn get_opts() -> Opts {
let pwd: String = env::var("MYSQL_SERVER_PASS").unwrap_or(PASS.to_string());
let port: u16 = env::var("MYSQL_SERVER_PORT")
Expand All @@ -2099,7 +2103,7 @@ mod test {
builder.into()
}

#[cfg(any(not(feature = "ssl"), target_os = "windows"))]
#[cfg(not(feature = "ssl"))]
pub fn get_opts() -> Opts {
let pwd: String = env::var("MYSQL_SERVER_PASS").unwrap_or(PASS.to_string());
let port: u16 = env::var("MYSQL_SERVER_PORT")
Expand Down Expand Up @@ -2556,7 +2560,7 @@ mod test {
}

#[test]
#[cfg(all(feature = "ssl", any(target_os = "macos", unix)))]
#[cfg(all(feature = "ssl", any(target_os = "macos", target_os = "windows", unix)))]
fn should_connect_via_ssl() {
let mut opts = OptsBuilder::from_opts(get_opts());
opts.prefer_socket(false);
Expand Down
17 changes: 8 additions & 9 deletions src/conn/opts.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::consts::CapabilityFlags;

use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
#[cfg(all(feature = "ssl", not(target_os = "windows")))]
#[cfg(all(feature = "ssl"))]
use std::path;
use std::str::FromStr;

Expand Down Expand Up @@ -29,8 +29,8 @@ pub type SslOpts = Option<Option<(path::PathBuf, String, Vec<path::PathBuf>)>>;
pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>;

#[cfg(all(feature = "ssl", target_os = "windows"))]
/// Not implemented on Windows
pub type SslOpts = Option<()>;
/// Ssl options: Option<(pem_ca_cert, Option<(pem_client_cert, pem_client_key)>)>.`
pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>;

#[cfg(not(feature = "ssl"))]
/// Requires `ssl` feature
Expand Down Expand Up @@ -445,7 +445,11 @@ impl OptsBuilder {
self
}

#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
#[cfg(all(
feature = "ssl",
not(target_os = "macos"),
any(unix, target_os = "windows")
))]
/// SSL certificates and keys in pem format.
///
/// If not None, then ssl connection implied.
Expand Down Expand Up @@ -487,11 +491,6 @@ impl OptsBuilder {
self
}

/// Not implemented on windows
#[cfg(all(feature = "ssl", target_os = "windows"))]
pub fn ssl_opts<A, B, C>(&mut self, _: Option<SslOpts>) -> &mut Self {
panic!("OptsBuilder::ssl_opts is not implemented on Windows");
}

/// Requires `ssl` feature
#[cfg(not(feature = "ssl"))]
Expand Down
113 changes: 110 additions & 3 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::net::SocketAddr;
use std::slice::Chunks;
use std::time::Duration;

#[cfg(all(feature = "ssl", not(target_os = "windows")))]
#[cfg(all(feature = "ssl"))]
use crate::conn::SslOpts;

use super::consts;
Expand All @@ -32,6 +32,14 @@ use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression};
use named_pipe as np;
#[cfg(all(feature = "ssl", all(unix, not(target_os = "macos"))))]
use openssl::ssl::{self, SslContext, SslStream};
#[cfg(all(feature = "ssl", target_os = "windows"))]
use schannel::cert_context::CertContext;
#[cfg(all(feature = "ssl", target_os = "windows"))]
use schannel::cert_store;
#[cfg(all(feature = "ssl", target_os = "windows"))]
use schannel::schannel_cred;
#[cfg(all(feature = "ssl", target_os = "windows"))]
use schannel::tls_stream;
#[cfg(all(feature = "ssl", target_os = "macos"))]
use security_framework::certificate::SecCertificate;
#[cfg(all(feature = "ssl", target_os = "macos"))]
Expand Down Expand Up @@ -763,6 +771,103 @@ impl Stream {
}
}

#[cfg(all(feature = "ssl", target_os = "windows"))]
impl Stream {
pub fn make_secure(
mut self,
verify_peer: bool,
ip_or_hostname: Option<&str>,
ssl_opts: &SslOpts,
) -> MyResult<Stream> {
use std::path::Path;

fn load_cert_data(path: &Path) -> MyResult<String> {
let mut client_file = ::std::fs::File::open(path)?;
let mut client_data = String::new();
client_file.read_to_string(&mut client_data)?;
Ok(client_data)
}

fn load_client_cert(path: &Path) -> MyResult<CertContext> {
let cert_data = load_cert_data(path)?;
let cert = CertContext::from_pem(&cert_data)?;
Ok(cert)
}

fn load_client_cert_with_key(cert_path: &Path, key_path: &Path) -> MyResult<CertContext> {
let mut cert_data = load_cert_data(cert_path)?;
let cert = CertContext::from_pem(&cert_data)?;
let key_data = load_cert_data(key_path)?;
cert_data.push_str(&key_data);
Ok(cert)
}

fn load_ca_store(path: &Path) -> MyResult<cert_store::CertStore> {
let ca_cert = load_client_cert(path)?;
let mut cert_store = cert_store::Memory::new().unwrap().into_store();
cert_store.add_cert(&ca_cert, cert_store::CertAdd::Always)?;
Ok(cert_store)
}

if self.is_insecure() {
let mut stream_builder = tls_stream::Builder::new();
let mut cred_builder = schannel_cred::Builder::default();
cred_builder.enabled_protocols(&[
schannel_cred::Protocol::Tls10,
schannel_cred::Protocol::Tls11,
]);
cred_builder.supported_algorithms(&[
schannel_cred::Algorithm::DhEphem,
schannel_cred::Algorithm::RsaSign,
schannel_cred::Algorithm::Aes256,
schannel_cred::Algorithm::Sha1,
]);
if verify_peer {
stream_builder.domain(ip_or_hostname.as_ref().unwrap_or(&("localhost".into())));
}

match *ssl_opts {
Some((ref ca_cert, None)) => {
stream_builder.cert_store(load_ca_store(&ca_cert)?);
}
Some((ref ca_cert, Some((ref client_cert, ref client_key)))) => {
cred_builder.cert(load_client_cert_with_key(&client_cert, &client_key)?);
stream_builder.cert_store(load_ca_store(&ca_cert)?);
}
_ => unreachable!(),
}

let cred = cred_builder.acquire(schannel_cred::Direction::Outbound)?;
match self {
Stream::TcpStream(ref mut opt_stream) if opt_stream.is_some() => {
let stream = opt_stream.take().unwrap();
match stream {
TcpStream::Insecure(mut stream) => {
stream.flush()?;
let s_stream = match stream_builder
.connect(cred, stream.into_inner().unwrap())
{
Ok(s_stream) => s_stream,
Err(tls_stream::HandshakeError::Failure(err)) => {
return Err(err.into());
}
Err(tls_stream::HandshakeError::Interrupted(_)) => unreachable!(),
};
Ok(Stream::TcpStream(Some(TcpStream::Secure(BufStream::new(
s_stream,
)))))
}
_ => unreachable!(),
}
}
_ => unreachable!(),
}
} else {
Ok(self)
}
}
}

#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
impl Stream {
pub fn make_secure(
Expand Down Expand Up @@ -838,13 +943,15 @@ impl Drop for Stream {
pub enum TcpStream {
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
Secure(BufStream<SslStream<net::TcpStream>>),
#[cfg(all(feature = "ssl", target_os = "windows"))]
Secure(BufStream<tls_stream::TlsStream<net::TcpStream>>),
Insecure(BufStream<net::TcpStream>),
}

impl AsMut<dyn IoPack> for TcpStream {
fn as_mut(&mut self) -> &mut dyn IoPack {
match *self {
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
#[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))]
TcpStream::Secure(ref mut stream) => stream,
TcpStream::Insecure(ref mut stream) => stream,
}
Expand All @@ -854,7 +961,7 @@ impl AsMut<dyn IoPack> for TcpStream {
impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
#[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))]
TcpStream::Secure(_) => write!(f, "Secure stream"),
TcpStream::Insecure(ref s) => write!(f, "Insecure stream {:?}", s),
}
Expand Down

0 comments on commit 693ccda

Please sign in to comment.