diff --git a/examples/tun2.rs b/examples/tun2.rs index 1df739f..ca81ca3 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -18,6 +18,7 @@ //! route add 1.2.3.4 mask 255.255.255.255 10.0.0.1 metric 100 # Windows //! sudo route add 1.2.3.4/32 10.0.0.1 # macOS //! ``` +//! //! Now you can test it with `nc 1.2.3.4 any_port` or `nc -u 1.2.3.4 any_port`. //! You can watch the echo information in the `nc` console. //! ``` @@ -55,6 +56,14 @@ struct Args { #[arg(short, long, value_name = "IP:port")] server_addr: SocketAddr, + /// tcp timeout + #[arg(long, value_name = "seconds", default_value = "60")] + tcp_timeout: u64, + + /// udp timeout + #[arg(long, value_name = "seconds", default_value = "10")] + udp_timeout: u64, + /// Verbosity level #[arg(short, long, value_name = "level", value_enum, default_value = "info")] pub verbosity: ArgVerbosity, @@ -69,64 +78,83 @@ async fn main() -> Result<(), Box> { let ipv4 = Ipv4Addr::new(10, 0, 0, 33); let netmask = Ipv4Addr::new(255, 255, 255, 0); - let gateway = Ipv4Addr::new(10, 0, 0, 1); + let _gateway = Ipv4Addr::new(10, 0, 0, 1); - let mut config = tun2::Configuration::default(); - config.address(ipv4).netmask(netmask).mtu(MTU).up(); - config.destination(gateway); + let mut tun_config = tun2::Configuration::default(); + tun_config.address(ipv4).netmask(netmask).mtu(MTU).up(); + #[cfg(not(target_os = "windows"))] + tun_config.destination(_gateway); // avoid routing all traffic to tun on Windows platform #[cfg(target_os = "linux")] - config.platform_config(|config| { - config.ensure_root_privileges(true); + tun_config.platform_config(|p_cfg| { + p_cfg.ensure_root_privileges(true); }); #[cfg(target_os = "windows")] - config.platform_config(|config| { - config.device_guid(Some(12324323423423434234_u128)); + tun_config.platform_config(|p_cfg| { + p_cfg.device_guid(Some(12324323423423434234_u128)); }); let mut ipstack_config = ipstack::IpStackConfig::default(); ipstack_config.mtu(MTU); + ipstack_config.tcp_timeout(std::time::Duration::from_secs(args.tcp_timeout)); + ipstack_config.udp_timeout(std::time::Duration::from_secs(args.udp_timeout)); - let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun2::create_as_async(&config)?); + let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun2::create_as_async(&tun_config)?); let server_addr = args.server_addr; + let count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let serial_number = std::sync::atomic::AtomicUsize::new(0); + loop { + let count = count.clone(); + let number = serial_number.fetch_add(1, std::sync::atomic::Ordering::Relaxed); match ip_stack.accept().await? { IpStackStream::Tcp(mut tcp) => { let mut s = match TcpStream::connect(server_addr).await { Ok(s) => s, Err(e) => { - println!("connect TCP server failed \"{}\"", e); + log::info!("connect TCP server failed \"{}\"", e); continue; } }; - println!("==== New TCP connection ===="); + let c = count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + let number1 = number; + log::info!("#{number1} TCP connecting, session count {c}"); tokio::spawn(async move { - let _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await; - println!("====== end tcp connection ======"); + if let Err(err) = tokio::io::copy_bidirectional(&mut tcp, &mut s).await { + log::info!("#{number1} TCP error: {}", err); + } + let c = count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) - 1; + log::info!("#{number1} TCP closed, session count {c}"); }); } IpStackStream::Udp(mut udp) => { let mut s = match UdpStream::connect(server_addr).await { Ok(s) => s, Err(e) => { - println!("connect UDP server failed \"{}\"", e); + log::info!("connect UDP server failed \"{}\"", e); continue; } }; - println!("==== New UDP connection ===="); + let c = count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + let number2 = number; + log::info!("#{number2} UDP connecting, session count {c}"); tokio::spawn(async move { - let _ = tokio::io::copy_bidirectional(&mut udp, &mut s).await; - println!("==== end UDP connection ===="); + if let Err(err) = tokio::io::copy_bidirectional(&mut udp, &mut s).await { + log::info!("#{number2} UDP error: {}", err); + } + let c = count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) - 1; + log::info!("#{number2} UDP closed, session count {c}"); }); } IpStackStream::UnknownTransport(u) => { + let n = number; if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { - println!("ICMPv4 echo"); + log::info!("#{n} ICMPv4 echo"); let echo = IcmpEchoHeader { id: req.id, seq: req.seq, @@ -137,15 +165,15 @@ async fn main() -> Result<(), Box> { payload.extend_from_slice(req_payload); u.send(payload)?; } else { - println!("ICMPv4"); + log::info!("#{n} ICMPv4"); } continue; } - println!("unknown transport - Ip Protocol {:?}", u.ip_protocol()); + log::info!("#{n} unknown transport - Ip Protocol {:?}", u.ip_protocol()); continue; } IpStackStream::UnknownNetwork(pkt) => { - println!("unknown transport - {} bytes", pkt.len()); + log::info!("#{number} unknown transport - {} bytes", pkt.len()); continue; } }; diff --git a/src/lib.rs b/src/lib.rs index 6fa7ade..c0659b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,10 @@ use tokio::{ task::JoinHandle, }; +pub(crate) type PacketSender = UnboundedSender; +pub(crate) type PacketReceiver = UnboundedReceiver; +pub(crate) type SessionCollection = AHashMap; + mod error; mod packet; pub mod stream; @@ -62,17 +66,21 @@ impl Default for IpStackConfig { } impl IpStackConfig { - pub fn tcp_timeout(&mut self, timeout: Duration) { + pub fn tcp_timeout(&mut self, timeout: Duration) -> &mut Self { self.tcp_timeout = timeout; + self } - pub fn udp_timeout(&mut self, timeout: Duration) { + pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self { self.udp_timeout = timeout; + self } - pub fn mtu(&mut self, mtu: u16) { + pub fn mtu(&mut self, mtu: u16) -> &mut Self { self.mtu = mtu; + self } - pub fn packet_information(&mut self, packet_information: bool) { + pub fn packet_information(&mut self, packet_information: bool) -> &mut Self { self.packet_information = packet_information; + self } } @@ -111,12 +119,9 @@ fn run( where D: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let mut streams: AHashMap> = AHashMap::new(); - let offset = if config.packet_information && cfg!(unix) { - 4 - } else { - 0 - }; + let mut sessions: SessionCollection = AHashMap::new(); + let pi = config.packet_information; + let offset = if pi && cfg!(unix) { 4 } else { 0 }; let mut buffer = [0_u8; u16::MAX as usize + 4]; let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); @@ -124,9 +129,9 @@ where loop { select! { Ok(n) = device.read(&mut buffer) => { - if let Some(stream) = process_read( + if let Some(stream) = process_device_read( &buffer[offset..n], - &mut streams, + &mut sessions, &pkt_sender, &config, ) { @@ -134,12 +139,12 @@ where } } Some(packet) = pkt_receiver.recv() => { - process_recv( + process_upstream_recv( packet, - &mut streams, + &mut sessions, &mut device, #[cfg(unix)] - config.packet_information, + pi, ) .await?; } @@ -148,10 +153,10 @@ where }) } -fn process_read( +fn process_device_read( data: &[u8], - streams: &mut AHashMap>, - pkt_sender: &UnboundedSender, + sessions: &mut SessionCollection, + pkt_sender: &PacketSender, config: &IpStackConfig, ) -> Option { let Ok(packet) = NetworkPacket::parse(data) else { @@ -171,7 +176,7 @@ fn process_read( )); } - match streams.entry(packet.network_tuple()) { + match sessions.entry(packet.network_tuple()) { Occupied(mut entry) => { if let Err(e) = entry.get().send(packet) { trace!("New stream because: {}", e); @@ -193,8 +198,8 @@ fn process_read( fn create_stream( packet: NetworkPacket, config: &IpStackConfig, - pkt_sender: &UnboundedSender, -) -> Option<(UnboundedSender, IpStackStream)> { + pkt_sender: &PacketSender, +) -> Option<(PacketSender, IpStackStream)> { match packet.transport_protocol() { IpStackPacketProtocol::Tcp(h) => { match IpStackTcpStream::new( @@ -233,9 +238,9 @@ fn create_stream( } } -async fn process_recv( +async fn process_upstream_recv( packet: NetworkPacket, - streams: &mut AHashMap>, + sessions: &mut SessionCollection, device: &mut D, #[cfg(unix)] packet_information: bool, ) -> Result<()> @@ -243,23 +248,23 @@ where D: AsyncWrite + Unpin + 'static, { if packet.ttl() == 0 { - streams.remove(&packet.reverse_network_tuple()); + sessions.remove(&packet.reverse_network_tuple()); return Ok(()); } #[allow(unused_mut)] - let Ok(mut packet_byte) = packet.to_bytes() else { + let Ok(mut packet_bytes) = packet.to_bytes() else { trace!("to_bytes error"); return Ok(()); }; #[cfg(unix)] if packet_information { if packet.src_addr().is_ipv4() { - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); + packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); } else { - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); + packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); } } - device.write_all(&packet_byte).await?; + device.write_all(&packet_bytes).await?; // device.flush().await.unwrap(); Ok(()) diff --git a/src/packet.rs b/src/packet.rs index f1eb7c8..3a25340 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -22,7 +22,7 @@ pub mod tcp_flags { #[derive(Debug, Clone)] pub(crate) enum IpStackPacketProtocol { - Tcp(TcpPacket), + Tcp(TcpHeaderWrapper), Unknown, Udp, } @@ -145,11 +145,11 @@ impl NetworkPacket { } #[derive(Debug, Clone)] -pub(super) struct TcpPacket { +pub(super) struct TcpHeaderWrapper { header: TcpHeader, } -impl TcpPacket { +impl TcpHeaderWrapper { pub fn inner(&self) -> &TcpHeader { &self.header } @@ -185,9 +185,9 @@ impl TcpPacket { } } -impl From<&TcpHeader> for TcpPacket { +impl From<&TcpHeader> for TcpHeaderWrapper { fn from(header: &TcpHeader) -> Self { - TcpPacket { + TcpHeaderWrapper { header: header.clone(), } } diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 2b70cf6..ff44486 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,11 +1,11 @@ -use crate::packet::TcpPacket; +use crate::packet::TcpHeaderWrapper; use std::{collections::BTreeMap, pin::Pin, time::Duration}; use tokio::time::Sleep; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum TcpState { SynReceived(bool), // bool means if syn/ack is sent Established, @@ -13,7 +13,8 @@ pub enum TcpState { FinWait2(bool), // bool means waiting for ack Closed, } -#[derive(Clone, Debug)] + +#[derive(Clone, Debug, PartialEq)] pub(super) enum PacketStatus { WindowUpdate, Invalid, @@ -25,16 +26,16 @@ pub(super) enum PacketStatus { #[derive(Debug)] pub(super) struct Tcb { - pub(super) seq: u32, + seq: u32, pub(super) retransmission: Option, - pub(super) ack: u32, - pub(super) last_ack: u32, + ack: u32, + last_ack: u32, pub(super) timeout: Pin>, tcp_timeout: Duration, recv_window: u16, - pub(super) send_window: u16, + send_window: u16, state: TcpState, - pub(super) avg_send_window: (u64, u64), + avg_send_window: (u64, u64), // (avg, count) pub(super) inflight_packets: Vec, unordered_packets: BTreeMap, } @@ -85,9 +86,7 @@ impl Tcb { // for (seq,_) in self.unordered_packets.iter() { // dbg!(seq); // } - self.unordered_packets - .remove(&self.ack) - .map(|p| p.payload.clone()) + self.unordered_packets.remove(&self.ack).map(|p| p.payload) } pub(super) fn add_seq_one(&mut self) { self.seq = self.seq.wrapping_add(1); @@ -101,11 +100,14 @@ impl Tcb { pub(super) fn get_ack(&self) -> u32 { self.ack } + pub(super) fn get_last_ack(&self) -> u32 { + self.last_ack + } pub(super) fn change_state(&mut self, state: TcpState) { self.state = state; } - pub(super) fn get_state(&self) -> &TcpState { - &self.state + pub(super) fn get_state(&self) -> TcpState { + self.state.clone() } pub(super) fn change_send_window(&mut self, window: u16) { let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) @@ -117,6 +119,9 @@ impl Tcb { pub(super) fn get_send_window(&self) -> u16 { self.send_window } + pub(super) fn get_avg_send_window(&self) -> u64 { + self.avg_send_window.0 + } pub(super) fn change_recv_window(&mut self, window: u16) { self.recv_window = window; } @@ -136,33 +141,27 @@ impl Tcb { // } // } - pub(super) fn check_pkt_type(&self, incoming_packet: &TcpPacket, p: &[u8]) -> PacketStatus { - let received_ack_distance = self - .seq - .wrapping_sub(incoming_packet.inner().acknowledgment_number); + pub(super) fn check_pkt_type(&self, header: &TcpHeaderWrapper, p: &[u8]) -> PacketStatus { + let tcp_header = header.inner(); + let received_ack_distance = self.seq.wrapping_sub(tcp_header.acknowledgment_number); let current_ack_distance = self.seq.wrapping_sub(self.last_ack); if received_ack_distance > current_ack_distance - || (incoming_packet.inner().acknowledgment_number != self.seq - && self - .seq - .saturating_sub(incoming_packet.inner().acknowledgment_number) - == 0) + || (tcp_header.acknowledgment_number != self.seq + && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) { PacketStatus::Invalid - } else if self.last_ack == incoming_packet.inner().acknowledgment_number { + } else if self.last_ack == tcp_header.acknowledgment_number { if !p.is_empty() { PacketStatus::NewPacket - } else if self.send_window == incoming_packet.inner().window_size - && self.seq != self.last_ack - { + } else if self.send_window == tcp_header.window_size && self.seq != self.last_ack { PacketStatus::RetransmissionRequest - } else if self.ack.wrapping_sub(1) == incoming_packet.inner().sequence_number { + } else if self.ack.wrapping_sub(1) == tcp_header.sequence_number { PacketStatus::KeepAlive } else { PacketStatus::WindowUpdate } - } else if self.last_ack < incoming_packet.inner().acknowledgment_number { + } else if self.last_ack < tcp_header.acknowledgment_number { if !p.is_empty() { PacketStatus::NewPacket } else { @@ -176,12 +175,12 @@ impl Tcb { let distance = ack.wrapping_sub(self.last_ack); self.last_ack = self.last_ack.wrapping_add(distance); - if matches!(self.state, TcpState::Established) { + if self.state == TcpState::Established { if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { let mut inflight_packet = self.inflight_packets.remove(i); - let distance = ack.wrapping_sub(inflight_packet.seq); - if (distance as usize) < inflight_packet.payload.len() { - inflight_packet.payload.drain(0..distance as usize); + let distance = ack.wrapping_sub(inflight_packet.seq) as usize; + if distance < inflight_packet.payload.len() { + inflight_packet.payload.drain(0..distance); inflight_packet.seq = ack; self.inflight_packets.push(inflight_packet); } @@ -219,7 +218,7 @@ impl InflightPacket { } } pub(crate) fn contains(&self, seq: u32) -> bool { - self.seq < seq && self.seq + self.payload.len() as u32 >= seq + self.seq < seq && seq <= self.seq + self.payload.len() as u32 } } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index fdf7d68..b5aa424 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -2,10 +2,10 @@ use crate::{ error::IpStackError, packet::{ tcp_flags::{ACK, FIN, NON, PSH, RST, SYN}, - IpStackPacketProtocol, TcpPacket, TransportHeader, + IpStackPacketProtocol, TcpHeaderWrapper, TransportHeader, }, stream::tcb::{Tcb, TcpState}, - DROP_TTL, TTL, + PacketReceiver, PacketSender, DROP_TTL, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel}; use std::{ @@ -18,10 +18,7 @@ use std::{ task::{Context, Poll, Waker}, time::Duration, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{UnboundedReceiver, UnboundedSender}, -}; +use tokio::io::{AsyncRead, AsyncWrite}; use log::{trace, warn}; @@ -53,8 +50,8 @@ impl Shutdown { pub(crate) struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, - stream_receiver: UnboundedReceiver, - packet_sender: UnboundedSender, + stream_receiver: PacketReceiver, + packet_sender: PacketSender, packet_to_send: Option, tcb: Tcb, mtu: u16, @@ -66,9 +63,9 @@ impl IpStackTcpStream { pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, - tcp: TcpPacket, - pkt_sender: UnboundedSender, - stream_receiver: UnboundedReceiver, + tcp: TcpHeaderWrapper, + packet_sender: PacketSender, + stream_receiver: PacketReceiver, mtu: u16, tcp_timeout: Duration, ) -> Result { @@ -76,7 +73,7 @@ impl IpStackTcpStream { src_addr, dst_addr, stream_receiver, - packet_sender: pkt_sender.clone(), + packet_sender, packet_to_send: None, tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), mtu, @@ -87,7 +84,10 @@ impl IpStackTcpStream { return Ok(stream); } if !tcp.inner().rst { - _ = pkt_sender.send(stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?); + let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; + if let Err(err) = stream.packet_sender.send(pkt) { + log::warn!("Error sending RST/ACK packet: {:?}", err); + } } Err(IpStackError::InvalidTcpPacket) } @@ -159,7 +159,8 @@ impl IpStackTcpStream { tcp_header.header_len() as u16, ); payload.truncate(payload_len as usize); - ip_h.payload_length = (payload.len() + tcp_header.header_len()) as u16; + let len = payload.len() + tcp_header.header_len(); + ip_h.set_payload_length(len).map_err(IpStackError::from)?; IpHeader::Ipv6(ip_h) } @@ -193,17 +194,27 @@ impl AsyncRead for IpStackTcpStream { buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { loop { - if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) - && self.packet_to_send.is_none() - { + if let Some(packet) = self.packet_to_send.take() { + self.packet_sender + .send(packet) + .or(Err(ErrorKind::UnexpectedEof))?; + } + if self.tcb.get_state() == TcpState::Closed { + self.shutdown.ready(); + return Poll::Ready(Ok(())); + } + + if self.tcb.get_state() == TcpState::FinWait2(false) { self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); - return Poll::Ready(Ok(())); + return Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))); } + let min = self.tcb.get_available_read_buffer_size() as u16; self.tcb.change_recv_window(min); + if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { trace!("timeout reached for {:?}", self.dst_addr); self.packet_sender @@ -213,26 +224,16 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } - self.tcb.reset_timeout(); - if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { + if self.tcb.get_state() == TcpState::SynReceived(false) { self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived(true)); - } - - if let Some(packet) = self.packet_to_send.take() { - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; - if matches!(self.tcb.get_state(), TcpState::Closed) { - self.shutdown.ready(); - return Poll::Ready(Ok(())); - } continue; } + if let Some(b) = self .tcb .get_unordered_packets() @@ -245,7 +246,7 @@ impl AsyncRead for IpStackTcpStream { .or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); } - if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) { + if self.tcb.get_state() == TcpState::FinWait1(true) { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); @@ -253,8 +254,8 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) - && matches!(self.tcb.get_state(), TcpState::Established) - && self.tcb.last_ack == self.tcb.seq + && self.tcb.get_state() == TcpState::Established + && self.tcb.get_last_ack() == self.tcb.get_seq() { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); @@ -274,20 +275,17 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); } - if matches!( - self.tcb.check_pkt_type(&t, &p.payload), - PacketStatus::Invalid - ) { + if self.tcb.check_pkt_type(&t, &p.payload) == PacketStatus::Invalid { continue; } - if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) { + if self.tcb.get_state() == TcpState::SynReceived(true) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::Established); } - } else if matches!(self.tcb.get_state(), TcpState::Established) { + } else if self.tcb.get_state() == TcpState::Established { if t.flags() == ACK { match self.tcb.check_pkt_type(&t, &p.payload) { PacketStatus::WindowUpdate => { @@ -394,7 +392,7 @@ impl AsyncRead for IpStackTcpStream { .add_unordered_packet(t.inner().sequence_number, &p.payload); continue; } - } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { + } else if self.tcb.get_state() == TcpState::FinWait1(false) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.add_ack(1); @@ -408,7 +406,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(true)); continue; } - } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) { + } else if self.tcb.get_state() == TcpState::FinWait2(true) { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { @@ -431,12 +429,12 @@ impl AsyncWrite for IpStackTcpStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if !matches!(self.tcb.get_state(), TcpState::Established) { + if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } self.tcb.reset_timeout(); - if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2 + if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); @@ -451,7 +449,7 @@ impl AsyncWrite for IpStackTcpStream { } let packet = self.create_rev_packet(PSH | ACK, TTL, None, buf.to_vec())?; - let seq = self.tcb.seq; + let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); let payload = packet.payload.clone(); self.packet_sender @@ -466,7 +464,7 @@ impl AsyncWrite for IpStackTcpStream { mut self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - if !matches!(self.tcb.get_state(), TcpState::Established) { + if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } if let Some(i) = self @@ -484,9 +482,9 @@ impl AsyncWrite for IpStackTcpStream { } else if let Some(_i) = self.tcb.retransmission { { warn!("{}", _i); - warn!("{}", self.tcb.seq); - warn!("{}", self.tcb.last_ack); - warn!("{}", self.tcb.ack); + warn!("{}", self.tcb.get_seq()); + warn!("{}", self.tcb.get_last_ack()); + warn!("{}", self.tcb.get_ack()); for p in self.tcb.inflight_packets.iter() { warn!("{}", p.seq); warn!("{}", p.payload.len()); @@ -516,7 +514,9 @@ impl AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - _ = self.packet_sender.send(p); + if let Err(err) = self.packet_sender.send(p) { + log::trace!("Error sending NON packet: {:?}", err); + } } } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index dd1d6eb..e6653b9 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -1,28 +1,24 @@ use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; use crate::{ - packet::{NetworkPacket, TcpPacket}, - IpStackError, + packet::{NetworkPacket, TcpHeaderWrapper}, + IpStackError, PacketSender, }; use std::{net::SocketAddr, pin::Pin, time::Duration}; -use tokio::{ - io::AsyncWriteExt, - sync::mpsc::{self, UnboundedSender}, - time::timeout, -}; +use tokio::{io::AsyncWriteExt, sync::mpsc, time::timeout}; pub struct IpStackTcpStream { inner: Option>, peer_addr: SocketAddr, local_addr: SocketAddr, - stream_sender: mpsc::UnboundedSender, + stream_sender: PacketSender, } impl IpStackTcpStream { pub(crate) fn new( local_addr: SocketAddr, peer_addr: SocketAddr, - tcp: TcpPacket, - pkt_sender: UnboundedSender, + tcp: TcpHeaderWrapper, + pkt_sender: PacketSender, mtu: u16, tcp_timeout: Duration, ) -> Result { @@ -36,9 +32,8 @@ impl IpStackTcpStream { mtu, tcp_timeout, ) - .map(Box::new) .map(|inner| IpStackTcpStream { - inner: Some(inner), + inner: Some(Box::new(inner)), peer_addr, local_addr, stream_sender, @@ -50,7 +45,7 @@ impl IpStackTcpStream { pub fn peer_addr(&self) -> SocketAddr { self.peer_addr } - pub fn stream_sender(&self) -> UnboundedSender { + pub fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } } @@ -111,7 +106,9 @@ impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Some(mut inner) = self.inner.take() { tokio::spawn(async move { - _ = timeout(Duration::from_secs(2), inner.shutdown()).await; + if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await { + log::warn!("Error while dropping IpStackTcpStream: {:?}", err); + } }); } } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index b7e6bc0..c9edc90 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -1,12 +1,12 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - IpStackError, TTL, + IpStackError, PacketReceiver, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader}; use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + sync::mpsc, time::Sleep, }; @@ -14,10 +14,10 @@ use tokio::{ pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, - stream_sender: UnboundedSender, - stream_receiver: UnboundedReceiver, - packet_sender: UnboundedSender, - first_paload: Option>, + stream_sender: PacketSender, + stream_receiver: PacketReceiver, + packet_sender: PacketSender, + first_payload: Option>, timeout: Pin>, udp_timeout: Duration, mtu: u16, @@ -28,7 +28,7 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - packet_sender: UnboundedSender, + packet_sender: PacketSender, mtu: u16, udp_timeout: Duration, ) -> Self { @@ -40,14 +40,14 @@ impl IpStackUdpStream { stream_sender, stream_receiver, packet_sender, - first_paload: Some(payload), + first_payload: Some(payload), timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, mtu, } } - pub(crate) fn stream_sender(&self) -> UnboundedSender { + pub(crate) fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } @@ -126,7 +126,7 @@ impl AsyncRead for IpStackUdpStream { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - if let Some(p) = self.first_paload.take() { + if let Some(p) = self.first_payload.take() { buf.put_slice(&p); return std::task::Poll::Ready(Ok(())); } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 5eccc5d..838d93f 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,10 +1,9 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - TTL, + PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header}; use std::{io::Error, mem, net::IpAddr}; -use tokio::sync::mpsc::UnboundedSender; pub struct IpStackUnknownTransport { src_addr: IpAddr, @@ -12,7 +11,7 @@ pub struct IpStackUnknownTransport { payload: Vec, protocol: IpNumber, mtu: u16, - packet_sender: UnboundedSender, + packet_sender: PacketSender, } impl IpStackUnknownTransport { @@ -22,7 +21,7 @@ impl IpStackUnknownTransport { payload: Vec, ip: &IpHeader, mtu: u16, - packet_sender: UnboundedSender, + packet_sender: PacketSender, ) -> Self { let protocol = match ip { IpHeader::Ipv4(ip) => ip.protocol,