Skip to content

Commit

Permalink
Merge pull request #41 from ssrlive/main
Browse files Browse the repository at this point in the history
Code refactor
  • Loading branch information
SajjadPourali authored Apr 24, 2024
2 parents a8afef8 + 6dbae4b commit 1eb6941
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 164 deletions.
70 changes: 49 additions & 21 deletions examples/tun2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! route add 1.2.3.4 mask 255.255.255.255 10.0.0.1 metric 100 # Windows
//! sudo route add 1.2.3.4/32 10.0.0.1 # macOS
//! ```
//!
//! Now you can test it with `nc 1.2.3.4 any_port` or `nc -u 1.2.3.4 any_port`.
//! You can watch the echo information in the `nc` console.
//! ```
Expand Down Expand Up @@ -55,6 +56,14 @@ struct Args {
#[arg(short, long, value_name = "IP:port")]
server_addr: SocketAddr,

/// tcp timeout
#[arg(long, value_name = "seconds", default_value = "60")]
tcp_timeout: u64,

/// udp timeout
#[arg(long, value_name = "seconds", default_value = "10")]
udp_timeout: u64,

/// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
pub verbosity: ArgVerbosity,
Expand All @@ -69,64 +78,83 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let ipv4 = Ipv4Addr::new(10, 0, 0, 33);
let netmask = Ipv4Addr::new(255, 255, 255, 0);
let gateway = Ipv4Addr::new(10, 0, 0, 1);
let _gateway = Ipv4Addr::new(10, 0, 0, 1);

let mut config = tun2::Configuration::default();
config.address(ipv4).netmask(netmask).mtu(MTU).up();
config.destination(gateway);
let mut tun_config = tun2::Configuration::default();
tun_config.address(ipv4).netmask(netmask).mtu(MTU).up();
#[cfg(not(target_os = "windows"))]
tun_config.destination(_gateway); // avoid routing all traffic to tun on Windows platform

#[cfg(target_os = "linux")]
config.platform_config(|config| {
config.ensure_root_privileges(true);
tun_config.platform_config(|p_cfg| {
p_cfg.ensure_root_privileges(true);
});

#[cfg(target_os = "windows")]
config.platform_config(|config| {
config.device_guid(Some(12324323423423434234_u128));
tun_config.platform_config(|p_cfg| {
p_cfg.device_guid(Some(12324323423423434234_u128));
});

let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(MTU);
ipstack_config.tcp_timeout(std::time::Duration::from_secs(args.tcp_timeout));
ipstack_config.udp_timeout(std::time::Duration::from_secs(args.udp_timeout));

let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun2::create_as_async(&config)?);
let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun2::create_as_async(&tun_config)?);

let server_addr = args.server_addr;

let count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let serial_number = std::sync::atomic::AtomicUsize::new(0);

loop {
let count = count.clone();
let number = serial_number.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
match ip_stack.accept().await? {
IpStackStream::Tcp(mut tcp) => {
let mut s = match TcpStream::connect(server_addr).await {
Ok(s) => s,
Err(e) => {
println!("connect TCP server failed \"{}\"", e);
log::info!("connect TCP server failed \"{}\"", e);
continue;
}
};
println!("==== New TCP connection ====");
let c = count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
let number1 = number;
log::info!("#{number1} TCP connecting, session count {c}");
tokio::spawn(async move {
let _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await;
println!("====== end tcp connection ======");
if let Err(err) = tokio::io::copy_bidirectional(&mut tcp, &mut s).await {
log::info!("#{number1} TCP error: {}", err);
}
let c = count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) - 1;
log::info!("#{number1} TCP closed, session count {c}");
});
}
IpStackStream::Udp(mut udp) => {
let mut s = match UdpStream::connect(server_addr).await {
Ok(s) => s,
Err(e) => {
println!("connect UDP server failed \"{}\"", e);
log::info!("connect UDP server failed \"{}\"", e);
continue;
}
};
println!("==== New UDP connection ====");
let c = count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
let number2 = number;
log::info!("#{number2} UDP connecting, session count {c}");
tokio::spawn(async move {
let _ = tokio::io::copy_bidirectional(&mut udp, &mut s).await;
println!("==== end UDP connection ====");
if let Err(err) = tokio::io::copy_bidirectional(&mut udp, &mut s).await {
log::info!("#{number2} UDP error: {}", err);
}
let c = count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) - 1;
log::info!("#{number2} UDP closed, session count {c}");
});
}
IpStackStream::UnknownTransport(u) => {
let n = number;
if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP {
let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?;
if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type {
println!("ICMPv4 echo");
log::info!("#{n} ICMPv4 echo");
let echo = IcmpEchoHeader {
id: req.id,
seq: req.seq,
Expand All @@ -137,15 +165,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
payload.extend_from_slice(req_payload);
u.send(payload)?;
} else {
println!("ICMPv4");
log::info!("#{n} ICMPv4");
}
continue;
}
println!("unknown transport - Ip Protocol {:?}", u.ip_protocol());
log::info!("#{n} unknown transport - Ip Protocol {:?}", u.ip_protocol());
continue;
}
IpStackStream::UnknownNetwork(pkt) => {
println!("unknown transport - {} bytes", pkt.len());
log::info!("#{number} unknown transport - {} bytes", pkt.len());
continue;
}
};
Expand Down
61 changes: 33 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ use tokio::{
task::JoinHandle,
};

pub(crate) type PacketSender = UnboundedSender<NetworkPacket>;
pub(crate) type PacketReceiver = UnboundedReceiver<NetworkPacket>;
pub(crate) type SessionCollection = AHashMap<NetworkTuple, PacketSender>;

mod error;
mod packet;
pub mod stream;
Expand Down Expand Up @@ -62,17 +66,21 @@ impl Default for IpStackConfig {
}

impl IpStackConfig {
pub fn tcp_timeout(&mut self, timeout: Duration) {
pub fn tcp_timeout(&mut self, timeout: Duration) -> &mut Self {
self.tcp_timeout = timeout;
self
}
pub fn udp_timeout(&mut self, timeout: Duration) {
pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self {
self.udp_timeout = timeout;
self
}
pub fn mtu(&mut self, mtu: u16) {
pub fn mtu(&mut self, mtu: u16) -> &mut Self {
self.mtu = mtu;
self
}
pub fn packet_information(&mut self, packet_information: bool) {
pub fn packet_information(&mut self, packet_information: bool) -> &mut Self {
self.packet_information = packet_information;
self
}
}

Expand Down Expand Up @@ -111,35 +119,32 @@ fn run<D>(
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut streams: AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = AHashMap::new();
let offset = if config.packet_information && cfg!(unix) {
4
} else {
0
};
let mut sessions: SessionCollection = AHashMap::new();
let pi = config.packet_information;
let offset = if pi && cfg!(unix) { 4 } else { 0 };
let mut buffer = [0_u8; u16::MAX as usize + 4];
let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

tokio::spawn(async move {
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
if let Some(stream) = process_read(
if let Some(stream) = process_device_read(
&buffer[offset..n],
&mut streams,
&mut sessions,
&pkt_sender,
&config,
) {
accept_sender.send(stream)?;
}
}
Some(packet) = pkt_receiver.recv() => {
process_recv(
process_upstream_recv(
packet,
&mut streams,
&mut sessions,
&mut device,
#[cfg(unix)]
config.packet_information,
pi,
)
.await?;
}
Expand All @@ -148,10 +153,10 @@ where
})
}

fn process_read(
fn process_device_read(
data: &[u8],
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
pkt_sender: &UnboundedSender<NetworkPacket>,
sessions: &mut SessionCollection,
pkt_sender: &PacketSender,
config: &IpStackConfig,
) -> Option<IpStackStream> {
let Ok(packet) = NetworkPacket::parse(data) else {
Expand All @@ -171,7 +176,7 @@ fn process_read(
));
}

match streams.entry(packet.network_tuple()) {
match sessions.entry(packet.network_tuple()) {
Occupied(mut entry) => {
if let Err(e) = entry.get().send(packet) {
trace!("New stream because: {}", e);
Expand All @@ -193,8 +198,8 @@ fn process_read(
fn create_stream(
packet: NetworkPacket,
config: &IpStackConfig,
pkt_sender: &UnboundedSender<NetworkPacket>,
) -> Option<(UnboundedSender<NetworkPacket>, IpStackStream)> {
pkt_sender: &PacketSender,
) -> Option<(PacketSender, IpStackStream)> {
match packet.transport_protocol() {
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
Expand Down Expand Up @@ -233,33 +238,33 @@ fn create_stream(
}
}

async fn process_recv<D>(
async fn process_upstream_recv<D>(
packet: NetworkPacket,
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
sessions: &mut SessionCollection,
device: &mut D,
#[cfg(unix)] packet_information: bool,
) -> Result<()>
where
D: AsyncWrite + Unpin + 'static,
{
if packet.ttl() == 0 {
streams.remove(&packet.reverse_network_tuple());
sessions.remove(&packet.reverse_network_tuple());
return Ok(());
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else {
let Ok(mut packet_bytes) = packet.to_bytes() else {
trace!("to_bytes error");
return Ok(());
};
#[cfg(unix)]
if packet_information {
if packet.src_addr().is_ipv4() {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
device.write_all(&packet_bytes).await?;
// device.flush().await.unwrap();

Ok(())
Expand Down
10 changes: 5 additions & 5 deletions src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod tcp_flags {

#[derive(Debug, Clone)]
pub(crate) enum IpStackPacketProtocol {
Tcp(TcpPacket),
Tcp(TcpHeaderWrapper),
Unknown,
Udp,
}
Expand Down Expand Up @@ -145,11 +145,11 @@ impl NetworkPacket {
}

#[derive(Debug, Clone)]
pub(super) struct TcpPacket {
pub(super) struct TcpHeaderWrapper {
header: TcpHeader,
}

impl TcpPacket {
impl TcpHeaderWrapper {
pub fn inner(&self) -> &TcpHeader {
&self.header
}
Expand Down Expand Up @@ -185,9 +185,9 @@ impl TcpPacket {
}
}

impl From<&TcpHeader> for TcpPacket {
impl From<&TcpHeader> for TcpHeaderWrapper {
fn from(header: &TcpHeader) -> Self {
TcpPacket {
TcpHeaderWrapper {
header: header.clone(),
}
}
Expand Down
Loading

0 comments on commit 1eb6941

Please sign in to comment.