Skip to content

Commit

Permalink
Forward vectored writes
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini committed Mar 3, 2024
1 parent 6f18143 commit da19bfc
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 9 deletions.
76 changes: 76 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,82 @@ where
}
}

/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[allow(clippy::match_single_binding)]
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let mut written = 0;

for buf in bufs {
if buf.is_empty() {
continue;
}

let len = match early_data.write(buf) {
Ok(0) => break,
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};

written += len;
data.extend_from_slice(&buf[..len]);

if len < buf.len() {
break;
}
}

if written != 0 {
return Poll::Ready(Ok(written));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}

stream.as_mut_pin().poll_write_vectored(cx, bufs)
}
_ => stream.as_mut_pin().poll_write_vectored(cx, bufs),
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Expand Down
37 changes: 37 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,43 @@ where
Poll::Ready(Ok(pos))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if !bufs.iter().any(|buf| !buf.is_empty()) {
return Poll::Ready(Ok(0));
}

loop {
let mut would_block = false;
let written = self.session.writer().write_vectored(bufs)?;

while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

return match (written, would_block) {
(0, true) => Poll::Pending,
(0, false) => continue,
(n, _) => Poll::Ready(Ok(n)),
};
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.writer().flush()?;
while self.session.wants_write() {
Expand Down
11 changes: 10 additions & 1 deletion src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ impl AsyncWrite for Expected {

#[tokio::test]
async fn stream_good() -> io::Result<()> {
stream_good_impl(false).await
}

#[tokio::test]
async fn stream_good_vectored() -> io::Result<()> {
stream_good_impl(true).await
}

async fn stream_good_impl(vectored: bool) -> io::Result<()> {
const FILE: &[u8] = include_bytes!("../../README.md");

let (server, mut client) = make_pair();
Expand All @@ -139,7 +148,7 @@ async fn stream_good() -> io::Result<()> {
dbg!(stream.read_to_end(&mut buf).await)?;
assert_eq!(buf, FILE);

dbg!(stream.write_all(b"Hello World!").await)?;
dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?;
stream.session.send_close_notify();

dbg!(stream.shutdown().await)?;
Expand Down
20 changes: 20 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,26 @@ where
}
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
match self {
TlsStream::Client(x) => x.is_write_vectored(),
TlsStream::Server(x) => x.is_write_vectored(),
}
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Expand Down
18 changes: 18 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ where
stream.as_mut_pin().poll_write(cx, buf)
}

/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_write_vectored(cx, bufs)
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Expand Down
28 changes: 25 additions & 3 deletions tests/badssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ async fn get(
config: Arc<ClientConfig>,
domain: &str,
port: u16,
vectored: bool,
) -> io::Result<(TlsStream<TcpStream>, String)> {
let connector = TlsConnector::from(config);
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
Expand All @@ -24,7 +25,7 @@ async fn get(

let stream = TcpStream::connect(&addr).await?;
let mut stream = connector.connect(domain, stream).await?;
stream.write_all(input.as_bytes()).await?;
utils::write(&mut stream, input.as_bytes(), vectored).await?;
stream.flush().await?;
stream.read_to_end(&mut buf).await?;

Expand All @@ -33,6 +34,15 @@ async fn get(

#[tokio::test]
async fn test_tls12() -> io::Result<()> {
test_tls12_impl(false).await
}

#[tokio::test]
async fn test_tls12_vectored() -> io::Result<()> {
test_tls12_impl(true).await
}

async fn test_tls12_impl(vectored: bool) -> io::Result<()> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS12])
Expand All @@ -42,7 +52,7 @@ async fn test_tls12() -> io::Result<()> {
let config = Arc::new(config);
let domain = "tls-v1-2.badssl.com";

let (_, output) = get(config.clone(), domain, 1012).await?;
let (_, output) = get(config.clone(), domain, 1012, vectored).await?;
assert!(
output.contains("<title>tls-v1-2.badssl.com</title>"),
"failed badssl test, output: {}",
Expand All @@ -61,6 +71,15 @@ fn test_tls13() {

#[tokio::test]
async fn test_modern() -> io::Result<()> {
test_modern_impl(false).await
}

#[tokio::test]
async fn test_modern_vectored() -> io::Result<()> {
test_modern_impl(true).await
}

async fn test_modern_impl(vectored: bool) -> io::Result<()> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder()
Expand All @@ -69,7 +88,7 @@ async fn test_modern() -> io::Result<()> {
let config = Arc::new(config);
let domain = "mozilla-modern.badssl.com";

let (_, output) = get(config.clone(), domain, 443).await?;
let (_, output) = get(config.clone(), domain, 443, vectored).await?;
assert!(
output.contains("<title>mozilla-modern.badssl.com</title>"),
"failed badssl test, output: {}",
Expand All @@ -78,3 +97,6 @@ async fn test_modern() -> io::Result<()> {

Ok(())
}

// Include `utils` module
include!("utils.rs");
20 changes: 16 additions & 4 deletions tests/early-data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async fn send(
config: Arc<ClientConfig>,
addr: SocketAddr,
data: &[u8],
vectored: bool,
) -> io::Result<TlsStream<TcpStream>> {
let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?;
Expand Down Expand Up @@ -81,7 +82,7 @@ async fn send(

wait.await.unwrap();

wd.write_all(data).await?;
utils::write(&mut wd, data, vectored).await?;
wd.flush().await?;
wd.shutdown().await?;

Expand Down Expand Up @@ -111,7 +112,15 @@ async fn wait_for_server(addr: &str) {

#[tokio::test]
async fn test_0rtt() -> io::Result<()> {
let server_port = 12354;
test_0rtt_impl(12354, false).await
}

#[tokio::test]
async fn test_0rtt_vectored() -> io::Result<()> {
test_0rtt_impl(12353, true).await
}

async fn test_0rtt_impl(server_port: u16, vectored: bool) -> io::Result<()> {
let mut handle = Command::new("openssl")
.arg("s_server")
.arg("-early_data")
Expand Down Expand Up @@ -152,10 +161,10 @@ async fn test_0rtt() -> io::Result<()> {
}
});

let io = send(config.clone(), addr, b"hello").await?;
let io = send(config.clone(), addr, b"hello", vectored).await?;
assert!(!io.get_ref().1.is_early_data_accepted());

let io = send(config, addr, b"world!").await?;
let io = send(config, addr, b"world!", vectored).await?;
assert!(io.get_ref().1.is_early_data_accepted());

let stdout = handle.0.stdout.as_mut().unwrap();
Expand All @@ -168,3 +177,6 @@ async fn test_0rtt() -> io::Result<()> {

Ok(())
}

// Include `utils` module
include!("utils.rs");
27 changes: 26 additions & 1 deletion tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
mod utils {
use std::io::{BufReader, Cursor};
use std::io::{BufReader, Cursor, IoSlice};
use std::sync::Arc;

use rustls::{ClientConfig, RootCertStore, ServerConfig};
use rustls_pemfile::{certs, rsa_private_keys};
use tokio::io::{self, AsyncWrite, AsyncWriteExt};

#[allow(dead_code)]
pub fn make_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>) {
Expand Down Expand Up @@ -35,4 +36,28 @@ mod utils {

(Arc::new(sconfig), Arc::new(cconfig))
}

#[allow(dead_code)]
pub async fn write<W: AsyncWrite + Unpin>(
w: &mut W,
data: &[u8],
vectored: bool,
) -> io::Result<()> {
if vectored {
let mut data = data;

while !data.is_empty() {
let chunk_size = (data.len() / 4).max(1);
let vectors = data
.chunks(chunk_size)
.map(IoSlice::new)
.collect::<Vec<_>>();
let written = w.write_vectored(&vectors).await?;
data = &data[written..];
}
} else {
w.write_all(data).await?;
}
Ok(())
}
}

0 comments on commit da19bfc

Please sign in to comment.