diff --git a/examples/tun2.rs b/examples/tun2.rs index 91b03b5..9c9a6e0 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -26,7 +26,7 @@ //! use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber}; use ipstack::stream::IpStackStream; use std::net::{Ipv4Addr, SocketAddr}; use tokio::net::TcpStream; @@ -103,7 +103,7 @@ async fn main() -> Result<(), Box> { }); } IpStackStream::UnknownTransport(u) => { - if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() { + if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { println!("ICMPv4 echo"); diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index 76f4616..397cc85 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -1,7 +1,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber}; use ipstack::stream::IpStackStream; use tokio::net::TcpStream; use udp_stream::UdpStream; @@ -82,7 +82,7 @@ async fn main() -> Result<(), Box> { }); } IpStackStream::UnknownTransport(u) => { - if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() { + if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { println!("ICMPv4 echo"); diff --git a/src/error.rs b/src/error.rs index 4d15246..02750fc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,7 +10,7 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorU16(#[from] etherparse::err::ValueTooBigError), - #[error("From> {0}")] + #[error("ValueTooBigError {0}")] ValueTooBigErrorU32(#[from] etherparse::err::ValueTooBigError), #[error("ValueTooBigError {0}")] diff --git a/src/lib.rs b/src/lib.rs index 02b0e44..5c90d2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,22 +109,9 @@ impl IpStack { match streams.entry(packet.network_tuple()){ Occupied(entry) =>{ - // let t = packet.transport_protocol(); if let Err(_x) = entry.get().send(packet){ #[cfg(feature = "log")] trace!("{}", _x); - // match t{ - // IpStackPacketProtocol::Tcp(_t) => { - // // dbg!(t.flags()); - // } - // IpStackPacketProtocol::Udp => { - // // dbg!("udp"); - // } - // IpStackPacketProtocol::Unknown => { - // // dbg!("unknown"); - // } - // } - } } Vacant(entry) => { @@ -183,11 +170,11 @@ impl IpStack { IpStack { accept_receiver } } + pub async fn accept(&mut self) -> Result { - if let Some(s) = self.accept_receiver.recv().await { - Ok(s) - } else { - Err(IpStackError::AcceptError) - } + self.accept_receiver + .recv() + .await + .ok_or(IpStackError::AcceptError) } } diff --git a/src/packet.rs b/src/packet.rs index 9685594..1a4da70 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -4,7 +4,7 @@ use etherparse::{NetHeaders, PacketHeaders, TcpHeader, UdpHeader}; use crate::error::IpStackError; -#[derive(Eq, Hash, PartialEq, Debug)] +#[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)] pub struct NetworkTuple { pub src: SocketAddr, pub dst: SocketAddr, @@ -21,18 +21,21 @@ pub mod tcp_flags { pub const FIN: u8 = 0b00000001; } +#[derive(Debug, Clone)] pub(crate) enum IpStackPacketProtocol { Tcp(TcpPacket), Unknown, Udp, } +#[derive(Debug, Clone)] pub(crate) enum TransportHeader { Tcp(TcpHeader), Udp(UdpHeader), Unknown, } +#[derive(Debug, Clone)] pub struct NetworkPacket { pub(crate) ip: NetHeaders, pub(crate) transport: TransportHeader, @@ -130,6 +133,7 @@ impl NetworkPacket { } } +#[derive(Debug, Clone)] pub(super) struct TcpPacket { header: TcpHeader, } diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index f3adbe2..a5cf212 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -29,6 +29,7 @@ pub(super) enum PacketStatus { KeepAlive, } +#[derive(Debug)] pub(super) struct Tcb { pub(super) seq: u32, pub(super) retransmission: Option, @@ -47,15 +48,14 @@ pub(super) struct Tcb { impl Tcb { pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { let seq = 100; + let deadline = tokio::time::Instant::now() + tcp_timeout; Tcb { seq, retransmission: None, ack, last_ack: seq, tcp_timeout, - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + tcp_timeout, - )), + timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, recv_window: 0, state: TcpState::SynReceived(false), @@ -200,6 +200,7 @@ impl Tcb { } } +#[derive(Debug, Clone)] pub struct InflightPacket { pub seq: u32, pub payload: Vec, @@ -219,6 +220,7 @@ impl InflightPacket { } } +#[derive(Debug, Clone)] pub struct UnorderedPacket { pub payload: Vec, pub recv_time: SystemTime, diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 714f5b9..bcd206b 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -4,7 +4,7 @@ use crate::{ stream::tcb::{Tcb, TcpState}, DROP_TTL, TTL, }; -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions}; +use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions}; use std::{ cmp, future::Future, @@ -28,6 +28,7 @@ use crate::packet::NetworkPacket; use super::tcb::PacketStatus; +#[derive(Debug)] pub struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -65,27 +66,26 @@ impl IpStackTcpStream { write_notify: None, }; if !tcp.inner().syn { + let flags = tcp_flags::RST | tcp_flags::ACK; pkt_sender - .send(stream.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) + .send(stream.create_rev_packet(flags, TTL, None, Vec::new())?) .map_err(|_| IpStackError::InvalidTcpPacket)?; stream.tcb.change_state(TcpState::Closed); } Ok(stream) } + pub(crate) fn stream_sender(&self) -> UnboundedSender { self.stream_sender.clone() } + fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), self.mtu.saturating_sub(ip_header_size + tcp_header_size), ) } + fn create_rev_packet( &self, flags: u8, @@ -119,7 +119,7 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 6.into(), dst.octets(), src.octets()) + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) .map_err(IpStackError::from)?; let payload_len = self.calculate_payload_len( ip_h.header_len() as u16, @@ -136,7 +136,7 @@ impl IpStackTcpStream { traffic_class: 0, flow_label: 0.try_into().map_err(IpStackError::from)?, payload_length: 0, - next_header: 6.into(), + next_header: IpNumber::TCP, hop_limit: ttl, source: dst.octets(), destination: src.octets(), @@ -171,9 +171,11 @@ impl IpStackTcpStream { payload, }) } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } + pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } @@ -200,24 +202,16 @@ impl AsyncRead for IpStackTcpStream { ) { #[cfg(feature = "log")] trace!("timeout reached for {:?}", self.dst_addr); + let flags = tcp_flags::RST | tcp_flags::ACK; self.packet_sender - .send(self.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) + .send(self.create_rev_packet(flags, TTL, None, Vec::new())?) .map_err(|_| ErrorKind::UnexpectedEof)?; return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::SYN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::SYN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived(true)); } @@ -243,12 +237,8 @@ impl AsyncRead for IpStackTcpStream { } if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) { self.tcb.change_state(TcpState::FinWait1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); continue; } match self.stream_receiver.poll_recv(cx) { @@ -357,12 +347,9 @@ impl AsyncRead for IpStackTcpStream { } if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { self.tcb.add_ack(1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = + Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(true)); continue; @@ -398,12 +385,9 @@ impl AsyncRead for IpStackTcpStream { } } else if matches!(self.tcb.get_state(), TcpState::FinWait1) { if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::ACK; + self.packet_to_send = + Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(false)); @@ -442,8 +426,8 @@ impl AsyncWrite for IpStackTcpStream { } } - let packet = - self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?; + let flags = tcp_flags::PSH | tcp_flags::ACK; + let packet = self.create_rev_packet(flags, TTL, None, buf.to_vec())?; let seq = self.tcb.seq; let payload_len = packet.payload.len(); let payload = packet.payload.clone(); @@ -466,12 +450,8 @@ impl AsyncWrite for IpStackTcpStream { .and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s)) .and_then(|p| self.tcb.inflight_packets.get(p)) { - let packet = self.create_rev_packet( - tcp_flags::PSH | tcp_flags::ACK, - TTL, - Some(i.seq), - i.payload.to_vec(), - )?; + let flags = tcp_flags::PSH | tcp_flags::ACK; + let packet = self.create_rev_packet(flags, TTL, Some(i.seq), i.payload.to_vec())?; self.packet_sender .send(packet) diff --git a/src/stream/udp.rs b/src/stream/udp.rs index 7c144dd..21c8d23 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -1,25 +1,18 @@ -use core::task; -use std::{ - future::Future, - io::{self, Error, ErrorKind}, - net::SocketAddr, - pin::Pin, - task::Poll, - time::Duration, +use crate::{ + packet::{NetworkPacket, TransportHeader}, + IpStackError, TTL, }; - -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader}; +use etherparse::{ + IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel, Ipv6Header, UdpHeader, +}; +use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, time::Sleep, }; -// use crate::packet::TransportHeader; -use crate::{ - packet::{NetworkPacket, TransportHeader}, - IpStackError, TTL, -}; +#[derive(Debug)] pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -37,7 +30,7 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - pkt_sender: UnboundedSender, + packet_sender: UnboundedSender, mtu: u16, udp_timeout: Duration, ) -> Self { @@ -47,7 +40,7 @@ impl IpStackUdpStream { dst_addr, stream_sender, stream_receiver, - packet_sender: pkt_sender.clone(), + packet_sender, first_paload: Some(payload), timeout: Box::pin(tokio::time::sleep_until( tokio::time::Instant::now() + udp_timeout, @@ -56,25 +49,28 @@ impl IpStackUdpStream { mtu, } } + pub(crate) fn stream_sender(&self) -> UnboundedSender { self.stream_sender.clone() } - fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> Result { + + fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> std::io::Result { + const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 17.into(), dst.octets(), src.octets()) + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()) .map_err(IpStackError::from)?; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.set_payload_len(payload.len() + 8) - .map_err(IpStackError::from)?; // 8 is udp header size + ip_h.set_payload_len(payload.len() + UHS) + .map_err(IpStackError::from)?; let udp_header = UdpHeader::with_ipv4_checksum( self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload, ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: etherparse::NetHeaders::Ipv4(ip_h, Ipv4Extensions::default()), transport: TransportHeader::Udp(udp_header), @@ -84,25 +80,25 @@ impl IpStackUdpStream { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, - next_header: 17.into(), + next_header: IpNumber::UDP, hop_limit: ttl, source: dst.octets(), destination: src.octets(), }; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size + ip_h.payload_length = (payload.len() + UHS) as u16; let udp_header = UdpHeader::with_ipv6_checksum( self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload, ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: etherparse::NetHeaders::Ipv6(ip_h, Ipv6Extensions::default()), transport: TransportHeader::Udp(udp_header), @@ -112,9 +108,11 @@ impl IpStackUdpStream { _ => unreachable!(), } } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } + pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } @@ -123,28 +121,28 @@ impl IpStackUdpStream { impl AsyncRead for IpStackUdpStream { fn poll_read( mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, + cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { + ) -> std::task::Poll> { if let Some(p) = self.first_paload.take() { buf.put_slice(&p); - return Poll::Ready(Ok(())); + return std::task::Poll::Ready(Ok(())); } if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) { - return Poll::Ready(Ok(())); // todo: return timeout error + return std::task::Poll::Ready(Ok(())); // todo: return timeout error } let udp_timeout = self.udp_timeout; match self.stream_receiver.poll_recv(cx) { - Poll::Ready(Some(p)) => { + std::task::Poll::Ready(Some(p)) => { buf.put_slice(&p.payload); self.timeout .as_mut() .reset(tokio::time::Instant::now() + udp_timeout); - Poll::Ready(Ok(())) + std::task::Poll::Ready(Ok(())) } - Poll::Ready(None) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, + std::task::Poll::Ready(None) => std::task::Poll::Ready(Ok(())), + std::task::Poll::Pending => std::task::Poll::Pending, } } } @@ -152,9 +150,9 @@ impl AsyncRead for IpStackUdpStream { impl AsyncWrite for IpStackUdpStream { fn poll_write( mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, + _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> task::Poll> { + ) -> std::task::Poll> { let udp_timeout = self.udp_timeout; self.timeout .as_mut() @@ -163,21 +161,21 @@ impl AsyncWrite for IpStackUdpStream { let payload_len = packet.payload.len(); self.packet_sender .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + .map_err(|_| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } fn poll_flush( self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } fn poll_shutdown( self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 72adc3e..61917bd 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -88,7 +88,7 @@ impl IpStackUnknownTransport { traffic_class: 0, flow_label: 0.try_into().map_err(crate::IpStackError::from)?, payload_length: 0, - next_header: 17.into(), + next_header: IpNumber::UDP, hop_limit: TTL, source: dst.octets(), destination: src.octets(),