Skip to content

Commit

Permalink
feat!: updates for rustls 0.23
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Mar 6, 2024
1 parent ab57fb8 commit cde4681
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 77 deletions.
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ edition = "2018"

[dependencies]
futures-io = "0.3"
rustls = { version = "0.22", default-features = false, features = ["tls12"] }
rustls = { version = "0.23", default-features = false, features = ["std"] }
pki-types = { package = "rustls-pki-types", version = "1" }

[features]
default = ["ring"]
default = ["aws-lc-rs", "tls12", "logging"]
aws-lc-rs = ["rustls/aws_lc_rs"]
early-data = []
fips = ["rustls/fips"]
logging = ["rustls/logging"]
ring = ["rustls/ring"]
aws-lc-rs = ["rustls/aws_lc_rs"]
tls12 = ["rustls/tls12"]

[dev-dependencies]
smol = "1"
Expand Down
4 changes: 0 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use super::*;
use crate::common::IoSession;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};

/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
Expand Down
29 changes: 26 additions & 3 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::common::{Stream, TlsState};
use crate::common::{Stream, SyncWriteAdapter, TlsState};
use futures_io::{AsyncRead, AsyncWrite};
use rustls::server::AcceptedAlert;
use rustls::{ConnectionCommon, SideData};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};
use futures_io::{AsyncRead, AsyncWrite};

pub(crate) trait IoSession {
type Io;
Expand All @@ -19,7 +20,15 @@ pub(crate) trait IoSession {
pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
End,
Error { io: IS::Io, error: io::Error },
SendAlert {
io: IS::Io,
alert: AcceptedAlert,
error: io::Error,
},
Error {
io: IS::Io,
error: io::Error,
},
}

impl<IS, SD> Future for MidHandshake<IS>
Expand All @@ -36,6 +45,20 @@ where

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
MidHandshake::SendAlert {
mut io,
mut alert,
error,
} => loop {
match alert.write(&mut SyncWriteAdapter { io: &mut io, cx }) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
*this = MidHandshake::SendAlert { io, error, alert };
return Poll::Pending;
}
Err(_) | Ok(0) => return Poll::Ready(Err((error, io))),
Ok(_) => {}
};
},
// Starting the handshake returned an error; fail the future immediately.
MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))),
_ => panic!("unexpected polling after handshake"),
Expand Down
88 changes: 36 additions & 52 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,7 @@ where
}

pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(Ok(n)) => Ok(n),
Poll::Ready(Err(err)) => Err(err),
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

let mut reader = Reader { io: self.io, cx };
let mut reader = SyncReadAdapter { io: self.io, cx };

let n = match self.session.read_tls(&mut reader) {
Ok(n) => n,
Expand Down Expand Up @@ -133,41 +117,7 @@ where
}

pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(&mut self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}

#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}

let mut writer = Writer { io: self.io, cx };
let mut writer = SyncWriteAdapter { io: self.io, cx };

match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Expand Down Expand Up @@ -371,5 +321,39 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
}
}

pub(crate) struct SyncWriteAdapter<'a, 'b, T> {
pub(crate) io: &'a mut T,
pub(crate) cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(&mut self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}

#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}

#[cfg(test)]
mod test_stream;
39 changes: 29 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Asynchronous TLS/SSL streams for futures using [Rustls](https://github.com/ctz/rustls).
//! Asynchronous TLS/SSL streams for futures using [Rustls](https://github.com/rustls/rustls).

macro_rules! ready {
( $e:expr ) => {
Expand All @@ -15,6 +15,7 @@ pub mod server;

use common::{MidHandshake, Stream, TlsState};
use futures_io::{AsyncRead, AsyncWrite};
use rustls::server::AcceptedAlert;
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use std::future::Future;
use std::io;
Expand All @@ -26,8 +27,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

pub use rustls;
pub use pki_types;
pub use rustls;

/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
#[derive(Clone)]
Expand Down Expand Up @@ -78,7 +79,12 @@ impl TlsConnector {
self.connect_with(domain, stream, |_| ())
}

pub fn connect_with<IO, F>(&self, domain: pki_types::ServerName<'static>, stream: IO, f: F) -> Connect<IO>
pub fn connect_with<IO, F>(
&self,
domain: pki_types::ServerName<'static>,
stream: IO,
f: F,
) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ClientConnection),
Expand Down Expand Up @@ -155,6 +161,7 @@ impl TlsAcceptor {
pub struct LazyConfigAcceptor<IO> {
acceptor: rustls::server::Acceptor,
io: Option<IO>,
alert: Option<(rustls::Error, AcceptedAlert)>,
}

impl<IO> LazyConfigAcceptor<IO>
Expand All @@ -166,6 +173,7 @@ where
Self {
acceptor,
io: Some(io),
alert: None,
}
}
}
Expand All @@ -189,6 +197,16 @@ where
}
};

if let Some((err, mut alert)) = this.alert.take() {
return match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
this.alert = Some((err, alert));
Poll::Pending
}
_ => Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))),
};
}

let mut reader = common::SyncReadAdapter { io, cx };
match this.acceptor.read_tls(&mut reader) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
Expand All @@ -202,9 +220,9 @@ where
let io = this.io.take().unwrap();
return Poll::Ready(Ok(StartHandshake { accepted, io }));
}
Ok(None) => continue,
Err(err) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
Ok(None) => {}
Err((err, alert)) => {
this.alert = Some((err, alert));
}
}
}
Expand Down Expand Up @@ -234,12 +252,13 @@ where
{
let mut conn = match self.accepted.into_connection(config) {
Ok(conn) => conn,
Err(error) => {
return Accept(MidHandshake::Error {
Err((error, alert)) => {
return Accept(MidHandshake::SendAlert {
io: self.io,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
alert,
});
}
};
Expand Down Expand Up @@ -333,11 +352,11 @@ impl<T> TlsStream<T> {
match self {
Client(io) => {
let (io, session) = io.get_ref();
(io, &*session)
(io, session)
}
Server(io) => {
let (io, session) = io.get_ref();
(io, &*session)
(io, session)
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};

use super::*;
use crate::common::IoSession;

Expand Down

0 comments on commit cde4681

Please sign in to comment.