diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index d928d05..9c2973f 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -12,14 +12,17 @@ use std::{ cmp, future::Future, io::{Error, ErrorKind}, + mem::MaybeUninit, net::SocketAddr, pin::Pin, task::{Context, Poll, Waker}, time::Duration, }; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + runtime::Handle, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + task, }; use log::{trace, warn}; @@ -244,6 +247,7 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Ok(())); } + continue; } if let Some(b) = self.tcb.get_unordered_packets() { self.tcb.add_ack(b.len() as u32); @@ -265,6 +269,7 @@ impl AsyncRead for IpStackTcpStream { { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait1(false)); continue; } @@ -402,6 +407,7 @@ impl AsyncRead for IpStackTcpStream { } } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { if t.flags() == ACK { + self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if t.flags() == (FIN | ACK) { @@ -417,6 +423,7 @@ impl AsyncRead for IpStackTcpStream { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { + self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); @@ -424,9 +431,7 @@ impl AsyncRead for IpStackTcpStream { } } Poll::Ready(None) => return Poll::Ready(Ok(())), - Poll::Pending => { - return Poll::Pending; - } + Poll::Pending => return Poll::Pending, } } } @@ -509,21 +514,31 @@ impl AsyncWrite for IpStackTcpStream { mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - match &self.shutdown { - Shutdown::Ready => Poll::Ready(Ok(())), - Shutdown::Pending(_) => Poll::Pending, - Shutdown::None => { - self.shutdown.pending(cx.waker().clone()); - Poll::Pending - } + if matches!(self.shutdown, Shutdown::Ready) { + return Poll::Ready(Ok(())); + } else if matches!(self.shutdown, Shutdown::None) { + self.shutdown.pending(cx.waker().clone()); } + self.poll_read( + cx, + &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()]), + ) } } impl Drop for IpStackTcpStream { fn drop(&mut self) { - if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - _ = self.packet_sender.send(p); - } + task::block_in_place(move || { + Handle::current().block_on(async move { + _ = self.shutdown().await; + println!( + "Shudown done, Drop for IpStackTcpStream {:?}", + self.dst_addr + ); + if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { + _ = self.packet_sender.send(p); + } + }); + }); } }