From 1b51ac092030f79d57dcdb3c5239c64e3f97891c Mon Sep 17 00:00:00 2001 From: Sajjad Pourali Date: Sun, 5 May 2024 14:34:42 -0400 Subject: [PATCH] Improve memory management --- src/lib.rs | 12 ++++++------ src/stream/tcb.rs | 12 ++++++------ src/stream/tcp.rs | 47 +++++++++++++++-------------------------------- src/stream/udp.rs | 8 ++++---- 4 files changed, 31 insertions(+), 48 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c0659b8..0787e15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -132,7 +132,7 @@ where if let Some(stream) = process_device_read( &buffer[offset..n], &mut sessions, - &pkt_sender, + pkt_sender.clone(), &config, ) { accept_sender.send(stream)?; @@ -156,7 +156,7 @@ where fn process_device_read( data: &[u8], sessions: &mut SessionCollection, - pkt_sender: &PacketSender, + pkt_sender: PacketSender, config: &IpStackConfig, ) -> Option { let Ok(packet) = NetworkPacket::parse(data) else { @@ -171,7 +171,7 @@ fn process_device_read( packet.payload, &packet.ip, config.mtu, - pkt_sender.clone(), + pkt_sender, ), )); } @@ -198,7 +198,7 @@ fn process_device_read( fn create_stream( packet: NetworkPacket, config: &IpStackConfig, - pkt_sender: &PacketSender, + pkt_sender: PacketSender, ) -> Option<(PacketSender, IpStackStream)> { match packet.transport_protocol() { IpStackPacketProtocol::Tcp(h) => { @@ -206,7 +206,7 @@ fn create_stream( packet.src_addr(), packet.dst_addr(), h, - pkt_sender.clone(), + pkt_sender, config.mtu, config.tcp_timeout, ) { @@ -226,7 +226,7 @@ fn create_stream( packet.src_addr(), packet.dst_addr(), packet.payload, - pkt_sender.clone(), + pkt_sender, config.mtu, config.udp_timeout, ); diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index ab071ad..a6f2d37 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -5,7 +5,7 @@ use tokio::time::Sleep; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] pub enum TcpState { SynReceived(bool), // bool means if syn/ack is sent Established, @@ -14,7 +14,7 @@ pub enum TcpState { Closed, } -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] pub(super) enum PacketStatus { WindowUpdate, Invalid, @@ -106,8 +106,8 @@ impl Tcb { pub(super) fn change_state(&mut self, state: TcpState) { self.state = state; } - pub(super) fn get_state(&self) -> TcpState { - self.state.clone() + pub(super) fn get_state(&self) -> &TcpState { + &self.state } pub(super) fn change_send_window(&mut self, window: u16) { let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) @@ -202,7 +202,7 @@ impl Tcb { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct InflightPacket { pub seq: u32, pub payload: Vec, @@ -222,7 +222,7 @@ impl InflightPacket { } } -#[derive(Debug, Clone)] +#[derive(Debug)] struct UnorderedPacket { payload: Vec, // pub recv_time: SystemTime, // todo diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 8e481f7..8f38f29 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -26,11 +26,10 @@ use crate::packet::{IpHeader, NetworkPacket}; use super::tcb::PacketStatus; -#[derive(Debug, Default)] +#[derive(Debug)] enum Shutdown { Ready, Pending(Waker), - #[default] None, } @@ -77,7 +76,7 @@ impl IpStackTcpStream { packet_to_send: None, tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), mtu, - shutdown: Shutdown::default(), + shutdown: Shutdown::None, write_notify: None, }; if tcp.inner().syn { @@ -199,12 +198,12 @@ impl AsyncRead for IpStackTcpStream { .send(packet) .or(Err(ErrorKind::UnexpectedEof))?; } - if self.tcb.get_state() == TcpState::Closed { + if *self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); return Poll::Ready(Ok(())); } - if self.tcb.get_state() == TcpState::FinWait2(false) { + if *self.tcb.get_state() == TcpState::FinWait2(false) { self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); @@ -226,7 +225,7 @@ impl AsyncRead for IpStackTcpStream { } self.tcb.reset_timeout(); - if self.tcb.get_state() == TcpState::SynReceived(false) { + if *self.tcb.get_state() == TcpState::SynReceived(false) { self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); @@ -246,7 +245,7 @@ impl AsyncRead for IpStackTcpStream { .or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); } - if self.tcb.get_state() == TcpState::FinWait1(true) { + if *self.tcb.get_state() == TcpState::FinWait1(true) { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); @@ -254,7 +253,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) - && self.tcb.get_state() == TcpState::Established + && *self.tcb.get_state() == TcpState::Established && self.tcb.get_last_ack() == self.tcb.get_seq() { self.packet_to_send = @@ -279,13 +278,13 @@ impl AsyncRead for IpStackTcpStream { continue; } - if self.tcb.get_state() == TcpState::SynReceived(true) { + if *self.tcb.get_state() == TcpState::SynReceived(true) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::Established); } - } else if self.tcb.get_state() == TcpState::Established { + } else if *self.tcb.get_state() == TcpState::Established { if t.flags() == ACK { match self.tcb.check_pkt_type(&t, &p.payload) { PacketStatus::WindowUpdate => { @@ -327,21 +326,13 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb .add_unordered_packet(t.inner().sequence_number, p.payload); - // buf.put_slice(&p.payload); - // self.tcb.add_ack(p.payload.len() as u32); - // self.packet_to_send = Some(self.create_rev_packet( - // ACK, - // TTL, - // None, - // Vec::new(), - // )?); + self.tcb.change_send_window(t.inner().window_size); if let Some(ref n) = self.write_notify { n.wake_by_ref(); self.write_notify = None; }; continue; - // return Poll::Ready(Ok(())); } PacketStatus::Ack => { self.tcb.change_last_ack(t.inner().acknowledgment_number); @@ -376,21 +367,13 @@ impl AsyncRead for IpStackTcpStream { continue; } - // self.tcb.add_ack(p.payload.len() as u32); self.tcb.change_send_window(t.inner().window_size); - // buf.put_slice(&p.payload); - // self.packet_to_send = Some(self.create_rev_packet( - // ACK, - // TTL, - // None, - // Vec::new(), - // )?); - // return Poll::Ready(Ok(())); + self.tcb .add_unordered_packet(t.inner().sequence_number, p.payload); continue; } - } else if self.tcb.get_state() == TcpState::FinWait1(false) { + } else if *self.tcb.get_state() == TcpState::FinWait1(false) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.add_ack(1); @@ -404,7 +387,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(true)); continue; } - } else if self.tcb.get_state() == TcpState::FinWait2(true) { + } else if *self.tcb.get_state() == TcpState::FinWait2(true) { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { @@ -427,7 +410,7 @@ impl AsyncWrite for IpStackTcpStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if self.tcb.get_state() != TcpState::Established { + if *self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } self.tcb.reset_timeout(); @@ -462,7 +445,7 @@ impl AsyncWrite for IpStackTcpStream { mut self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - if self.tcb.get_state() != TcpState::Established { + if *self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } if let Some(i) = self diff --git a/src/stream/udp.rs b/src/stream/udp.rs index c9edc90..ad8086c 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -16,7 +16,7 @@ pub struct IpStackUdpStream { dst_addr: SocketAddr, stream_sender: PacketSender, stream_receiver: PacketReceiver, - packet_sender: PacketSender, + pkt_sender: PacketSender, first_payload: Option>, timeout: Pin>, udp_timeout: Duration, @@ -28,7 +28,7 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - packet_sender: PacketSender, + pkt_sender: PacketSender, mtu: u16, udp_timeout: Duration, ) -> Self { @@ -39,7 +39,7 @@ impl IpStackUdpStream { dst_addr, stream_sender, stream_receiver, - packet_sender, + pkt_sender, first_payload: Some(payload), timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, @@ -156,7 +156,7 @@ impl AsyncWrite for IpStackUdpStream { self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); - self.packet_sender + self.pkt_sender .send(packet) .or(Err(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len))