diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index de2e07b7..a0903cb7 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -500,7 +500,7 @@ impl Active { } } - fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_new_outbound(&mut self, cx: &Context<'_>) -> Poll> { if self.streams.len() >= self.config.max_num_streams { log::error!("{}: maximum number of streams reached", self.id); return Poll::Ready(Err(ConnectionError::TooManyStreams)); @@ -722,7 +722,7 @@ impl Active { shared.update_state(self.id, stream_id, State::RecvClosed); } shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); + shared.buffer.push(frame.into_body().into()); if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { if let Some(credit) = shared.next_window_update() { @@ -765,7 +765,7 @@ impl Active { return Action::Reset(Frame::new(header)); } shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); + shared.buffer.push(frame.into_body().into()); if let Some(w) = shared.reader.take() { w.wake() } diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index b4dfa816..87c6d0ca 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -33,7 +33,7 @@ impl Future for Cleanup { type Output = ConnectionError; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.get_mut(); + let this = self.get_mut(); loop { match this.state { diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index d503941f..b02c465e 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -45,7 +45,7 @@ where type Output = Result<()>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.get_mut(); + let this = self.get_mut(); loop { match this.state { diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 692840a4..1f8bace5 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -12,11 +12,12 @@ pub mod header; mod io; use futures::future::Either; -use header::{Data, GoAway, Header, Ping, StreamId, WindowUpdate}; +use header::{Data, GoAway, Header, StreamId, WindowUpdate}; use std::{convert::TryInto, num::TryFromIntError}; pub use io::FrameDecodeError; pub(crate) use io::Io; +use crate::frame::header::{HEADER_SIZE, Ping}; /// A Yamux message frame consisting of header and body. #[derive(Clone, Debug, PartialEq, Eq)] @@ -32,7 +33,6 @@ impl Frame { body: Vec::new(), } } - pub fn header(&self) -> &Header { &self.header } @@ -56,6 +56,14 @@ impl Frame { body: self.body, } } + + pub(crate) fn body(&self) -> &[u8] { + self.body.as_slice() + } + + pub(crate) fn len(&self) -> usize { + self.body.len() + HEADER_SIZE + } } impl From> for Frame<()> { @@ -68,6 +76,23 @@ impl From> for Frame<()> { } impl Frame<()> { + pub(crate) fn try_from_header_buffer( + buffer: &[u8; HEADER_SIZE], + max_body_len: usize, + ) -> Result, Frame>, FrameDecodeError> { + let header = header::decode(buffer)?; + + let either = match header.try_into_data() { + Ok(data) if data.body_len() > max_body_len => { + return Err(FrameDecodeError::FrameTooLarge(data.body_len())); + } + Ok(data) => Either::Right(Frame::new(data)), + Err(other) => Either::Left(Frame::new(other)), + }; + + Ok(either) + } + pub(crate) fn into_data(self) -> Frame { Frame { header: self.header.into_data(), @@ -108,14 +133,17 @@ impl Frame { Frame::new(header) } - pub fn body(&self) -> &[u8] { - &self.body + pub fn body_mut(&mut self) -> &mut [u8] { + self.body.as_mut_slice() } pub fn body_len(&self) -> u32 { - // Safe cast since we construct `Frame::`s only with - // `Vec` of length [0, u32::MAX] in `Frame::data` above. - self.body().len() as u32 + let len_in_header = self.header.body_len(); + let actual_len = self.body.len(); + + // debug_assert_eq!(len_in_header, actual_len); + + len_in_header as u32 } pub fn into_body(self) -> Vec { diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index cbbf704d..4b4fa0c5 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -48,13 +48,9 @@ impl Header { self.stream_id } - pub fn len(&self) -> Len { - self.length - } - #[cfg(test)] - pub fn set_len(&mut self, len: u32) { - self.length = Len(len) + pub fn set_len(&mut self, len: usize) { + self.length = Len(len as u32) } /// Arbitrary type cast, use with caution. @@ -92,6 +88,14 @@ impl Header<()> { self.cast() } + pub(crate) fn try_into_data(self) -> Result, Self> { + if self.tag == Tag::Data { + return Ok(self.into_data()); + } + + Err(self) + } + pub(crate) fn into_window_update(self) -> Header { debug_assert_eq!(self.tag, Tag::WindowUpdate); self.cast() @@ -101,6 +105,7 @@ impl Header<()> { debug_assert_eq!(self.tag, Tag::Ping); self.cast() } + } impl Header { @@ -143,6 +148,14 @@ impl Header { _marker: std::marker::PhantomData, } } + + /// Returns the length of the body. + /// + /// The `length` field in the header has a different semantic meaning depending on the tag. + /// For [`Tag::Data`], it describes the length of the body. + pub fn body_len(&self) -> usize { + self.length.val() as usize + } } impl Header { diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 795d9f5c..b180931c 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -13,12 +13,15 @@ use super::{ Frame, }; use crate::connection::Id; +use crate::frame::header::{Data, HEADER_SIZE}; +use futures::future::Either; use futures::{prelude::*, ready}; use std::{ - fmt, io, + fmt, io, mem, pin::Pin, task::{Context, Poll}, }; +use std::io::{IoSlice}; /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] @@ -35,7 +38,7 @@ impl Io { Io { id, io, - read_state: ReadState::Init, + read_state: ReadState::header(), write_state: WriteState::Init, max_body_len: max_frame_body_len, } @@ -45,30 +48,19 @@ impl Io { /// The stages of writing a new `Frame`. enum WriteState { Init, - Header { - header: [u8; header::HEADER_SIZE], - buffer: Vec, - offset: usize, - }, - Body { - buffer: Vec, - offset: usize, - }, + Writing { frame: Frame<()>, offset: usize }, } impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { WriteState::Init => f.write_str("(WriteState::Init)"), - WriteState::Header { offset, .. } => { - write!(f, "(WriteState::Header (offset {}))", offset) - } - WriteState::Body { offset, buffer } => { + WriteState::Writing { offset, frame } => { write!( f, - "(WriteState::Body (offset {}) (buffer-len {}))", + "(WriteState::Writing (offset {}) (buffer-len {}))", offset, - buffer.len() + frame.len() ) } } @@ -84,56 +76,31 @@ impl Sink> for Io { log::trace!("{}: write: {:?}", this.id, this.write_state); match &mut this.write_state { WriteState::Init => return Poll::Ready(Ok(())), - WriteState::Header { - header, - buffer, - ref mut offset, - } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == header.len() { - if !buffer.is_empty() { - let buffer = std::mem::take(buffer); - this.write_state = WriteState::Body { buffer, offset: 0 }; - } else { - this.write_state = WriteState::Init; - } - } + WriteState::Writing { frame, offset } => { + let io = Pin::new(&mut this.io); + + let n = if *offset < HEADER_SIZE { + ready!(io.poll_write_vectored(cx, &[IoSlice::new(&header::encode(frame.header())[*offset..]), IoSlice::new(frame.body())]))? + } else { + let body_offset = *offset - HEADER_SIZE; + ready!(io.poll_write_vectored(cx, &[IoSlice::new(&frame.body()[body_offset..])]))? + }; + + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - }, - WriteState::Body { - buffer, - ref mut offset, - } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == buffer.len() { - this.write_state = WriteState::Init; - } + + *offset += n; + if *offset == frame.len() { + this.write_state = WriteState::Init; } - }, + } } } } - fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> { - let header = header::encode(&f.header); - let buffer = f.body; - self.get_mut().write_state = WriteState::Header { - header, - buffer, - offset: 0, - }; + fn start_send(self: Pin<&mut Self>, frame: Frame<()>) -> Result<(), Self::Error> { + self.get_mut().write_state = WriteState::Writing { frame, offset: 0 }; Ok(()) } @@ -152,66 +119,50 @@ impl Sink> for Io { /// The stages of reading a new `Frame`. enum ReadState { - /// Initial reading state. - Init, /// Reading the frame header. Header { offset: usize, buffer: [u8; header::HEADER_SIZE], }, /// Reading the frame body. - Body { - header: header::Header<()>, - offset: usize, - buffer: Vec, - }, + Body { frame: Frame, offset: usize }, +} + +impl ReadState { + fn header() -> Self { + ReadState::Header { + offset: 0, + buffer: [0u8; header::HEADER_SIZE], + } + } } impl Stream for Io { type Item = Result, FrameDecodeError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = &mut *self; + let this = &mut *self; loop { log::trace!("{}: read: {:?}", this.id, this.read_state); match this.read_state { - ReadState::Init => { - this.read_state = ReadState::Header { - offset: 0, - buffer: [0; header::HEADER_SIZE], - }; - } ReadState::Header { ref mut offset, ref mut buffer, } => { if *offset == header::HEADER_SIZE { - let header = match header::decode(buffer) { - Ok(hd) => hd, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; + let frame = Frame::try_from_header_buffer(buffer, this.max_body_len)?; - log::trace!("{}: read: {}", this.id, header); + log::trace!("{}: read: {:?}", this.id, frame); - if header.tag() != header::Tag::Data { - this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame::new(header)))); - } - - let body_len = header.len().val() as usize; - - if body_len > this.max_body_len { - return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( - body_len, - )))); - } - - this.read_state = ReadState::Body { - header, - offset: 0, - buffer: vec![0; body_len], + let frame = match frame { + Either::Right(data_frame) => data_frame, + Either::Left(other_frame) => { + this.read_state = ReadState::header(); + return Poll::Ready(Some(Ok(other_frame))); + } }; + this.read_state = ReadState::Body { frame, offset: 0 }; continue; } @@ -228,20 +179,20 @@ impl Stream for Io { } } ReadState::Body { - ref header, ref mut offset, - ref mut buffer, + ref mut frame, } => { - let body_len = header.len().val() as usize; + let body_len = frame.body_len() as usize; if *offset == body_len { - let h = header.clone(); - let v = std::mem::take(buffer); - this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame { header: h, body: v }))); + let frame = match mem::replace(&mut self.read_state, ReadState::header()) { + ReadState::Header { .. } => unreachable!("we matched above"), + ReadState::Body { frame, .. } => frame, + }; + return Poll::Ready(Some(Ok(frame.into()))); } - let buf = &mut buffer[*offset..body_len]; + let buf = &mut frame.body_mut()[*offset..]; match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); @@ -258,21 +209,16 @@ impl Stream for Io { impl fmt::Debug for ReadState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - ReadState::Init => f.write_str("(ReadState::Init)"), ReadState::Header { offset, .. } => { write!(f, "(ReadState::Header (offset {}))", offset) } - ReadState::Body { - header, - offset, - buffer, - } => { + ReadState::Body { frame, offset } => { write!( f, "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", - header, + frame.header(), offset, - buffer.len() + frame.body_len() ) } } @@ -331,16 +277,16 @@ mod tests { impl Arbitrary for Frame<()> { fn arbitrary(g: &mut Gen) -> Self { - let mut header: header::Header<()> = Arbitrary::arbitrary(g); - let body = if header.tag() == header::Tag::Data { - header.set_len(header.len().val() % 4096); - let mut b = vec![0; header.len().val() as usize]; - rand::thread_rng().fill_bytes(&mut b); - b + let header: header::Header<()> = Arbitrary::arbitrary(g); + if header.tag() == header::Tag::Data { + let mut header = header.into_data(); + header.set_len(header.body_len() % 4096); + let mut frame = Frame::new(header); + rand::thread_rng().fill_bytes(frame.body_mut()); + frame.into() } else { - Vec::new() - }; - Frame { header, body } + Frame::new(header) + } } } @@ -349,7 +295,7 @@ mod tests { fn property(f: Frame<()>) -> bool { futures::executor::block_on(async move { let id = crate::connection::Id::random(); - let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); + let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.len()); if io.send(f.clone()).await.is_err() { return false; }