From 4f0e506a10896881d779c3c70d3d773adb6ff26e Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Thu, 28 Mar 2024 16:44:08 +0800 Subject: [PATCH] fix automatically drop IpstackTcpStream --- src/lib.rs | 18 +++++- src/stream/mod.rs | 1 + src/stream/tcp.rs | 140 +++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 144 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3c76e04..684c36e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,10 @@ use log::{error, trace}; use crate::{ packet::IpStackPacketProtocol, - stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, + stream::{ + IpStackStream, IpStackTcpStream, IpStackTcpStreamInner, IpStackUdpStream, + IpStackUnknownTransport, + }, }; mod error; mod packet; @@ -88,6 +91,17 @@ impl IpStack { { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); + let (drop_sender, mut drop_receiver) = mpsc::unbounded_channel::(); + tokio::spawn(async move { + while let Some(mut inner) = drop_receiver.recv().await { + tokio::spawn(async move { + if let Err(e) = inner.shutdown().await { + trace!("fail to drop {e:?}"); + } + }); + } + }); + tokio::spawn(async move { let mut streams: HashMap> = HashMap::new(); let mut buffer = [0u8; u16::MAX as usize]; @@ -116,7 +130,7 @@ impl IpStack { Vacant(entry) => { match packet.transport_protocol(){ IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ + match IpStackTcpStream::new(drop_sender.clone(),packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ Ok(stream) => { entry.insert(stream.stream_sender()); accept_sender.send(IpStackStream::Tcp(stream))?; diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 4c9bd68..04d7087 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,6 +1,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; pub use self::tcp::IpStackTcpStream; +pub(crate) use self::tcp::IpStackTcpStreamInner; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 8c396ad..fbf7a19 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -50,7 +50,7 @@ impl Shutdown { } #[derive(Debug)] -pub struct IpStackTcpStream { +pub(crate) struct IpStackTcpStreamInner { src_addr: SocketAddr, dst_addr: SocketAddr, stream_sender: UnboundedSender, @@ -63,7 +63,7 @@ pub struct IpStackTcpStream { write_notify: Option, } -impl IpStackTcpStream { +impl IpStackTcpStreamInner { pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, @@ -71,10 +71,10 @@ impl IpStackTcpStream { pkt_sender: UnboundedSender, mtu: u16, tcp_timeout: Duration, - ) -> Result { + ) -> Result { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - let stream = IpStackTcpStream { + let stream = IpStackTcpStreamInner { src_addr, dst_addr, stream_sender, @@ -191,16 +191,16 @@ impl IpStackTcpStream { }) } - pub fn local_addr(&self) -> SocketAddr { - self.src_addr - } + // pub fn local_addr(&self) -> SocketAddr { + // self.src_addr + // } - pub fn peer_addr(&self) -> SocketAddr { - self.dst_addr - } + // pub fn peer_addr(&self) -> SocketAddr { + // self.dst_addr + // } } -impl AsyncRead for IpStackTcpStream { +impl AsyncRead for IpStackTcpStreamInner { fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -439,7 +439,7 @@ impl AsyncRead for IpStackTcpStream { } } -impl AsyncWrite for IpStackTcpStream { +impl AsyncWrite for IpStackTcpStreamInner { fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -528,10 +528,124 @@ impl AsyncWrite for IpStackTcpStream { } } -impl Drop for IpStackTcpStream { +impl Drop for IpStackTcpStreamInner { fn drop(&mut self) { if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { _ = self.packet_sender.send(p); } } } + +#[derive(Debug)] +pub struct IpStackTcpStream { + inner: Option, + drop_sender: UnboundedSender, + src_addr: SocketAddr, + dst_addr: SocketAddr, + stream_sender: UnboundedSender, +} + +impl IpStackTcpStream { + pub(crate) fn new( + drop_sender: UnboundedSender, + src_addr: SocketAddr, + dst_addr: SocketAddr, + tcp: TcpPacket, + pkt_sender: UnboundedSender, + mtu: u16, + tcp_timeout: Duration, + ) -> Result { + let stream = + IpStackTcpStreamInner::new(src_addr, dst_addr, tcp, pkt_sender, mtu, tcp_timeout)?; + Ok(IpStackTcpStream { + stream_sender: stream.stream_sender(), + inner: Some(stream), + drop_sender, + src_addr, + dst_addr, + }) + } + pub(crate) fn stream_sender(&self) -> UnboundedSender { + self.stream_sender.clone() + } + + pub fn local_addr(&self) -> SocketAddr { + self.src_addr + } + + pub fn peer_addr(&self) -> SocketAddr { + self.dst_addr + } +} + +impl AsyncRead for IpStackTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_read(cx, buf) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } +} + +impl AsyncWrite for IpStackTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_write(cx, buf) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_flush(cx) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_shutdown(cx) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } +} + +impl Drop for IpStackTcpStream { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + if let Err(e) = self.drop_sender.send(inner) { + trace!("fail to send IpStackTcpStreamInner to drop {:?}", e); + } + } + } +}