Skip to content

Commit

Permalink
Merge pull request #9 from second-state/feat/wasi_tls
Browse files Browse the repository at this point in the history
Feat/wasi tls
  • Loading branch information
juntao authored Nov 11, 2023
2 parents 117e143 + f46c069 commit 0798e95
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 36 deletions.
40 changes: 21 additions & 19 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0"
name = "mysql_async_wasi"
readme = "README.md"
repository = "https://github.com/WasmEdge/mysql_async_wasi"
version = "0.31.5"
version = "0.31.6"
exclude = ["test/*"]
edition = "2018"
categories = ["asynchronous", "database"]
Expand Down Expand Up @@ -53,17 +53,33 @@ tokio_wasi = { version = "1", features = [
] }
tokio-util_wasi = { version = "0.7.2", features = ["codec", "io"] }
wasmedge_wasi_socket = "0.5"
wasmedge_rustls_api = { version = "0.1.0", optional = true, features = [
"tokio_async",
] }

# [target.'cfg(not(target_os="wasi"))'.dev-dependencies]
# tempfile = "3.1.0"
# socket2 = { version = "0.4.0", features = ["all"] }
# tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] }
# rand = "0.8.0"

[target.'cfg(target_os="wasi")'.dev-dependencies]
[dev-dependencies]
tempfile = "3.1.0"
tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] }
tokio-util = { version = "0.7.2", features = ["codec", "io"] }

[target.'cfg(target_os="wasi")'.dev-dependencies]
tempfile = "3.1.0"
tokio_wasi = { version = "1", features = [
"io-util",
"fs",
"net",
"time",
"rt",
"macros",
] }
rand = "0.8.0"


[dependencies.tokio-rustls]
version = "0.23.4"
Expand Down Expand Up @@ -94,25 +110,9 @@ optional = true
version = "0.22.1"
optional = true

[dev-dependencies]
tempfile = "3.1.0"
tokio_wasi = { version = "1", features = [
"io-util",
"fs",
"net",
"time",
"rt",
"macros",
] }
rand = "0.8.0"

[features]
default = ["common", "rust_backend"]
default-rustls = [
"common",
"rust_backend",
"rustls-tls",
]
default-rustls = ["common", "rust_backend", "wasmedge-tls"]
common = [
"mysql_common/bigdecimal03",
"mysql_common/rust_decimal",
Expand All @@ -129,6 +129,7 @@ rustls-tls = [
"webpki-roots",
"rustls-pemfile",
]
wasmedge-tls = ["wasmedge_rustls_api"]
minimal = ["flate2/zlib"]
zlib = ["flate2/zlib"]
full = ["default", "zlib"]
Expand All @@ -140,3 +141,4 @@ path = "src/lib.rs"

[profile.bench]
debug = true

22 changes: 15 additions & 7 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,22 @@ impl Conn {

/// Returns true if io stream is encrypted.
fn is_secure(&self) -> bool {
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
if let Some(ref stream) = self.inner.stream {
stream.is_secure()
} else {
false
}

#[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
#[cfg(not(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
)))]
false
}

Expand Down Expand Up @@ -486,7 +494,7 @@ impl Conn {
};
Ok(())
}
#[cfg(not(target_os = "wasi"))]
#[cfg(any(not(target_os = "wasi"), feature = "wasmedge-tls"))]
async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
if self
.inner
Expand All @@ -503,12 +511,12 @@ impl Conn {
}

let collation = if self.inner.version >= (5, 5, 3) {
UTF8MB4_GENERAL_CI
mysql_common::constants::UTF8MB4_GENERAL_CI
} else {
UTF8_GENERAL_CI
crate::consts::UTF8MB4_GENERAL_CI
};

let ssl_request = SslRequest::new(
let ssl_request = mysql_common::packets::SslRequest::new(
self.inner.capabilities,
DEFAULT_MAX_ALLOWED_PACKET as u32,
collation as u8,
Expand Down Expand Up @@ -846,7 +854,7 @@ impl Conn {
conn.inner.stream = Some(stream);
conn.setup_stream()?;
conn.handle_handshake().await?;
#[cfg(not(target_os = "wasi"))]
#[cfg(any(not(target_os = "wasi"), feature = "wasmedge-tls"))]
conn.switch_to_ssl_if_needed().await?;
conn.do_handshake_response().await?;
conn.continue_auth().await?;
Expand Down
6 changes: 5 additions & 1 deletion src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ impl Error {
pub enum IoError {
#[error("Input/output error: {}", _0)]
Io(#[source] io::Error),
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
#[error("TLS error: `{}'", _0)]
Tls(#[source] tls::TlsError),
}
Expand Down
9 changes: 8 additions & 1 deletion src/error/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#![cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#![cfg(any(
feature = "native-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]

pub mod native_tls_error;
pub mod rustls_error;
Expand All @@ -8,3 +12,6 @@ pub use native_tls_error::TlsError;

#[cfg(feature = "rustls")]
pub use rustls_error::TlsError;

#[cfg(feature = "wasmedge-tls")]
pub use wasmedge_rustls_api::TlsError;
56 changes: 50 additions & 6 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ pub(crate) enum Endpoint {
Secure(#[pin] tokio_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "rustls-tls")]
Secure(#[pin] tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
#[cfg(feature = "wasmedge-tls")]
Secure(#[pin] wasmedge_rustls_api::stream::async_stream::TlsStream<tokio::net::TcpStream>),
#[cfg(unix)]
Socket(#[pin] Socket),
}
Expand Down Expand Up @@ -150,7 +152,14 @@ impl Future for CheckTcpStream<'_> {
}

impl Endpoint {
#[cfg(all(any(feature = "native-tls-tls", feature = "rustls-tls"), unix))]
#[cfg(all(
any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
),
unix
))]
fn is_socket(&self) -> bool {
match self {
Self::Socket(_) => true,
Expand All @@ -177,6 +186,12 @@ impl Endpoint {
CheckTcpStream(stream).await?;
Ok(())
}
#[cfg(feature = "wasmedge-tls")]
Endpoint::Secure(tls_stream) => {
let stream = tls_stream.get_mut().0;
CheckTcpStream(stream).await?;
Ok(())
}
#[cfg(unix)]
Endpoint::Socket(socket) => {
socket.write(&[]).await?;
Expand All @@ -186,12 +201,20 @@ impl Endpoint {
}
}

#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
pub fn is_secure(&self) -> bool {
matches!(self, Endpoint::Secure(_))
}

#[cfg(all(not(feature = "native-tls"), not(feature = "rustls")))]
#[cfg(all(
not(feature = "native-tls"),
not(feature = "rustls"),
not(feature = "wasmedge-tls")
))]
pub async fn _make_secure(
&mut self,
_domain: String,
Expand All @@ -216,6 +239,11 @@ impl Endpoint {
let stream = stream.get_ref().0;
stream.set_nodelay(val)?;
}
#[cfg(feature = "wasmedge-tls")]
Endpoint::Secure(ref stream) => {
let stream = stream.get_ref().0;
stream.set_nodelay(val)?;
}
#[cfg(unix)]
Endpoint::Socket(_) => (/* inapplicable */),
}
Expand Down Expand Up @@ -262,6 +290,8 @@ impl AsyncRead for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
})
Expand All @@ -283,6 +313,8 @@ impl AsyncWrite for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
})
Expand All @@ -301,6 +333,8 @@ impl AsyncWrite for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
})
Expand All @@ -319,6 +353,8 @@ impl AsyncWrite for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
})
Expand Down Expand Up @@ -409,12 +445,14 @@ impl Stream {
pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val)
}
#[cfg(not(target_os = "wasi"))]
#[cfg(any(not(target_os = "wasi"), feature = "wasmedge-tls"))]
pub(crate) async fn make_secure(
&mut self,
domain: String,
ssl_opts: SslOpts,
ssl_opts: crate::SslOpts,
) -> crate::error::Result<()> {
use tokio_util::codec::FramedParts;

let codec = self.codec.take().unwrap();
let FramedParts { mut io, codec, .. } = codec.into_parts();
io.make_secure(domain, ssl_opts).await?;
Expand All @@ -423,7 +461,11 @@ impl Stream {
Ok(())
}

#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
pub(crate) fn is_secure(&self) -> bool {
self.codec.as_ref().unwrap().get_ref().is_secure()
}
Expand Down Expand Up @@ -506,6 +548,8 @@ mod test {
super::Endpoint::Plain(Some(stream)) => stream,
#[cfg(feature = "rustls-tls")]
super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
#[cfg(feature = "wasmedge-tls")]
super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
#[cfg(feature = "native-tls")]
super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(),
_ => unreachable!(),
Expand Down
3 changes: 2 additions & 1 deletion src/io/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg(any(feature = "native-tls", feature = "rustls"))]
#![cfg(any(feature = "native-tls", feature = "rustls", feature = "wasmedge-tls"))]

mod native_tls_io;
mod rustls_io;
mod wasmedge_rustls_io;
34 changes: 34 additions & 0 deletions src/io/tls/wasmedge_rustls_io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#![cfg(feature = "wasmedge-tls")]

use wasmedge_rustls_api::{stream::async_stream::TlsStream, ClientConfig};

use crate::{io::Endpoint, Result};

impl Endpoint {
pub async fn make_secure(&mut self, domain: String, _ssl_opts: crate::SslOpts) -> Result<()> {
#[cfg(unix)]
if self.is_socket() {
// won't secure socket connection
return Ok(());
}

let config = ClientConfig::default();

*self = match self {
Endpoint::Plain(ref mut stream) => {
let stream = stream.take().unwrap();

let connection = TlsStream::connect(&config, domain, stream)
.await
.map_err(|e| e.0)?;

Endpoint::Secure(connection)
}
Endpoint::Secure(_) => unreachable!(),
#[cfg(unix)]
Endpoint::Socket(_) => unreachable!(),
};

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/opts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl HostPortOrUrl {
/// ```
#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
pub struct SslOpts {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "rustls-tls",))]
client_identity: Option<ClientIdentity>,
root_cert_path: Option<Cow<'static, Path>>,
skip_domain_validation: bool,
Expand Down

0 comments on commit 0798e95

Please sign in to comment.