Skip to content

Commit

Permalink
Improve memory management
Browse files Browse the repository at this point in the history
  • Loading branch information
Sajjad Pourali committed May 5, 2024
1 parent f387f70 commit 1b51ac0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 48 deletions.
12 changes: 6 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ where
if let Some(stream) = process_device_read(
&buffer[offset..n],
&mut sessions,
&pkt_sender,
pkt_sender.clone(),
&config,
) {
accept_sender.send(stream)?;
Expand All @@ -156,7 +156,7 @@ where
fn process_device_read(
data: &[u8],
sessions: &mut SessionCollection,
pkt_sender: &PacketSender,
pkt_sender: PacketSender,
config: &IpStackConfig,
) -> Option<IpStackStream> {
let Ok(packet) = NetworkPacket::parse(data) else {
Expand All @@ -171,7 +171,7 @@ fn process_device_read(
packet.payload,
&packet.ip,
config.mtu,
pkt_sender.clone(),
pkt_sender,
),
));
}
Expand All @@ -198,15 +198,15 @@ fn process_device_read(
fn create_stream(
packet: NetworkPacket,
config: &IpStackConfig,
pkt_sender: &PacketSender,
pkt_sender: PacketSender,
) -> Option<(PacketSender, IpStackStream)> {
match packet.transport_protocol() {
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
packet.src_addr(),
packet.dst_addr(),
h,
pkt_sender.clone(),
pkt_sender,
config.mtu,
config.tcp_timeout,
) {
Expand All @@ -226,7 +226,7 @@ fn create_stream(
packet.src_addr(),
packet.dst_addr(),
packet.payload,
pkt_sender.clone(),
pkt_sender,
config.mtu,
config.udp_timeout,
);
Expand Down
12 changes: 6 additions & 6 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio::time::Sleep;
const MAX_UNACK: u32 = 1024 * 16; // 16KB
const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB

#[derive(Clone, Debug, PartialEq)]
#[derive(Debug, PartialEq)]
pub enum TcpState {
SynReceived(bool), // bool means if syn/ack is sent
Established,
Expand All @@ -14,7 +14,7 @@ pub enum TcpState {
Closed,
}

#[derive(Clone, Debug, PartialEq)]
#[derive(Debug, PartialEq)]
pub(super) enum PacketStatus {
WindowUpdate,
Invalid,
Expand Down Expand Up @@ -106,8 +106,8 @@ impl Tcb {
pub(super) fn change_state(&mut self, state: TcpState) {
self.state = state;
}
pub(super) fn get_state(&self) -> TcpState {
self.state.clone()
pub(super) fn get_state(&self) -> &TcpState {
&self.state
}
pub(super) fn change_send_window(&mut self, window: u16) {
let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64)
Expand Down Expand Up @@ -202,7 +202,7 @@ impl Tcb {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct InflightPacket {
pub seq: u32,
pub payload: Vec<u8>,
Expand All @@ -222,7 +222,7 @@ impl InflightPacket {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
struct UnorderedPacket {
payload: Vec<u8>,
// pub recv_time: SystemTime, // todo
Expand Down
47 changes: 15 additions & 32 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ use crate::packet::{IpHeader, NetworkPacket};

use super::tcb::PacketStatus;

#[derive(Debug, Default)]
#[derive(Debug)]
enum Shutdown {
Ready,
Pending(Waker),
#[default]
None,
}

Expand Down Expand Up @@ -77,7 +76,7 @@ impl IpStackTcpStream {
packet_to_send: None,
tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout),
mtu,
shutdown: Shutdown::default(),
shutdown: Shutdown::None,
write_notify: None,
};
if tcp.inner().syn {
Expand Down Expand Up @@ -199,12 +198,12 @@ impl AsyncRead for IpStackTcpStream {
.send(packet)
.or(Err(ErrorKind::UnexpectedEof))?;
}
if self.tcb.get_state() == TcpState::Closed {
if *self.tcb.get_state() == TcpState::Closed {
self.shutdown.ready();
return Poll::Ready(Ok(()));
}

if self.tcb.get_state() == TcpState::FinWait2(false) {
if *self.tcb.get_state() == TcpState::FinWait2(false) {
self.packet_to_send =
Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
Expand All @@ -226,7 +225,7 @@ impl AsyncRead for IpStackTcpStream {
}
self.tcb.reset_timeout();

if self.tcb.get_state() == TcpState::SynReceived(false) {
if *self.tcb.get_state() == TcpState::SynReceived(false) {
self.packet_to_send =
Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
Expand All @@ -246,15 +245,15 @@ impl AsyncRead for IpStackTcpStream {
.or(Err(ErrorKind::UnexpectedEof))?;
return Poll::Ready(Ok(()));
}
if self.tcb.get_state() == TcpState::FinWait1(true) {
if *self.tcb.get_state() == TcpState::FinWait1(true) {
self.packet_to_send =
Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.add_ack(1);
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if matches!(self.shutdown, Shutdown::Pending(_))
&& self.tcb.get_state() == TcpState::Established
&& *self.tcb.get_state() == TcpState::Established
&& self.tcb.get_last_ack() == self.tcb.get_seq()
{
self.packet_to_send =
Expand All @@ -279,13 +278,13 @@ impl AsyncRead for IpStackTcpStream {
continue;
}

if self.tcb.get_state() == TcpState::SynReceived(true) {
if *self.tcb.get_state() == TcpState::SynReceived(true) {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.change_state(TcpState::Established);
}
} else if self.tcb.get_state() == TcpState::Established {
} else if *self.tcb.get_state() == TcpState::Established {
if t.flags() == ACK {
match self.tcb.check_pkt_type(&t, &p.payload) {
PacketStatus::WindowUpdate => {
Expand Down Expand Up @@ -327,21 +326,13 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb
.add_unordered_packet(t.inner().sequence_number, p.payload);
// buf.put_slice(&p.payload);
// self.tcb.add_ack(p.payload.len() as u32);
// self.packet_to_send = Some(self.create_rev_packet(
// ACK,
// TTL,
// None,
// Vec::new(),
// )?);

self.tcb.change_send_window(t.inner().window_size);
if let Some(ref n) = self.write_notify {
n.wake_by_ref();
self.write_notify = None;
};
continue;
// return Poll::Ready(Ok(()));
}
PacketStatus::Ack => {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
Expand Down Expand Up @@ -376,21 +367,13 @@ impl AsyncRead for IpStackTcpStream {
continue;
}

// self.tcb.add_ack(p.payload.len() as u32);
self.tcb.change_send_window(t.inner().window_size);
// buf.put_slice(&p.payload);
// self.packet_to_send = Some(self.create_rev_packet(
// ACK,
// TTL,
// None,
// Vec::new(),
// )?);
// return Poll::Ready(Ok(()));

self.tcb
.add_unordered_packet(t.inner().sequence_number, p.payload);
continue;
}
} else if self.tcb.get_state() == TcpState::FinWait1(false) {
} else if *self.tcb.get_state() == TcpState::FinWait1(false) {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.add_ack(1);
Expand All @@ -404,7 +387,7 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.change_state(TcpState::FinWait2(true));
continue;
}
} else if self.tcb.get_state() == TcpState::FinWait2(true) {
} else if *self.tcb.get_state() == TcpState::FinWait2(true) {
if t.flags() == ACK {
self.tcb.change_state(TcpState::FinWait2(false));
} else if t.flags() == (FIN | ACK) {
Expand All @@ -427,7 +410,7 @@ impl AsyncWrite for IpStackTcpStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if self.tcb.get_state() != TcpState::Established {
if *self.tcb.get_state() != TcpState::Established {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
self.tcb.reset_timeout();
Expand Down Expand Up @@ -462,7 +445,7 @@ impl AsyncWrite for IpStackTcpStream {
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
if self.tcb.get_state() != TcpState::Established {
if *self.tcb.get_state() != TcpState::Established {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
if let Some(i) = self
Expand Down
8 changes: 4 additions & 4 deletions src/stream/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct IpStackUdpStream {
dst_addr: SocketAddr,
stream_sender: PacketSender,
stream_receiver: PacketReceiver,
packet_sender: PacketSender,
pkt_sender: PacketSender,
first_payload: Option<Vec<u8>>,
timeout: Pin<Box<Sleep>>,
udp_timeout: Duration,
Expand All @@ -28,7 +28,7 @@ impl IpStackUdpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
payload: Vec<u8>,
packet_sender: PacketSender,
pkt_sender: PacketSender,
mtu: u16,
udp_timeout: Duration,
) -> Self {
Expand All @@ -39,7 +39,7 @@ impl IpStackUdpStream {
dst_addr,
stream_sender,
stream_receiver,
packet_sender,
pkt_sender,
first_payload: Some(payload),
timeout: Box::pin(tokio::time::sleep_until(deadline)),
udp_timeout,
Expand Down Expand Up @@ -156,7 +156,7 @@ impl AsyncWrite for IpStackUdpStream {
self.reset_timeout();
let packet = self.create_rev_packet(TTL, buf.to_vec())?;
let payload_len = packet.payload.len();
self.packet_sender
self.pkt_sender
.send(packet)
.or(Err(std::io::ErrorKind::UnexpectedEof))?;
std::task::Poll::Ready(Ok(payload_len))
Expand Down

0 comments on commit 1b51ac0

Please sign in to comment.