Skip to content

Commit

Permalink
fix: handle edge cases in frame read and write
Browse files Browse the repository at this point in the history
  • Loading branch information
sabify committed May 2, 2023
1 parent 5d02960 commit e846a03
Showing 1 changed file with 87 additions and 68 deletions.
155 changes: 87 additions & 68 deletions src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! LEN is a 16-bit unsigned integer in big endian byte order.
//!
use std::cmp::Ordering;
use std::io::Result;
use std::io::IoSlice;
use std::pin::Pin;
Expand All @@ -28,6 +29,7 @@ macro_rules! ready {
};
}

#[derive(Debug)]
enum State {
Len,
Data(u16),
Expand All @@ -52,6 +54,7 @@ pub struct UotStream<T> {
io: T,
rd: State,
wr: State,
buf: Vec<u8>,
}

impl<T> UotStream<T> {
Expand All @@ -61,7 +64,8 @@ impl<T> UotStream<T> {
Self {
io,
rd: State::new(),
wr: State::new(),
wr: State::Data(0),
buf: vec![],
}
}
}
Expand All @@ -88,42 +92,61 @@ where
let this = self.get_mut();

loop {
if !this.buf.is_empty() {
buf.put_slice(this.buf.as_slice());
};

let mut total = buf.filled().len();
if (matches!(this.rd, State::Len) && total < 2)
|| matches!(this.rd, State::Data(len) if total < len as usize)
{
let n = ready!(Pin::new(&mut this.io).poll_read(cx, buf))
.map(|_| buf.filled().len() - total)?;
// EOF
if n == 0 {
return Poll::Ready(Ok(()));
}
total += n;
}

// make it immutable
let total = total;

// we can safely clear the buffer
this.buf.clear();

match this.rd {
State::Len => {
let mut len_be = [0u8; 2];
let mut len_be_buf = ReadBuf::new(&mut len_be);
let mut total = 0;

while total < 2 {
let n = ready!(Pin::new(&mut this.io).poll_read(cx, &mut len_be_buf))
.map(|_| len_be_buf.filled().len() - total)?;
// EOF
if n == 0 {
this.rd = State::Fin;
return Poll::Ready(Ok(()));
}
total += n;
if total < 2 {
this.buf.reserve_exact(total);
this.buf.extend(buf.filled());
buf.clear();
continue;
}

this.rd = State::Data(u16::from_be_bytes(len_be));
}
State::Data(len) => {
debug_assert!(len as usize <= buf.remaining());
let mut buf_limit = ReadBuf::new(buf.initialize_unfilled_to(len as usize));
let n = ready!(Pin::new(&mut this.io).poll_read(cx, &mut buf_limit))
.map(|_| buf_limit.filled().len())?;
this.rd =
State::Data(u16::from_be_bytes(buf.filled()[..2].try_into().unwrap()));

buf.advance(n);
if n == 0 {
this.rd = State::Fin;
if total > 2 {
this.buf.reserve_exact(buf.filled().len() - 2);
this.buf.extend(&buf.filled()[2..]);
}
buf.clear();
}
State::Data(len) => match total.cmp(&(len as usize)) {
Ordering::Equal => {
this.rd = State::Len;
return Poll::Ready(Ok(()));
} else if n as u16 == len {
}
Ordering::Less => {}
Ordering::Greater => {
this.buf.reserve_exact(buf.filled()[len as usize..].len());
this.buf.extend(&buf.filled()[len as usize..]);
buf.set_filled(len as usize);
this.rd = State::Len;
return Poll::Ready(Ok(()));
} else {
this.rd = State::Data(len - n as u16);
};
}
}
},
State::Fin => return Poll::Ready(Ok(())),
}
}
Expand All @@ -139,53 +162,49 @@ where

let this = self.get_mut();

match this.wr {
State::Len => {
let len_be = (buf.len() as u16).to_be_bytes();
let mut total = 0;
let mut iovec = &mut [IoSlice::new(&len_be), IoSlice::new(buf)][..];
loop {
let n = ready!(Pin::new(&mut this.io).poll_write_vectored(cx, iovec))?;
total += n;
// write zero
loop {
match this.wr {
State::Len => {
unreachable!();
}
State::Data(cursor) => {
let n = if cursor < 2 {
let len_be = &(buf.len() as u16).to_be_bytes()[cursor as usize..];
let iovec = &mut [IoSlice::new(len_be), IoSlice::new(buf)][..];
ready!(Pin::new(&mut this.io).poll_write_vectored(cx, iovec))?
} else {
ready!(Pin::new(&mut this.io).poll_write(cx, buf))?
};

let written_bytes = if cursor < 2 {
if n + cursor as usize > 2 {
n - (2 - cursor) as usize
} else {
0
}
} else {
n
};

if n == 0 {
// EOF
this.wr = State::Fin;
return Poll::Ready(Ok(0));
}
// write partial len
#[allow(clippy::comparison_chain)]
if total < 2 {
iovec[0] = IoSlice::new(&len_be[total..]);
continue;
} else if total == 2 {
iovec = &mut iovec[1..];
continue;

if written_bytes == buf.len() {
this.wr = State::Data(0);
return Poll::Ready(Ok(written_bytes));
} else {
// write len + data
let write_n = total - 2;
if write_n == buf.len() {
this.wr = State::Len;
} else {
this.wr = State::Data((buf.len() - write_n) as u16);
this.wr = State::Data(n as u16 + cursor);

if written_bytes != 0 {
return Poll::Ready(Ok(written_bytes));
}
return Poll::Ready(Ok(write_n));
}
}
State::Fin => return Poll::Ready(Ok(0)),
}
State::Data(len) => {
let n = ready!(Pin::new(&mut this.io).poll_write(cx, &buf[..len as usize]))?;
if n == 0 {
this.wr = State::Fin;
Poll::Ready(Ok(0))
} else if n < len as usize {
this.wr = State::Data(len - n as u16);
Poll::Ready(Ok(n))
} else {
this.wr = State::Len;
Poll::Ready(Ok(n))
}
}
State::Fin => Poll::Ready(Ok(0)),
}
}

Expand Down Expand Up @@ -290,7 +309,7 @@ mod test {
let mut stream = UotStream::new(SlowStream {
buf: Vec::with_capacity(65535),
rlimit: 0,
wlimit: wlimit,
wlimit,
cursor: 0,
});
for i in 1..=512 {
Expand Down

0 comments on commit e846a03

Please sign in to comment.