From d1d76a37f3df5d842f63f179de47cf36c990707b Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sat, 2 Mar 2024 12:51:59 +0800 Subject: [PATCH 1/5] reading code --- examples/tun2.rs | 4 +- examples/tun_wintun.rs | 4 +- src/error.rs | 2 +- src/lib.rs | 23 ++------- src/packet.rs | 6 ++- src/stream/tcb.rs | 8 +-- src/stream/tcp.rs | 74 ++++++++++------------------ src/stream/udp.rs | 109 ++++++++++++++++++++--------------------- src/stream/unknown.rs | 2 +- 9 files changed, 101 insertions(+), 131 deletions(-) diff --git a/examples/tun2.rs b/examples/tun2.rs index 91b03b5..9c9a6e0 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -26,7 +26,7 @@ //! use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber}; use ipstack::stream::IpStackStream; use std::net::{Ipv4Addr, SocketAddr}; use tokio::net::TcpStream; @@ -103,7 +103,7 @@ async fn main() -> Result<(), Box> { }); } IpStackStream::UnknownTransport(u) => { - if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() { + if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { println!("ICMPv4 echo"); diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index 76f4616..397cc85 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -1,7 +1,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber}; use ipstack::stream::IpStackStream; use tokio::net::TcpStream; use udp_stream::UdpStream; @@ -82,7 +82,7 @@ async fn main() -> Result<(), Box> { }); } IpStackStream::UnknownTransport(u) => { - if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() { + if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { println!("ICMPv4 echo"); diff --git a/src/error.rs b/src/error.rs index 4d15246..02750fc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,7 +10,7 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorU16(#[from] etherparse::err::ValueTooBigError), - #[error("From> {0}")] + #[error("ValueTooBigError {0}")] ValueTooBigErrorU32(#[from] etherparse::err::ValueTooBigError), #[error("ValueTooBigError {0}")] diff --git a/src/lib.rs b/src/lib.rs index 02b0e44..5c90d2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,22 +109,9 @@ impl IpStack { match streams.entry(packet.network_tuple()){ Occupied(entry) =>{ - // let t = packet.transport_protocol(); if let Err(_x) = entry.get().send(packet){ #[cfg(feature = "log")] trace!("{}", _x); - // match t{ - // IpStackPacketProtocol::Tcp(_t) => { - // // dbg!(t.flags()); - // } - // IpStackPacketProtocol::Udp => { - // // dbg!("udp"); - // } - // IpStackPacketProtocol::Unknown => { - // // dbg!("unknown"); - // } - // } - } } Vacant(entry) => { @@ -183,11 +170,11 @@ impl IpStack { IpStack { accept_receiver } } + pub async fn accept(&mut self) -> Result { - if let Some(s) = self.accept_receiver.recv().await { - Ok(s) - } else { - Err(IpStackError::AcceptError) - } + self.accept_receiver + .recv() + .await + .ok_or(IpStackError::AcceptError) } } diff --git a/src/packet.rs b/src/packet.rs index 9685594..1a4da70 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -4,7 +4,7 @@ use etherparse::{NetHeaders, PacketHeaders, TcpHeader, UdpHeader}; use crate::error::IpStackError; -#[derive(Eq, Hash, PartialEq, Debug)] +#[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)] pub struct NetworkTuple { pub src: SocketAddr, pub dst: SocketAddr, @@ -21,18 +21,21 @@ pub mod tcp_flags { pub const FIN: u8 = 0b00000001; } +#[derive(Debug, Clone)] pub(crate) enum IpStackPacketProtocol { Tcp(TcpPacket), Unknown, Udp, } +#[derive(Debug, Clone)] pub(crate) enum TransportHeader { Tcp(TcpHeader), Udp(UdpHeader), Unknown, } +#[derive(Debug, Clone)] pub struct NetworkPacket { pub(crate) ip: NetHeaders, pub(crate) transport: TransportHeader, @@ -130,6 +133,7 @@ impl NetworkPacket { } } +#[derive(Debug, Clone)] pub(super) struct TcpPacket { header: TcpHeader, } diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index f3adbe2..a5cf212 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -29,6 +29,7 @@ pub(super) enum PacketStatus { KeepAlive, } +#[derive(Debug)] pub(super) struct Tcb { pub(super) seq: u32, pub(super) retransmission: Option, @@ -47,15 +48,14 @@ pub(super) struct Tcb { impl Tcb { pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { let seq = 100; + let deadline = tokio::time::Instant::now() + tcp_timeout; Tcb { seq, retransmission: None, ack, last_ack: seq, tcp_timeout, - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + tcp_timeout, - )), + timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, recv_window: 0, state: TcpState::SynReceived(false), @@ -200,6 +200,7 @@ impl Tcb { } } +#[derive(Debug, Clone)] pub struct InflightPacket { pub seq: u32, pub payload: Vec, @@ -219,6 +220,7 @@ impl InflightPacket { } } +#[derive(Debug, Clone)] pub struct UnorderedPacket { pub payload: Vec, pub recv_time: SystemTime, diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 714f5b9..bcd206b 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -4,7 +4,7 @@ use crate::{ stream::tcb::{Tcb, TcpState}, DROP_TTL, TTL, }; -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions}; +use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions}; use std::{ cmp, future::Future, @@ -28,6 +28,7 @@ use crate::packet::NetworkPacket; use super::tcb::PacketStatus; +#[derive(Debug)] pub struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -65,27 +66,26 @@ impl IpStackTcpStream { write_notify: None, }; if !tcp.inner().syn { + let flags = tcp_flags::RST | tcp_flags::ACK; pkt_sender - .send(stream.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) + .send(stream.create_rev_packet(flags, TTL, None, Vec::new())?) .map_err(|_| IpStackError::InvalidTcpPacket)?; stream.tcb.change_state(TcpState::Closed); } Ok(stream) } + pub(crate) fn stream_sender(&self) -> UnboundedSender { self.stream_sender.clone() } + fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), self.mtu.saturating_sub(ip_header_size + tcp_header_size), ) } + fn create_rev_packet( &self, flags: u8, @@ -119,7 +119,7 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 6.into(), dst.octets(), src.octets()) + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) .map_err(IpStackError::from)?; let payload_len = self.calculate_payload_len( ip_h.header_len() as u16, @@ -136,7 +136,7 @@ impl IpStackTcpStream { traffic_class: 0, flow_label: 0.try_into().map_err(IpStackError::from)?, payload_length: 0, - next_header: 6.into(), + next_header: IpNumber::TCP, hop_limit: ttl, source: dst.octets(), destination: src.octets(), @@ -171,9 +171,11 @@ impl IpStackTcpStream { payload, }) } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } + pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } @@ -200,24 +202,16 @@ impl AsyncRead for IpStackTcpStream { ) { #[cfg(feature = "log")] trace!("timeout reached for {:?}", self.dst_addr); + let flags = tcp_flags::RST | tcp_flags::ACK; self.packet_sender - .send(self.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) + .send(self.create_rev_packet(flags, TTL, None, Vec::new())?) .map_err(|_| ErrorKind::UnexpectedEof)?; return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::SYN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::SYN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived(true)); } @@ -243,12 +237,8 @@ impl AsyncRead for IpStackTcpStream { } if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) { self.tcb.change_state(TcpState::FinWait1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); continue; } match self.stream_receiver.poll_recv(cx) { @@ -357,12 +347,9 @@ impl AsyncRead for IpStackTcpStream { } if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { self.tcb.add_ack(1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = + Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(true)); continue; @@ -398,12 +385,9 @@ impl AsyncRead for IpStackTcpStream { } } else if matches!(self.tcb.get_state(), TcpState::FinWait1) { if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::ACK; + self.packet_to_send = + Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(false)); @@ -442,8 +426,8 @@ impl AsyncWrite for IpStackTcpStream { } } - let packet = - self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?; + let flags = tcp_flags::PSH | tcp_flags::ACK; + let packet = self.create_rev_packet(flags, TTL, None, buf.to_vec())?; let seq = self.tcb.seq; let payload_len = packet.payload.len(); let payload = packet.payload.clone(); @@ -466,12 +450,8 @@ impl AsyncWrite for IpStackTcpStream { .and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s)) .and_then(|p| self.tcb.inflight_packets.get(p)) { - let packet = self.create_rev_packet( - tcp_flags::PSH | tcp_flags::ACK, - TTL, - Some(i.seq), - i.payload.to_vec(), - )?; + let flags = tcp_flags::PSH | tcp_flags::ACK; + let packet = self.create_rev_packet(flags, TTL, Some(i.seq), i.payload.to_vec())?; self.packet_sender .send(packet) diff --git a/src/stream/udp.rs b/src/stream/udp.rs index 7c144dd..a2d5469 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -1,25 +1,18 @@ -use core::task; -use std::{ - future::Future, - io::{self, Error, ErrorKind}, - net::SocketAddr, - pin::Pin, - task::Poll, - time::Duration, +use crate::{ + packet::{NetworkPacket, TransportHeader}, + IpStackError, TTL, }; - -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader}; +use etherparse::{ + IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel, Ipv6Header, UdpHeader, +}; +use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, time::Sleep, }; -// use crate::packet::TransportHeader; -use crate::{ - packet::{NetworkPacket, TransportHeader}, - IpStackError, TTL, -}; +#[derive(Debug)] pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -37,44 +30,46 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - pkt_sender: UnboundedSender, + packet_sender: UnboundedSender, mtu: u16, udp_timeout: Duration, ) -> Self { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); + let deadline = tokio::time::Instant::now() + udp_timeout; IpStackUdpStream { src_addr, dst_addr, stream_sender, stream_receiver, - packet_sender: pkt_sender.clone(), + packet_sender, first_paload: Some(payload), - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + udp_timeout, - )), + timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, mtu, } } + pub(crate) fn stream_sender(&self) -> UnboundedSender { self.stream_sender.clone() } - fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> Result { + + fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> std::io::Result { + const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 17.into(), dst.octets(), src.octets()) + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()) .map_err(IpStackError::from)?; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.set_payload_len(payload.len() + 8) - .map_err(IpStackError::from)?; // 8 is udp header size + ip_h.set_payload_len(payload.len() + UHS) + .map_err(IpStackError::from)?; let udp_header = UdpHeader::with_ipv4_checksum( self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload, ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: etherparse::NetHeaders::Ipv4(ip_h, Ipv4Extensions::default()), transport: TransportHeader::Udp(udp_header), @@ -84,25 +79,25 @@ impl IpStackUdpStream { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, - next_header: 17.into(), + next_header: IpNumber::UDP, hop_limit: ttl, source: dst.octets(), destination: src.octets(), }; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size + ip_h.payload_length = (payload.len() + UHS) as u16; let udp_header = UdpHeader::with_ipv6_checksum( self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload, ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: etherparse::NetHeaders::Ipv6(ip_h, Ipv6Extensions::default()), transport: TransportHeader::Udp(udp_header), @@ -112,39 +107,44 @@ impl IpStackUdpStream { _ => unreachable!(), } } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } + pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } + + fn reset_timeout(&mut self) { + let deadline = tokio::time::Instant::now() + self.udp_timeout; + self.timeout.as_mut().reset(deadline); + } } impl AsyncRead for IpStackUdpStream { fn poll_read( mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, + cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { + ) -> std::task::Poll> { if let Some(p) = self.first_paload.take() { buf.put_slice(&p); - return Poll::Ready(Ok(())); + return std::task::Poll::Ready(Ok(())); } if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) { - return Poll::Ready(Ok(())); // todo: return timeout error + return std::task::Poll::Ready(Ok(())); // todo: return timeout error } - let udp_timeout = self.udp_timeout; + self.reset_timeout(); + match self.stream_receiver.poll_recv(cx) { - Poll::Ready(Some(p)) => { + std::task::Poll::Ready(Some(p)) => { buf.put_slice(&p.payload); - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + udp_timeout); - Poll::Ready(Ok(())) + std::task::Poll::Ready(Ok(())) } - Poll::Ready(None) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, + std::task::Poll::Ready(None) => std::task::Poll::Ready(Ok(())), + std::task::Poll::Pending => std::task::Poll::Pending, } } } @@ -152,32 +152,29 @@ impl AsyncRead for IpStackUdpStream { impl AsyncWrite for IpStackUdpStream { fn poll_write( mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, + _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> task::Poll> { - let udp_timeout = self.udp_timeout; - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + udp_timeout); + ) -> std::task::Poll> { + self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); self.packet_sender .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + .map_err(|_| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } fn poll_flush( self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } fn poll_shutdown( self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 72adc3e..61917bd 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -88,7 +88,7 @@ impl IpStackUnknownTransport { traffic_class: 0, flow_label: 0.try_into().map_err(crate::IpStackError::from)?, payload_length: 0, - next_header: 17.into(), + next_header: IpNumber::UDP, hop_limit: TTL, source: dst.octets(), destination: src.octets(), From a74475d240a6ccc95ae5bf9538c1d7cc56602c9a Mon Sep 17 00:00:00 2001 From: SajjadPourali Date: Sat, 2 Mar 2024 23:32:57 -0500 Subject: [PATCH 2/5] Remove unused variables --- src/stream/tcb.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index a5cf212..b1eaff2 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,8 +1,4 @@ -use std::{ - collections::BTreeMap, - pin::Pin, - time::{Duration, SystemTime}, -}; +use std::{collections::BTreeMap, pin::Pin, time::Duration}; use tokio::time::Sleep; @@ -42,7 +38,7 @@ pub(super) struct Tcb { state: TcpState, pub(super) avg_send_window: (u64, u64), pub(super) inflight_packets: Vec, - pub(super) unordered_packets: BTreeMap, + unordered_packets: BTreeMap, } impl Tcb { @@ -204,7 +200,7 @@ impl Tcb { pub struct InflightPacket { pub seq: u32, pub payload: Vec, - pub send_time: SystemTime, + // pub send_time: SystemTime, // todo } impl InflightPacket { @@ -212,7 +208,7 @@ impl InflightPacket { Self { seq, payload, - send_time: SystemTime::now(), + // send_time: SystemTime::now(), // todo } } pub(crate) fn contains(&self, seq: u32) -> bool { @@ -221,16 +217,16 @@ impl InflightPacket { } #[derive(Debug, Clone)] -pub struct UnorderedPacket { - pub payload: Vec, - pub recv_time: SystemTime, +struct UnorderedPacket { + payload: Vec, + // pub recv_time: SystemTime, // todo } impl UnorderedPacket { pub(crate) fn new(payload: Vec) -> Self { Self { payload, - recv_time: SystemTime::now(), + // recv_time: SystemTime::now(), // todo } } } From 3e9ea28dbcb31fbef88c44ff61424e5dfee5fb02 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 3 Mar 2024 12:56:52 +0800 Subject: [PATCH 3/5] env_logger for tun2 example --- Cargo.toml | 5 +++++ examples/tun2.rs | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index d84c725..4bcf7e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ tracing = { version = "0.1", default-features = false, features = [ [dev-dependencies] clap = { version = "4.5", features = ["derive"] } +env_logger = "0.11" udp-stream = { version = "0.0", default-features = false } tokio = { version = "1.36", features = [ "rt-multi-thread", @@ -54,3 +55,7 @@ debug-assertions = false # Remove assertions from the binary. incremental = false # Disable incremental compilation. overflow-checks = false # Disable overflow checks. strip = true # Automatically strip symbols from the binary. + +[[example]] +name = "tun2" +required-features = ["log"] diff --git a/examples/tun2.rs b/examples/tun2.rs index 9c9a6e0..a4b4b94 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -35,18 +35,37 @@ use udp_stream::UdpStream; // const MTU: u16 = 1500; const MTU: u16 = u16::MAX; +#[repr(C)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)] +pub enum ArgVerbosity { + Off = 0, + Error, + Warn, + #[default] + Info, + Debug, + Trace, +} + #[derive(Parser)] #[command(author, version, about = "Testing app for tun.", long_about = None)] struct Args { /// echo server address, likes `127.0.0.1:8080` #[arg(short, long, value_name = "IP:port")] server_addr: SocketAddr, + + /// Verbosity level + #[arg(short, long, value_name = "level", value_enum, default_value = "info")] + pub verbosity: ArgVerbosity, } #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::parse(); + let default = format!("{:?}", args.verbosity); + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); + let ipv4 = Ipv4Addr::new(10, 0, 0, 33); let netmask = Ipv4Addr::new(255, 255, 255, 0); let gateway = Ipv4Addr::new(10, 0, 0, 1); From b19a1fa6d15e66cb800930269fddf9de1665366d Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 3 Mar 2024 13:15:24 +0800 Subject: [PATCH 4/5] reset_timeout for tcp --- src/stream/tcb.rs | 8 +++++--- src/stream/tcp.rs | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index b1eaff2..6a18398 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -172,9 +172,6 @@ impl Tcb { } } pub(super) fn change_last_ack(&mut self, ack: u32) { - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + self.tcp_timeout); let distance = ack.wrapping_sub(self.last_ack); if matches!(self.state, TcpState::Established) { @@ -194,6 +191,11 @@ impl Tcb { pub fn is_send_buffer_full(&self) -> bool { self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK } + + pub(crate) fn reset_timeout(&mut self) { + let deadline = tokio::time::Instant::now() + self.tcp_timeout; + self.timeout.as_mut().reset(deadline); + } } #[derive(Debug, Clone)] diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index bcd206b..1cd2499 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -209,6 +209,8 @@ impl AsyncRead for IpStackTcpStream { return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } + self.tcb.reset_timeout(); + if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { let flags = tcp_flags::SYN | tcp_flags::ACK; self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); @@ -412,6 +414,8 @@ impl AsyncWrite for IpStackTcpStream { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { + self.tcb.reset_timeout(); + if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2 || self.tcb.is_send_buffer_full() { From 848ea0d9c63135f27b9c2a9657b258b2716d2aa9 Mon Sep 17 00:00:00 2001 From: SajjadPourali Date: Sun, 3 Mar 2024 00:42:09 -0500 Subject: [PATCH 5/5] Remove unnecessary conversion --- src/error.rs | 3 --- src/stream/tcp.rs | 10 +++++----- src/stream/unknown.rs | 6 ++++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/error.rs b/src/error.rs index 02750fc..360badd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,9 +10,6 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorU16(#[from] etherparse::err::ValueTooBigError), - #[error("ValueTooBigError {0}")] - ValueTooBigErrorU32(#[from] etherparse::err::ValueTooBigError), - #[error("ValueTooBigError {0}")] ValueTooBigErrorUsize(#[from] etherparse::err::ValueTooBigError), diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 1cd2499..077c7bc 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -4,7 +4,7 @@ use crate::{ stream::tcb::{Tcb, TcpState}, DROP_TTL, TTL, }; -use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions}; +use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel}; use std::{ cmp, future::Future, @@ -90,13 +90,13 @@ impl IpStackTcpStream { &self, flags: u8, ttl: u8, - seq: Option, + seq: impl Into>, mut payload: Vec, ) -> Result { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), - seq.unwrap_or(self.tcb.get_seq()), + seq.into().unwrap_or(self.tcb.get_seq()), self.tcb.get_recv_window(), ); @@ -134,7 +134,7 @@ impl IpStackTcpStream { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = etherparse::Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, next_header: IpNumber::TCP, hop_limit: ttl, @@ -455,7 +455,7 @@ impl AsyncWrite for IpStackTcpStream { .and_then(|p| self.tcb.inflight_packets.get(p)) { let flags = tcp_flags::PSH | tcp_flags::ACK; - let packet = self.create_rev_packet(flags, TTL, Some(i.seq), i.payload.to_vec())?; + let packet = self.create_rev_packet(flags, TTL, i.seq, i.payload.to_vec())?; self.packet_sender .send(packet) diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 61917bd..15b6740 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,6 +1,8 @@ use std::{io::Error, mem, net::IpAddr}; -use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, NetHeaders}; +use etherparse::{ + IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel, Ipv6Header, NetHeaders, +}; use tokio::sync::mpsc::UnboundedSender; use crate::{ @@ -86,7 +88,7 @@ impl IpStackUnknownTransport { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(crate::IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, next_header: IpNumber::UDP, hop_limit: TTL,