From 446b65550a403e67408b8a5a8eb312eda95b1cdb Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 14 Dec 2023 02:27:35 +0100 Subject: [PATCH] Introduce tokio_boring::SslStreamBuilder --- boring/src/ssl/mod.rs | 5 ++++ tokio-boring/src/lib.rs | 65 +++++++++++++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 74dc3fb0..b767457b 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -3903,6 +3903,11 @@ impl SslStreamBuilder { &self.inner.ssl } + /// Returns a mutable reference to the `Ssl` object associated with this builder. + pub fn ssl_mut(&mut self) -> &mut SslRef { + &mut self.inner.ssl + } + /// Set the DTLS MTU size. /// /// It will be ignored if the value is smaller than the minimum packet size diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index a63c3068..c8ad4f3b 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -13,7 +13,6 @@ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -use boring::error::ErrorStack; use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, @@ -51,7 +50,11 @@ pub async fn connect( where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| config.setup_connect(domain, s), stream).await + let mid_handshake = config + .setup_connect(domain, AsyncStreamBridge::new(stream)) + .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; + + HandshakeFuture(Some(mid_handshake)).await } /// Asynchronously performs a server-side TLS handshake over the provided stream. @@ -62,19 +65,8 @@ pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| acceptor.setup_accept(s), stream).await -} - -async fn handshake( - f: impl FnOnce( - AsyncStreamBridge, - ) -> Result>, ErrorStack>, - stream: S, -) -> Result, HandshakeError> -where - S: AsyncRead + AsyncWrite + Unpin, -{ - let mid_handshake = f(AsyncStreamBridge::new(stream)) + let mid_handshake = acceptor + .setup_accept(AsyncStreamBridge::new(stream)) .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; HandshakeFuture(Some(mid_handshake)).await @@ -88,6 +80,49 @@ fn cvt(r: io::Result) -> Poll> { } } +/// A partially constructed `SslStream`, useful for unusual handshakes. +pub struct SslStreamBuilder { + inner: ssl::SslStreamBuilder>, +} + +impl SslStreamBuilder +where + S: AsyncRead + AsyncWrite + Unpin, +{ + /// Begins creating an `SslStream` atop `stream`. + pub fn new(ssl: ssl::Ssl, stream: S) -> Self { + Self { + inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)), + } + } + + /// Initiates a client-side TLS handshake. + pub async fn accept(self) -> Result, HandshakeError> { + let mid_handshake = self.inner.setup_accept(); + + HandshakeFuture(Some(mid_handshake)).await + } + + /// Initiates a server-side TLS handshake. + pub async fn connect(self) -> Result, HandshakeError> { + let mid_handshake = self.inner.setup_connect(); + + HandshakeFuture(Some(mid_handshake)).await + } +} + +impl SslStreamBuilder { + /// Returns a shared reference to the `Ssl` object associated with this builder. + pub fn ssl(&self) -> &SslRef { + self.inner.ssl() + } + + /// Returns a mutable reference to the `Ssl` object associated with this builder. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.inner.ssl_mut() + } +} + /// A wrapper around an underlying raw stream which implements the SSL /// protocol. ///