diff --git a/Cargo.toml b/Cargo.toml index d84c725..4bcf7e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ tracing = { version = "0.1", default-features = false, features = [ [dev-dependencies] clap = { version = "4.5", features = ["derive"] } +env_logger = "0.11" udp-stream = { version = "0.0", default-features = false } tokio = { version = "1.36", features = [ "rt-multi-thread", @@ -54,3 +55,7 @@ debug-assertions = false # Remove assertions from the binary. incremental = false # Disable incremental compilation. overflow-checks = false # Disable overflow checks. strip = true # Automatically strip symbols from the binary. + +[[example]] +name = "tun2" +required-features = ["log"] diff --git a/examples/tun2.rs b/examples/tun2.rs index 91b03b5..a4b4b94 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -26,7 +26,7 @@ //! use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber}; use ipstack::stream::IpStackStream; use std::net::{Ipv4Addr, SocketAddr}; use tokio::net::TcpStream; @@ -35,18 +35,37 @@ use udp_stream::UdpStream; // const MTU: u16 = 1500; const MTU: u16 = u16::MAX; +#[repr(C)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)] +pub enum ArgVerbosity { + Off = 0, + Error, + Warn, + #[default] + Info, + Debug, + Trace, +} + #[derive(Parser)] #[command(author, version, about = "Testing app for tun.", long_about = None)] struct Args { /// echo server address, likes `127.0.0.1:8080` #[arg(short, long, value_name = "IP:port")] server_addr: SocketAddr, + + /// Verbosity level + #[arg(short, long, value_name = "level", value_enum, default_value = "info")] + pub verbosity: ArgVerbosity, } #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::parse(); + let default = format!("{:?}", args.verbosity); + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); + let ipv4 = Ipv4Addr::new(10, 0, 0, 33); let netmask = Ipv4Addr::new(255, 255, 255, 0); let gateway = Ipv4Addr::new(10, 0, 0, 1); @@ -103,7 +122,7 @@ async fn main() -> Result<(), Box> { }); } IpStackStream::UnknownTransport(u) => { - if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() { + if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { println!("ICMPv4 echo"); diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index 76f4616..397cc85 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -1,7 +1,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber}; use ipstack::stream::IpStackStream; use tokio::net::TcpStream; use udp_stream::UdpStream; @@ -82,7 +82,7 @@ async fn main() -> Result<(), Box> { }); } IpStackStream::UnknownTransport(u) => { - if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() { + if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { println!("ICMPv4 echo"); diff --git a/src/error.rs b/src/error.rs index 4d15246..360badd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,9 +10,6 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorU16(#[from] etherparse::err::ValueTooBigError), - #[error("From> {0}")] - ValueTooBigErrorU32(#[from] etherparse::err::ValueTooBigError), - #[error("ValueTooBigError {0}")] ValueTooBigErrorUsize(#[from] etherparse::err::ValueTooBigError), diff --git a/src/lib.rs b/src/lib.rs index 02b0e44..5c90d2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,22 +109,9 @@ impl IpStack { match streams.entry(packet.network_tuple()){ Occupied(entry) =>{ - // let t = packet.transport_protocol(); if let Err(_x) = entry.get().send(packet){ #[cfg(feature = "log")] trace!("{}", _x); - // match t{ - // IpStackPacketProtocol::Tcp(_t) => { - // // dbg!(t.flags()); - // } - // IpStackPacketProtocol::Udp => { - // // dbg!("udp"); - // } - // IpStackPacketProtocol::Unknown => { - // // dbg!("unknown"); - // } - // } - } } Vacant(entry) => { @@ -183,11 +170,11 @@ impl IpStack { IpStack { accept_receiver } } + pub async fn accept(&mut self) -> Result { - if let Some(s) = self.accept_receiver.recv().await { - Ok(s) - } else { - Err(IpStackError::AcceptError) - } + self.accept_receiver + .recv() + .await + .ok_or(IpStackError::AcceptError) } } diff --git a/src/packet.rs b/src/packet.rs index 9685594..1a4da70 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -4,7 +4,7 @@ use etherparse::{NetHeaders, PacketHeaders, TcpHeader, UdpHeader}; use crate::error::IpStackError; -#[derive(Eq, Hash, PartialEq, Debug)] +#[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)] pub struct NetworkTuple { pub src: SocketAddr, pub dst: SocketAddr, @@ -21,18 +21,21 @@ pub mod tcp_flags { pub const FIN: u8 = 0b00000001; } +#[derive(Debug, Clone)] pub(crate) enum IpStackPacketProtocol { Tcp(TcpPacket), Unknown, Udp, } +#[derive(Debug, Clone)] pub(crate) enum TransportHeader { Tcp(TcpHeader), Udp(UdpHeader), Unknown, } +#[derive(Debug, Clone)] pub struct NetworkPacket { pub(crate) ip: NetHeaders, pub(crate) transport: TransportHeader, @@ -130,6 +133,7 @@ impl NetworkPacket { } } +#[derive(Debug, Clone)] pub(super) struct TcpPacket { header: TcpHeader, } diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index f3adbe2..6a18398 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,8 +1,4 @@ -use std::{ - collections::BTreeMap, - pin::Pin, - time::{Duration, SystemTime}, -}; +use std::{collections::BTreeMap, pin::Pin, time::Duration}; use tokio::time::Sleep; @@ -29,6 +25,7 @@ pub(super) enum PacketStatus { KeepAlive, } +#[derive(Debug)] pub(super) struct Tcb { pub(super) seq: u32, pub(super) retransmission: Option, @@ -41,21 +38,20 @@ pub(super) struct Tcb { state: TcpState, pub(super) avg_send_window: (u64, u64), pub(super) inflight_packets: Vec, - pub(super) unordered_packets: BTreeMap, + unordered_packets: BTreeMap, } impl Tcb { pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { let seq = 100; + let deadline = tokio::time::Instant::now() + tcp_timeout; Tcb { seq, retransmission: None, ack, last_ack: seq, tcp_timeout, - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + tcp_timeout, - )), + timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, recv_window: 0, state: TcpState::SynReceived(false), @@ -176,9 +172,6 @@ impl Tcb { } } pub(super) fn change_last_ack(&mut self, ack: u32) { - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + self.tcp_timeout); let distance = ack.wrapping_sub(self.last_ack); if matches!(self.state, TcpState::Established) { @@ -198,12 +191,18 @@ impl Tcb { pub fn is_send_buffer_full(&self) -> bool { self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK } + + pub(crate) fn reset_timeout(&mut self) { + let deadline = tokio::time::Instant::now() + self.tcp_timeout; + self.timeout.as_mut().reset(deadline); + } } +#[derive(Debug, Clone)] pub struct InflightPacket { pub seq: u32, pub payload: Vec, - pub send_time: SystemTime, + // pub send_time: SystemTime, // todo } impl InflightPacket { @@ -211,7 +210,7 @@ impl InflightPacket { Self { seq, payload, - send_time: SystemTime::now(), + // send_time: SystemTime::now(), // todo } } pub(crate) fn contains(&self, seq: u32) -> bool { @@ -219,16 +218,17 @@ impl InflightPacket { } } -pub struct UnorderedPacket { - pub payload: Vec, - pub recv_time: SystemTime, +#[derive(Debug, Clone)] +struct UnorderedPacket { + payload: Vec, + // pub recv_time: SystemTime, // todo } impl UnorderedPacket { pub(crate) fn new(payload: Vec) -> Self { Self { payload, - recv_time: SystemTime::now(), + // recv_time: SystemTime::now(), // todo } } } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 714f5b9..077c7bc 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -4,7 +4,7 @@ use crate::{ stream::tcb::{Tcb, TcpState}, DROP_TTL, TTL, }; -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions}; +use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel}; use std::{ cmp, future::Future, @@ -28,6 +28,7 @@ use crate::packet::NetworkPacket; use super::tcb::PacketStatus; +#[derive(Debug)] pub struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -65,38 +66,37 @@ impl IpStackTcpStream { write_notify: None, }; if !tcp.inner().syn { + let flags = tcp_flags::RST | tcp_flags::ACK; pkt_sender - .send(stream.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) + .send(stream.create_rev_packet(flags, TTL, None, Vec::new())?) .map_err(|_| IpStackError::InvalidTcpPacket)?; stream.tcb.change_state(TcpState::Closed); } Ok(stream) } + pub(crate) fn stream_sender(&self) -> UnboundedSender { self.stream_sender.clone() } + fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), self.mtu.saturating_sub(ip_header_size + tcp_header_size), ) } + fn create_rev_packet( &self, flags: u8, ttl: u8, - seq: Option, + seq: impl Into>, mut payload: Vec, ) -> Result { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), - seq.unwrap_or(self.tcb.get_seq()), + seq.into().unwrap_or(self.tcb.get_seq()), self.tcb.get_recv_window(), ); @@ -119,7 +119,7 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 6.into(), dst.octets(), src.octets()) + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) .map_err(IpStackError::from)?; let payload_len = self.calculate_payload_len( ip_h.header_len() as u16, @@ -134,9 +134,9 @@ impl IpStackTcpStream { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = etherparse::Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, - next_header: 6.into(), + next_header: IpNumber::TCP, hop_limit: ttl, source: dst.octets(), destination: src.octets(), @@ -171,9 +171,11 @@ impl IpStackTcpStream { payload, }) } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } + pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } @@ -200,24 +202,18 @@ impl AsyncRead for IpStackTcpStream { ) { #[cfg(feature = "log")] trace!("timeout reached for {:?}", self.dst_addr); + let flags = tcp_flags::RST | tcp_flags::ACK; self.packet_sender - .send(self.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) + .send(self.create_rev_packet(flags, TTL, None, Vec::new())?) .map_err(|_| ErrorKind::UnexpectedEof)?; return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } + self.tcb.reset_timeout(); + if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::SYN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::SYN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived(true)); } @@ -243,12 +239,8 @@ impl AsyncRead for IpStackTcpStream { } if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) { self.tcb.change_state(TcpState::FinWait1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); continue; } match self.stream_receiver.poll_recv(cx) { @@ -357,12 +349,9 @@ impl AsyncRead for IpStackTcpStream { } if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { self.tcb.add_ack(1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = + Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(true)); continue; @@ -398,12 +387,9 @@ impl AsyncRead for IpStackTcpStream { } } else if matches!(self.tcb.get_state(), TcpState::FinWait1) { if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); + let flags = tcp_flags::ACK; + self.packet_to_send = + Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(false)); @@ -428,6 +414,8 @@ impl AsyncWrite for IpStackTcpStream { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { + self.tcb.reset_timeout(); + if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2 || self.tcb.is_send_buffer_full() { @@ -442,8 +430,8 @@ impl AsyncWrite for IpStackTcpStream { } } - let packet = - self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?; + let flags = tcp_flags::PSH | tcp_flags::ACK; + let packet = self.create_rev_packet(flags, TTL, None, buf.to_vec())?; let seq = self.tcb.seq; let payload_len = packet.payload.len(); let payload = packet.payload.clone(); @@ -466,12 +454,8 @@ impl AsyncWrite for IpStackTcpStream { .and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s)) .and_then(|p| self.tcb.inflight_packets.get(p)) { - let packet = self.create_rev_packet( - tcp_flags::PSH | tcp_flags::ACK, - TTL, - Some(i.seq), - i.payload.to_vec(), - )?; + let flags = tcp_flags::PSH | tcp_flags::ACK; + let packet = self.create_rev_packet(flags, TTL, i.seq, i.payload.to_vec())?; self.packet_sender .send(packet) diff --git a/src/stream/udp.rs b/src/stream/udp.rs index 7c144dd..a2d5469 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -1,25 +1,18 @@ -use core::task; -use std::{ - future::Future, - io::{self, Error, ErrorKind}, - net::SocketAddr, - pin::Pin, - task::Poll, - time::Duration, +use crate::{ + packet::{NetworkPacket, TransportHeader}, + IpStackError, TTL, }; - -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader}; +use etherparse::{ + IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel, Ipv6Header, UdpHeader, +}; +use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, time::Sleep, }; -// use crate::packet::TransportHeader; -use crate::{ - packet::{NetworkPacket, TransportHeader}, - IpStackError, TTL, -}; +#[derive(Debug)] pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -37,44 +30,46 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - pkt_sender: UnboundedSender, + packet_sender: UnboundedSender, mtu: u16, udp_timeout: Duration, ) -> Self { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); + let deadline = tokio::time::Instant::now() + udp_timeout; IpStackUdpStream { src_addr, dst_addr, stream_sender, stream_receiver, - packet_sender: pkt_sender.clone(), + packet_sender, first_paload: Some(payload), - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + udp_timeout, - )), + timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, mtu, } } + pub(crate) fn stream_sender(&self) -> UnboundedSender { self.stream_sender.clone() } - fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> Result { + + fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> std::io::Result { + const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 17.into(), dst.octets(), src.octets()) + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()) .map_err(IpStackError::from)?; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.set_payload_len(payload.len() + 8) - .map_err(IpStackError::from)?; // 8 is udp header size + ip_h.set_payload_len(payload.len() + UHS) + .map_err(IpStackError::from)?; let udp_header = UdpHeader::with_ipv4_checksum( self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload, ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: etherparse::NetHeaders::Ipv4(ip_h, Ipv4Extensions::default()), transport: TransportHeader::Udp(udp_header), @@ -84,25 +79,25 @@ impl IpStackUdpStream { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, - next_header: 17.into(), + next_header: IpNumber::UDP, hop_limit: ttl, source: dst.octets(), destination: src.octets(), }; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size + ip_h.payload_length = (payload.len() + UHS) as u16; let udp_header = UdpHeader::with_ipv6_checksum( self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload, ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: etherparse::NetHeaders::Ipv6(ip_h, Ipv6Extensions::default()), transport: TransportHeader::Udp(udp_header), @@ -112,39 +107,44 @@ impl IpStackUdpStream { _ => unreachable!(), } } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } + pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } + + fn reset_timeout(&mut self) { + let deadline = tokio::time::Instant::now() + self.udp_timeout; + self.timeout.as_mut().reset(deadline); + } } impl AsyncRead for IpStackUdpStream { fn poll_read( mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, + cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { + ) -> std::task::Poll> { if let Some(p) = self.first_paload.take() { buf.put_slice(&p); - return Poll::Ready(Ok(())); + return std::task::Poll::Ready(Ok(())); } if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) { - return Poll::Ready(Ok(())); // todo: return timeout error + return std::task::Poll::Ready(Ok(())); // todo: return timeout error } - let udp_timeout = self.udp_timeout; + self.reset_timeout(); + match self.stream_receiver.poll_recv(cx) { - Poll::Ready(Some(p)) => { + std::task::Poll::Ready(Some(p)) => { buf.put_slice(&p.payload); - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + udp_timeout); - Poll::Ready(Ok(())) + std::task::Poll::Ready(Ok(())) } - Poll::Ready(None) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, + std::task::Poll::Ready(None) => std::task::Poll::Ready(Ok(())), + std::task::Poll::Pending => std::task::Poll::Pending, } } } @@ -152,32 +152,29 @@ impl AsyncRead for IpStackUdpStream { impl AsyncWrite for IpStackUdpStream { fn poll_write( mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, + _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> task::Poll> { - let udp_timeout = self.udp_timeout; - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + udp_timeout); + ) -> std::task::Poll> { + self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); self.packet_sender .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + .map_err(|_| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } fn poll_flush( self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } fn poll_shutdown( self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 72adc3e..15b6740 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,6 +1,8 @@ use std::{io::Error, mem, net::IpAddr}; -use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, NetHeaders}; +use etherparse::{ + IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6FlowLabel, Ipv6Header, NetHeaders, +}; use tokio::sync::mpsc::UnboundedSender; use crate::{ @@ -86,9 +88,9 @@ impl IpStackUnknownTransport { (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { let mut ip_h = Ipv6Header { traffic_class: 0, - flow_label: 0.try_into().map_err(crate::IpStackError::from)?, + flow_label: Ipv6FlowLabel::ZERO, payload_length: 0, - next_header: 17.into(), + next_header: IpNumber::UDP, hop_limit: TTL, source: dst.octets(), destination: src.octets(),