Skip to content

Commit

Permalink
reading code
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Apr 18, 2024
1 parent f5ae7c8 commit 3db90b1
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 59 deletions.
52 changes: 28 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use tokio::{
task::JoinHandle,
};

pub(crate) type PacketSender = UnboundedSender<NetworkPacket>;
pub(crate) type PacketReceiver = UnboundedReceiver<NetworkPacket>;

mod error;
mod packet;
pub mod stream;
Expand Down Expand Up @@ -62,17 +65,21 @@ impl Default for IpStackConfig {
}

impl IpStackConfig {
pub fn tcp_timeout(&mut self, timeout: Duration) {
pub fn tcp_timeout(&mut self, timeout: Duration) -> &mut Self {
self.tcp_timeout = timeout;
self
}
pub fn udp_timeout(&mut self, timeout: Duration) {
pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self {
self.udp_timeout = timeout;
self
}
pub fn mtu(&mut self, mtu: u16) {
pub fn mtu(&mut self, mtu: u16) -> &mut Self {
self.mtu = mtu;
self
}
pub fn packet_information(&mut self, packet_information: bool) {
pub fn packet_information(&mut self, packet_information: bool) -> &mut Self {
self.packet_information = packet_information;
self
}
}

Expand Down Expand Up @@ -111,20 +118,17 @@ fn run<D>(
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut streams: AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = AHashMap::new();
let offset = if config.packet_information && cfg!(unix) {
4
} else {
0
};
let mut streams: AHashMap<NetworkTuple, PacketSender> = AHashMap::new();
let pi = config.packet_information;
let offset = if pi && cfg!(unix) { 4 } else { 0 };
let mut buffer = [0_u8; u16::MAX as usize + 4];
let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

tokio::spawn(async move {
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
if let Some(stream) = process_read(
if let Some(stream) = process_device_read(
&buffer[offset..n],
&mut streams,
&pkt_sender,
Expand All @@ -134,12 +138,12 @@ where
}
}
Some(packet) = pkt_receiver.recv() => {
process_recv(
process_upstream_recv(
packet,
&mut streams,
&mut device,
#[cfg(unix)]
config.packet_information,
pi,
)
.await?;
}
Expand All @@ -148,10 +152,10 @@ where
})
}

fn process_read(
fn process_device_read(
data: &[u8],
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
pkt_sender: &UnboundedSender<NetworkPacket>,
streams: &mut AHashMap<NetworkTuple, PacketSender>,
pkt_sender: &PacketSender,
config: &IpStackConfig,
) -> Option<IpStackStream> {
let Ok(packet) = NetworkPacket::parse(data) else {
Expand Down Expand Up @@ -193,8 +197,8 @@ fn process_read(
fn create_stream(
packet: NetworkPacket,
config: &IpStackConfig,
pkt_sender: &UnboundedSender<NetworkPacket>,
) -> Option<(UnboundedSender<NetworkPacket>, IpStackStream)> {
pkt_sender: &PacketSender,
) -> Option<(PacketSender, IpStackStream)> {
match packet.transport_protocol() {
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
Expand Down Expand Up @@ -233,9 +237,9 @@ fn create_stream(
}
}

async fn process_recv<D>(
async fn process_upstream_recv<D>(
packet: NetworkPacket,
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
streams: &mut AHashMap<NetworkTuple, PacketSender>,
device: &mut D,
#[cfg(unix)] packet_information: bool,
) -> Result<()>
Expand All @@ -247,19 +251,19 @@ where
return Ok(());
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else {
let Ok(mut packet_bytes) = packet.to_bytes() else {
trace!("to_bytes error");
return Ok(());
};
#[cfg(unix)]
if packet_information {
if packet.src_addr().is_ipv4() {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
device.write_all(&packet_bytes).await?;
// device.flush().await.unwrap();

Ok(())
Expand Down
6 changes: 3 additions & 3 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ impl Tcb {
if matches!(self.state, TcpState::Established) {
if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) {
let mut inflight_packet = self.inflight_packets.remove(i);
let distance = ack.wrapping_sub(inflight_packet.seq);
if (distance as usize) < inflight_packet.payload.len() {
inflight_packet.payload.drain(0..distance as usize);
let distance = ack.wrapping_sub(inflight_packet.seq) as usize;
if distance < inflight_packet.payload.len() {
inflight_packet.payload.drain(0..distance);
inflight_packet.seq = ack;
self.inflight_packets.push(inflight_packet);
}
Expand Down
15 changes: 6 additions & 9 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
IpStackPacketProtocol, TcpHeaderWrapper, TransportHeader,
},
stream::tcb::{Tcb, TcpState},
DROP_TTL, TTL,
PacketReceiver, PacketSender, DROP_TTL, TTL,
};
use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel};
use std::{
Expand All @@ -18,10 +18,7 @@ use std::{
task::{Context, Poll, Waker},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
};
use tokio::io::{AsyncRead, AsyncWrite};

use log::{trace, warn};

Expand Down Expand Up @@ -53,8 +50,8 @@ impl Shutdown {
pub(crate) struct IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_receiver: UnboundedReceiver<NetworkPacket>,
packet_sender: UnboundedSender<NetworkPacket>,
stream_receiver: PacketReceiver,
packet_sender: PacketSender,
packet_to_send: Option<NetworkPacket>,
tcb: Tcb,
mtu: u16,
Expand All @@ -67,8 +64,8 @@ impl IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
tcp: TcpHeaderWrapper,
pkt_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
pkt_sender: PacketSender,
stream_receiver: PacketReceiver,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
Expand Down
14 changes: 5 additions & 9 deletions src/stream/tcp_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
use super::tcp::IpStackTcpStream as IpStackTcpStreamInner;
use crate::{
packet::{NetworkPacket, TcpHeaderWrapper},
IpStackError,
IpStackError, PacketSender,
};
use std::{net::SocketAddr, pin::Pin, time::Duration};
use tokio::{
io::AsyncWriteExt,
sync::mpsc::{self, UnboundedSender},
time::timeout,
};
use tokio::{io::AsyncWriteExt, sync::mpsc, time::timeout};

pub struct IpStackTcpStream {
inner: Option<Box<IpStackTcpStreamInner>>,
peer_addr: SocketAddr,
local_addr: SocketAddr,
stream_sender: mpsc::UnboundedSender<NetworkPacket>,
stream_sender: PacketSender,
}

impl IpStackTcpStream {
pub(crate) fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
tcp: TcpHeaderWrapper,
pkt_sender: UnboundedSender<NetworkPacket>,
pkt_sender: PacketSender,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
Expand Down Expand Up @@ -50,7 +46,7 @@ impl IpStackTcpStream {
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
pub fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
pub fn stream_sender(&self) -> PacketSender {
self.stream_sender.clone()
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/stream/udp.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
use crate::{
packet::{IpHeader, NetworkPacket, TransportHeader},
IpStackError, TTL,
IpStackError, PacketReceiver, PacketSender, TTL,
};
use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader};
use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
sync::mpsc,
time::Sleep,
};

#[derive(Debug)]
pub struct IpStackUdpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
packet_sender: UnboundedSender<NetworkPacket>,
first_paload: Option<Vec<u8>>,
stream_sender: PacketSender,
stream_receiver: PacketReceiver,
packet_sender: PacketSender,
first_payload: Option<Vec<u8>>,
timeout: Pin<Box<Sleep>>,
udp_timeout: Duration,
mtu: u16,
Expand All @@ -28,7 +28,7 @@ impl IpStackUdpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
payload: Vec<u8>,
packet_sender: UnboundedSender<NetworkPacket>,
packet_sender: PacketSender,
mtu: u16,
udp_timeout: Duration,
) -> Self {
Expand All @@ -40,14 +40,14 @@ impl IpStackUdpStream {
stream_sender,
stream_receiver,
packet_sender,
first_paload: Some(payload),
first_payload: Some(payload),
timeout: Box::pin(tokio::time::sleep_until(deadline)),
udp_timeout,
mtu,
}
}

pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
pub(crate) fn stream_sender(&self) -> PacketSender {
self.stream_sender.clone()
}

Expand Down Expand Up @@ -126,7 +126,7 @@ impl AsyncRead for IpStackUdpStream {
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if let Some(p) = self.first_paload.take() {
if let Some(p) = self.first_payload.take() {
buf.put_slice(&p);
return std::task::Poll::Ready(Ok(()));
}
Expand Down
7 changes: 3 additions & 4 deletions src/stream/unknown.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use crate::{
packet::{IpHeader, NetworkPacket, TransportHeader},
TTL,
PacketSender, TTL,
};
use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header};
use std::{io::Error, mem, net::IpAddr};
use tokio::sync::mpsc::UnboundedSender;

pub struct IpStackUnknownTransport {
src_addr: IpAddr,
dst_addr: IpAddr,
payload: Vec<u8>,
protocol: IpNumber,
mtu: u16,
packet_sender: UnboundedSender<NetworkPacket>,
packet_sender: PacketSender,
}

impl IpStackUnknownTransport {
Expand All @@ -22,7 +21,7 @@ impl IpStackUnknownTransport {
payload: Vec<u8>,
ip: &IpHeader,
mtu: u16,
packet_sender: UnboundedSender<NetworkPacket>,
packet_sender: PacketSender,
) -> Self {
let protocol = match ip {
IpHeader::Ipv4(ip) => ip.protocol,
Expand Down

0 comments on commit 3db90b1

Please sign in to comment.