From 1c06435b8ea1df6831324230c5706842fcdbad68 Mon Sep 17 00:00:00 2001 From: Philippe Jalaber Date: Wed, 19 Jul 2023 16:09:50 +0200 Subject: [PATCH 01/29] Compilation + test using zerocopy ok --- yamux/Cargo.toml | 1 + yamux/src/chunks.rs | 15 +- yamux/src/connection.rs | 130 ++++++++------- yamux/src/connection/cleanup.rs | 2 +- yamux/src/connection/closing.rs | 2 +- yamux/src/connection/stream.rs | 17 +- yamux/src/frame.rs | 193 +++++++++++++++------- yamux/src/frame/header.rs | 279 ++++++++++++++++++-------------- yamux/src/frame/io.rs | 203 +++++++++++------------ 9 files changed, 483 insertions(+), 359 deletions(-) diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 431bd6c7..372ba0e3 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -17,6 +17,7 @@ parking_lot = "0.12" rand = "0.8.3" static_assertions = "1" pin-project = "1.1.0" +zerocopy = { version = "0.7.0", features = ["derive"] } [dev-dependencies] quickcheck = "1.0" diff --git a/yamux/src/chunks.rs b/yamux/src/chunks.rs index 0e66d894..7f656a77 100644 --- a/yamux/src/chunks.rs +++ b/yamux/src/chunks.rs @@ -36,12 +36,15 @@ impl Chunks { } /// Add another chunk of bytes to the end. - pub(crate) fn push(&mut self, x: Vec) { - self.len += x.len(); - if !x.is_empty() { - self.seq.push_back(Chunk { - cursor: io::Cursor::new(x), - }) + pub(crate) fn push(&mut self, x: Vec, offset: usize) { + let x_len = x.len(); + let cursor = io::Cursor::new(x); + let mut chunk = Chunk { cursor }; + chunk.advance(offset); + if !chunk.is_empty() { + assert_eq!(chunk.len(), x_len - offset); + self.len += chunk.len() + offset; + self.seq.push_back(chunk); } } diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index de2e07b7..d2ee68c9 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -91,6 +91,7 @@ mod cleanup; mod closing; mod stream; +use crate::frame::header::HEADER_SIZE; use crate::tagged_stream::TaggedStream; use crate::{ error::ConnectionError, @@ -355,7 +356,7 @@ struct Active { socket: Fuse>, next_id: u32, - streams: IntMap>>, + streams: IntMap>>, stream_receivers: SelectAll>>, no_streams_waker: Option, @@ -519,8 +520,14 @@ impl Active { if extra_credit > 0 { let mut frame = Frame::window_update(id, extra_credit); - frame.header_mut().syn(); - log::trace!("{}/{}: sending initial {}", self.id, id, frame.header()); + let mut parsed_frame = frame.parse_mut().expect("valid frame"); + parsed_frame.header_mut().syn(); + log::trace!( + "{}/{}: sending initial {}", + self.id, + id, + parsed_frame.header() + ); self.pending_frames.push_back(frame.into()); } @@ -531,17 +538,18 @@ impl Active { } log::debug!("{}: new outbound {} of {}", self.id, stream, self); - self.streams.insert(id, stream.clone_shared()); + self.streams.insert(id.val(), stream.clone_shared()); Poll::Ready(Ok(stream)) } fn on_send_frame(&mut self, frame: Frame>) { + let parsed_frame = frame.parse().expect("valid frame"); log::trace!( "{}/{}: sending: {}", self.id, - frame.header().stream_id(), - frame.header() + parsed_frame.header().stream_id(), + parsed_frame.header() ); self.pending_frames.push_back(frame.into()); } @@ -553,7 +561,10 @@ impl Active { } fn on_drop_stream(&mut self, stream_id: StreamId) { - let s = self.streams.remove(&stream_id).expect("stream not found"); + let s = self + .streams + .remove(&stream_id.val()) + .expect("stream not found"); log::trace!("{}: removing dropped stream {}", self.id, stream_id); let frame = { @@ -564,7 +575,7 @@ impl Active { State::Open { .. } => { let mut header = Header::data(stream_id, 0); header.rst(); - Some(Frame::new(header)) + Some(Frame::from_header(header)) } // The stream was dropped without calling `poll_close`. // We have already received a FIN from remote and send one @@ -572,7 +583,7 @@ impl Active { State::RecvClosed => { let mut header = Header::data(stream_id, 0); header.fin(); - Some(Frame::new(header)) + Some(Frame::from_header(header)) } // The stream was properly closed. We already sent our FIN frame. // The remote may be out of credit though and blocked on @@ -585,7 +596,7 @@ impl Active { // which we will never send, so reset the stream now. let mut header = Header::data(stream_id, 0); header.rst(); - Some(Frame::new(header)) + Some(Frame::from_header(header)) } else { // The remote has either still credit or will be given more // (due to an enqueued window update or because the update @@ -609,7 +620,8 @@ impl Active { frame }; if let Some(f) = frame { - log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header()); + let pf = f.parse().expect("valid frame"); + log::trace!("{}/{}: sending: {}", self.id, stream_id, pf.header()); self.pending_frames.push_back(f.into()); } } @@ -621,11 +633,12 @@ impl Active { /// Otherwise we process the frame and potentially return a new `Stream` /// if one was opened by the remote. fn on_frame(&mut self, frame: Frame<()>) -> Result> { - log::trace!("{}: received: {}", self.id, frame.header()); + let parsed_frame = frame.parse().expect("valid frame"); + log::trace!("{}: received: {}", self.id, parsed_frame.header()); - if frame.header().flags().contains(header::ACK) { - let id = frame.header().stream_id(); - if let Some(stream) = self.streams.get(&id) { + if parsed_frame.header().flags().contains(header::ACK) { + let id = parsed_frame.header().stream_id(); + if let Some(stream) = self.streams.get(&id.val()) { stream .lock() .update_state(self.id, id, State::Open { acknowledged: true }); @@ -635,7 +648,7 @@ impl Active { } } - let action = match frame.header().tag() { + let action = match parsed_frame.header().tag().expect("valid header's tag") { Tag::Data => self.on_data(frame.into_data()), Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()), Tag::Ping => self.on_ping(&frame.into_ping()), @@ -646,21 +659,25 @@ impl Active { Action::New(stream, update) => { log::trace!("{}: new inbound {} of {}", self.id, stream, self); if let Some(f) = update { - log::trace!("{}/{}: sending update", self.id, f.header().stream_id()); + let pf = f.parse().expect("valid frame"); + log::trace!("{}/{}: sending update", self.id, pf.header().stream_id()); self.pending_frames.push_back(f.into()); } return Ok(Some(stream)); } Action::Update(f) => { - log::trace!("{}: sending update: {:?}", self.id, f.header()); + let pf = f.parse().expect("valid frame"); + log::trace!("{}: sending update: {:?}", self.id, pf.header()); self.pending_frames.push_back(f.into()); } Action::Ping(f) => { - log::trace!("{}/{}: pong", self.id, f.header().stream_id()); + let pf = f.parse().expect("valid frame"); + log::trace!("{}/{}: pong", self.id, pf.header().stream_id()); self.pending_frames.push_back(f.into()); } Action::Reset(f) => { - log::trace!("{}/{}: sending reset", self.id, f.header().stream_id()); + let pf = f.parse().expect("valid frame"); + log::trace!("{}/{}: sending reset", self.id, pf.header().stream_id()); self.pending_frames.push_back(f.into()); } Action::Terminate(f) => { @@ -673,11 +690,12 @@ impl Active { } fn on_data(&mut self, frame: Frame) -> Action { - let stream_id = frame.header().stream_id(); + let parsed_frame = frame.parse().expect("valid frame"); + let stream_id = parsed_frame.header().stream_id(); - if frame.header().flags().contains(header::RST) { + if parsed_frame.header().flags().contains(header::RST) { // stream reset - if let Some(s) = self.streams.get_mut(&stream_id) { + if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); shared.update_state(self.id, stream_id, State::Closed); if let Some(w) = shared.reader.take() { @@ -690,15 +708,15 @@ impl Active { return Action::None; } - let is_finish = frame.header().flags().contains(header::FIN); // half-close + let is_finish = parsed_frame.header().flags().contains(header::FIN); // half-close - if frame.header().flags().contains(header::SYN) { + if parsed_frame.header().flags().contains(header::SYN) { // new stream if !self.is_valid_remote_id(stream_id, Tag::Data) { log::error!("{}: invalid stream id {}", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } - if frame.body().len() > DEFAULT_CREDIT as usize { + if parsed_frame.body().len() > DEFAULT_CREDIT as usize { log::error!( "{}/{}: 1st body of stream exceeds default credit", self.id, @@ -706,7 +724,7 @@ impl Active { ); return Action::Terminate(Frame::protocol_error()); } - if self.streams.contains_key(&stream_id) { + if self.streams.contains_key(&stream_id.val()) { log::error!("{}/{}: stream already exists", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } @@ -721,14 +739,16 @@ impl Active { if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } - shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); + shared.window = shared.window.saturating_sub(parsed_frame.body_len()); + shared.buffer.push(frame.into_buffer(), HEADER_SIZE); if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { if let Some(credit) = shared.next_window_update() { shared.window += credit; + let mut frame = Frame::window_update(stream_id, credit); - frame.header_mut().ack(); + let mut parsed_frame = frame.parse_mut().expect("valid frame"); + parsed_frame.header_mut().ack(); window_update = Some(frame) } } @@ -736,13 +756,13 @@ impl Active { if window_update.is_none() { stream.set_flag(stream::Flag::Ack) } - self.streams.insert(stream_id, stream.clone_shared()); + self.streams.insert(stream_id.val(), stream.clone_shared()); return Action::New(stream, window_update); } - if let Some(s) = self.streams.get_mut(&stream_id) { + if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); - if frame.body().len() > shared.window as usize { + if parsed_frame.body().len() > shared.window as usize { log::error!( "{}/{}: frame body larger than window of stream", self.id, @@ -762,10 +782,10 @@ impl Active { ); let mut header = Header::data(stream_id, 0); header.rst(); - return Action::Reset(Frame::new(header)); + return Action::Reset(Frame::from_header(header)); } - shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); + shared.window = shared.window.saturating_sub(parsed_frame.body_len()); + shared.buffer.push(frame.into_buffer(), HEADER_SIZE); if let Some(w) = shared.reader.take() { w.wake() } @@ -796,11 +816,12 @@ impl Active { } fn on_window_update(&mut self, frame: &Frame) -> Action { - let stream_id = frame.header().stream_id(); + let parsed_frame = frame.parse().expect("valid frame"); + let stream_id = parsed_frame.header().stream_id(); - if frame.header().flags().contains(header::RST) { + if parsed_frame.header().flags().contains(header::RST) { // stream reset - if let Some(s) = self.streams.get_mut(&stream_id) { + if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); shared.update_state(self.id, stream_id, State::Closed); if let Some(w) = shared.reader.take() { @@ -813,15 +834,15 @@ impl Active { return Action::None; } - let is_finish = frame.header().flags().contains(header::FIN); // half-close + let is_finish = parsed_frame.header().flags().contains(header::FIN); // half-close - if frame.header().flags().contains(header::SYN) { + if parsed_frame.header().flags().contains(header::SYN) { // new stream if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) { log::error!("{}: invalid stream id {}", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } - if self.streams.contains_key(&stream_id) { + if self.streams.contains_key(&stream_id.val()) { log::error!("{}/{}: stream already exists", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } @@ -830,7 +851,7 @@ impl Active { return Action::Terminate(Frame::protocol_error()); } - let credit = frame.header().credit() + DEFAULT_CREDIT; + let credit = parsed_frame.header().credit() + DEFAULT_CREDIT; let mut stream = self.make_new_inbound_stream(stream_id, credit); stream.set_flag(stream::Flag::Ack); @@ -839,13 +860,13 @@ impl Active { .shared() .update_state(self.id, stream_id, State::RecvClosed); } - self.streams.insert(stream_id, stream.clone_shared()); + self.streams.insert(stream_id.val(), stream.clone_shared()); return Action::New(stream, None); } - if let Some(s) = self.streams.get_mut(&stream_id) { + if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); - shared.credit += frame.header().credit(); + shared.credit += parsed_frame.header().credit(); if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } @@ -872,15 +893,16 @@ impl Active { } fn on_ping(&mut self, frame: &Frame) -> Action { - let stream_id = frame.header().stream_id(); - if frame.header().flags().contains(header::ACK) { + let parsed_frame = frame.parse().expect("valid frame"); + let stream_id = parsed_frame.header().stream_id(); + if parsed_frame.header().flags().contains(header::ACK) { // pong return Action::None; } - if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { - let mut hdr = Header::ping(frame.header().nonce()); + if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id.val()) { + let mut hdr = Header::ping(parsed_frame.header().nonce()); hdr.ack(); - return Action::Ping(Frame::new(hdr)); + return Action::Ping(Frame::from_header(hdr)); } log::trace!( "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", @@ -947,8 +969,8 @@ impl Active { // - Its ID is odd and we are the client. // - Its ID is even and we are the server. .filter(|(id, _)| match self.mode { - Mode::Client => id.is_client(), - Mode::Server => id.is_server(), + Mode::Client => StreamId::new(**id).is_client(), + Mode::Server => StreamId::new(**id).is_server(), }) .filter(|(_, s)| s.lock().is_pending_ack()) .count() @@ -971,7 +993,7 @@ impl Active { fn drop_all_streams(&mut self) { for (id, s) in self.streams.drain() { let mut shared = s.lock(); - shared.update_state(self.id, id, State::Closed); + shared.update_state(self.id, StreamId::new(id), State::Closed); if let Some(w) = shared.reader.take() { w.wake() } diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index b4dfa816..87c6d0ca 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -33,7 +33,7 @@ impl Future for Cleanup { type Output = ConnectionError; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.get_mut(); + let this = self.get_mut(); loop { match this.state { diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index d503941f..b02c465e 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -45,7 +45,7 @@ where type Output = Result<()>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.get_mut(); + let this = self.get_mut(); loop { match this.state { diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index ae745577..0d948c09 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -8,7 +8,7 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use crate::frame::header::ACK; +use crate::frame::header::{ACK, HEADER_SIZE}; use crate::{ chunks::Chunks, connection::{self, StreamCommand}, @@ -218,7 +218,8 @@ impl Stream { drop(shared); let mut frame = Frame::window_update(self.id, credit).right(); - self.add_flag(frame.header_mut()); + let mut parsed_frame = frame.parse_mut().expect("valid frame"); + self.add_flag(parsed_frame.header_mut()); let cmd = StreamCommand::SendFrame(frame); self.sender .start_send(cmd) @@ -257,7 +258,8 @@ impl futures::stream::Stream for Stream { let mut shared = self.shared(); if let Some(bytes) = shared.buffer.pop() { - let off = bytes.offset(); + // Every chunk starts with a frame header, so we add HEADER_SIZE to offset. + let off = bytes.offset() + HEADER_SIZE; let mut vec = bytes.into_vec(); if off != 0 { // This should generally not happen when the stream is used only as @@ -269,7 +271,7 @@ impl futures::stream::Stream for Stream { self.conn, self.id ); - vec = vec.split_off(off) + vec = vec.split_off(off); } return Poll::Ready(Some(Ok(Packet(vec)))); } @@ -367,18 +369,19 @@ impl AsyncWrite for Stream { let k = std::cmp::min(shared.credit as usize, buf.len()); let k = std::cmp::min(k, self.config.split_send_size); shared.credit = shared.credit.saturating_sub(k as u32); - Vec::from(&buf[..k]) + &buf[..k] }; let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); - self.add_flag(frame.header_mut()); + let mut parsed_frame = frame.parse_mut().expect("valid frame"); + self.add_flag(parsed_frame.header_mut()); log::trace!("{}/{}: write {} bytes", self.conn, self.id, n); // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending // We are tracking this information: // a) to be consistent with outbound streams // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. - if frame.header().flags().contains(ACK) { + if parsed_frame.header().flags().contains(ACK) { self.shared() .update_state(self.conn, self.id, State::Open { acknowledged: true }); } diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 692840a4..72d1f170 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -13,47 +13,99 @@ mod io; use futures::future::Either; use header::{Data, GoAway, Header, Ping, StreamId, WindowUpdate}; -use std::{convert::TryInto, num::TryFromIntError}; +use std::{convert::TryInto, fmt::Debug, marker::PhantomData, num::TryFromIntError}; +use zerocopy::{AsBytes, ByteSlice, ByteSliceMut, Ref}; pub use io::FrameDecodeError; pub(crate) use io::Io; -/// A Yamux message frame consisting of header and body. -#[derive(Clone, Debug, PartialEq, Eq)] +use self::header::HEADER_SIZE; + +/// A Yamux message frame consisting of header and body in a single buffer +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Frame { - header: Header, - body: Vec, + buffer: Vec, + _marker: std::marker::PhantomData, +} + +impl Default for Frame { + fn default() -> Self { + Self { + buffer: Vec::new(), + _marker: PhantomData, + } + } } impl Frame { - pub fn new(header: Header) -> Self { - Frame { - header, - body: Vec::new(), + pub(crate) fn new(buffer: Vec) -> Self { + Self { + buffer, + _marker: std::marker::PhantomData, } } - pub fn header(&self) -> &Header { - &self.header + pub(crate) fn from_header(header: Header) -> Self { + let mut buffer = vec![0; HEADER_SIZE]; + header.write_to(&mut buffer).expect("write_to success"); + Self::new(buffer) + } + + fn make_parsed_frame( + header: Ref>, + body: B, + ) -> Result, io::FrameDecodeError> { + let frame = ParsedFrame { header, body }; + let version = frame.header.version().val(); + if version != 0 { + Err(FrameDecodeError::Header(crate::HeaderDecodeError::Version( + version, + ))) + } else { + frame.header.tag().map(|_| frame).map_err(|e| e.into()) + } } - pub fn header_mut(&mut self) -> &mut Header { - &mut self.header + pub(crate) fn parse(&self) -> Result, io::FrameDecodeError> { + let (header, body) = Ref::new_from_prefix(&self.buffer[..]).expect("construct a valid Ref"); + Self::make_parsed_frame(header, body) + } + + pub(crate) fn parse_mut(&mut self) -> Result, io::FrameDecodeError> { + let (header, body) = + Ref::new_from_prefix(&mut self.buffer[..]).expect("construct a valid Ref"); + Self::make_parsed_frame(header, body) + } + + pub fn into_buffer(self) -> Vec { + self.buffer + } + + pub fn buffer(&self) -> &[u8] { + &self.buffer + } + + pub fn buffer_mut(&mut self) -> &mut [u8] { + &mut self.buffer + } + + pub fn append_bytes(&mut self, bytes: &mut Vec) { + self.buffer.append(bytes); } /// Introduce this frame to the right of a binary frame type. pub(crate) fn right(self) -> Frame> { Frame { - header: self.header.right(), - body: self.body, + buffer: self.buffer, + _marker: PhantomData, } } /// Introduce this frame to the left of a binary frame type. pub(crate) fn left(self) -> Frame> { Frame { - header: self.header.left(), - body: self.body, + buffer: self.buffer, + _marker: PhantomData, } } } @@ -61,8 +113,8 @@ impl Frame { impl From> for Frame<()> { fn from(f: Frame) -> Frame<()> { Frame { - header: f.header.into(), - body: f.body, + buffer: f.buffer, + _marker: PhantomData, } } } @@ -70,32 +122,35 @@ impl From> for Frame<()> { impl Frame<()> { pub(crate) fn into_data(self) -> Frame { Frame { - header: self.header.into_data(), - body: self.body, + buffer: self.buffer, + _marker: PhantomData, } } pub(crate) fn into_window_update(self) -> Frame { Frame { - header: self.header.into_window_update(), - body: self.body, + buffer: self.buffer, + _marker: PhantomData, } } pub(crate) fn into_ping(self) -> Frame { Frame { - header: self.header.into_ping(), - body: self.body, + buffer: self.buffer, + _marker: PhantomData, } } } impl Frame { - pub fn data(id: StreamId, b: Vec) -> Result { - Ok(Frame { - header: Header::data(id, b.len().try_into()?), - body: b, - }) + pub fn data(id: StreamId, body: &[u8]) -> Result { + let header = Header::data(id, body.len().try_into()?); + let mut buffer = vec![0; HEADER_SIZE + body.len()]; + header + .write_to(&mut buffer[..HEADER_SIZE]) + .expect("write_to success"); + buffer[HEADER_SIZE..].copy_from_slice(body); + Ok(Frame::new(buffer)) } pub fn close_stream(id: StreamId, ack: bool) -> Self { @@ -105,52 +160,68 @@ impl Frame { header.ack() } - Frame::new(header) + Frame::from_header(header) } +} - pub fn body(&self) -> &[u8] { - &self.body +impl Frame { + pub fn window_update(id: StreamId, credit: u32) -> Frame { + Frame::from_header(Header::window_update(id, credit)) } +} - pub fn body_len(&self) -> u32 { - // Safe cast since we construct `Frame::`s only with - // `Vec` of length [0, u32::MAX] in `Frame::data` above. - self.body().len() as u32 +impl Frame { + pub fn term() -> Frame { + Frame::::from_header(Header::term()) + } + + pub fn protocol_error() -> Frame { + Frame::::from_header(Header::protocol_error()) } - pub fn into_body(self) -> Vec { - self.body + pub fn internal_error() -> Frame { + Frame::::from_header(Header::internal_error()) } } -impl Frame { - pub fn window_update(id: StreamId, credit: u32) -> Self { - Frame { - header: Header::window_update(id, credit), - body: Vec::new(), - } +/// A zero-copied-parsed view of a Frame +pub struct ParsedFrame { + header: Ref>, + body: B, +} + +impl Debug for ParsedFrame { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Frame") + .field("header", &self.header) + .field("body", &"..") + .finish() } } -impl Frame { - pub fn term() -> Self { - Frame { - header: Header::term(), - body: Vec::new(), - } +impl ParsedFrame { + pub fn header(&self) -> &Header { + &self.header } - pub fn protocol_error() -> Self { - Frame { - header: Header::protocol_error(), - body: Vec::new(), - } + pub fn body(&self) -> &B { + &self.body } - pub fn internal_error() -> Self { - Frame { - header: Header::internal_error(), - body: Vec::new(), - } + pub fn body_len(&self) -> u32 { + // Safe cast since we construct `Frame::`s only with + // `Vec` of length [0, u32::MAX] in `Frame::data` above. + self.body.len() as u32 + } + + #[cfg(test)] + pub fn bytes(&self) -> &[u8] { + self.header.bytes() + } +} + +impl ParsedFrame { + pub fn header_mut(&mut self) -> &mut Header { + &mut self.header } } diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index cbbf704d..1d7166a4 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -9,19 +9,51 @@ // at https://opensource.org/licenses/MIT. use futures::future::Either; +use std::convert::TryFrom; use std::fmt; +use std::fmt::Debug; +use std::ops::BitOrAssign; +use zerocopy::byteorder::network_endian::{U16, U32}; +use zerocopy::{AsBytes, FromBytes, FromZeroes}; /// The message frame header. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, FromZeroes, FromBytes, AsBytes)] +#[repr(packed)] pub struct Header { version: Version, - tag: Tag, + tag: u8, flags: Flags, stream_id: StreamId, length: Len, _marker: std::marker::PhantomData, } +impl Debug for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Header") + .field("version", &self.version) + .field("tag", &self.tag) + .field("flags", &self.flags) + .field("stream_id", &self.stream_id) + .field("length", &self.length) + .field("_marker", &self._marker) + .finish() + } +} + +impl PartialEq for Header { + fn eq(&self, other: &Self) -> bool { + self.version == other.version + && self.tag == other.tag + && self.flags == other.flags + && self.stream_id == other.stream_id + && self.length == other.length + && self._marker == other._marker + } +} + +impl Eq for Header {} + impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -36,8 +68,12 @@ impl fmt::Display for Header { } impl Header { - pub fn tag(&self) -> Tag { - self.tag + pub fn version(&self) -> Version { + self.version + } + + pub fn tag(&self) -> Result { + Tag::try_from(self.tag) } pub fn flags(&self) -> Flags { @@ -54,11 +90,11 @@ impl Header { #[cfg(test)] pub fn set_len(&mut self, len: u32) { - self.length = Len(len) + self.length = Len(len.into()) } /// Arbitrary type cast, use with caution. - fn cast(self) -> Header { + pub fn cast(self) -> Header { Header { version: self.version, tag: self.tag, @@ -68,16 +104,6 @@ impl Header { _marker: std::marker::PhantomData, } } - - /// Introduce this header to the right of a binary header type. - pub(crate) fn right(self) -> Header> { - self.cast() - } - - /// Introduce this header to the left of a binary header type. - pub(crate) fn left(self) -> Header> { - self.cast() - } } impl From> for Header<()> { @@ -86,48 +112,31 @@ impl From> for Header<()> { } } -impl Header<()> { - pub(crate) fn into_data(self) -> Header { - debug_assert_eq!(self.tag, Tag::Data); - self.cast() - } - - pub(crate) fn into_window_update(self) -> Header { - debug_assert_eq!(self.tag, Tag::WindowUpdate); - self.cast() - } - - pub(crate) fn into_ping(self) -> Header { - debug_assert_eq!(self.tag, Tag::Ping); - self.cast() - } -} - impl Header { /// Set the [`SYN`] flag. pub fn syn(&mut self) { - self.flags.0 |= SYN.0 + self.flags |= SYN; } } impl Header { /// Set the [`ACK`] flag. pub fn ack(&mut self) { - self.flags.0 |= ACK.0 + self.flags |= ACK; } } impl Header { /// Set the [`FIN`] flag. pub fn fin(&mut self) { - self.flags.0 |= FIN.0 + self.flags |= FIN; } } impl Header { /// Set the [`RST`] flag. pub fn rst(&mut self) { - self.flags.0 |= RST.0 + self.flags |= RST; } } @@ -136,10 +145,10 @@ impl Header { pub fn data(id: StreamId, len: u32) -> Self { Header { version: Version(0), - tag: Tag::Data, - flags: Flags(0), + tag: Tag::Data as u8, + flags: Flags(U16::ZERO), stream_id: id, - length: Len(len), + length: Len(len.into()), _marker: std::marker::PhantomData, } } @@ -150,17 +159,17 @@ impl Header { pub fn window_update(id: StreamId, credit: u32) -> Self { Header { version: Version(0), - tag: Tag::WindowUpdate, - flags: Flags(0), + tag: Tag::WindowUpdate as u8, + flags: Flags(U16::ZERO), stream_id: id, - length: Len(credit), + length: Len(credit.into()), _marker: std::marker::PhantomData, } } /// The credit this window update grants to the remote. pub fn credit(&self) -> u32 { - self.length.0 + self.length.val() } } @@ -169,17 +178,17 @@ impl Header { pub fn ping(nonce: u32) -> Self { Header { version: Version(0), - tag: Tag::Ping, - flags: Flags(0), - stream_id: StreamId(0), - length: Len(nonce), + tag: Tag::Ping as u8, + flags: Flags(U16::ZERO), + stream_id: StreamId(U32::ZERO), + length: Len(nonce.into()), _marker: std::marker::PhantomData, } } /// The nonce of this ping. pub fn nonce(&self) -> u32 { - self.length.0 + self.length.val() } } @@ -202,10 +211,10 @@ impl Header { fn go_away(code: u32) -> Self { Header { version: Version(0), - tag: Tag::GoAway, - flags: Flags(0), - stream_id: StreamId(0), - length: Len(code), + tag: Tag::GoAway as u8, + flags: Flags(U16::ZERO), + stream_id: StreamId(U32::ZERO), + length: Len(code.into()), _marker: std::marker::PhantomData, } } @@ -264,41 +273,76 @@ pub(super) mod private { /// A tag is the runtime representation of a message type. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum Tag { - Data, - WindowUpdate, - Ping, - GoAway, + Data = 0, + WindowUpdate = 1, + Ping = 2, + GoAway = 3, +} + +impl TryFrom for Tag { + type Error = HeaderDecodeError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Self::Data), + 1 => Ok(Self::WindowUpdate), + 2 => Ok(Self::Ping), + 3 => Ok(Self::GoAway), + _ => Err(HeaderDecodeError::Type(value)), + } + } } /// The protocol version a message corresponds to. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] +#[repr(C)] pub struct Version(u8); +impl Version { + pub fn val(self) -> u8 { + self.0 + } +} + /// The message length. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Len(u32); +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] +#[repr(C)] +pub struct Len(U32); impl Len { pub fn val(self) -> u32 { - self.0 + self.0.get() } } -pub const CONNECTION_ID: StreamId = StreamId(0); +pub const CONNECTION_ID: StreamId = StreamId(U32::ZERO); /// The ID of a stream. /// /// The value 0 denotes no particular stream but the whole session. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] -pub struct StreamId(u32); +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] +#[repr(C)] +pub struct StreamId(U32); + +impl PartialOrd for StreamId { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.get().partial_cmp(&other.0.get()) + } +} + +impl Ord for StreamId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.get().cmp(&other.0.get()) + } +} impl StreamId { pub(crate) fn new(val: u32) -> Self { - StreamId(val) + StreamId(val.into()) } pub fn is_server(self) -> bool { - self.0 % 2 == 0 + self.0.get() % 2 == 0 } pub fn is_client(self) -> bool { @@ -310,7 +354,7 @@ impl StreamId { } pub fn val(self) -> u32 { - self.0 + self.0.get() } } @@ -320,72 +364,47 @@ impl fmt::Display for StreamId { } } -impl nohash_hasher::IsEnabled for StreamId {} - /// Possible flags set on a message. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Flags(u16); +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] +#[repr(C)] +pub struct Flags(U16); impl Flags { pub fn contains(self, other: Flags) -> bool { - self.0 & other.0 == other.0 + let other = other.0.get(); + self.0.get() & other == other } pub fn val(self) -> u16 { - self.0 + self.0.get() + } + + pub fn set(&mut self, val: u16) { + self.0.set(val) + } +} + +impl BitOrAssign for Flags { + fn bitor_assign(&mut self, rhs: Self) { + self.set(self.val() | rhs.val()); } } /// Indicates the start of a new stream. -pub const SYN: Flags = Flags(1); +pub const SYN: Flags = Flags(U16::from_bytes([0, 1])); /// Acknowledges the start of a new stream. -pub const ACK: Flags = Flags(2); +pub const ACK: Flags = Flags(U16::from_bytes([0, 2])); /// Indicates the half-closing of a stream. -pub const FIN: Flags = Flags(4); +pub const FIN: Flags = Flags(U16::from_bytes([0, 4])); /// Indicates an immediate stream reset. -pub const RST: Flags = Flags(8); +pub const RST: Flags = Flags(U16::from_bytes([0, 8])); /// The serialised header size in bytes. pub const HEADER_SIZE: usize = 12; -/// Encode a [`Header`] value. -pub fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { - let mut buf = [0; HEADER_SIZE]; - buf[0] = hdr.version.0; - buf[1] = hdr.tag as u8; - buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes()); - buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes()); - buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes()); - buf -} - -/// Decode a [`Header`] value. -pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { - if buf[0] != 0 { - return Err(HeaderDecodeError::Version(buf[0])); - } - - let hdr = Header { - version: Version(buf[0]), - tag: match buf[1] { - 0 => Tag::Data, - 1 => Tag::WindowUpdate, - 2 => Tag::Ping, - 3 => Tag::GoAway, - t => return Err(HeaderDecodeError::Type(t)), - }, - flags: Flags(u16::from_be_bytes([buf[2], buf[3]])), - stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])), - length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])), - _marker: std::marker::PhantomData, - }; - - Ok(hdr) -} - /// Possible errors while decoding a message frame header. #[non_exhaustive] #[derive(Debug)] @@ -409,9 +428,32 @@ impl std::error::Error for HeaderDecodeError {} #[cfg(test)] mod tests { + use crate::frame::Frame; + use super::*; use quickcheck::{Arbitrary, Gen, QuickCheck}; + impl Arbitrary for Flags { + fn arbitrary(g: &mut Gen) -> Self { + let flags: u16 = Arbitrary::arbitrary(g); + Flags(flags.into()) + } + } + + impl Arbitrary for Len { + fn arbitrary(g: &mut Gen) -> Self { + let len: u32 = Arbitrary::arbitrary(g); + Len(len.into()) + } + } + + impl Arbitrary for StreamId { + fn arbitrary(g: &mut Gen) -> Self { + let stream_id: u32 = Arbitrary::arbitrary(g); + StreamId(stream_id.into()) + } + } + impl Arbitrary for Header<()> { fn arbitrary(g: &mut Gen) -> Self { let tag = *g @@ -420,10 +462,10 @@ mod tests { Header { version: Version(0), - tag, - flags: Flags(Arbitrary::arbitrary(g)), - stream_id: StreamId(Arbitrary::arbitrary(g)), - length: Len(Arbitrary::arbitrary(g)), + tag: tag as u8, + flags: Arbitrary::arbitrary(g), + stream_id: Arbitrary::arbitrary(g), + length: Arbitrary::arbitrary(g), _marker: std::marker::PhantomData, } } @@ -432,8 +474,9 @@ mod tests { #[test] fn encode_decode_identity() { fn property(hdr: Header<()>) -> bool { - match decode(&encode(&hdr)) { - Ok(x) => x == hdr, + let frame = Frame::from_header(hdr); + match frame.parse() { + Ok(pf) => pf.bytes() == frame.buffer(), Err(e) => { eprintln!("decode error: {}", e); false diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 795d9f5c..31cf491b 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -12,7 +12,7 @@ use super::{ header::{self, HeaderDecodeError}, Frame, }; -use crate::connection::Id; +use crate::{connection::Id, frame::header::HEADER_SIZE}; use futures::{prelude::*, ready}; use std::{ fmt, io, @@ -45,32 +45,26 @@ impl Io { /// The stages of writing a new `Frame`. enum WriteState { Init, - Header { - header: [u8; header::HEADER_SIZE], - buffer: Vec, - offset: usize, - }, - Body { - buffer: Vec, - offset: usize, - }, + Data { frame: Frame<()>, offset: usize }, } impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { WriteState::Init => f.write_str("(WriteState::Init)"), - WriteState::Header { offset, .. } => { - write!(f, "(WriteState::Header (offset {}))", offset) - } - WriteState::Body { offset, buffer } => { - write!( + WriteState::Data { frame, offset } => match frame.parse() { + Ok(parsed_frame) => write!( f, - "(WriteState::Body (offset {}) (buffer-len {}))", + "(WriteState::Body (header {:?}) (offset {}))", + parsed_frame.header(), offset, - buffer.len() - ) - } + ), + Err(e) => write!( + f, + "(WriteState::Body (invalid header ({})) (offset {}))", + e, offset + ), + }, } } } @@ -84,11 +78,10 @@ impl Sink> for Io { log::trace!("{}: write: {:?}", this.id, this.write_state); match &mut this.write_state { WriteState::Init => return Poll::Ready(Ok(())), - WriteState::Header { - header, - buffer, + WriteState::Data { + frame, ref mut offset, - } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) { + } => match Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..]) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Ready(Ok(n)) => { @@ -96,28 +89,7 @@ impl Sink> for Io { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } *offset += n; - if *offset == header.len() { - if !buffer.is_empty() { - let buffer = std::mem::take(buffer); - this.write_state = WriteState::Body { buffer, offset: 0 }; - } else { - this.write_state = WriteState::Init; - } - } - } - }, - WriteState::Body { - buffer, - ref mut offset, - } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == buffer.len() { + if *offset == frame.buffer().len() { this.write_state = WriteState::Init; } } @@ -126,14 +98,8 @@ impl Sink> for Io { } } - fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> { - let header = header::encode(&f.header); - let buffer = f.body; - self.get_mut().write_state = WriteState::Header { - header, - buffer, - offset: 0, - }; + fn start_send(self: Pin<&mut Self>, frame: Frame<()>) -> Result<(), Self::Error> { + self.get_mut().write_state = WriteState::Data { frame, offset: 0 }; Ok(()) } @@ -155,68 +121,65 @@ enum ReadState { /// Initial reading state. Init, /// Reading the frame header. - Header { - offset: usize, - buffer: [u8; header::HEADER_SIZE], - }, + Header { frame: Frame<()>, offset: usize }, /// Reading the frame body. - Body { - header: header::Header<()>, - offset: usize, - buffer: Vec, - }, + Body { frame: Frame<()>, offset: usize }, } +const READ_BUFFER_DEFAULT_CAPACITY: usize = 2048; + impl Stream for Io { type Item = Result, FrameDecodeError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = &mut *self; + let this = &mut *self; loop { log::trace!("{}: read: {:?}", this.id, this.read_state); match this.read_state { ReadState::Init => { + let mut buffer = Vec::with_capacity(READ_BUFFER_DEFAULT_CAPACITY); + buffer.append(&mut vec![0_u8; HEADER_SIZE]); this.read_state = ReadState::Header { offset: 0, - buffer: [0; header::HEADER_SIZE], + frame: Frame::<()>::new(buffer), }; } ReadState::Header { ref mut offset, - ref mut buffer, + ref mut frame, } => { - if *offset == header::HEADER_SIZE { - let header = match header::decode(buffer) { - Ok(hd) => hd, - Err(e) => return Poll::Ready(Some(Err(e.into()))), + if *offset == HEADER_SIZE { + let parsed_frame = match frame.parse_mut() { + Ok(frame) => frame, + Err(e) => return Poll::Ready(Some(Err(e))), }; - - log::trace!("{}: read: {}", this.id, header); - - if header.tag() != header::Tag::Data { + log::trace!("{}: read: {:?}", this.id, parsed_frame); + if parsed_frame.header().tag().expect("valid tag") != header::Tag::Data { + let frame = std::mem::take(frame); this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame::new(header)))); + return Poll::Ready(Some(Ok(frame))); } - let body_len = header.len().val() as usize; + let body_len = parsed_frame.header().len().val() as usize; if body_len > this.max_body_len { return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( body_len, )))); } + frame.append_bytes(&mut vec![0; body_len]); this.read_state = ReadState::Body { - header, - offset: 0, - buffer: vec![0; body_len], + frame: std::mem::take(frame), + offset: HEADER_SIZE, }; continue; } - let buf = &mut buffer[*offset..header::HEADER_SIZE]; - match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { + match ready!(Pin::new(&mut this.io) + .poll_read(cx, &mut frame.buffer_mut()[*offset..HEADER_SIZE]))? + { 0 => { if *offset == 0 { return Poll::Ready(None); @@ -228,21 +191,21 @@ impl Stream for Io { } } ReadState::Body { - ref header, + ref mut frame, ref mut offset, - ref mut buffer, } => { - let body_len = header.len().val() as usize; + let parsed_frame = frame.parse().expect("valid frame"); + let body_len = parsed_frame.header().len().val() as usize; - if *offset == body_len { - let h = header.clone(); - let v = std::mem::take(buffer); + if *offset == HEADER_SIZE + body_len { + let frame = std::mem::take(frame); this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame { header: h, body: v }))); + return Poll::Ready(Some(Ok(frame))); } - let buf = &mut buffer[*offset..body_len]; - match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { + match ready!(Pin::new(&mut this.io) + .poll_read(cx, &mut frame.buffer_mut()[*offset..HEADER_SIZE + body_len]))? + { 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); @@ -259,22 +222,32 @@ impl fmt::Debug for ReadState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ReadState::Init => f.write_str("(ReadState::Init)"), - ReadState::Header { offset, .. } => { - write!(f, "(ReadState::Header (offset {}))", offset) - } - ReadState::Body { - header, - offset, - buffer, - } => { - write!( + ReadState::Header { frame, offset } => match frame.parse() { + Ok(parsed_frame) => write!( f, - "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", - header, + "(ReadState::Header (header {:?}) (offset {}))", + parsed_frame.header(), offset, - buffer.len() - ) - } + ), + Err(e) => write!( + f, + "(ReadState::Header (invalid header ({})) (offset {}))", + e, offset + ), + }, + ReadState::Body { frame, offset } => match frame.parse() { + Ok(parsed_frame) => write!( + f, + "(ReadState::Body (header {:?}) (offset {}))", + parsed_frame.header(), + offset, + ), + Err(e) => write!( + f, + "(ReadState::Body (invalid header ({})) (offset {}))", + e, offset + ), + }, } } } @@ -328,19 +301,22 @@ mod tests { use super::*; use quickcheck::{Arbitrary, Gen, QuickCheck}; use rand::RngCore; + use zerocopy::AsBytes; impl Arbitrary for Frame<()> { fn arbitrary(g: &mut Gen) -> Self { let mut header: header::Header<()> = Arbitrary::arbitrary(g); - let body = if header.tag() == header::Tag::Data { + if header.tag().unwrap() == header::Tag::Data { header.set_len(header.len().val() % 4096); - let mut b = vec![0; header.len().val() as usize]; - rand::thread_rng().fill_bytes(&mut b); - b + let mut buffer = vec![0; HEADER_SIZE + header.len().val() as usize]; + rand::thread_rng().fill_bytes(&mut buffer[HEADER_SIZE..]); + header + .write_to(&mut buffer[..HEADER_SIZE]) + .expect("write_to success"); + Frame::new(buffer) } else { - Vec::new() - }; - Frame { header, body } + Frame::from_header(header) + } } } @@ -348,8 +324,13 @@ mod tests { fn encode_decode_identity() { fn property(f: Frame<()>) -> bool { futures::executor::block_on(async move { + let pf = f.parse().expect("valid frame"); let id = crate::connection::Id::random(); - let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); + let mut io = Io::new( + id, + futures::io::Cursor::new(Vec::new()), + pf.header().len().val() as usize, + ); if io.send(f.clone()).await.is_err() { return false; } From 39fa1e56cab6c353cf066846679d713ed5378ab9 Mon Sep 17 00:00:00 2001 From: Philippe Jalaber Date: Thu, 7 Sep 2023 15:07:36 +0200 Subject: [PATCH 02/29] Fix after comments --- yamux/src/connection.rs | 45 ++++++++-------- yamux/src/connection/stream.rs | 23 ++++---- yamux/src/frame.rs | 95 ++++++++++++++++++++-------------- yamux/src/frame/io.rs | 6 +-- 4 files changed, 94 insertions(+), 75 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index d2ee68c9..ff443b0d 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -548,7 +548,7 @@ impl Active { log::trace!( "{}/{}: sending: {}", self.id, - parsed_frame.header().stream_id(), + parsed_frame.stream_id(), parsed_frame.header() ); self.pending_frames.push_back(frame.into()); @@ -620,8 +620,7 @@ impl Active { frame }; if let Some(f) = frame { - let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: sending: {}", self.id, stream_id, pf.header()); + log::trace!("{}/{}: sending: {:?}", self.id, stream_id, f); self.pending_frames.push_back(f.into()); } } @@ -634,10 +633,10 @@ impl Active { /// if one was opened by the remote. fn on_frame(&mut self, frame: Frame<()>) -> Result> { let parsed_frame = frame.parse().expect("valid frame"); - log::trace!("{}: received: {}", self.id, parsed_frame.header()); + log::trace!("{}: received: {:?}", self.id, parsed_frame); - if parsed_frame.header().flags().contains(header::ACK) { - let id = parsed_frame.header().stream_id(); + if parsed_frame.has_flag(header::ACK) { + let id = parsed_frame.stream_id(); if let Some(stream) = self.streams.get(&id.val()) { stream .lock() @@ -648,7 +647,7 @@ impl Active { } } - let action = match parsed_frame.header().tag().expect("valid header's tag") { + let action = match parsed_frame.tag().expect("valid header's tag") { Tag::Data => self.on_data(frame.into_data()), Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()), Tag::Ping => self.on_ping(&frame.into_ping()), @@ -660,7 +659,7 @@ impl Active { log::trace!("{}: new inbound {} of {}", self.id, stream, self); if let Some(f) = update { let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: sending update", self.id, pf.header().stream_id()); + log::trace!("{}/{}: sending update", self.id, pf.stream_id()); self.pending_frames.push_back(f.into()); } return Ok(Some(stream)); @@ -672,12 +671,12 @@ impl Active { } Action::Ping(f) => { let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: pong", self.id, pf.header().stream_id()); + log::trace!("{}/{}: pong", self.id, pf.stream_id()); self.pending_frames.push_back(f.into()); } Action::Reset(f) => { let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: sending reset", self.id, pf.header().stream_id()); + log::trace!("{}/{}: sending reset", self.id, pf.stream_id()); self.pending_frames.push_back(f.into()); } Action::Terminate(f) => { @@ -691,9 +690,9 @@ impl Active { fn on_data(&mut self, frame: Frame) -> Action { let parsed_frame = frame.parse().expect("valid frame"); - let stream_id = parsed_frame.header().stream_id(); + let stream_id = parsed_frame.stream_id(); - if parsed_frame.header().flags().contains(header::RST) { + if parsed_frame.has_flag(header::RST) { // stream reset if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); @@ -708,9 +707,9 @@ impl Active { return Action::None; } - let is_finish = parsed_frame.header().flags().contains(header::FIN); // half-close + let is_finish = parsed_frame.has_flag(header::FIN); // half-close - if parsed_frame.header().flags().contains(header::SYN) { + if parsed_frame.has_flag(header::SYN) { // new stream if !self.is_valid_remote_id(stream_id, Tag::Data) { log::error!("{}: invalid stream id {}", self.id, stream_id); @@ -817,9 +816,9 @@ impl Active { fn on_window_update(&mut self, frame: &Frame) -> Action { let parsed_frame = frame.parse().expect("valid frame"); - let stream_id = parsed_frame.header().stream_id(); + let stream_id = parsed_frame.stream_id(); - if parsed_frame.header().flags().contains(header::RST) { + if parsed_frame.has_flag(header::RST) { // stream reset if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); @@ -834,9 +833,9 @@ impl Active { return Action::None; } - let is_finish = parsed_frame.header().flags().contains(header::FIN); // half-close + let is_finish = parsed_frame.has_flag(header::FIN); // half-close - if parsed_frame.header().flags().contains(header::SYN) { + if parsed_frame.has_flag(header::SYN) { // new stream if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) { log::error!("{}: invalid stream id {}", self.id, stream_id); @@ -851,7 +850,7 @@ impl Active { return Action::Terminate(Frame::protocol_error()); } - let credit = parsed_frame.header().credit() + DEFAULT_CREDIT; + let credit = parsed_frame.credit() + DEFAULT_CREDIT; let mut stream = self.make_new_inbound_stream(stream_id, credit); stream.set_flag(stream::Flag::Ack); @@ -866,7 +865,7 @@ impl Active { if let Some(s) = self.streams.get_mut(&stream_id.val()) { let mut shared = s.lock(); - shared.credit += parsed_frame.header().credit(); + shared.credit += parsed_frame.credit(); if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } @@ -894,13 +893,13 @@ impl Active { fn on_ping(&mut self, frame: &Frame) -> Action { let parsed_frame = frame.parse().expect("valid frame"); - let stream_id = parsed_frame.header().stream_id(); - if parsed_frame.header().flags().contains(header::ACK) { + let stream_id = parsed_frame.stream_id(); + if parsed_frame.has_flag(header::ACK) { // pong return Action::None; } if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id.val()) { - let mut hdr = Header::ping(parsed_frame.header().nonce()); + let mut hdr = Header::ping(parsed_frame.nonce()); hdr.ack(); return Action::Ping(Frame::from_header(hdr)); } diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 0d948c09..07fa41c9 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -8,12 +8,12 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use crate::frame::header::{ACK, HEADER_SIZE}; +use crate::frame::header::{Flags, ACK, HEADER_SIZE}; use crate::{ chunks::Chunks, connection::{self, StreamCommand}, frame::{ - header::{Data, Header, StreamId, WindowUpdate}, + header::{Data, StreamId, WindowUpdate}, Frame, }, Config, WindowUpdateMode, DEFAULT_CREDIT, @@ -183,18 +183,21 @@ impl Stream { } /// Set ACK or SYN flag if necessary. - fn add_flag(&mut self, header: &mut Header>) { + fn add_flag(&mut self, frame: &mut Frame>) -> Flags { + let mut parsed_frame = frame.parse_mut().expect("valid frame"); + let header = parsed_frame.header_mut(); match self.flag { - Flag::None => (), + Flag::None => {} Flag::Syn => { header.syn(); - self.flag = Flag::None + self.flag = Flag::None; } Flag::Ack => { header.ack(); - self.flag = Flag::None + self.flag = Flag::None; } } + header.flags() } /// Send new credit to the sending side via a window update message if @@ -218,8 +221,7 @@ impl Stream { drop(shared); let mut frame = Frame::window_update(self.id, credit).right(); - let mut parsed_frame = frame.parse_mut().expect("valid frame"); - self.add_flag(parsed_frame.header_mut()); + self.add_flag(&mut frame); let cmd = StreamCommand::SendFrame(frame); self.sender .start_send(cmd) @@ -373,15 +375,14 @@ impl AsyncWrite for Stream { }; let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); - let mut parsed_frame = frame.parse_mut().expect("valid frame"); - self.add_flag(parsed_frame.header_mut()); + let flags = self.add_flag(&mut frame); log::trace!("{}/{}: write {} bytes", self.conn, self.id, n); // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending // We are tracking this information: // a) to be consistent with outbound streams // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. - if parsed_frame.header().flags().contains(ACK) { + if flags.contains(ACK) { self.shared() .update_state(self.conn, self.id, State::Open { acknowledged: true }); } diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 72d1f170..bd5165cd 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -19,7 +19,9 @@ use zerocopy::{AsBytes, ByteSlice, ByteSliceMut, Ref}; pub use io::FrameDecodeError; pub(crate) use io::Io; -use self::header::HEADER_SIZE; +use crate::HeaderDecodeError; + +use self::header::{HEADER_SIZE, Flags, Tag}; /// A Yamux message frame consisting of header and body in a single buffer #[derive(Clone, Debug, Eq, PartialEq)] @@ -28,15 +30,6 @@ pub struct Frame { _marker: std::marker::PhantomData, } -impl Default for Frame { - fn default() -> Self { - Self { - buffer: Vec::new(), - _marker: PhantomData, - } - } -} - impl Frame { pub(crate) fn new(buffer: Vec) -> Self { Self { @@ -45,6 +38,22 @@ impl Frame { } } + /// Introduce this frame to the right of a binary frame type. + pub(crate) fn right(self) -> Frame> { + Frame { + buffer: self.buffer, + _marker: PhantomData, + } + } + + /// Introduce this frame to the left of a binary frame type. + pub(crate) fn left(self) -> Frame> { + Frame { + buffer: self.buffer, + _marker: PhantomData, + } + } + pub(crate) fn from_header(header: Header) -> Self { let mut buffer = vec![0; HEADER_SIZE]; header.write_to(&mut buffer).expect("write_to success"); @@ -77,14 +86,6 @@ impl Frame { Self::make_parsed_frame(header, body) } - pub fn into_buffer(self) -> Vec { - self.buffer - } - - pub fn buffer(&self) -> &[u8] { - &self.buffer - } - pub fn buffer_mut(&mut self) -> &mut [u8] { &mut self.buffer } @@ -92,22 +93,6 @@ impl Frame { pub fn append_bytes(&mut self, bytes: &mut Vec) { self.buffer.append(bytes); } - - /// Introduce this frame to the right of a binary frame type. - pub(crate) fn right(self) -> Frame> { - Frame { - buffer: self.buffer, - _marker: PhantomData, - } - } - - /// Introduce this frame to the left of a binary frame type. - pub(crate) fn left(self) -> Frame> { - Frame { - buffer: self.buffer, - _marker: PhantomData, - } - } } impl From> for Frame<()> { @@ -164,6 +149,16 @@ impl Frame { } } +impl Frame { + pub fn buffer(&self) -> &[u8] { + &self.buffer + } + + pub fn into_buffer(self) -> Vec { + self.buffer + } +} + impl Frame { pub fn window_update(id: StreamId, credit: u32) -> Frame { Frame::from_header(Header::window_update(id, credit)) @@ -199,6 +194,24 @@ impl Debug for ParsedFrame { } } +impl ParsedFrame { + pub fn header_mut(&mut self) -> &mut Header { + &mut self.header + } +} + +impl ParsedFrame { + pub fn credit(&self) -> u32 { + self.header().credit() + } +} + +impl ParsedFrame { + pub fn nonce(&self) -> u32 { + self.header().nonce() + } +} + impl ParsedFrame { pub fn header(&self) -> &Header { &self.header @@ -218,10 +231,16 @@ impl ParsedFrame { pub fn bytes(&self) -> &[u8] { self.header.bytes() } -} -impl ParsedFrame { - pub fn header_mut(&mut self) -> &mut Header { - &mut self.header + pub fn has_flag(&self, flag: Flags) -> bool { + self.header().flags().contains(flag) + } + + pub fn stream_id(&self) -> StreamId { + self.header().stream_id() + } + + pub fn tag(&self) -> Result { + self.header().tag() } } diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 31cf491b..5f18c1d8 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -155,7 +155,7 @@ impl Stream for Io { }; log::trace!("{}: read: {:?}", this.id, parsed_frame); if parsed_frame.header().tag().expect("valid tag") != header::Tag::Data { - let frame = std::mem::take(frame); + let frame = std::mem::replace(frame, Frame::new(Vec::new())); this.read_state = ReadState::Init; return Poll::Ready(Some(Ok(frame))); } @@ -170,7 +170,7 @@ impl Stream for Io { frame.append_bytes(&mut vec![0; body_len]); this.read_state = ReadState::Body { - frame: std::mem::take(frame), + frame: std::mem::replace(frame, Frame::new(Vec::new())), offset: HEADER_SIZE, }; @@ -198,7 +198,7 @@ impl Stream for Io { let body_len = parsed_frame.header().len().val() as usize; if *offset == HEADER_SIZE + body_len { - let frame = std::mem::take(frame); + let frame = std::mem::replace(frame, Frame::new(Vec::new())); this.read_state = ReadState::Init; return Poll::Ready(Some(Ok(frame))); } From 5d4e2d20090f1f865eb59f443dccab4ee7518d92 Mon Sep 17 00:00:00 2001 From: Philippe Jalaber Date: Fri, 8 Sep 2023 09:38:22 +0200 Subject: [PATCH 03/29] cargo fmt --- yamux/src/frame.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index bd5165cd..5810583a 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -21,7 +21,7 @@ pub(crate) use io::Io; use crate::HeaderDecodeError; -use self::header::{HEADER_SIZE, Flags, Tag}; +use self::header::{Flags, Tag, HEADER_SIZE}; /// A Yamux message frame consisting of header and body in a single buffer #[derive(Clone, Debug, Eq, PartialEq)] From 62002ae47b28232c20299e9683ffdcc79584d271 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sun, 17 Sep 2023 16:20:13 +1000 Subject: [PATCH 04/29] WIP Minimal diff --- yamux/src/chunks.rs | 15 +- yamux/src/connection.rs | 129 ++++++-------- yamux/src/connection/stream.rs | 26 ++- yamux/src/frame.rs | 185 ++++++--------------- yamux/src/frame/header.rs | 295 +++++++++++++++++---------------- yamux/src/frame/io.rs | 248 ++++++++++++--------------- 6 files changed, 383 insertions(+), 515 deletions(-) diff --git a/yamux/src/chunks.rs b/yamux/src/chunks.rs index 7f656a77..0e66d894 100644 --- a/yamux/src/chunks.rs +++ b/yamux/src/chunks.rs @@ -36,15 +36,12 @@ impl Chunks { } /// Add another chunk of bytes to the end. - pub(crate) fn push(&mut self, x: Vec, offset: usize) { - let x_len = x.len(); - let cursor = io::Cursor::new(x); - let mut chunk = Chunk { cursor }; - chunk.advance(offset); - if !chunk.is_empty() { - assert_eq!(chunk.len(), x_len - offset); - self.len += chunk.len() + offset; - self.seq.push_back(chunk); + pub(crate) fn push(&mut self, x: Vec) { + self.len += x.len(); + if !x.is_empty() { + self.seq.push_back(Chunk { + cursor: io::Cursor::new(x), + }) } } diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index ff443b0d..de2e07b7 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -91,7 +91,6 @@ mod cleanup; mod closing; mod stream; -use crate::frame::header::HEADER_SIZE; use crate::tagged_stream::TaggedStream; use crate::{ error::ConnectionError, @@ -356,7 +355,7 @@ struct Active { socket: Fuse>, next_id: u32, - streams: IntMap>>, + streams: IntMap>>, stream_receivers: SelectAll>>, no_streams_waker: Option, @@ -520,14 +519,8 @@ impl Active { if extra_credit > 0 { let mut frame = Frame::window_update(id, extra_credit); - let mut parsed_frame = frame.parse_mut().expect("valid frame"); - parsed_frame.header_mut().syn(); - log::trace!( - "{}/{}: sending initial {}", - self.id, - id, - parsed_frame.header() - ); + frame.header_mut().syn(); + log::trace!("{}/{}: sending initial {}", self.id, id, frame.header()); self.pending_frames.push_back(frame.into()); } @@ -538,18 +531,17 @@ impl Active { } log::debug!("{}: new outbound {} of {}", self.id, stream, self); - self.streams.insert(id.val(), stream.clone_shared()); + self.streams.insert(id, stream.clone_shared()); Poll::Ready(Ok(stream)) } fn on_send_frame(&mut self, frame: Frame>) { - let parsed_frame = frame.parse().expect("valid frame"); log::trace!( "{}/{}: sending: {}", self.id, - parsed_frame.stream_id(), - parsed_frame.header() + frame.header().stream_id(), + frame.header() ); self.pending_frames.push_back(frame.into()); } @@ -561,10 +553,7 @@ impl Active { } fn on_drop_stream(&mut self, stream_id: StreamId) { - let s = self - .streams - .remove(&stream_id.val()) - .expect("stream not found"); + let s = self.streams.remove(&stream_id).expect("stream not found"); log::trace!("{}: removing dropped stream {}", self.id, stream_id); let frame = { @@ -575,7 +564,7 @@ impl Active { State::Open { .. } => { let mut header = Header::data(stream_id, 0); header.rst(); - Some(Frame::from_header(header)) + Some(Frame::new(header)) } // The stream was dropped without calling `poll_close`. // We have already received a FIN from remote and send one @@ -583,7 +572,7 @@ impl Active { State::RecvClosed => { let mut header = Header::data(stream_id, 0); header.fin(); - Some(Frame::from_header(header)) + Some(Frame::new(header)) } // The stream was properly closed. We already sent our FIN frame. // The remote may be out of credit though and blocked on @@ -596,7 +585,7 @@ impl Active { // which we will never send, so reset the stream now. let mut header = Header::data(stream_id, 0); header.rst(); - Some(Frame::from_header(header)) + Some(Frame::new(header)) } else { // The remote has either still credit or will be given more // (due to an enqueued window update or because the update @@ -620,7 +609,7 @@ impl Active { frame }; if let Some(f) = frame { - log::trace!("{}/{}: sending: {:?}", self.id, stream_id, f); + log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header()); self.pending_frames.push_back(f.into()); } } @@ -632,12 +621,11 @@ impl Active { /// Otherwise we process the frame and potentially return a new `Stream` /// if one was opened by the remote. fn on_frame(&mut self, frame: Frame<()>) -> Result> { - let parsed_frame = frame.parse().expect("valid frame"); - log::trace!("{}: received: {:?}", self.id, parsed_frame); + log::trace!("{}: received: {}", self.id, frame.header()); - if parsed_frame.has_flag(header::ACK) { - let id = parsed_frame.stream_id(); - if let Some(stream) = self.streams.get(&id.val()) { + if frame.header().flags().contains(header::ACK) { + let id = frame.header().stream_id(); + if let Some(stream) = self.streams.get(&id) { stream .lock() .update_state(self.id, id, State::Open { acknowledged: true }); @@ -647,7 +635,7 @@ impl Active { } } - let action = match parsed_frame.tag().expect("valid header's tag") { + let action = match frame.header().tag() { Tag::Data => self.on_data(frame.into_data()), Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()), Tag::Ping => self.on_ping(&frame.into_ping()), @@ -658,25 +646,21 @@ impl Active { Action::New(stream, update) => { log::trace!("{}: new inbound {} of {}", self.id, stream, self); if let Some(f) = update { - let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: sending update", self.id, pf.stream_id()); + log::trace!("{}/{}: sending update", self.id, f.header().stream_id()); self.pending_frames.push_back(f.into()); } return Ok(Some(stream)); } Action::Update(f) => { - let pf = f.parse().expect("valid frame"); - log::trace!("{}: sending update: {:?}", self.id, pf.header()); + log::trace!("{}: sending update: {:?}", self.id, f.header()); self.pending_frames.push_back(f.into()); } Action::Ping(f) => { - let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: pong", self.id, pf.stream_id()); + log::trace!("{}/{}: pong", self.id, f.header().stream_id()); self.pending_frames.push_back(f.into()); } Action::Reset(f) => { - let pf = f.parse().expect("valid frame"); - log::trace!("{}/{}: sending reset", self.id, pf.stream_id()); + log::trace!("{}/{}: sending reset", self.id, f.header().stream_id()); self.pending_frames.push_back(f.into()); } Action::Terminate(f) => { @@ -689,12 +673,11 @@ impl Active { } fn on_data(&mut self, frame: Frame) -> Action { - let parsed_frame = frame.parse().expect("valid frame"); - let stream_id = parsed_frame.stream_id(); + let stream_id = frame.header().stream_id(); - if parsed_frame.has_flag(header::RST) { + if frame.header().flags().contains(header::RST) { // stream reset - if let Some(s) = self.streams.get_mut(&stream_id.val()) { + if let Some(s) = self.streams.get_mut(&stream_id) { let mut shared = s.lock(); shared.update_state(self.id, stream_id, State::Closed); if let Some(w) = shared.reader.take() { @@ -707,15 +690,15 @@ impl Active { return Action::None; } - let is_finish = parsed_frame.has_flag(header::FIN); // half-close + let is_finish = frame.header().flags().contains(header::FIN); // half-close - if parsed_frame.has_flag(header::SYN) { + if frame.header().flags().contains(header::SYN) { // new stream if !self.is_valid_remote_id(stream_id, Tag::Data) { log::error!("{}: invalid stream id {}", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } - if parsed_frame.body().len() > DEFAULT_CREDIT as usize { + if frame.body().len() > DEFAULT_CREDIT as usize { log::error!( "{}/{}: 1st body of stream exceeds default credit", self.id, @@ -723,7 +706,7 @@ impl Active { ); return Action::Terminate(Frame::protocol_error()); } - if self.streams.contains_key(&stream_id.val()) { + if self.streams.contains_key(&stream_id) { log::error!("{}/{}: stream already exists", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } @@ -738,16 +721,14 @@ impl Active { if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } - shared.window = shared.window.saturating_sub(parsed_frame.body_len()); - shared.buffer.push(frame.into_buffer(), HEADER_SIZE); + shared.window = shared.window.saturating_sub(frame.body_len()); + shared.buffer.push(frame.into_body()); if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { if let Some(credit) = shared.next_window_update() { shared.window += credit; - let mut frame = Frame::window_update(stream_id, credit); - let mut parsed_frame = frame.parse_mut().expect("valid frame"); - parsed_frame.header_mut().ack(); + frame.header_mut().ack(); window_update = Some(frame) } } @@ -755,13 +736,13 @@ impl Active { if window_update.is_none() { stream.set_flag(stream::Flag::Ack) } - self.streams.insert(stream_id.val(), stream.clone_shared()); + self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream, window_update); } - if let Some(s) = self.streams.get_mut(&stream_id.val()) { + if let Some(s) = self.streams.get_mut(&stream_id) { let mut shared = s.lock(); - if parsed_frame.body().len() > shared.window as usize { + if frame.body().len() > shared.window as usize { log::error!( "{}/{}: frame body larger than window of stream", self.id, @@ -781,10 +762,10 @@ impl Active { ); let mut header = Header::data(stream_id, 0); header.rst(); - return Action::Reset(Frame::from_header(header)); + return Action::Reset(Frame::new(header)); } - shared.window = shared.window.saturating_sub(parsed_frame.body_len()); - shared.buffer.push(frame.into_buffer(), HEADER_SIZE); + shared.window = shared.window.saturating_sub(frame.body_len()); + shared.buffer.push(frame.into_body()); if let Some(w) = shared.reader.take() { w.wake() } @@ -815,12 +796,11 @@ impl Active { } fn on_window_update(&mut self, frame: &Frame) -> Action { - let parsed_frame = frame.parse().expect("valid frame"); - let stream_id = parsed_frame.stream_id(); + let stream_id = frame.header().stream_id(); - if parsed_frame.has_flag(header::RST) { + if frame.header().flags().contains(header::RST) { // stream reset - if let Some(s) = self.streams.get_mut(&stream_id.val()) { + if let Some(s) = self.streams.get_mut(&stream_id) { let mut shared = s.lock(); shared.update_state(self.id, stream_id, State::Closed); if let Some(w) = shared.reader.take() { @@ -833,15 +813,15 @@ impl Active { return Action::None; } - let is_finish = parsed_frame.has_flag(header::FIN); // half-close + let is_finish = frame.header().flags().contains(header::FIN); // half-close - if parsed_frame.has_flag(header::SYN) { + if frame.header().flags().contains(header::SYN) { // new stream if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) { log::error!("{}: invalid stream id {}", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } - if self.streams.contains_key(&stream_id.val()) { + if self.streams.contains_key(&stream_id) { log::error!("{}/{}: stream already exists", self.id, stream_id); return Action::Terminate(Frame::protocol_error()); } @@ -850,7 +830,7 @@ impl Active { return Action::Terminate(Frame::protocol_error()); } - let credit = parsed_frame.credit() + DEFAULT_CREDIT; + let credit = frame.header().credit() + DEFAULT_CREDIT; let mut stream = self.make_new_inbound_stream(stream_id, credit); stream.set_flag(stream::Flag::Ack); @@ -859,13 +839,13 @@ impl Active { .shared() .update_state(self.id, stream_id, State::RecvClosed); } - self.streams.insert(stream_id.val(), stream.clone_shared()); + self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream, None); } - if let Some(s) = self.streams.get_mut(&stream_id.val()) { + if let Some(s) = self.streams.get_mut(&stream_id) { let mut shared = s.lock(); - shared.credit += parsed_frame.credit(); + shared.credit += frame.header().credit(); if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } @@ -892,16 +872,15 @@ impl Active { } fn on_ping(&mut self, frame: &Frame) -> Action { - let parsed_frame = frame.parse().expect("valid frame"); - let stream_id = parsed_frame.stream_id(); - if parsed_frame.has_flag(header::ACK) { + let stream_id = frame.header().stream_id(); + if frame.header().flags().contains(header::ACK) { // pong return Action::None; } - if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id.val()) { - let mut hdr = Header::ping(parsed_frame.nonce()); + if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { + let mut hdr = Header::ping(frame.header().nonce()); hdr.ack(); - return Action::Ping(Frame::from_header(hdr)); + return Action::Ping(Frame::new(hdr)); } log::trace!( "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", @@ -968,8 +947,8 @@ impl Active { // - Its ID is odd and we are the client. // - Its ID is even and we are the server. .filter(|(id, _)| match self.mode { - Mode::Client => StreamId::new(**id).is_client(), - Mode::Server => StreamId::new(**id).is_server(), + Mode::Client => id.is_client(), + Mode::Server => id.is_server(), }) .filter(|(_, s)| s.lock().is_pending_ack()) .count() @@ -992,7 +971,7 @@ impl Active { fn drop_all_streams(&mut self) { for (id, s) in self.streams.drain() { let mut shared = s.lock(); - shared.update_state(self.id, StreamId::new(id), State::Closed); + shared.update_state(self.id, id, State::Closed); if let Some(w) = shared.reader.take() { w.wake() } diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 07fa41c9..9f17ca88 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -8,12 +8,12 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use crate::frame::header::{Flags, ACK, HEADER_SIZE}; +use crate::frame::header::ACK; use crate::{ chunks::Chunks, connection::{self, StreamCommand}, frame::{ - header::{Data, StreamId, WindowUpdate}, + header::{Data, Header, StreamId, WindowUpdate}, Frame, }, Config, WindowUpdateMode, DEFAULT_CREDIT, @@ -183,21 +183,18 @@ impl Stream { } /// Set ACK or SYN flag if necessary. - fn add_flag(&mut self, frame: &mut Frame>) -> Flags { - let mut parsed_frame = frame.parse_mut().expect("valid frame"); - let header = parsed_frame.header_mut(); + fn add_flag(&mut self, header: &mut Header>) { match self.flag { - Flag::None => {} + Flag::None => (), Flag::Syn => { header.syn(); - self.flag = Flag::None; + self.flag = Flag::None } Flag::Ack => { header.ack(); - self.flag = Flag::None; + self.flag = Flag::None } } - header.flags() } /// Send new credit to the sending side via a window update message if @@ -221,7 +218,7 @@ impl Stream { drop(shared); let mut frame = Frame::window_update(self.id, credit).right(); - self.add_flag(&mut frame); + self.add_flag(frame.header_mut()); let cmd = StreamCommand::SendFrame(frame); self.sender .start_send(cmd) @@ -260,8 +257,7 @@ impl futures::stream::Stream for Stream { let mut shared = self.shared(); if let Some(bytes) = shared.buffer.pop() { - // Every chunk starts with a frame header, so we add HEADER_SIZE to offset. - let off = bytes.offset() + HEADER_SIZE; + let off = bytes.offset(); let mut vec = bytes.into_vec(); if off != 0 { // This should generally not happen when the stream is used only as @@ -273,7 +269,7 @@ impl futures::stream::Stream for Stream { self.conn, self.id ); - vec = vec.split_off(off); + vec = vec.split_off(off) } return Poll::Ready(Some(Ok(Packet(vec)))); } @@ -375,14 +371,14 @@ impl AsyncWrite for Stream { }; let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); - let flags = self.add_flag(&mut frame); + self.add_flag(frame.header_mut()); log::trace!("{}/{}: write {} bytes", self.conn, self.id, n); // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending // We are tracking this information: // a) to be consistent with outbound streams // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. - if flags.contains(ACK) { + if frame.header().flags().contains(ACK) { self.shared() .update_state(self.conn, self.id, State::Open { acknowledged: true }); } diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 5810583a..3a298744 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -27,17 +27,59 @@ use self::header::{Flags, Tag, HEADER_SIZE}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct Frame { buffer: Vec, - _marker: std::marker::PhantomData, + _marker: PhantomData, } impl Frame { - pub(crate) fn new(buffer: Vec) -> Self { + pub(crate) fn new(header: Header) -> Self { + let total_buffer_size = HEADER_SIZE + header.len().val() as usize; + + let mut buffer = Vec::with_capacity(total_buffer_size); + header + .write_to_prefix(&mut buffer) + .expect("buffer always fits the header"); + Self { buffer, - _marker: std::marker::PhantomData, + _marker: PhantomData, } } + pub(crate) fn header(&self) -> &Header { + Ref::<_, Header>::new_from_prefix(self.buffer.as_slice()) + .expect("buffer always holds a valid header") + .0 + .into_ref() + } + + pub(crate) fn header_mut(&mut self) -> &mut Header { + Ref::<_, Header>::new_from_prefix(self.buffer.as_mut_slice()) + .expect("buffer always holds a valid header") + .0 + .into_mut() + } + + pub(crate) fn buffer(&self) -> &[u8] { + self.buffer.as_slice() + } + + pub(crate) fn body(&self) -> &[u8] { + &self.buffer[HEADER_SIZE..] + } + + pub(crate) fn body_len(&self) -> u32 { + self.body().len() as u32 + } + + pub(crate) fn into_body(mut self) -> Vec { + // FIXME: Should we implement this more efficiently with `BytesMut`? I think that one would allow us to split of the body without allocating again .. + self.buffer.split_off(HEADER_SIZE) + } + + pub(crate) fn body_mut(&mut self) -> &mut [u8] { + &mut self.buffer[HEADER_SIZE..] + } + /// Introduce this frame to the right of a binary frame type. pub(crate) fn right(self) -> Frame> { Frame { @@ -53,46 +95,6 @@ impl Frame { _marker: PhantomData, } } - - pub(crate) fn from_header(header: Header) -> Self { - let mut buffer = vec![0; HEADER_SIZE]; - header.write_to(&mut buffer).expect("write_to success"); - Self::new(buffer) - } - - fn make_parsed_frame( - header: Ref>, - body: B, - ) -> Result, io::FrameDecodeError> { - let frame = ParsedFrame { header, body }; - let version = frame.header.version().val(); - if version != 0 { - Err(FrameDecodeError::Header(crate::HeaderDecodeError::Version( - version, - ))) - } else { - frame.header.tag().map(|_| frame).map_err(|e| e.into()) - } - } - - pub(crate) fn parse(&self) -> Result, io::FrameDecodeError> { - let (header, body) = Ref::new_from_prefix(&self.buffer[..]).expect("construct a valid Ref"); - Self::make_parsed_frame(header, body) - } - - pub(crate) fn parse_mut(&mut self) -> Result, io::FrameDecodeError> { - let (header, body) = - Ref::new_from_prefix(&mut self.buffer[..]).expect("construct a valid Ref"); - Self::make_parsed_frame(header, body) - } - - pub fn buffer_mut(&mut self) -> &mut [u8] { - &mut self.buffer - } - - pub fn append_bytes(&mut self, bytes: &mut Vec) { - self.buffer.append(bytes); - } } impl From> for Frame<()> { @@ -130,12 +132,11 @@ impl Frame<()> { impl Frame { pub fn data(id: StreamId, body: &[u8]) -> Result { let header = Header::data(id, body.len().try_into()?); - let mut buffer = vec![0; HEADER_SIZE + body.len()]; - header - .write_to(&mut buffer[..HEADER_SIZE]) - .expect("write_to success"); - buffer[HEADER_SIZE..].copy_from_slice(body); - Ok(Frame::new(buffer)) + + let mut frame = Frame::new(header); + frame.body_mut().copy_from_slice(body); + + Ok(frame) } pub fn close_stream(id: StreamId, ack: bool) -> Self { @@ -145,102 +146,26 @@ impl Frame { header.ack() } - Frame::from_header(header) - } -} - -impl Frame { - pub fn buffer(&self) -> &[u8] { - &self.buffer - } - - pub fn into_buffer(self) -> Vec { - self.buffer + Frame::new(header) } } impl Frame { pub fn window_update(id: StreamId, credit: u32) -> Frame { - Frame::from_header(Header::window_update(id, credit)) + Frame::new(Header::window_update(id, credit)) } } impl Frame { pub fn term() -> Frame { - Frame::::from_header(Header::term()) + Frame::::new(Header::term()) } pub fn protocol_error() -> Frame { - Frame::::from_header(Header::protocol_error()) + Frame::::new(Header::protocol_error()) } pub fn internal_error() -> Frame { - Frame::::from_header(Header::internal_error()) - } -} - -/// A zero-copied-parsed view of a Frame -pub struct ParsedFrame { - header: Ref>, - body: B, -} - -impl Debug for ParsedFrame { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Frame") - .field("header", &self.header) - .field("body", &"..") - .finish() - } -} - -impl ParsedFrame { - pub fn header_mut(&mut self) -> &mut Header { - &mut self.header - } -} - -impl ParsedFrame { - pub fn credit(&self) -> u32 { - self.header().credit() - } -} - -impl ParsedFrame { - pub fn nonce(&self) -> u32 { - self.header().nonce() - } -} - -impl ParsedFrame { - pub fn header(&self) -> &Header { - &self.header - } - - pub fn body(&self) -> &B { - &self.body - } - - pub fn body_len(&self) -> u32 { - // Safe cast since we construct `Frame::`s only with - // `Vec` of length [0, u32::MAX] in `Frame::data` above. - self.body.len() as u32 - } - - #[cfg(test)] - pub fn bytes(&self) -> &[u8] { - self.header.bytes() - } - - pub fn has_flag(&self, flag: Flags) -> bool { - self.header().flags().contains(flag) - } - - pub fn stream_id(&self) -> StreamId { - self.header().stream_id() - } - - pub fn tag(&self) -> Result { - self.header().tag() + Frame::::new(Header::internal_error()) } } diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index 1d7166a4..1d280b1b 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -9,15 +9,13 @@ // at https://opensource.org/licenses/MIT. use futures::future::Either; -use std::convert::TryFrom; +use std::cmp::Ordering; use std::fmt; -use std::fmt::Debug; -use std::ops::BitOrAssign; -use zerocopy::byteorder::network_endian::{U16, U32}; +use zerocopy::big_endian::{U16, U32}; use zerocopy::{AsBytes, FromBytes, FromZeroes}; /// The message frame header. -#[derive(Clone, FromZeroes, FromBytes, AsBytes)] +#[derive(Clone, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] #[repr(packed)] pub struct Header { version: Version, @@ -28,7 +26,7 @@ pub struct Header { _marker: std::marker::PhantomData, } -impl Debug for Header { +impl fmt::Debug for Header { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Header") .field("version", &self.version) @@ -41,19 +39,13 @@ impl Debug for Header { } } -impl PartialEq for Header { - fn eq(&self, other: &Self) -> bool { - self.version == other.version - && self.tag == other.tag - && self.flags == other.flags - && self.stream_id == other.stream_id - && self.length == other.length - && self._marker == other._marker +impl Header { + #[must_use] + pub(crate) fn has_valid_tag(&self) -> bool { + (0..4).contains(&self.tag) } } -impl Eq for Header {} - impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -68,12 +60,14 @@ impl fmt::Display for Header { } impl Header { - pub fn version(&self) -> Version { - self.version - } - - pub fn tag(&self) -> Result { - Tag::try_from(self.tag) + pub fn tag(&self) -> Tag { + match self.tag { + 0 => Tag::Data, + 1 => Tag::WindowUpdate, + 2 => Tag::Ping, + 3 => Tag::GoAway, + _ => unreachable!("header always has valid tag"), // TODO: Fix this once `zerocopy` has `TryFromBytes` + } } pub fn flags(&self) -> Flags { @@ -90,11 +84,11 @@ impl Header { #[cfg(test)] pub fn set_len(&mut self, len: u32) { - self.length = Len(len.into()) + self.length = Len(len) } /// Arbitrary type cast, use with caution. - pub fn cast(self) -> Header { + fn cast(self) -> Header { Header { version: self.version, tag: self.tag, @@ -104,6 +98,16 @@ impl Header { _marker: std::marker::PhantomData, } } + + /// Introduce this header to the right of a binary header type. + pub(crate) fn right(self) -> Header> { + self.cast() + } + + /// Introduce this header to the left of a binary header type. + pub(crate) fn left(self) -> Header> { + self.cast() + } } impl From> for Header<()> { @@ -112,31 +116,48 @@ impl From> for Header<()> { } } +impl Header<()> { + pub(crate) fn into_data(self) -> Header { + // FIXME debug_assert_eq!(self.tag, Tag::Data); + self.cast() + } + + pub(crate) fn into_window_update(self) -> Header { + // FIXME debug_assert_eq!(self.tag, Tag::WindowUpdate); + self.cast() + } + + pub(crate) fn into_ping(self) -> Header { + // FIXME debug_assert_eq!(self.tag, Tag::Ping); + self.cast() + } +} + impl Header { /// Set the [`SYN`] flag. pub fn syn(&mut self) { - self.flags |= SYN; + self.flags.0.set(self.flags.val() | SYN.0.get()) } } impl Header { /// Set the [`ACK`] flag. pub fn ack(&mut self) { - self.flags |= ACK; + // self.flags.0 |= ACK.0 } } impl Header { /// Set the [`FIN`] flag. pub fn fin(&mut self) { - self.flags |= FIN; + // self.flags.0 |= FIN.0 } } impl Header { /// Set the [`RST`] flag. pub fn rst(&mut self) { - self.flags |= RST; + // self.flags.0 |= RST.0 } } @@ -146,9 +167,9 @@ impl Header { Header { version: Version(0), tag: Tag::Data as u8, - flags: Flags(U16::ZERO), + flags: Flags(U16::new(0)), stream_id: id, - length: Len(len.into()), + length: Len(len), _marker: std::marker::PhantomData, } } @@ -160,16 +181,16 @@ impl Header { Header { version: Version(0), tag: Tag::WindowUpdate as u8, - flags: Flags(U16::ZERO), + flags: Flags(U16::new(0)), stream_id: id, - length: Len(credit.into()), + length: Len(credit), _marker: std::marker::PhantomData, } } /// The credit this window update grants to the remote. pub fn credit(&self) -> u32 { - self.length.val() + self.length.0 } } @@ -179,16 +200,16 @@ impl Header { Header { version: Version(0), tag: Tag::Ping as u8, - flags: Flags(U16::ZERO), - stream_id: StreamId(U32::ZERO), - length: Len(nonce.into()), + flags: Flags(U16::new(0)), + stream_id: CONNECTION_ID, + length: Len(nonce), _marker: std::marker::PhantomData, } } /// The nonce of this ping. pub fn nonce(&self) -> u32 { - self.length.val() + self.length.0 } } @@ -212,9 +233,9 @@ impl Header { Header { version: Version(0), tag: Tag::GoAway as u8, - flags: Flags(U16::ZERO), - stream_id: StreamId(U32::ZERO), - length: Len(code.into()), + flags: Flags(U16::new(0)), + stream_id: CONNECTION_ID, + length: Len(code), _marker: std::marker::PhantomData, } } @@ -272,6 +293,7 @@ pub(super) mod private { /// A tag is the runtime representation of a message type. #[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(u8)] pub enum Tag { Data = 0, WindowUpdate = 1, @@ -279,39 +301,19 @@ pub enum Tag { GoAway = 3, } -impl TryFrom for Tag { - type Error = HeaderDecodeError; - - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(Self::Data), - 1 => Ok(Self::WindowUpdate), - 2 => Ok(Self::Ping), - 3 => Ok(Self::GoAway), - _ => Err(HeaderDecodeError::Type(value)), - } - } -} - /// The protocol version a message corresponds to. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] -#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] +#[repr(packed)] pub struct Version(u8); -impl Version { - pub fn val(self) -> u8 { - self.0 - } -} - /// The message length. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] -#[repr(C)] -pub struct Len(U32); +#[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] +#[repr(packed)] +pub struct Len(u32); impl Len { pub fn val(self) -> u32 { - self.0.get() + self.0 } } @@ -320,25 +322,32 @@ pub const CONNECTION_ID: StreamId = StreamId(U32::ZERO); /// The ID of a stream. /// /// The value 0 denotes no particular stream but the whole session. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] -#[repr(C)] +#[derive(Copy, Clone, Debug, Hash, Eq, FromBytes, AsBytes, FromZeroes)] +#[repr(packed)] pub struct StreamId(U32); +// TODO: Research why these can't be derived. Is this wrong? +impl PartialEq for StreamId { + fn eq(&self, other: &Self) -> bool { + self.0.get() == other.0.get() + } +} + impl PartialOrd for StreamId { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { self.0.get().partial_cmp(&other.0.get()) } } impl Ord for StreamId { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { + fn cmp(&self, other: &Self) -> Ordering { self.0.get().cmp(&other.0.get()) } } impl StreamId { pub(crate) fn new(val: u32) -> Self { - StreamId(val.into()) + StreamId(U32::new(val)) } pub fn is_server(self) -> bool { @@ -364,30 +373,21 @@ impl fmt::Display for StreamId { } } +impl nohash_hasher::IsEnabled for StreamId {} + /// Possible flags set on a message. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, FromZeroes, FromBytes, AsBytes)] -#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] +#[repr(packed)] pub struct Flags(U16); impl Flags { pub fn contains(self, other: Flags) -> bool { - let other = other.0.get(); - self.0.get() & other == other + self.0.get() & other.0.get() == other.0.get() } pub fn val(self) -> u16 { self.0.get() } - - pub fn set(&mut self, val: u16) { - self.0.set(val) - } -} - -impl BitOrAssign for Flags { - fn bitor_assign(&mut self, rhs: Self) { - self.set(self.val() | rhs.val()); - } } /// Indicates the start of a new stream. @@ -405,6 +405,30 @@ pub const RST: Flags = Flags(U16::from_bytes([0, 8])); /// The serialised header size in bytes. pub const HEADER_SIZE: usize = 12; +/// Encode a [`Header`] value. +pub fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { + let mut buf = [0; HEADER_SIZE]; + + hdr.write_to(&mut buf).expect("buffer to be correct length"); + buf +} + +/// Decode a [`Header`] value. +pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { + if buf[0] != 0 { + return Err(HeaderDecodeError::Version(buf[0])); + } + + let tag = buf[1]; + if !(0..4).contains(&tag) { + return Err(HeaderDecodeError::Type(tag)); + } + + let hdr = Header::read_from(buf).expect("buffer to be correct size"); // FIXME do we know this here? + + Ok(hdr) +} + /// Possible errors while decoding a message frame header. #[non_exhaustive] #[derive(Debug)] @@ -426,65 +450,42 @@ impl std::fmt::Display for HeaderDecodeError { impl std::error::Error for HeaderDecodeError {} -#[cfg(test)] -mod tests { - use crate::frame::Frame; - - use super::*; - use quickcheck::{Arbitrary, Gen, QuickCheck}; - - impl Arbitrary for Flags { - fn arbitrary(g: &mut Gen) -> Self { - let flags: u16 = Arbitrary::arbitrary(g); - Flags(flags.into()) - } - } - - impl Arbitrary for Len { - fn arbitrary(g: &mut Gen) -> Self { - let len: u32 = Arbitrary::arbitrary(g); - Len(len.into()) - } - } - - impl Arbitrary for StreamId { - fn arbitrary(g: &mut Gen) -> Self { - let stream_id: u32 = Arbitrary::arbitrary(g); - StreamId(stream_id.into()) - } - } - - impl Arbitrary for Header<()> { - fn arbitrary(g: &mut Gen) -> Self { - let tag = *g - .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]) - .unwrap(); - - Header { - version: Version(0), - tag: tag as u8, - flags: Arbitrary::arbitrary(g), - stream_id: Arbitrary::arbitrary(g), - length: Arbitrary::arbitrary(g), - _marker: std::marker::PhantomData, - } - } - } - - #[test] - fn encode_decode_identity() { - fn property(hdr: Header<()>) -> bool { - let frame = Frame::from_header(hdr); - match frame.parse() { - Ok(pf) => pf.bytes() == frame.buffer(), - Err(e) => { - eprintln!("decode error: {}", e); - false - } - } - } - QuickCheck::new() - .tests(10_000) - .quickcheck(property as fn(Header<()>) -> bool) - } -} +// FIXME +// #[cfg(test)] +// mod tests { +// use super::*; +// use quickcheck::{Arbitrary, Gen, QuickCheck}; +// +// impl Arbitrary for Header<()> { +// fn arbitrary(g: &mut Gen) -> Self { +// let tag = *g +// .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]) +// .unwrap(); +// +// Header { +// version: Version(0), +// tag: tag as u8, +// flags: Flags(Arbitrary::arbitrary(g)), +// stream_id: StreamId(Arbitrary::arbitrary(g)), +// length: Len(Arbitrary::arbitrary(g)), +// _marker: std::marker::PhantomData, +// } +// } +// } +// +// #[test] +// fn encode_decode_identity() { +// fn property(hdr: Header<()>) -> bool { +// match decode(&encode(&hdr)) { +// Ok(x) => x == hdr, +// Err(e) => { +// eprintln!("decode error: {}", e); +// false +// } +// } +// } +// QuickCheck::new() +// .tests(10_000) +// .quickcheck(property as fn(Header<()>) -> bool) +// } +// } diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 5f18c1d8..29401165 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -12,10 +12,12 @@ use super::{ header::{self, HeaderDecodeError}, Frame, }; -use crate::{connection::Id, frame::header::HEADER_SIZE}; +use crate::connection::Id; +use crate::frame::header::Data; +use futures::future::Either; use futures::{prelude::*, ready}; use std::{ - fmt, io, + fmt, io, mem, pin::Pin, task::{Context, Poll}, }; @@ -45,26 +47,14 @@ impl Io { /// The stages of writing a new `Frame`. enum WriteState { Init, - Data { frame: Frame<()>, offset: usize }, + Writing { frame: Frame<()>, offset: usize }, } impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { WriteState::Init => f.write_str("(WriteState::Init)"), - WriteState::Data { frame, offset } => match frame.parse() { - Ok(parsed_frame) => write!( - f, - "(WriteState::Body (header {:?}) (offset {}))", - parsed_frame.header(), - offset, - ), - Err(e) => write!( - f, - "(WriteState::Body (invalid header ({})) (offset {}))", - e, offset - ), - }, + _ => todo!(), } } } @@ -78,7 +68,7 @@ impl Sink> for Io { log::trace!("{}: write: {:?}", this.id, this.write_state); match &mut this.write_state { WriteState::Init => return Poll::Ready(Ok(())), - WriteState::Data { + WriteState::Writing { frame, ref mut offset, } => match Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..]) { @@ -99,7 +89,7 @@ impl Sink> for Io { } fn start_send(self: Pin<&mut Self>, frame: Frame<()>) -> Result<(), Self::Error> { - self.get_mut().write_state = WriteState::Data { frame, offset: 0 }; + self.get_mut().write_state = WriteState::Writing { frame, offset: 0 }; Ok(()) } @@ -121,96 +111,97 @@ enum ReadState { /// Initial reading state. Init, /// Reading the frame header. - Header { frame: Frame<()>, offset: usize }, + Header { + offset: usize, + buffer: [u8; header::HEADER_SIZE], + }, /// Reading the frame body. Body { frame: Frame<()>, offset: usize }, } -const READ_BUFFER_DEFAULT_CAPACITY: usize = 2048; - impl Stream for Io { type Item = Result, FrameDecodeError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = &mut *self; + let mut this = &mut *self; loop { log::trace!("{}: read: {:?}", this.id, this.read_state); - match this.read_state { + let state = mem::replace(&mut this.read_state, ReadState::Init); + + match state { ReadState::Init => { - let mut buffer = Vec::with_capacity(READ_BUFFER_DEFAULT_CAPACITY); - buffer.append(&mut vec![0_u8; HEADER_SIZE]); this.read_state = ReadState::Header { offset: 0, - frame: Frame::<()>::new(buffer), + buffer: [0; header::HEADER_SIZE], }; } - ReadState::Header { - ref mut offset, - ref mut frame, - } => { - if *offset == HEADER_SIZE { - let parsed_frame = match frame.parse_mut() { - Ok(frame) => frame, - Err(e) => return Poll::Ready(Some(Err(e))), + ReadState::Header { offset, mut buffer } => { + if offset == header::HEADER_SIZE { + let header = match header::decode(&buffer) { + Ok(hd) => hd, + Err(e) => return Poll::Ready(Some(Err(e.into()))), }; - log::trace!("{}: read: {:?}", this.id, parsed_frame); - if parsed_frame.header().tag().expect("valid tag") != header::Tag::Data { - let frame = std::mem::replace(frame, Frame::new(Vec::new())); + + log::trace!("{}: read: {}", this.id, header); + + if header.tag() != header::Tag::Data { this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(frame))); + return Poll::Ready(Some(Ok(Frame::new(header)))); } - let body_len = parsed_frame.header().len().val() as usize; + let body_len = header.len().val() as usize; if body_len > this.max_body_len { return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( body_len, )))); } - frame.append_bytes(&mut vec![0; body_len]); this.read_state = ReadState::Body { - frame: std::mem::replace(frame, Frame::new(Vec::new())), - offset: HEADER_SIZE, + frame: Frame::new(header), + offset: 0, }; continue; } - match ready!(Pin::new(&mut this.io) - .poll_read(cx, &mut frame.buffer_mut()[*offset..HEADER_SIZE]))? - { + let buf = &mut buffer[offset..header::HEADER_SIZE]; + match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { 0 => { - if *offset == 0 { + if offset == 0 { return Poll::Ready(None); } let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => *offset += n, + n => { + this.read_state = ReadState::Header { + buffer, + offset: offset + n, + } + } } } - ReadState::Body { - ref mut frame, - ref mut offset, - } => { - let parsed_frame = frame.parse().expect("valid frame"); - let body_len = parsed_frame.header().len().val() as usize; + ReadState::Body { offset, mut frame } => { + let body_len = frame.header().len().val() as usize; - if *offset == HEADER_SIZE + body_len { - let frame = std::mem::replace(frame, Frame::new(Vec::new())); - this.read_state = ReadState::Init; + if offset == body_len { return Poll::Ready(Some(Ok(frame))); } - match ready!(Pin::new(&mut this.io) - .poll_read(cx, &mut frame.buffer_mut()[*offset..HEADER_SIZE + body_len]))? - { + match ready!( + Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[offset..]) + )? { 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => *offset += n, + n => { + this.read_state = ReadState::Body { + frame, + offset: offset + n, + } + } } } } @@ -222,32 +213,18 @@ impl fmt::Debug for ReadState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ReadState::Init => f.write_str("(ReadState::Init)"), - ReadState::Header { frame, offset } => match frame.parse() { - Ok(parsed_frame) => write!( - f, - "(ReadState::Header (header {:?}) (offset {}))", - parsed_frame.header(), - offset, - ), - Err(e) => write!( - f, - "(ReadState::Header (invalid header ({})) (offset {}))", - e, offset - ), - }, - ReadState::Body { frame, offset } => match frame.parse() { - Ok(parsed_frame) => write!( + ReadState::Header { offset, .. } => { + write!(f, "(ReadState::Header (offset {}))", offset) + } + ReadState::Body { frame, offset } => { + write!( f, - "(ReadState::Body (header {:?}) (offset {}))", - parsed_frame.header(), + "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", + frame.header(), offset, - ), - Err(e) => write!( - f, - "(ReadState::Body (invalid header ({})) (offset {}))", - e, offset - ), - }, + frame.header().len().val() + ) + } } } } @@ -296,58 +273,51 @@ impl From for FrameDecodeError { } } -#[cfg(test)] -mod tests { - use super::*; - use quickcheck::{Arbitrary, Gen, QuickCheck}; - use rand::RngCore; - use zerocopy::AsBytes; - - impl Arbitrary for Frame<()> { - fn arbitrary(g: &mut Gen) -> Self { - let mut header: header::Header<()> = Arbitrary::arbitrary(g); - if header.tag().unwrap() == header::Tag::Data { - header.set_len(header.len().val() % 4096); - let mut buffer = vec![0; HEADER_SIZE + header.len().val() as usize]; - rand::thread_rng().fill_bytes(&mut buffer[HEADER_SIZE..]); - header - .write_to(&mut buffer[..HEADER_SIZE]) - .expect("write_to success"); - Frame::new(buffer) - } else { - Frame::from_header(header) - } - } - } - - #[test] - fn encode_decode_identity() { - fn property(f: Frame<()>) -> bool { - futures::executor::block_on(async move { - let pf = f.parse().expect("valid frame"); - let id = crate::connection::Id::random(); - let mut io = Io::new( - id, - futures::io::Cursor::new(Vec::new()), - pf.header().len().val() as usize, - ); - if io.send(f.clone()).await.is_err() { - return false; - } - if io.flush().await.is_err() { - return false; - } - io.io.set_position(0); - if let Ok(Some(x)) = io.try_next().await { - x == f - } else { - false - } - }) - } - - QuickCheck::new() - .tests(10_000) - .quickcheck(property as fn(Frame<()>) -> bool) - } -} +// TODO: Fix this. +// #[cfg(test)] +// mod tests { +// use super::*; +// use quickcheck::{Arbitrary, Gen, QuickCheck}; +// use rand::RngCore; +// +// impl Arbitrary for Frame<()> { +// fn arbitrary(g: &mut Gen) -> Self { +// let mut header: header::Header<()> = Arbitrary::arbitrary(g); +// let body = if header.tag() == header::Tag::Data { +// header.set_len(header.len().val() % 4096); +// let mut b = vec![0; header.len().val() as usize]; +// rand::thread_rng().fill_bytes(&mut b); +// b +// } else { +// Vec::new() +// }; +// Frame { header, body } +// } +// } +// +// #[test] +// fn encode_decode_identity() { +// fn property(f: Frame<()>) -> bool { +// futures::executor::block_on(async move { +// let id = crate::connection::Id::random(); +// let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); +// if io.send(f.clone()).await.is_err() { +// return false; +// } +// if io.flush().await.is_err() { +// return false; +// } +// io.io.set_position(0); +// if let Ok(Some(x)) = io.try_next().await { +// x == f +// } else { +// false +// } +// }) +// } +// +// QuickCheck::new() +// .tests(10_000) +// .quickcheck(property as fn(Frame<()>) -> bool) +// } +// } From ee00e1de45f62963049febf0f264953fe9943653 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sun, 17 Sep 2023 16:35:51 +1000 Subject: [PATCH 05/29] Some basic fixes --- yamux/src/frame.rs | 2 +- yamux/src/frame/header.rs | 9 ++++++++- yamux/src/frame/io.rs | 12 +++++------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 3a298744..91956488 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -34,7 +34,7 @@ impl Frame { pub(crate) fn new(header: Header) -> Self { let total_buffer_size = HEADER_SIZE + header.len().val() as usize; - let mut buffer = Vec::with_capacity(total_buffer_size); + let mut buffer = vec![0; total_buffer_size]; header .write_to_prefix(&mut buffer) .expect("buffer always fits the header"); diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index 1d280b1b..2f4326da 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -11,6 +11,7 @@ use futures::future::Either; use std::cmp::Ordering; use std::fmt; +use std::hash::{Hash, Hasher}; use zerocopy::big_endian::{U16, U32}; use zerocopy::{AsBytes, FromBytes, FromZeroes}; @@ -322,7 +323,7 @@ pub const CONNECTION_ID: StreamId = StreamId(U32::ZERO); /// The ID of a stream. /// /// The value 0 denotes no particular stream but the whole session. -#[derive(Copy, Clone, Debug, Hash, Eq, FromBytes, AsBytes, FromZeroes)] +#[derive(Copy, Clone, Debug, Eq, FromBytes, AsBytes, FromZeroes)] #[repr(packed)] pub struct StreamId(U32); @@ -345,6 +346,12 @@ impl Ord for StreamId { } } +impl Hash for StreamId { + fn hash(&self, state: &mut H) { + self.0.get().hash(state) + } +} + impl StreamId { pub(crate) fn new(val: u32) -> Self { StreamId(U32::new(val)) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 29401165..12eadef4 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -70,11 +70,10 @@ impl Sink> for Io { WriteState::Init => return Poll::Ready(Ok(())), WriteState::Writing { frame, - ref mut offset, - } => match Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { + offset, + } => match ready!(Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..])) { + Err(e) => return Poll::Ready(Err(e)), + Ok(n) => { if n == 0 { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } @@ -165,8 +164,7 @@ impl Stream for Io { continue; } - let buf = &mut buffer[offset..header::HEADER_SIZE]; - match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { + match ready!(Pin::new(&mut this.io).poll_read(cx, &mut buffer[offset..]))? { 0 => { if offset == 0 { return Poll::Ready(None); From 801bf4f4050f081eb028c00cbbac35de82ca9022 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 12:24:34 +1100 Subject: [PATCH 06/29] Don't use `ready!` macro with `mem::replace` --- yamux/src/frame/io.rs | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 12eadef4..ee48e806 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -133,6 +133,7 @@ impl Stream for Io { offset: 0, buffer: [0; header::HEADER_SIZE], }; + continue; } ReadState::Header { offset, mut buffer } => { if offset == header::HEADER_SIZE { @@ -160,24 +161,25 @@ impl Stream for Io { frame: Frame::new(header), offset: 0, }; - continue; } - match ready!(Pin::new(&mut this.io).poll_read(cx, &mut buffer[offset..]))? { - 0 => { + match Pin::new(&mut this.io).poll_read(cx, &mut buffer[offset..])? { + Poll::Ready(0) => { if offset == 0 { return Poll::Ready(None); } let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => { + Poll::Ready(n) => { this.read_state = ReadState::Header { buffer, offset: offset + n, - } + }; + continue; } + Poll::Pending => {} } } ReadState::Body { offset, mut frame } => { @@ -187,22 +189,24 @@ impl Stream for Io { return Poll::Ready(Some(Ok(frame))); } - match ready!( - Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[offset..]) - )? { - 0 => { + match Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[offset..])? { + Poll::Ready(0) => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => { + Poll::Ready(n) => { this.read_state = ReadState::Body { frame, offset: offset + n, - } + }; + continue; } + Poll::Pending => {} } } } + + return Poll::Pending; } } } From 8e11260fbf2bd9826aa92d2d067c5d6136893100 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 12:25:17 +1100 Subject: [PATCH 07/29] Inline variable --- yamux/src/frame/io.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index ee48e806..383e6b7b 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -125,9 +125,8 @@ impl Stream for Io { let mut this = &mut *self; loop { log::trace!("{}: read: {:?}", this.id, this.read_state); - let state = mem::replace(&mut this.read_state, ReadState::Init); - match state { + match mem::replace(&mut this.read_state, ReadState::Init) { ReadState::Init => { this.read_state = ReadState::Header { offset: 0, From aeb4cd203e06f7a2cc00e580ee76ee664951fc15 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 12:39:57 +1100 Subject: [PATCH 08/29] Remove `Init` state --- yamux/src/frame/io.rs | 65 +++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 383e6b7b..f81422a9 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -37,7 +37,10 @@ impl Io { Io { id, io, - read_state: ReadState::Init, + read_state: ReadState::Header { + offset: 0, + buffer: [0; header::HEADER_SIZE], + }, write_state: WriteState::Init, max_body_len: max_frame_body_len, } @@ -107,8 +110,6 @@ impl Sink> for Io { /// The stages of reading a new `Frame`. enum ReadState { - /// Initial reading state. - Init, /// Reading the frame header. Header { offset: usize, @@ -118,6 +119,15 @@ enum ReadState { Body { frame: Frame<()>, offset: usize }, } +impl ReadState { + fn header() -> Self { + ReadState::Header { + offset: 0, + buffer: [0u8; header::HEADER_SIZE], + } + } +} + impl Stream for Io { type Item = Result, FrameDecodeError>; @@ -126,16 +136,9 @@ impl Stream for Io { loop { log::trace!("{}: read: {:?}", this.id, this.read_state); - match mem::replace(&mut this.read_state, ReadState::Init) { - ReadState::Init => { - this.read_state = ReadState::Header { - offset: 0, - buffer: [0; header::HEADER_SIZE], - }; - continue; - } + match &mut this.read_state { ReadState::Header { offset, mut buffer } => { - if offset == header::HEADER_SIZE { + if *offset == header::HEADER_SIZE { let header = match header::decode(&buffer) { Ok(hd) => hd, Err(e) => return Poll::Ready(Some(Err(e.into()))), @@ -144,7 +147,7 @@ impl Stream for Io { log::trace!("{}: read: {}", this.id, header); if header.tag() != header::Tag::Data { - this.read_state = ReadState::Init; + this.read_state = ReadState::header(); return Poll::Ready(Some(Ok(Frame::new(header)))); } @@ -163,49 +166,44 @@ impl Stream for Io { continue; } - match Pin::new(&mut this.io).poll_read(cx, &mut buffer[offset..])? { - Poll::Ready(0) => { - if offset == 0 { + match ready!(Pin::new(&mut this.io).poll_read(cx, &mut buffer[*offset..]))? { + 0 => { + if *offset == 0 { return Poll::Ready(None); } let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - Poll::Ready(n) => { + n => { this.read_state = ReadState::Header { buffer, - offset: offset + n, + offset: *offset + n, }; - continue; } - Poll::Pending => {} } } - ReadState::Body { offset, mut frame } => { + ReadState::Body { offset, ref mut frame } => { let body_len = frame.header().len().val() as usize; - if offset == body_len { + if *offset == body_len { + let frame = match mem::replace(&mut self.read_state, ReadState::header()) { + ReadState::Header { .. } => unreachable!("we matched above"), + ReadState::Body { frame, .. } => frame + }; return Poll::Ready(Some(Ok(frame))); } - match Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[offset..])? { - Poll::Ready(0) => { + match ready!(Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[*offset..]))? { + 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - Poll::Ready(n) => { - this.read_state = ReadState::Body { - frame, - offset: offset + n, - }; - continue; + n => { + *offset += n; } - Poll::Pending => {} } } } - - return Poll::Pending; } } } @@ -213,7 +211,6 @@ impl Stream for Io { impl fmt::Debug for ReadState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - ReadState::Init => f.write_str("(ReadState::Init)"), ReadState::Header { offset, .. } => { write!(f, "(ReadState::Header (offset {}))", offset) } From f90c9900d10c018706c19c9f9de2dcebba37baa7 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 12:43:19 +1100 Subject: [PATCH 09/29] Use `?` for decoding header --- yamux/src/frame/io.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index f81422a9..76d7fc0f 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -21,6 +21,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use std::ops::AddAssign; /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] @@ -139,10 +140,7 @@ impl Stream for Io { match &mut this.read_state { ReadState::Header { offset, mut buffer } => { if *offset == header::HEADER_SIZE { - let header = match header::decode(&buffer) { - Ok(hd) => hd, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; + let header = header::decode(&buffer)?; log::trace!("{}: read: {}", this.id, header); @@ -199,7 +197,7 @@ impl Stream for Io { return Poll::Ready(Some(Err(e))); } n => { - *offset += n; + offset.add_assign(n) } } } From 93c3834e2cb09aa0678f73848ebc0d7ae7f1d99d Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 12:47:36 +1100 Subject: [PATCH 10/29] Use ctor --- yamux/src/frame/io.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 76d7fc0f..c58412d6 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -38,10 +38,7 @@ impl Io { Io { id, io, - read_state: ReadState::Header { - offset: 0, - buffer: [0; header::HEADER_SIZE], - }, + read_state: ReadState::header(), write_state: WriteState::Init, max_body_len: max_frame_body_len, } From 872815ebda3457b6124a2e164c86071ef6e95528 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 12:57:38 +1100 Subject: [PATCH 11/29] Use type-system to only allocate for data frames --- yamux/src/connection.rs | 2 +- yamux/src/frame.rs | 39 ++++++++++++++++++++++-------- yamux/src/frame/header.rs | 2 +- yamux/src/frame/io.rs | 51 +++++++++++++++++++++------------------ 4 files changed, 58 insertions(+), 36 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index de2e07b7..91cae542 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -880,7 +880,7 @@ impl Active { if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { let mut hdr = Header::ping(frame.header().nonce()); hdr.ack(); - return Action::Ping(Frame::new(hdr)); + return Action::Ping(Frame::no_body(hdr)); } log::trace!( "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 91956488..39054d74 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -31,13 +31,11 @@ pub struct Frame { } impl Frame { - pub(crate) fn new(header: Header) -> Self { - let total_buffer_size = HEADER_SIZE + header.len().val() as usize; - - let mut buffer = vec![0; total_buffer_size]; + pub(crate) fn no_body(header: Header) -> Self { + let mut buffer = vec![0; HEADER_SIZE]; header - .write_to_prefix(&mut buffer) - .expect("buffer always fits the header"); + .write_to(&mut buffer) + .expect("buffer is size of header"); Self { buffer, @@ -95,6 +93,13 @@ impl Frame { _marker: PhantomData, } } + + pub(crate) fn into_generic_frame(self) -> Frame<()> { + Frame { + buffer: self.buffer, + _marker: PhantomData, + } + } } impl From> for Frame<()> { @@ -130,6 +135,20 @@ impl Frame<()> { } impl Frame { + pub(crate) fn new(header: Header) -> Self { + let total_buffer_size = HEADER_SIZE + header.len().val() as usize; + + let mut buffer = vec![0; total_buffer_size]; + header + .write_to_prefix(&mut buffer) + .expect("buffer always fits the header"); + + Self { + buffer, + _marker: PhantomData, + } + } + pub fn data(id: StreamId, body: &[u8]) -> Result { let header = Header::data(id, body.len().try_into()?); @@ -152,20 +171,20 @@ impl Frame { impl Frame { pub fn window_update(id: StreamId, credit: u32) -> Frame { - Frame::new(Header::window_update(id, credit)) + Frame::no_body(Header::window_update(id, credit)) } } impl Frame { pub fn term() -> Frame { - Frame::::new(Header::term()) + Frame::::no_body(Header::term()) } pub fn protocol_error() -> Frame { - Frame::::new(Header::protocol_error()) + Frame::::no_body(Header::protocol_error()) } pub fn internal_error() -> Frame { - Frame::::new(Header::internal_error()) + Frame::::no_body(Header::internal_error()) } } diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index 2f4326da..d33bfb72 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -89,7 +89,7 @@ impl Header { } /// Arbitrary type cast, use with caution. - fn cast(self) -> Header { + pub(crate) fn cast(self) -> Header { Header { version: self.version, tag: self.tag, diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index c58412d6..a81e8a5d 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -16,12 +16,12 @@ use crate::connection::Id; use crate::frame::header::Data; use futures::future::Either; use futures::{prelude::*, ready}; +use std::ops::AddAssign; use std::{ fmt, io, mem, pin::Pin, task::{Context, Poll}, }; -use std::ops::AddAssign; /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] @@ -69,21 +69,21 @@ impl Sink> for Io { log::trace!("{}: write: {:?}", this.id, this.write_state); match &mut this.write_state { WriteState::Init => return Poll::Ready(Ok(())), - WriteState::Writing { - frame, - offset, - } => match ready!(Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..])) { - Err(e) => return Poll::Ready(Err(e)), - Ok(n) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == frame.buffer().len() { - this.write_state = WriteState::Init; + WriteState::Writing { frame, offset } => { + match ready!(Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..])) + { + Err(e) => return Poll::Ready(Err(e)), + Ok(n) => { + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + *offset += n; + if *offset == frame.buffer().len() { + this.write_state = WriteState::Init; + } } } - }, + } } } } @@ -114,7 +114,7 @@ enum ReadState { buffer: [u8; header::HEADER_SIZE], }, /// Reading the frame body. - Body { frame: Frame<()>, offset: usize }, + Body { frame: Frame, offset: usize }, } impl ReadState { @@ -143,7 +143,7 @@ impl Stream for Io { if header.tag() != header::Tag::Data { this.read_state = ReadState::header(); - return Poll::Ready(Some(Ok(Frame::new(header)))); + return Poll::Ready(Some(Ok(Frame::no_body(header)))); } let body_len = header.len().val() as usize; @@ -155,7 +155,7 @@ impl Stream for Io { } this.read_state = ReadState::Body { - frame: Frame::new(header), + frame: Frame::new(header.cast()), // Safe to cast here because we asserted above that it is a data frame. offset: 0, }; continue; @@ -177,25 +177,28 @@ impl Stream for Io { } } } - ReadState::Body { offset, ref mut frame } => { + ReadState::Body { + offset, + ref mut frame, + } => { let body_len = frame.header().len().val() as usize; if *offset == body_len { let frame = match mem::replace(&mut self.read_state, ReadState::header()) { ReadState::Header { .. } => unreachable!("we matched above"), - ReadState::Body { frame, .. } => frame + ReadState::Body { frame, .. } => frame, }; - return Poll::Ready(Some(Ok(frame))); + return Poll::Ready(Some(Ok(frame.into_generic_frame()))); } - match ready!(Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[*offset..]))? { + match ready!( + Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[*offset..]) + )? { 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => { - offset.add_assign(n) - } + n => offset.add_assign(n), } } } From 722b7c813572a2a96b72e2eec8e5dc4869d5946e Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 13:00:32 +1100 Subject: [PATCH 12/29] Don't use `cast` outside of `header` module --- yamux/src/frame/header.rs | 12 +++++++++++- yamux/src/frame/io.rs | 15 +++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index d33bfb72..e91d7645 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -89,7 +89,7 @@ impl Header { } /// Arbitrary type cast, use with caution. - pub(crate) fn cast(self) -> Header { + fn cast(self) -> Header { Header { version: self.version, tag: self.tag, @@ -134,6 +134,16 @@ impl Header<()> { } } +impl Header<()> { + pub(crate) fn try_into_data(self) -> Result, Self> { + if self.tag == Tag::Data as u8 { + return Ok(self.cast()); + } + + Err(self) + } +} + impl Header { /// Set the [`SYN`] flag. pub fn syn(&mut self) { diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index a81e8a5d..46e09280 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -13,7 +13,7 @@ use super::{ Frame, }; use crate::connection::Id; -use crate::frame::header::Data; +use crate::frame::header::{Data, Header}; use futures::future::Either; use futures::{prelude::*, ready}; use std::ops::AddAssign; @@ -141,10 +141,13 @@ impl Stream for Io { log::trace!("{}: read: {}", this.id, header); - if header.tag() != header::Tag::Data { - this.read_state = ReadState::header(); - return Poll::Ready(Some(Ok(Frame::no_body(header)))); - } + let header = match header.try_into_data() { + Ok(data_header) => data_header, + Err(other_header) => { + this.read_state = ReadState::header(); + return Poll::Ready(Some(Ok(Frame::no_body(other_header)))); + } + }; let body_len = header.len().val() as usize; @@ -155,7 +158,7 @@ impl Stream for Io { } this.read_state = ReadState::Body { - frame: Frame::new(header.cast()), // Safe to cast here because we asserted above that it is a data frame. + frame: Frame::new(header), offset: 0, }; continue; From 91e812a9af5c1ac4854ff70727da877845c5e9c1 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 5 Oct 2023 13:01:06 +1100 Subject: [PATCH 13/29] Add TODO --- yamux/src/frame/header.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index e91d7645..e4c259fb 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -79,6 +79,7 @@ impl Header { self.stream_id } + // FIXME: This should not be a generic accessor because its semantic depends on the type. pub fn len(&self) -> Len { self.length } From c095aacf0e3761f86a2baaa73d0f6bdad23442ae Mon Sep 17 00:00:00 2001 From: Philippe Jalaber Date: Wed, 4 Oct 2023 16:29:55 +0200 Subject: [PATCH 14/29] Replace header::decode() with Frame::try_from_header_buffer() Revert Header Len(u32) to Len(big_endian::U32) Fix all tests by calling Frame::ensure_buffer_len() before reading a data frame body in io.rs --- yamux/Cargo.toml | 1 + yamux/src/connection.rs | 2 +- yamux/src/frame.rs | 47 +++++++-- yamux/src/frame/header.rs | 208 ++++++++++++++------------------------ yamux/src/frame/io.rs | 124 +++++++++++------------ 5 files changed, 178 insertions(+), 204 deletions(-) diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 372ba0e3..5c01fafa 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -21,3 +21,4 @@ zerocopy = { version = "0.7.0", features = ["derive"] } [dev-dependencies] quickcheck = "1.0" +futures = "0.3.4" diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 91cae542..9617f2de 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -500,7 +500,7 @@ impl Active { } } - fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_new_outbound(&mut self, cx: &Context<'_>) -> Poll> { if self.streams.len() >= self.config.max_num_streams { log::error!("{}: maximum number of streams reached", self.id); return Poll::Ready(Err(ConnectionError::TooManyStreams)); diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 39054d74..8a67c0ba 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -14,16 +14,15 @@ mod io; use futures::future::Either; use header::{Data, GoAway, Header, Ping, StreamId, WindowUpdate}; use std::{convert::TryInto, fmt::Debug, marker::PhantomData, num::TryFromIntError}; -use zerocopy::{AsBytes, ByteSlice, ByteSliceMut, Ref}; +use zerocopy::{AsBytes, Ref}; pub use io::FrameDecodeError; pub(crate) use io::Io; -use crate::HeaderDecodeError; +use self::header::HEADER_SIZE; -use self::header::{Flags, Tag, HEADER_SIZE}; - -/// A Yamux message frame consisting of header and body in a single buffer +/// A Yamux message frame consisting of a single buffer with header followed by body. +/// The header can be zerocopy parsed into a Header struct by calling header()/header_mut(). #[derive(Clone, Debug, Eq, PartialEq)] pub struct Frame { buffer: Vec, @@ -43,14 +42,25 @@ impl Frame { } } - pub(crate) fn header(&self) -> &Header { + pub fn try_from_header_buffer(buffer: [u8; HEADER_SIZE]) -> Result { + let frame = Self { + buffer: buffer.to_vec(), + _marker: PhantomData, + }; + let header = frame.header(); + header.validate()?; + + Ok(frame) + } + + pub fn header(&self) -> &Header { Ref::<_, Header>::new_from_prefix(self.buffer.as_slice()) .expect("buffer always holds a valid header") .0 .into_ref() } - pub(crate) fn header_mut(&mut self) -> &mut Header { + pub fn header_mut(&mut self) -> &mut Header { Ref::<_, Header>::new_from_prefix(self.buffer.as_mut_slice()) .expect("buffer always holds a valid header") .0 @@ -112,6 +122,14 @@ impl From> for Frame<()> { } impl Frame<()> { + pub(crate) fn try_into_data(self) -> Result, Self> { + if self.header().is_data() { + Ok(self.into_data()) + } else { + Err(self) + } + } + pub(crate) fn into_data(self) -> Frame { Frame { buffer: self.buffer, @@ -135,7 +153,7 @@ impl Frame<()> { } impl Frame { - pub(crate) fn new(header: Header) -> Self { + pub fn new(header: Header) -> Self { let total_buffer_size = HEADER_SIZE + header.len().val() as usize; let mut buffer = vec![0; total_buffer_size]; @@ -153,6 +171,7 @@ impl Frame { let header = Header::data(id, body.len().try_into()?); let mut frame = Frame::new(header); + frame.body_mut().copy_from_slice(body); Ok(frame) @@ -167,6 +186,12 @@ impl Frame { Frame::new(header) } + + fn ensure_buffer_len(&mut self) { + self.buffer + .resize(HEADER_SIZE + self.header().len().val() as usize, 0); + } + } impl Frame { @@ -176,15 +201,15 @@ impl Frame { } impl Frame { - pub fn term() -> Frame { + pub fn term() -> Self { Frame::::no_body(Header::term()) } - pub fn protocol_error() -> Frame { + pub fn protocol_error() -> Self { Frame::::no_body(Header::protocol_error()) } - pub fn internal_error() -> Frame { + pub fn internal_error() -> Self { Frame::::no_body(Header::internal_error()) } } diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index e4c259fb..e17de050 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -9,7 +9,6 @@ // at https://opensource.org/licenses/MIT. use futures::future::Either; -use std::cmp::Ordering; use std::fmt; use std::hash::{Hash, Hasher}; use zerocopy::big_endian::{U16, U32}; @@ -40,13 +39,6 @@ impl fmt::Debug for Header { } } -impl Header { - #[must_use] - pub(crate) fn has_valid_tag(&self) -> bool { - (0..4).contains(&self.tag) - } -} - impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -86,7 +78,7 @@ impl Header { #[cfg(test)] pub fn set_len(&mut self, len: u32) { - self.length = Len(len) + self.length = Len(len.into()); } /// Arbitrary type cast, use with caution. @@ -101,15 +93,23 @@ impl Header { } } - /// Introduce this header to the right of a binary header type. - pub(crate) fn right(self) -> Header> { - self.cast() + pub(crate) fn is_data(&self) -> bool { + self.tag == Tag::Data as u8 } - /// Introduce this header to the left of a binary header type. - pub(crate) fn left(self) -> Header> { - self.cast() + /// Validate a [`Header`] value. + pub fn validate(&self) -> Result<(), HeaderDecodeError> { + if self.version.0 != 0 { + return Err(HeaderDecodeError::Version(self.version.0)); + } + + if !(0..4).contains(&self.tag) { + return Err(HeaderDecodeError::Type(self.tag)); + } + + Ok(()) } + } impl From> for Header<()> { @@ -120,29 +120,9 @@ impl From> for Header<()> { impl Header<()> { pub(crate) fn into_data(self) -> Header { - // FIXME debug_assert_eq!(self.tag, Tag::Data); - self.cast() - } - - pub(crate) fn into_window_update(self) -> Header { - // FIXME debug_assert_eq!(self.tag, Tag::WindowUpdate); + debug_assert!(self.is_data()); self.cast() } - - pub(crate) fn into_ping(self) -> Header { - // FIXME debug_assert_eq!(self.tag, Tag::Ping); - self.cast() - } -} - -impl Header<()> { - pub(crate) fn try_into_data(self) -> Result, Self> { - if self.tag == Tag::Data as u8 { - return Ok(self.cast()); - } - - Err(self) - } } impl Header { @@ -155,21 +135,21 @@ impl Header { impl Header { /// Set the [`ACK`] flag. pub fn ack(&mut self) { - // self.flags.0 |= ACK.0 + self.flags.0.set(self.flags.val() | ACK.0.get()); } } impl Header { /// Set the [`FIN`] flag. pub fn fin(&mut self) { - // self.flags.0 |= FIN.0 + self.flags.0.set(self.flags.val() | FIN.0.get()); } } impl Header { /// Set the [`RST`] flag. pub fn rst(&mut self) { - // self.flags.0 |= RST.0 + self.flags.0.set(self.flags.val() | RST.0.get()); } } @@ -181,7 +161,7 @@ impl Header { tag: Tag::Data as u8, flags: Flags(U16::new(0)), stream_id: id, - length: Len(len), + length: Len(len.into()), _marker: std::marker::PhantomData, } } @@ -195,14 +175,14 @@ impl Header { tag: Tag::WindowUpdate as u8, flags: Flags(U16::new(0)), stream_id: id, - length: Len(credit), + length: Len(credit.into()), _marker: std::marker::PhantomData, } } /// The credit this window update grants to the remote. pub fn credit(&self) -> u32 { - self.length.0 + self.length.0.into() } } @@ -214,14 +194,14 @@ impl Header { tag: Tag::Ping as u8, flags: Flags(U16::new(0)), stream_id: CONNECTION_ID, - length: Len(nonce), + length: Len(nonce.into()), _marker: std::marker::PhantomData, } } /// The nonce of this ping. pub fn nonce(&self) -> u32 { - self.length.0 + self.length.0.into() } } @@ -247,7 +227,7 @@ impl Header { tag: Tag::GoAway as u8, flags: Flags(U16::new(0)), stream_id: CONNECTION_ID, - length: Len(code), + length: Len(code.into()), _marker: std::marker::PhantomData, } } @@ -321,11 +301,11 @@ pub struct Version(u8); /// The message length. #[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] #[repr(packed)] -pub struct Len(u32); +pub struct Len(U32); impl Len { pub fn val(self) -> u32 { - self.0 + self.0.into() } } @@ -334,29 +314,10 @@ pub const CONNECTION_ID: StreamId = StreamId(U32::ZERO); /// The ID of a stream. /// /// The value 0 denotes no particular stream but the whole session. -#[derive(Copy, Clone, Debug, Eq, FromBytes, AsBytes, FromZeroes)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, FromBytes, AsBytes, FromZeroes)] #[repr(packed)] pub struct StreamId(U32); -// TODO: Research why these can't be derived. Is this wrong? -impl PartialEq for StreamId { - fn eq(&self, other: &Self) -> bool { - self.0.get() == other.0.get() - } -} - -impl PartialOrd for StreamId { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.get().partial_cmp(&other.0.get()) - } -} - -impl Ord for StreamId { - fn cmp(&self, other: &Self) -> Ordering { - self.0.get().cmp(&other.0.get()) - } -} - impl Hash for StreamId { fn hash(&self, state: &mut H) { self.0.get().hash(state) @@ -423,30 +384,6 @@ pub const RST: Flags = Flags(U16::from_bytes([0, 8])); /// The serialised header size in bytes. pub const HEADER_SIZE: usize = 12; -/// Encode a [`Header`] value. -pub fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { - let mut buf = [0; HEADER_SIZE]; - - hdr.write_to(&mut buf).expect("buffer to be correct length"); - buf -} - -/// Decode a [`Header`] value. -pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { - if buf[0] != 0 { - return Err(HeaderDecodeError::Version(buf[0])); - } - - let tag = buf[1]; - if !(0..4).contains(&tag) { - return Err(HeaderDecodeError::Type(tag)); - } - - let hdr = Header::read_from(buf).expect("buffer to be correct size"); // FIXME do we know this here? - - Ok(hdr) -} - /// Possible errors while decoding a message frame header. #[non_exhaustive] #[derive(Debug)] @@ -468,42 +405,53 @@ impl std::fmt::Display for HeaderDecodeError { impl std::error::Error for HeaderDecodeError {} -// FIXME -// #[cfg(test)] -// mod tests { -// use super::*; -// use quickcheck::{Arbitrary, Gen, QuickCheck}; -// -// impl Arbitrary for Header<()> { -// fn arbitrary(g: &mut Gen) -> Self { -// let tag = *g -// .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]) -// .unwrap(); -// -// Header { -// version: Version(0), -// tag: tag as u8, -// flags: Flags(Arbitrary::arbitrary(g)), -// stream_id: StreamId(Arbitrary::arbitrary(g)), -// length: Len(Arbitrary::arbitrary(g)), -// _marker: std::marker::PhantomData, -// } -// } -// } -// -// #[test] -// fn encode_decode_identity() { -// fn property(hdr: Header<()>) -> bool { -// match decode(&encode(&hdr)) { -// Ok(x) => x == hdr, -// Err(e) => { -// eprintln!("decode error: {}", e); -// false -// } -// } -// } -// QuickCheck::new() -// .tests(10_000) -// .quickcheck(property as fn(Header<()>) -> bool) -// } -// } +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{Arbitrary, Gen, QuickCheck}; + + impl Arbitrary for Header<()> { + fn arbitrary(g: &mut Gen) -> Self { + let tag = *g + .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]) + .unwrap(); + + Header { + version: Version(0), + tag: tag as u8, + flags: Flags(u16::arbitrary(g).into()), + stream_id: StreamId(u32::arbitrary(g).into()), + length: Len(u32::arbitrary(g).into()), + _marker: std::marker::PhantomData, + } + } + } + + fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { + let hdr = Header::read_from(buf).expect("buffer to be correct size"); + hdr.validate().map(|_| hdr) + } + + /// Encode a [`Header`] value. + fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { + let mut buf = [0; HEADER_SIZE]; + hdr.write_to(&mut buf).expect("buffer to be correct length"); + buf + } + + #[test] + fn encode_decode_identity() { + fn property(hdr: Header<()>) -> bool { + match decode(&encode(&hdr)) { + Ok(x) => x == hdr, + Err(e) => { + eprintln!("decode error: {}", e); + false + } + } + } + QuickCheck::new() + .tests(10_000) + .quickcheck(property as fn(Header<()>) -> bool) + } +} diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 46e09280..8c7033cf 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -13,8 +13,7 @@ use super::{ Frame, }; use crate::connection::Id; -use crate::frame::header::{Data, Header}; -use futures::future::Either; +use crate::frame::header::Data; use futures::{prelude::*, ready}; use std::ops::AddAssign; use std::{ @@ -130,26 +129,26 @@ impl Stream for Io { type Item = Result, FrameDecodeError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = &mut *self; + let this = &mut *self; loop { log::trace!("{}: read: {:?}", this.id, this.read_state); match &mut this.read_state { ReadState::Header { offset, mut buffer } => { if *offset == header::HEADER_SIZE { - let header = header::decode(&buffer)?; + let frame = Frame::try_from_header_buffer(buffer)?; - log::trace!("{}: read: {}", this.id, header); + log::trace!("{}: read: {:?}", this.id, frame); - let header = match header.try_into_data() { - Ok(data_header) => data_header, - Err(other_header) => { + let mut frame = match frame.try_into_data() { + Ok(data_frame) => data_frame, + Err(other_frame) => { this.read_state = ReadState::header(); - return Poll::Ready(Some(Ok(Frame::no_body(other_header)))); + return Poll::Ready(Some(Ok(other_frame))); } }; - let body_len = header.len().val() as usize; + let body_len = frame.header().len().val() as usize; if body_len > this.max_body_len { return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( @@ -157,10 +156,9 @@ impl Stream for Io { )))); } - this.read_state = ReadState::Body { - frame: Frame::new(header), - offset: 0, - }; + frame.ensure_buffer_len(); + + this.read_state = ReadState::Body { frame, offset: 0 }; continue; } @@ -272,51 +270,53 @@ impl From for FrameDecodeError { } } -// TODO: Fix this. -// #[cfg(test)] -// mod tests { -// use super::*; -// use quickcheck::{Arbitrary, Gen, QuickCheck}; -// use rand::RngCore; -// -// impl Arbitrary for Frame<()> { -// fn arbitrary(g: &mut Gen) -> Self { -// let mut header: header::Header<()> = Arbitrary::arbitrary(g); -// let body = if header.tag() == header::Tag::Data { -// header.set_len(header.len().val() % 4096); -// let mut b = vec![0; header.len().val() as usize]; -// rand::thread_rng().fill_bytes(&mut b); -// b -// } else { -// Vec::new() -// }; -// Frame { header, body } -// } -// } -// -// #[test] -// fn encode_decode_identity() { -// fn property(f: Frame<()>) -> bool { -// futures::executor::block_on(async move { -// let id = crate::connection::Id::random(); -// let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); -// if io.send(f.clone()).await.is_err() { -// return false; -// } -// if io.flush().await.is_err() { -// return false; -// } -// io.io.set_position(0); -// if let Ok(Some(x)) = io.try_next().await { -// x == f -// } else { -// false -// } -// }) -// } -// -// QuickCheck::new() -// .tests(10_000) -// .quickcheck(property as fn(Frame<()>) -> bool) -// } -// } +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{Arbitrary, Gen, QuickCheck}; + use rand::RngCore; + + impl Arbitrary for Frame<()> { + fn arbitrary(g: &mut Gen) -> Self { + let mut header: header::Header<()> = Arbitrary::arbitrary(g); + if header.tag() == header::Tag::Data { + header.set_len(header.len().val() % 4096); + let mut frame = Frame::new(header.into_data()); + rand::thread_rng().fill_bytes(frame.body_mut()); + frame.into_generic_frame() + } else { + Frame::no_body(header) + } + } + } + + #[test] + fn encode_decode_identity() { + fn property(f: Frame<()>) -> bool { + futures::executor::block_on(async move { + let id = crate::connection::Id::random(); + let mut io = Io::new( + id, + futures::io::Cursor::new(Vec::new()), + f.body_len() as usize, + ); + if io.send(f.clone()).await.is_err() { + return false; + } + if io.flush().await.is_err() { + return false; + } + io.io.set_position(0); + if let Ok(Some(x)) = io.try_next().await { + x == f + } else { + false + } + }) + } + + QuickCheck::new() + .tests(10_000) + .quickcheck(property as fn(Frame<()>) -> bool) + } +} From 2cab6b40b8f3b7c32b91c2d1cb67c89d98974409 Mon Sep 17 00:00:00 2001 From: Philippe Jalaber Date: Thu, 5 Oct 2023 13:50:36 +0200 Subject: [PATCH 15/29] Cargo fmt --- yamux/src/frame.rs | 1 - yamux/src/frame/header.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 8a67c0ba..dd392bba 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -191,7 +191,6 @@ impl Frame { self.buffer .resize(HEADER_SIZE + self.header().len().val() as usize, 0); } - } impl Frame { diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index e17de050..ee3baf97 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -109,7 +109,6 @@ impl Header { Ok(()) } - } impl From> for Header<()> { From 166f8ffeb725e0c22b2eb000399b8bf04a58a74a Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 11:42:21 +1100 Subject: [PATCH 16/29] Reduce diff --- yamux/src/frame/io.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 8c7033cf..3afd1a8f 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -15,7 +15,6 @@ use super::{ use crate::connection::Id; use crate::frame::header::Data; use futures::{prelude::*, ready}; -use std::ops::AddAssign; use std::{ fmt, io, mem, pin::Pin, @@ -192,14 +191,13 @@ impl Stream for Io { return Poll::Ready(Some(Ok(frame.into_generic_frame()))); } - match ready!( - Pin::new(&mut this.io).poll_read(cx, &mut frame.body_mut()[*offset..]) - )? { + let buf = &mut frame.body_mut()[*offset..]; + match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => offset.add_assign(n), + n => *offset += n, } } } From 5c6b1728ce2f08933831b5acac20905744930050 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 11:57:30 +1100 Subject: [PATCH 17/29] Bring back header::decode --- yamux/src/frame.rs | 33 ++++++++++-------------------- yamux/src/frame/header.rs | 42 ++++++++++++++++++++++----------------- yamux/src/frame/io.rs | 9 ++++----- 3 files changed, 39 insertions(+), 45 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index dd392bba..4566029a 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -42,17 +42,6 @@ impl Frame { } } - pub fn try_from_header_buffer(buffer: [u8; HEADER_SIZE]) -> Result { - let frame = Self { - buffer: buffer.to_vec(), - _marker: PhantomData, - }; - let header = frame.header(); - header.validate()?; - - Ok(frame) - } - pub fn header(&self) -> &Header { Ref::<_, Header>::new_from_prefix(self.buffer.as_slice()) .expect("buffer always holds a valid header") @@ -122,12 +111,17 @@ impl From> for Frame<()> { } impl Frame<()> { - pub(crate) fn try_into_data(self) -> Result, Self> { - if self.header().is_data() { - Ok(self.into_data()) - } else { - Err(self) - } + pub(crate) fn try_from_header_buffer( + buffer: [u8; HEADER_SIZE], + ) -> Result, Frame>, FrameDecodeError> { + let header = header::decode(&buffer)?; + + let either = match header.try_into_data() { + Ok(data) => Either::Right(Frame::new(data)), + Err(other) => Either::Left(Frame::no_body(other)), + }; + + Ok(either) } pub(crate) fn into_data(self) -> Frame { @@ -186,11 +180,6 @@ impl Frame { Frame::new(header) } - - fn ensure_buffer_len(&mut self) { - self.buffer - .resize(HEADER_SIZE + self.header().len().val() as usize, 0); - } } impl Frame { diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index ee3baf97..b86b46d4 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -96,19 +96,6 @@ impl Header { pub(crate) fn is_data(&self) -> bool { self.tag == Tag::Data as u8 } - - /// Validate a [`Header`] value. - pub fn validate(&self) -> Result<(), HeaderDecodeError> { - if self.version.0 != 0 { - return Err(HeaderDecodeError::Version(self.version.0)); - } - - if !(0..4).contains(&self.tag) { - return Err(HeaderDecodeError::Type(self.tag)); - } - - Ok(()) - } } impl From> for Header<()> { @@ -122,6 +109,14 @@ impl Header<()> { debug_assert!(self.is_data()); self.cast() } + + pub(crate) fn try_into_data(self) -> Result, Self> { + if self.tag == Tag::Data as u8 { + return Ok(self.into_data()); + } + + Err(self) + } } impl Header { @@ -383,6 +378,22 @@ pub const RST: Flags = Flags(U16::from_bytes([0, 8])); /// The serialised header size in bytes. pub const HEADER_SIZE: usize = 12; +// Decode a [`Header`] value. +pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { + if buf[0] != 0 { + return Err(HeaderDecodeError::Version(buf[0])); + } + + let tag = buf[1]; + if !(0..4).contains(&tag) { + return Err(HeaderDecodeError::Type(tag)); + } + + let hdr = Header::read_from(buf).expect("buffer to be correct size"); + + Ok(hdr) +} + /// Possible errors while decoding a message frame header. #[non_exhaustive] #[derive(Debug)] @@ -426,11 +437,6 @@ mod tests { } } - fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { - let hdr = Header::read_from(buf).expect("buffer to be correct size"); - hdr.validate().map(|_| hdr) - } - /// Encode a [`Header`] value. fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { let mut buf = [0; HEADER_SIZE]; diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 3afd1a8f..41f78c13 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -14,6 +14,7 @@ use super::{ }; use crate::connection::Id; use crate::frame::header::Data; +use futures::future::Either; use futures::{prelude::*, ready}; use std::{ fmt, io, mem, @@ -139,9 +140,9 @@ impl Stream for Io { log::trace!("{}: read: {:?}", this.id, frame); - let mut frame = match frame.try_into_data() { - Ok(data_frame) => data_frame, - Err(other_frame) => { + let frame = match frame { + Either::Right(data_frame) => data_frame, + Either::Left(other_frame) => { this.read_state = ReadState::header(); return Poll::Ready(Some(Ok(other_frame))); } @@ -155,8 +156,6 @@ impl Stream for Io { )))); } - frame.ensure_buffer_len(); - this.read_state = ReadState::Body { frame, offset: 0 }; continue; } From 524994f814544db42eb0701e08ede9cc6afbf401 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:03:15 +1100 Subject: [PATCH 18/29] Reduce diff --- yamux/src/frame.rs | 57 ++++++++++++++++++++++--------------------- yamux/src/frame/io.rs | 2 +- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 4566029a..62e1bc34 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -60,23 +60,6 @@ impl Frame { self.buffer.as_slice() } - pub(crate) fn body(&self) -> &[u8] { - &self.buffer[HEADER_SIZE..] - } - - pub(crate) fn body_len(&self) -> u32 { - self.body().len() as u32 - } - - pub(crate) fn into_body(mut self) -> Vec { - // FIXME: Should we implement this more efficiently with `BytesMut`? I think that one would allow us to split of the body without allocating again .. - self.buffer.split_off(HEADER_SIZE) - } - - pub(crate) fn body_mut(&mut self) -> &mut [u8] { - &mut self.buffer[HEADER_SIZE..] - } - /// Introduce this frame to the right of a binary frame type. pub(crate) fn right(self) -> Frame> { Frame { @@ -147,6 +130,15 @@ impl Frame<()> { } impl Frame { + pub fn data(id: StreamId, b: &[u8]) -> Result { + let header = Header::data(id, b.len().try_into()?); + + let mut frame = Frame::new(header); + frame.body_mut().copy_from_slice(b); + + Ok(frame) + } + pub fn new(header: Header) -> Self { let total_buffer_size = HEADER_SIZE + header.len().val() as usize; @@ -161,16 +153,6 @@ impl Frame { } } - pub fn data(id: StreamId, body: &[u8]) -> Result { - let header = Header::data(id, body.len().try_into()?); - - let mut frame = Frame::new(header); - - frame.body_mut().copy_from_slice(body); - - Ok(frame) - } - pub fn close_stream(id: StreamId, ack: bool) -> Self { let mut header = Header::data(id, 0); header.fin(); @@ -180,10 +162,29 @@ impl Frame { Frame::new(header) } + + + pub(crate) fn body(&self) -> &[u8] { + &self.buffer[HEADER_SIZE..] + } + + + pub(crate) fn body_mut(&mut self) -> &mut [u8] { + &mut self.buffer[HEADER_SIZE..] + } + + pub(crate) fn body_len(&self) -> u32 { + self.body().len() as u32 + } + + pub(crate) fn into_body(mut self) -> Vec { + // FIXME: Should we implement this more efficiently with `BytesMut`? I think that one would allow us to split of the body without allocating again .. + self.buffer.split_off(HEADER_SIZE) + } } impl Frame { - pub fn window_update(id: StreamId, credit: u32) -> Frame { + pub fn window_update(id: StreamId, credit: u32) -> Self { Frame::no_body(Header::window_update(id, credit)) } } diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 41f78c13..3dc0f012 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -295,7 +295,7 @@ mod tests { let mut io = Io::new( id, futures::io::Cursor::new(Vec::new()), - f.body_len() as usize, + f.buffer.len(), ); if io.send(f.clone()).await.is_err() { return false; From 0108a3d8f1498642c045de1568686707c03a38dd Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:05:01 +1100 Subject: [PATCH 19/29] Reduce diff --- yamux/src/frame.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 62e1bc34..a6165241 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -163,21 +163,19 @@ impl Frame { Frame::new(header) } - - pub(crate) fn body(&self) -> &[u8] { + pub fn body(&self) -> &[u8] { &self.buffer[HEADER_SIZE..] } - - pub(crate) fn body_mut(&mut self) -> &mut [u8] { + pub fn body_mut(&mut self) -> &mut [u8] { &mut self.buffer[HEADER_SIZE..] } - pub(crate) fn body_len(&self) -> u32 { + pub fn body_len(&self) -> u32 { self.body().len() as u32 } - pub(crate) fn into_body(mut self) -> Vec { + pub fn into_body(mut self) -> Vec { // FIXME: Should we implement this more efficiently with `BytesMut`? I think that one would allow us to split of the body without allocating again .. self.buffer.split_off(HEADER_SIZE) } From 8d32d16a918b0f19c139b6d56eb69f4347c9a37f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:06:56 +1100 Subject: [PATCH 20/29] Resolve todo --- yamux/src/frame/io.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 3dc0f012..47a359ed 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -54,7 +54,14 @@ impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { WriteState::Init => f.write_str("(WriteState::Init)"), - _ => todo!(), + WriteState::Writing { offset, frame } => { + write!( + f, + "(WriteState::Writing (offset {}) (buffer-len {}))", + offset, + frame.len() + ) + } } } } From 851341e96d642f6848ae285d576c40a4495aa62f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:08:37 +1100 Subject: [PATCH 21/29] Reduce diff --- yamux/src/frame.rs | 4 ++-- yamux/src/frame/io.rs | 23 +++++++++-------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index a6165241..50f6e464 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -95,9 +95,9 @@ impl From> for Frame<()> { impl Frame<()> { pub(crate) fn try_from_header_buffer( - buffer: [u8; HEADER_SIZE], + buffer: &[u8; HEADER_SIZE], ) -> Result, Frame>, FrameDecodeError> { - let header = header::decode(&buffer)?; + let header = header::decode(buffer)?; let either = match header.try_into_data() { Ok(data) => Either::Right(Frame::new(data)), diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 47a359ed..fd2809b7 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -59,7 +59,7 @@ impl fmt::Debug for WriteState { f, "(WriteState::Writing (offset {}) (buffer-len {}))", offset, - frame.len() + frame.buffer().len() ) } } @@ -141,7 +141,10 @@ impl Stream for Io { log::trace!("{}: read: {:?}", this.id, this.read_state); match &mut this.read_state { - ReadState::Header { offset, mut buffer } => { + ReadState::Header { + offset, + ref mut buffer, + } => { if *offset == header::HEADER_SIZE { let frame = Frame::try_from_header_buffer(buffer)?; @@ -167,7 +170,8 @@ impl Stream for Io { continue; } - match ready!(Pin::new(&mut this.io).poll_read(cx, &mut buffer[*offset..]))? { + let buf = &mut buffer[*offset..header::HEADER_SIZE]; + match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { 0 => { if *offset == 0 { return Poll::Ready(None); @@ -175,12 +179,7 @@ impl Stream for Io { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(e))); } - n => { - this.read_state = ReadState::Header { - buffer, - offset: *offset + n, - }; - } + n => *offset += n, } } ReadState::Body { @@ -299,11 +298,7 @@ mod tests { fn property(f: Frame<()>) -> bool { futures::executor::block_on(async move { let id = crate::connection::Id::random(); - let mut io = Io::new( - id, - futures::io::Cursor::new(Vec::new()), - f.buffer.len(), - ); + let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.buffer.len()); if io.send(f.clone()).await.is_err() { return false; } From 9b264096f3b5d285eea04b7f5f450854204cfb8c Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:14:05 +1100 Subject: [PATCH 22/29] Reduce diff --- yamux/src/frame/io.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index fd2809b7..c0e32d4e 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -140,9 +140,9 @@ impl Stream for Io { loop { log::trace!("{}: read: {:?}", this.id, this.read_state); - match &mut this.read_state { + match this.read_state { ReadState::Header { - offset, + ref mut offset, ref mut buffer, } => { if *offset == header::HEADER_SIZE { @@ -183,7 +183,7 @@ impl Stream for Io { } } ReadState::Body { - offset, + ref mut offset, ref mut frame, } => { let body_len = frame.header().len().val() as usize; From 3df462d47c6da793c0d3705a81ef9f63cbd31f1f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:15:55 +1100 Subject: [PATCH 23/29] Simplify things a bit further --- yamux/src/frame/io.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index c0e32d4e..7e48e907 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -139,7 +139,6 @@ impl Stream for Io { let this = &mut *self; loop { log::trace!("{}: read: {:?}", this.id, this.read_state); - match this.read_state { ReadState::Header { ref mut offset, @@ -158,7 +157,7 @@ impl Stream for Io { } }; - let body_len = frame.header().len().val() as usize; + let body_len = frame.body_len() as usize; if body_len > this.max_body_len { return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( @@ -186,7 +185,7 @@ impl Stream for Io { ref mut offset, ref mut frame, } => { - let body_len = frame.header().len().val() as usize; + let body_len = frame.body_len() as usize; if *offset == body_len { let frame = match mem::replace(&mut self.read_state, ReadState::header()) { From 09cda48c8fafbc3f4036e5be9085c278afe8d43f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:16:52 +1100 Subject: [PATCH 24/29] Use body_len in debug impl --- yamux/src/frame/io.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 7e48e907..0d0352a7 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -221,7 +221,7 @@ impl fmt::Debug for ReadState { "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", frame.header(), offset, - frame.header().len().val() + frame.body_len() ) } } From 5e3e65b1bcee595669042b9770f5877c62696fb6 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:21:03 +1100 Subject: [PATCH 25/29] Remove generic length accessor --- yamux/src/frame.rs | 2 +- yamux/src/frame/header.rs | 17 ++++++++++------- yamux/src/frame/io.rs | 7 ++++--- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 50f6e464..5808d659 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -140,7 +140,7 @@ impl Frame { } pub fn new(header: Header) -> Self { - let total_buffer_size = HEADER_SIZE + header.len().val() as usize; + let total_buffer_size = HEADER_SIZE + header.body_len(); let mut buffer = vec![0; total_buffer_size]; header diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index b86b46d4..8e07d703 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -71,14 +71,9 @@ impl Header { self.stream_id } - // FIXME: This should not be a generic accessor because its semantic depends on the type. - pub fn len(&self) -> Len { - self.length - } - #[cfg(test)] - pub fn set_len(&mut self, len: u32) { - self.length = Len(len.into()); + pub fn set_len(&mut self, len: usize) { + self.length = Len((len as u32).into()); } /// Arbitrary type cast, use with caution. @@ -159,6 +154,14 @@ impl Header { _marker: std::marker::PhantomData, } } + + /// Returns the length of the body. + /// + /// The `length` field in the header has a different semantic meaning depending on the tag. + /// For [`Tag::Data`], it describes the length of the body. + pub fn body_len(&self) -> usize { + self.length.val() as usize + } } impl Header { diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 0d0352a7..479677fb 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -280,10 +280,11 @@ mod tests { impl Arbitrary for Frame<()> { fn arbitrary(g: &mut Gen) -> Self { - let mut header: header::Header<()> = Arbitrary::arbitrary(g); + let header: header::Header<()> = Arbitrary::arbitrary(g); if header.tag() == header::Tag::Data { - header.set_len(header.len().val() % 4096); - let mut frame = Frame::new(header.into_data()); + let mut header = header.into_data(); + header.set_len(header.body_len() % 4096); + let mut frame = Frame::new(header); rand::thread_rng().fill_bytes(frame.body_mut()); frame.into_generic_frame() } else { From ab2966436e09e69269d560cb648fba3d40efa42f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 6 Oct 2023 12:22:23 +1100 Subject: [PATCH 26/29] Reduce diff --- yamux/src/frame/header.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index 8e07d703..234dee4a 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -87,10 +87,6 @@ impl Header { _marker: std::marker::PhantomData, } } - - pub(crate) fn is_data(&self) -> bool { - self.tag == Tag::Data as u8 - } } impl From> for Header<()> { @@ -101,7 +97,7 @@ impl From> for Header<()> { impl Header<()> { pub(crate) fn into_data(self) -> Header { - debug_assert!(self.is_data()); + debug_assert_eq!(self.tag, Tag::Data as u8); self.cast() } From b65f2bdeebd7fb2ed7068ca9e8e8885e1798f138 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 9 Oct 2023 13:21:06 +1100 Subject: [PATCH 27/29] Ensure we check max body len before allocating --- yamux/src/frame.rs | 4 ++++ yamux/src/frame/io.rs | 10 +--------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 5808d659..a7fa0e0c 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -96,10 +96,14 @@ impl From> for Frame<()> { impl Frame<()> { pub(crate) fn try_from_header_buffer( buffer: &[u8; HEADER_SIZE], + max_body_len: usize, ) -> Result, Frame>, FrameDecodeError> { let header = header::decode(buffer)?; let either = match header.try_into_data() { + Ok(data) if data.body_len() > max_body_len => { + return Err(FrameDecodeError::FrameTooLarge(data.body_len())); + } Ok(data) => Either::Right(Frame::new(data)), Err(other) => Either::Left(Frame::no_body(other)), }; diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 479677fb..b92bf639 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -145,7 +145,7 @@ impl Stream for Io { ref mut buffer, } => { if *offset == header::HEADER_SIZE { - let frame = Frame::try_from_header_buffer(buffer)?; + let frame = Frame::try_from_header_buffer(buffer, this.max_body_len)?; log::trace!("{}: read: {:?}", this.id, frame); @@ -157,14 +157,6 @@ impl Stream for Io { } }; - let body_len = frame.body_len() as usize; - - if body_len > this.max_body_len { - return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( - body_len, - )))); - } - this.read_state = ReadState::Body { frame, offset: 0 }; continue; } From ee9c9203760de186fb08c179106ab51d490a3d44 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 9 Oct 2023 13:57:25 +1100 Subject: [PATCH 28/29] Don't allocate unless necessary --- yamux/Cargo.toml | 1 + yamux/src/connection.rs | 4 ++-- yamux/src/frame.rs | 19 ++++++++++--------- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 5c01fafa..7f500d8a 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -18,6 +18,7 @@ rand = "0.8.3" static_assertions = "1" pin-project = "1.1.0" zerocopy = { version = "0.7.0", features = ["derive"] } +bytes = "1.5.0" [dev-dependencies] quickcheck = "1.0" diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 9617f2de..277a9f9a 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -722,7 +722,7 @@ impl Active { shared.update_state(self.id, stream_id, State::RecvClosed); } shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); + shared.buffer.push(frame.into_body().into()); if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { if let Some(credit) = shared.next_window_update() { @@ -765,7 +765,7 @@ impl Active { return Action::Reset(Frame::new(header)); } shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); + shared.buffer.push(frame.into_body().into()); if let Some(w) = shared.reader.take() { w.wake() } diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index a7fa0e0c..65226a1d 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -11,6 +11,7 @@ pub mod header; mod io; +use bytes::{Buf, Bytes, BytesMut}; use futures::future::Either; use header::{Data, GoAway, Header, Ping, StreamId, WindowUpdate}; use std::{convert::TryInto, fmt::Debug, marker::PhantomData, num::TryFromIntError}; @@ -25,13 +26,13 @@ use self::header::HEADER_SIZE; /// The header can be zerocopy parsed into a Header struct by calling header()/header_mut(). #[derive(Clone, Debug, Eq, PartialEq)] pub struct Frame { - buffer: Vec, + buffer: BytesMut, _marker: PhantomData, } impl Frame { pub(crate) fn no_body(header: Header) -> Self { - let mut buffer = vec![0; HEADER_SIZE]; + let mut buffer = BytesMut::zeroed(HEADER_SIZE); header .write_to(&mut buffer) .expect("buffer is size of header"); @@ -43,21 +44,21 @@ impl Frame { } pub fn header(&self) -> &Header { - Ref::<_, Header>::new_from_prefix(self.buffer.as_slice()) + Ref::<_, Header>::new_from_prefix(self.buffer.as_ref()) .expect("buffer always holds a valid header") .0 .into_ref() } pub fn header_mut(&mut self) -> &mut Header { - Ref::<_, Header>::new_from_prefix(self.buffer.as_mut_slice()) + Ref::<_, Header>::new_from_prefix(self.buffer.as_mut()) .expect("buffer always holds a valid header") .0 .into_mut() } pub(crate) fn buffer(&self) -> &[u8] { - self.buffer.as_slice() + self.buffer.as_ref() } /// Introduce this frame to the right of a binary frame type. @@ -146,7 +147,7 @@ impl Frame { pub fn new(header: Header) -> Self { let total_buffer_size = HEADER_SIZE + header.body_len(); - let mut buffer = vec![0; total_buffer_size]; + let mut buffer = BytesMut::zeroed(total_buffer_size); header .write_to_prefix(&mut buffer) .expect("buffer always fits the header"); @@ -179,9 +180,9 @@ impl Frame { self.body().len() as u32 } - pub fn into_body(mut self) -> Vec { - // FIXME: Should we implement this more efficiently with `BytesMut`? I think that one would allow us to split of the body without allocating again .. - self.buffer.split_off(HEADER_SIZE) + pub fn into_body(mut self) -> Bytes { + self.buffer.advance(HEADER_SIZE); + self.buffer.freeze() } } From ae3bc3d381a17475f6604a80da3c6d55bbf00628 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sun, 29 Oct 2023 20:15:43 +1100 Subject: [PATCH 29/29] WIP: Use `AsyncWrite::poll_write_vectored` --- yamux/Cargo.toml | 3 - yamux/src/connection.rs | 2 +- yamux/src/connection/stream.rs | 2 +- yamux/src/frame.rs | 147 ++++++++++-------------- yamux/src/frame/header.rs | 199 ++++++++++++++++----------------- yamux/src/frame/io.rs | 41 ++++--- 6 files changed, 186 insertions(+), 208 deletions(-) diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 7f500d8a..431bd6c7 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -17,9 +17,6 @@ parking_lot = "0.12" rand = "0.8.3" static_assertions = "1" pin-project = "1.1.0" -zerocopy = { version = "0.7.0", features = ["derive"] } -bytes = "1.5.0" [dev-dependencies] quickcheck = "1.0" -futures = "0.3.4" diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 277a9f9a..a0903cb7 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -880,7 +880,7 @@ impl Active { if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { let mut hdr = Header::ping(frame.header().nonce()); hdr.ack(); - return Action::Ping(Frame::no_body(hdr)); + return Action::Ping(Frame::new(hdr)); } log::trace!( "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 9f17ca88..ae745577 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -367,7 +367,7 @@ impl AsyncWrite for Stream { let k = std::cmp::min(shared.credit as usize, buf.len()); let k = std::cmp::min(k, self.config.split_send_size); shared.credit = shared.credit.saturating_sub(k as u32); - &buf[..k] + Vec::from(&buf[..k]) }; let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 65226a1d..1f8bace5 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -11,85 +11,66 @@ pub mod header; mod io; -use bytes::{Buf, Bytes, BytesMut}; use futures::future::Either; -use header::{Data, GoAway, Header, Ping, StreamId, WindowUpdate}; -use std::{convert::TryInto, fmt::Debug, marker::PhantomData, num::TryFromIntError}; -use zerocopy::{AsBytes, Ref}; +use header::{Data, GoAway, Header, StreamId, WindowUpdate}; +use std::{convert::TryInto, num::TryFromIntError}; pub use io::FrameDecodeError; pub(crate) use io::Io; +use crate::frame::header::{HEADER_SIZE, Ping}; -use self::header::HEADER_SIZE; - -/// A Yamux message frame consisting of a single buffer with header followed by body. -/// The header can be zerocopy parsed into a Header struct by calling header()/header_mut(). -#[derive(Clone, Debug, Eq, PartialEq)] +/// A Yamux message frame consisting of header and body. +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Frame { - buffer: BytesMut, - _marker: PhantomData, + header: Header, + body: Vec, } impl Frame { - pub(crate) fn no_body(header: Header) -> Self { - let mut buffer = BytesMut::zeroed(HEADER_SIZE); - header - .write_to(&mut buffer) - .expect("buffer is size of header"); - - Self { - buffer, - _marker: PhantomData, + pub fn new(header: Header) -> Self { + Frame { + header, + body: Vec::new(), } } - pub fn header(&self) -> &Header { - Ref::<_, Header>::new_from_prefix(self.buffer.as_ref()) - .expect("buffer always holds a valid header") - .0 - .into_ref() + &self.header } pub fn header_mut(&mut self) -> &mut Header { - Ref::<_, Header>::new_from_prefix(self.buffer.as_mut()) - .expect("buffer always holds a valid header") - .0 - .into_mut() - } - - pub(crate) fn buffer(&self) -> &[u8] { - self.buffer.as_ref() + &mut self.header } /// Introduce this frame to the right of a binary frame type. pub(crate) fn right(self) -> Frame> { Frame { - buffer: self.buffer, - _marker: PhantomData, + header: self.header.right(), + body: self.body, } } /// Introduce this frame to the left of a binary frame type. pub(crate) fn left(self) -> Frame> { Frame { - buffer: self.buffer, - _marker: PhantomData, + header: self.header.left(), + body: self.body, } } - pub(crate) fn into_generic_frame(self) -> Frame<()> { - Frame { - buffer: self.buffer, - _marker: PhantomData, - } + pub(crate) fn body(&self) -> &[u8] { + self.body.as_slice() + } + + pub(crate) fn len(&self) -> usize { + self.body.len() + HEADER_SIZE } } impl From> for Frame<()> { fn from(f: Frame) -> Frame<()> { Frame { - buffer: f.buffer, - _marker: PhantomData, + header: f.header.into(), + body: f.body, } } } @@ -106,7 +87,7 @@ impl Frame<()> { return Err(FrameDecodeError::FrameTooLarge(data.body_len())); } Ok(data) => Either::Right(Frame::new(data)), - Err(other) => Either::Left(Frame::no_body(other)), + Err(other) => Either::Left(Frame::new(other)), }; Ok(either) @@ -114,48 +95,32 @@ impl Frame<()> { pub(crate) fn into_data(self) -> Frame { Frame { - buffer: self.buffer, - _marker: PhantomData, + header: self.header.into_data(), + body: self.body, } } pub(crate) fn into_window_update(self) -> Frame { Frame { - buffer: self.buffer, - _marker: PhantomData, + header: self.header.into_window_update(), + body: self.body, } } pub(crate) fn into_ping(self) -> Frame { Frame { - buffer: self.buffer, - _marker: PhantomData, + header: self.header.into_ping(), + body: self.body, } } } impl Frame { - pub fn data(id: StreamId, b: &[u8]) -> Result { - let header = Header::data(id, b.len().try_into()?); - - let mut frame = Frame::new(header); - frame.body_mut().copy_from_slice(b); - - Ok(frame) - } - - pub fn new(header: Header) -> Self { - let total_buffer_size = HEADER_SIZE + header.body_len(); - - let mut buffer = BytesMut::zeroed(total_buffer_size); - header - .write_to_prefix(&mut buffer) - .expect("buffer always fits the header"); - - Self { - buffer, - _marker: PhantomData, - } + pub fn data(id: StreamId, b: Vec) -> Result { + Ok(Frame { + header: Header::data(id, b.len().try_into()?), + body: b, + }) } pub fn close_stream(id: StreamId, ack: bool) -> Self { @@ -168,40 +133,52 @@ impl Frame { Frame::new(header) } - pub fn body(&self) -> &[u8] { - &self.buffer[HEADER_SIZE..] - } - pub fn body_mut(&mut self) -> &mut [u8] { - &mut self.buffer[HEADER_SIZE..] + self.body.as_mut_slice() } pub fn body_len(&self) -> u32 { - self.body().len() as u32 + let len_in_header = self.header.body_len(); + let actual_len = self.body.len(); + + // debug_assert_eq!(len_in_header, actual_len); + + len_in_header as u32 } - pub fn into_body(mut self) -> Bytes { - self.buffer.advance(HEADER_SIZE); - self.buffer.freeze() + pub fn into_body(self) -> Vec { + self.body } } impl Frame { pub fn window_update(id: StreamId, credit: u32) -> Self { - Frame::no_body(Header::window_update(id, credit)) + Frame { + header: Header::window_update(id, credit), + body: Vec::new(), + } } } impl Frame { pub fn term() -> Self { - Frame::::no_body(Header::term()) + Frame { + header: Header::term(), + body: Vec::new(), + } } pub fn protocol_error() -> Self { - Frame::::no_body(Header::protocol_error()) + Frame { + header: Header::protocol_error(), + body: Vec::new(), + } } pub fn internal_error() -> Self { - Frame::::no_body(Header::internal_error()) + Frame { + header: Header::internal_error(), + body: Vec::new(), + } } } diff --git a/yamux/src/frame/header.rs b/yamux/src/frame/header.rs index 234dee4a..4b4fa0c5 100644 --- a/yamux/src/frame/header.rs +++ b/yamux/src/frame/header.rs @@ -10,35 +10,18 @@ use futures::future::Either; use std::fmt; -use std::hash::{Hash, Hasher}; -use zerocopy::big_endian::{U16, U32}; -use zerocopy::{AsBytes, FromBytes, FromZeroes}; /// The message frame header. -#[derive(Clone, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] -#[repr(packed)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Header { version: Version, - tag: u8, + tag: Tag, flags: Flags, stream_id: StreamId, length: Len, _marker: std::marker::PhantomData, } -impl fmt::Debug for Header { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Header") - .field("version", &self.version) - .field("tag", &self.tag) - .field("flags", &self.flags) - .field("stream_id", &self.stream_id) - .field("length", &self.length) - .field("_marker", &self._marker) - .finish() - } -} - impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -54,13 +37,7 @@ impl fmt::Display for Header { impl Header { pub fn tag(&self) -> Tag { - match self.tag { - 0 => Tag::Data, - 1 => Tag::WindowUpdate, - 2 => Tag::Ping, - 3 => Tag::GoAway, - _ => unreachable!("header always has valid tag"), // TODO: Fix this once `zerocopy` has `TryFromBytes` - } + self.tag } pub fn flags(&self) -> Flags { @@ -73,7 +50,7 @@ impl Header { #[cfg(test)] pub fn set_len(&mut self, len: usize) { - self.length = Len((len as u32).into()); + self.length = Len(len as u32) } /// Arbitrary type cast, use with caution. @@ -87,6 +64,16 @@ impl Header { _marker: std::marker::PhantomData, } } + + /// Introduce this header to the right of a binary header type. + pub(crate) fn right(self) -> Header> { + self.cast() + } + + /// Introduce this header to the left of a binary header type. + pub(crate) fn left(self) -> Header> { + self.cast() + } } impl From> for Header<()> { @@ -97,44 +84,55 @@ impl From> for Header<()> { impl Header<()> { pub(crate) fn into_data(self) -> Header { - debug_assert_eq!(self.tag, Tag::Data as u8); + debug_assert_eq!(self.tag, Tag::Data); self.cast() } pub(crate) fn try_into_data(self) -> Result, Self> { - if self.tag == Tag::Data as u8 { + if self.tag == Tag::Data { return Ok(self.into_data()); } Err(self) } + + pub(crate) fn into_window_update(self) -> Header { + debug_assert_eq!(self.tag, Tag::WindowUpdate); + self.cast() + } + + pub(crate) fn into_ping(self) -> Header { + debug_assert_eq!(self.tag, Tag::Ping); + self.cast() + } + } impl Header { /// Set the [`SYN`] flag. pub fn syn(&mut self) { - self.flags.0.set(self.flags.val() | SYN.0.get()) + self.flags.0 |= SYN.0 } } impl Header { /// Set the [`ACK`] flag. pub fn ack(&mut self) { - self.flags.0.set(self.flags.val() | ACK.0.get()); + self.flags.0 |= ACK.0 } } impl Header { /// Set the [`FIN`] flag. pub fn fin(&mut self) { - self.flags.0.set(self.flags.val() | FIN.0.get()); + self.flags.0 |= FIN.0 } } impl Header { /// Set the [`RST`] flag. pub fn rst(&mut self) { - self.flags.0.set(self.flags.val() | RST.0.get()); + self.flags.0 |= RST.0 } } @@ -143,10 +141,10 @@ impl Header { pub fn data(id: StreamId, len: u32) -> Self { Header { version: Version(0), - tag: Tag::Data as u8, - flags: Flags(U16::new(0)), + tag: Tag::Data, + flags: Flags(0), stream_id: id, - length: Len(len.into()), + length: Len(len), _marker: std::marker::PhantomData, } } @@ -165,17 +163,17 @@ impl Header { pub fn window_update(id: StreamId, credit: u32) -> Self { Header { version: Version(0), - tag: Tag::WindowUpdate as u8, - flags: Flags(U16::new(0)), + tag: Tag::WindowUpdate, + flags: Flags(0), stream_id: id, - length: Len(credit.into()), + length: Len(credit), _marker: std::marker::PhantomData, } } /// The credit this window update grants to the remote. pub fn credit(&self) -> u32 { - self.length.0.into() + self.length.0 } } @@ -184,17 +182,17 @@ impl Header { pub fn ping(nonce: u32) -> Self { Header { version: Version(0), - tag: Tag::Ping as u8, - flags: Flags(U16::new(0)), - stream_id: CONNECTION_ID, - length: Len(nonce.into()), + tag: Tag::Ping, + flags: Flags(0), + stream_id: StreamId(0), + length: Len(nonce), _marker: std::marker::PhantomData, } } /// The nonce of this ping. pub fn nonce(&self) -> u32 { - self.length.0.into() + self.length.0 } } @@ -217,10 +215,10 @@ impl Header { fn go_away(code: u32) -> Self { Header { version: Version(0), - tag: Tag::GoAway as u8, - flags: Flags(U16::new(0)), - stream_id: CONNECTION_ID, - length: Len(code.into()), + tag: Tag::GoAway, + flags: Flags(0), + stream_id: StreamId(0), + length: Len(code), _marker: std::marker::PhantomData, } } @@ -278,52 +276,42 @@ pub(super) mod private { /// A tag is the runtime representation of a message type. #[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[repr(u8)] pub enum Tag { - Data = 0, - WindowUpdate = 1, - Ping = 2, - GoAway = 3, + Data, + WindowUpdate, + Ping, + GoAway, } /// The protocol version a message corresponds to. -#[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] -#[repr(packed)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct Version(u8); /// The message length. -#[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] -#[repr(packed)] -pub struct Len(U32); +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Len(u32); impl Len { pub fn val(self) -> u32 { - self.0.into() + self.0 } } -pub const CONNECTION_ID: StreamId = StreamId(U32::ZERO); +pub const CONNECTION_ID: StreamId = StreamId(0); /// The ID of a stream. /// /// The value 0 denotes no particular stream but the whole session. -#[derive(Copy, Clone, Debug, Eq, PartialEq, FromBytes, AsBytes, FromZeroes)] -#[repr(packed)] -pub struct StreamId(U32); - -impl Hash for StreamId { - fn hash(&self, state: &mut H) { - self.0.get().hash(state) - } -} +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct StreamId(u32); impl StreamId { pub(crate) fn new(val: u32) -> Self { - StreamId(U32::new(val)) + StreamId(val) } pub fn is_server(self) -> bool { - self.0.get() % 2 == 0 + self.0 % 2 == 0 } pub fn is_client(self) -> bool { @@ -335,7 +323,7 @@ impl StreamId { } pub fn val(self) -> u32 { - self.0.get() + self.0 } } @@ -348,47 +336,65 @@ impl fmt::Display for StreamId { impl nohash_hasher::IsEnabled for StreamId {} /// Possible flags set on a message. -#[derive(Copy, Clone, Debug, PartialEq, Eq, FromBytes, AsBytes, FromZeroes)] -#[repr(packed)] -pub struct Flags(U16); +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Flags(u16); impl Flags { pub fn contains(self, other: Flags) -> bool { - self.0.get() & other.0.get() == other.0.get() + self.0 & other.0 == other.0 } pub fn val(self) -> u16 { - self.0.get() + self.0 } } /// Indicates the start of a new stream. -pub const SYN: Flags = Flags(U16::from_bytes([0, 1])); +pub const SYN: Flags = Flags(1); /// Acknowledges the start of a new stream. -pub const ACK: Flags = Flags(U16::from_bytes([0, 2])); +pub const ACK: Flags = Flags(2); /// Indicates the half-closing of a stream. -pub const FIN: Flags = Flags(U16::from_bytes([0, 4])); +pub const FIN: Flags = Flags(4); /// Indicates an immediate stream reset. -pub const RST: Flags = Flags(U16::from_bytes([0, 8])); +pub const RST: Flags = Flags(8); /// The serialised header size in bytes. pub const HEADER_SIZE: usize = 12; -// Decode a [`Header`] value. +/// Encode a [`Header`] value. +pub fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { + let mut buf = [0; HEADER_SIZE]; + buf[0] = hdr.version.0; + buf[1] = hdr.tag as u8; + buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes()); + buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes()); + buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes()); + buf +} + +/// Decode a [`Header`] value. pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { if buf[0] != 0 { return Err(HeaderDecodeError::Version(buf[0])); } - let tag = buf[1]; - if !(0..4).contains(&tag) { - return Err(HeaderDecodeError::Type(tag)); - } - - let hdr = Header::read_from(buf).expect("buffer to be correct size"); + let hdr = Header { + version: Version(buf[0]), + tag: match buf[1] { + 0 => Tag::Data, + 1 => Tag::WindowUpdate, + 2 => Tag::Ping, + 3 => Tag::GoAway, + t => return Err(HeaderDecodeError::Type(t)), + }, + flags: Flags(u16::from_be_bytes([buf[2], buf[3]])), + stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])), + length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])), + _marker: std::marker::PhantomData, + }; Ok(hdr) } @@ -427,22 +433,15 @@ mod tests { Header { version: Version(0), - tag: tag as u8, - flags: Flags(u16::arbitrary(g).into()), - stream_id: StreamId(u32::arbitrary(g).into()), - length: Len(u32::arbitrary(g).into()), + tag, + flags: Flags(Arbitrary::arbitrary(g)), + stream_id: StreamId(Arbitrary::arbitrary(g)), + length: Len(Arbitrary::arbitrary(g)), _marker: std::marker::PhantomData, } } } - /// Encode a [`Header`] value. - fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { - let mut buf = [0; HEADER_SIZE]; - hdr.write_to(&mut buf).expect("buffer to be correct length"); - buf - } - #[test] fn encode_decode_identity() { fn property(hdr: Header<()>) -> bool { diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index b92bf639..b180931c 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -13,7 +13,7 @@ use super::{ Frame, }; use crate::connection::Id; -use crate::frame::header::Data; +use crate::frame::header::{Data, HEADER_SIZE}; use futures::future::Either; use futures::{prelude::*, ready}; use std::{ @@ -21,6 +21,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use std::io::{IoSlice}; /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] @@ -59,7 +60,7 @@ impl fmt::Debug for WriteState { f, "(WriteState::Writing (offset {}) (buffer-len {}))", offset, - frame.buffer().len() + frame.len() ) } } @@ -76,18 +77,22 @@ impl Sink> for Io { match &mut this.write_state { WriteState::Init => return Poll::Ready(Ok(())), WriteState::Writing { frame, offset } => { - match ready!(Pin::new(&mut this.io).poll_write(cx, &frame.buffer()[*offset..])) - { - Err(e) => return Poll::Ready(Err(e)), - Ok(n) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == frame.buffer().len() { - this.write_state = WriteState::Init; - } - } + let io = Pin::new(&mut this.io); + + let n = if *offset < HEADER_SIZE { + ready!(io.poll_write_vectored(cx, &[IoSlice::new(&header::encode(frame.header())[*offset..]), IoSlice::new(frame.body())]))? + } else { + let body_offset = *offset - HEADER_SIZE; + ready!(io.poll_write_vectored(cx, &[IoSlice::new(&frame.body()[body_offset..])]))? + }; + + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + + *offset += n; + if *offset == frame.len() { + this.write_state = WriteState::Init; } } } @@ -184,7 +189,7 @@ impl Stream for Io { ReadState::Header { .. } => unreachable!("we matched above"), ReadState::Body { frame, .. } => frame, }; - return Poll::Ready(Some(Ok(frame.into_generic_frame()))); + return Poll::Ready(Some(Ok(frame.into()))); } let buf = &mut frame.body_mut()[*offset..]; @@ -278,9 +283,9 @@ mod tests { header.set_len(header.body_len() % 4096); let mut frame = Frame::new(header); rand::thread_rng().fill_bytes(frame.body_mut()); - frame.into_generic_frame() + frame.into() } else { - Frame::no_body(header) + Frame::new(header) } } } @@ -290,7 +295,7 @@ mod tests { fn property(f: Frame<()>) -> bool { futures::executor::block_on(async move { let id = crate::connection::Id::random(); - let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.buffer.len()); + let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.len()); if io.send(f.clone()).await.is_err() { return false; }