Skip to content

Commit

Permalink
fix automatically drop IpstackTcpStream
Browse files Browse the repository at this point in the history
  • Loading branch information
xmh0511 committed Mar 28, 2024
1 parent 774f909 commit 4f0e506
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 15 deletions.
18 changes: 16 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use log::{error, trace};

use crate::{
packet::IpStackPacketProtocol,
stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport},
stream::{
IpStackStream, IpStackTcpStream, IpStackTcpStreamInner, IpStackUdpStream,
IpStackUnknownTransport,
},
};
mod error;
mod packet;
Expand Down Expand Up @@ -88,6 +91,17 @@ impl IpStack {
{
let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();

let (drop_sender, mut drop_receiver) = mpsc::unbounded_channel::<IpStackTcpStreamInner>();
tokio::spawn(async move {
while let Some(mut inner) = drop_receiver.recv().await {
tokio::spawn(async move {
if let Err(e) = inner.shutdown().await {
trace!("fail to drop {e:?}");
}
});
}
});

tokio::spawn(async move {
let mut streams: HashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = HashMap::new();
let mut buffer = [0u8; u16::MAX as usize];
Expand Down Expand Up @@ -116,7 +130,7 @@ impl IpStack {
Vacant(entry) => {
match packet.transport_protocol(){
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){
match IpStackTcpStream::new(drop_sender.clone(),packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){
Ok(stream) => {
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Tcp(stream))?;
Expand Down
1 change: 1 addition & 0 deletions src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

pub use self::tcp::IpStackTcpStream;
pub(crate) use self::tcp::IpStackTcpStreamInner;
pub use self::udp::IpStackUdpStream;
pub use self::unknown::IpStackUnknownTransport;

Expand Down
140 changes: 127 additions & 13 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl Shutdown {
}

#[derive(Debug)]
pub struct IpStackTcpStream {
pub(crate) struct IpStackTcpStreamInner {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
Expand All @@ -63,18 +63,18 @@ pub struct IpStackTcpStream {
write_notify: Option<Waker>,
}

impl IpStackTcpStream {
impl IpStackTcpStreamInner {
pub(crate) fn new(
src_addr: SocketAddr,
dst_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
) -> Result<IpStackTcpStreamInner, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

let stream = IpStackTcpStream {
let stream = IpStackTcpStreamInner {
src_addr,
dst_addr,
stream_sender,
Expand Down Expand Up @@ -191,16 +191,16 @@ impl IpStackTcpStream {
})
}

pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}
// pub fn local_addr(&self) -> SocketAddr {
// self.src_addr
// }

pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
// pub fn peer_addr(&self) -> SocketAddr {
// self.dst_addr
// }
}

impl AsyncRead for IpStackTcpStream {
impl AsyncRead for IpStackTcpStreamInner {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down Expand Up @@ -439,7 +439,7 @@ impl AsyncRead for IpStackTcpStream {
}
}

impl AsyncWrite for IpStackTcpStream {
impl AsyncWrite for IpStackTcpStreamInner {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down Expand Up @@ -528,10 +528,124 @@ impl AsyncWrite for IpStackTcpStream {
}
}

impl Drop for IpStackTcpStream {
impl Drop for IpStackTcpStreamInner {
fn drop(&mut self) {
if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) {
_ = self.packet_sender.send(p);
}
}
}

#[derive(Debug)]
pub struct IpStackTcpStream {
inner: Option<IpStackTcpStreamInner>,
drop_sender: UnboundedSender<IpStackTcpStreamInner>,
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
}

impl IpStackTcpStream {
pub(crate) fn new(
drop_sender: UnboundedSender<IpStackTcpStreamInner>,
src_addr: SocketAddr,
dst_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
let stream =
IpStackTcpStreamInner::new(src_addr, dst_addr, tcp, pkt_sender, mtu, tcp_timeout)?;
Ok(IpStackTcpStream {
stream_sender: stream.stream_sender(),
inner: Some(stream),
drop_sender,
src_addr,
dst_addr,
})
}
pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}

pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}

pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
}

impl AsyncRead for IpStackTcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if let Some(inner) = &mut self.inner {
Pin::new(inner).poll_read(cx, buf)
} else {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"",
)))
}
}
}

impl AsyncWrite for IpStackTcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
if let Some(inner) = &mut self.inner {
Pin::new(inner).poll_write(cx, buf)
} else {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"",
)))
}
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
if let Some(inner) = &mut self.inner {
Pin::new(inner).poll_flush(cx)
} else {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"",
)))
}
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
if let Some(inner) = &mut self.inner {
Pin::new(inner).poll_shutdown(cx)
} else {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"",
)))
}
}
}

impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
if let Err(e) = self.drop_sender.send(inner) {
trace!("fail to send IpStackTcpStreamInner to drop {:?}", e);
}
}
}
}

0 comments on commit 4f0e506

Please sign in to comment.