Skip to content

Commit

Permalink
reading code
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Mar 3, 2024
1 parent 8187e9a commit d1d76a3
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 131 deletions.
4 changes: 2 additions & 2 deletions examples/tun2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
//!
use clap::Parser;
use etherparse::{IcmpEchoHeader, Icmpv4Header};
use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber};
use ipstack::stream::IpStackStream;
use std::net::{Ipv4Addr, SocketAddr};
use tokio::net::TcpStream;
Expand Down Expand Up @@ -103,7 +103,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});
}
IpStackStream::UnknownTransport(u) => {
if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() {
if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP {
let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?;
if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type {
println!("ICMPv4 echo");
Expand Down
4 changes: 2 additions & 2 deletions examples/tun_wintun.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::net::{Ipv4Addr, SocketAddr};

use clap::Parser;
use etherparse::{IcmpEchoHeader, Icmpv4Header};
use etherparse::{IcmpEchoHeader, Icmpv4Header, IpNumber};
use ipstack::stream::IpStackStream;
use tokio::net::TcpStream;
use udp_stream::UdpStream;
Expand Down Expand Up @@ -82,7 +82,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});
}
IpStackStream::UnknownTransport(u) => {
if u.src_addr().is_ipv4() && u.ip_protocol() == 1.into() {
if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP {
let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?;
if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type {
println!("ICMPv4 echo");
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub enum IpStackError {
#[error("ValueTooBigError<u16> {0}")]
ValueTooBigErrorU16(#[from] etherparse::err::ValueTooBigError<u16>),

#[error("From<ValueTooBigError<u32>> {0}")]
#[error("ValueTooBigError<u32> {0}")]
ValueTooBigErrorU32(#[from] etherparse::err::ValueTooBigError<u32>),

#[error("ValueTooBigError<usize> {0}")]
Expand Down
23 changes: 5 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,9 @@ impl IpStack {

match streams.entry(packet.network_tuple()){
Occupied(entry) =>{
// let t = packet.transport_protocol();
if let Err(_x) = entry.get().send(packet){
#[cfg(feature = "log")]
trace!("{}", _x);
// match t{
// IpStackPacketProtocol::Tcp(_t) => {
// // dbg!(t.flags());
// }
// IpStackPacketProtocol::Udp => {
// // dbg!("udp");
// }
// IpStackPacketProtocol::Unknown => {
// // dbg!("unknown");
// }
// }

}
}
Vacant(entry) => {
Expand Down Expand Up @@ -183,11 +170,11 @@ impl IpStack {

IpStack { accept_receiver }
}

pub async fn accept(&mut self) -> Result<IpStackStream, IpStackError> {
if let Some(s) = self.accept_receiver.recv().await {
Ok(s)
} else {
Err(IpStackError::AcceptError)
}
self.accept_receiver
.recv()
.await
.ok_or(IpStackError::AcceptError)
}
}
6 changes: 5 additions & 1 deletion src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use etherparse::{NetHeaders, PacketHeaders, TcpHeader, UdpHeader};

use crate::error::IpStackError;

#[derive(Eq, Hash, PartialEq, Debug)]
#[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)]
pub struct NetworkTuple {
pub src: SocketAddr,
pub dst: SocketAddr,
Expand All @@ -21,18 +21,21 @@ pub mod tcp_flags {
pub const FIN: u8 = 0b00000001;
}

#[derive(Debug, Clone)]
pub(crate) enum IpStackPacketProtocol {
Tcp(TcpPacket),
Unknown,
Udp,
}

#[derive(Debug, Clone)]
pub(crate) enum TransportHeader {
Tcp(TcpHeader),
Udp(UdpHeader),
Unknown,
}

#[derive(Debug, Clone)]
pub struct NetworkPacket {
pub(crate) ip: NetHeaders,
pub(crate) transport: TransportHeader,
Expand Down Expand Up @@ -130,6 +133,7 @@ impl NetworkPacket {
}
}

#[derive(Debug, Clone)]
pub(super) struct TcpPacket {
header: TcpHeader,
}
Expand Down
8 changes: 5 additions & 3 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub(super) enum PacketStatus {
KeepAlive,
}

#[derive(Debug)]
pub(super) struct Tcb {
pub(super) seq: u32,
pub(super) retransmission: Option<u32>,
Expand All @@ -47,15 +48,14 @@ pub(super) struct Tcb {
impl Tcb {
pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb {
let seq = 100;
let deadline = tokio::time::Instant::now() + tcp_timeout;
Tcb {
seq,
retransmission: None,
ack,
last_ack: seq,
tcp_timeout,
timeout: Box::pin(tokio::time::sleep_until(
tokio::time::Instant::now() + tcp_timeout,
)),
timeout: Box::pin(tokio::time::sleep_until(deadline)),
send_window: u16::MAX,
recv_window: 0,
state: TcpState::SynReceived(false),
Expand Down Expand Up @@ -200,6 +200,7 @@ impl Tcb {
}
}

#[derive(Debug, Clone)]
pub struct InflightPacket {
pub seq: u32,
pub payload: Vec<u8>,
Expand All @@ -219,6 +220,7 @@ impl InflightPacket {
}
}

#[derive(Debug, Clone)]
pub struct UnorderedPacket {
pub payload: Vec<u8>,
pub recv_time: SystemTime,
Expand Down
74 changes: 27 additions & 47 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
stream::tcb::{Tcb, TcpState},
DROP_TTL, TTL,
};
use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions};
use etherparse::{IpNumber, Ipv4Extensions, Ipv4Header, Ipv6Extensions};
use std::{
cmp,
future::Future,
Expand All @@ -28,6 +28,7 @@ use crate::packet::NetworkPacket;

use super::tcb::PacketStatus;

#[derive(Debug)]
pub struct IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
Expand Down Expand Up @@ -65,27 +66,26 @@ impl IpStackTcpStream {
write_notify: None,
};
if !tcp.inner().syn {
let flags = tcp_flags::RST | tcp_flags::ACK;
pkt_sender
.send(stream.create_rev_packet(
tcp_flags::RST | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?)
.send(stream.create_rev_packet(flags, TTL, None, Vec::new())?)
.map_err(|_| IpStackError::InvalidTcpPacket)?;
stream.tcb.change_state(TcpState::Closed);
}
Ok(stream)
}

pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}

fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 {
cmp::min(
self.tcb.get_send_window(),
self.mtu.saturating_sub(ip_header_size + tcp_header_size),
)
}

fn create_rev_packet(
&self,
flags: u8,
Expand Down Expand Up @@ -119,7 +119,7 @@ impl IpStackTcpStream {

let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) {
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
let mut ip_h = Ipv4Header::new(0, ttl, 6.into(), dst.octets(), src.octets())
let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets())
.map_err(IpStackError::from)?;
let payload_len = self.calculate_payload_len(
ip_h.header_len() as u16,
Expand All @@ -136,7 +136,7 @@ impl IpStackTcpStream {
traffic_class: 0,
flow_label: 0.try_into().map_err(IpStackError::from)?,
payload_length: 0,
next_header: 6.into(),
next_header: IpNumber::TCP,
hop_limit: ttl,
source: dst.octets(),
destination: src.octets(),
Expand Down Expand Up @@ -171,9 +171,11 @@ impl IpStackTcpStream {
payload,
})
}

pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}

pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
Expand All @@ -200,24 +202,16 @@ impl AsyncRead for IpStackTcpStream {
) {
#[cfg(feature = "log")]
trace!("timeout reached for {:?}", self.dst_addr);
let flags = tcp_flags::RST | tcp_flags::ACK;
self.packet_sender
.send(self.create_rev_packet(
tcp_flags::RST | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?)
.send(self.create_rev_packet(flags, TTL, None, Vec::new())?)
.map_err(|_| ErrorKind::UnexpectedEof)?;
return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut)));
}

if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) {
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::SYN | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
let flags = tcp_flags::SYN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::SynReceived(true));
}
Expand All @@ -243,12 +237,8 @@ impl AsyncRead for IpStackTcpStream {
}
if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) {
self.tcb.change_state(TcpState::FinWait1);
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::FIN | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
continue;
}
match self.stream_receiver.poll_recv(cx) {
Expand Down Expand Up @@ -357,12 +347,9 @@ impl AsyncRead for IpStackTcpStream {
}
if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
self.tcb.add_ack(1);
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::FIN | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(true));
continue;
Expand Down Expand Up @@ -398,12 +385,9 @@ impl AsyncRead for IpStackTcpStream {
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait1) {
if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
let flags = tcp_flags::ACK;
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(false));
Expand Down Expand Up @@ -442,8 +426,8 @@ impl AsyncWrite for IpStackTcpStream {
}
}

let packet =
self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?;
let flags = tcp_flags::PSH | tcp_flags::ACK;
let packet = self.create_rev_packet(flags, TTL, None, buf.to_vec())?;
let seq = self.tcb.seq;
let payload_len = packet.payload.len();
let payload = packet.payload.clone();
Expand All @@ -466,12 +450,8 @@ impl AsyncWrite for IpStackTcpStream {
.and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s))
.and_then(|p| self.tcb.inflight_packets.get(p))
{
let packet = self.create_rev_packet(
tcp_flags::PSH | tcp_flags::ACK,
TTL,
Some(i.seq),
i.payload.to_vec(),
)?;
let flags = tcp_flags::PSH | tcp_flags::ACK;
let packet = self.create_rev_packet(flags, TTL, Some(i.seq), i.payload.to_vec())?;

self.packet_sender
.send(packet)
Expand Down
Loading

0 comments on commit d1d76a3

Please sign in to comment.