Skip to content

Commit

Permalink
supports auto-choosing ports for client and server
Browse files Browse the repository at this point in the history
  • Loading branch information
neevek committed Apr 16, 2023
1 parent 0ca5066 commit f3b862a
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 73 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions src/access_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ impl AccessServer {
}
}

pub async fn bind(&mut self) -> Result<()> {
pub async fn bind(&mut self) -> Result<SocketAddr> {
info!("starting access server, addr: {}", self.addr);
self.tcp_listener = Some(Arc::new(TcpListener::bind(self.addr).await?));
let tcp_listener = TcpListener::bind(self.addr).await?;
let bound_addr = tcp_listener.local_addr().unwrap();
self.tcp_listener = Some(Arc::new(tcp_listener));
info!("started access server, addr: {}", self.addr);

Ok(())
Ok(bound_addr)
}

pub async fn start(&mut self) -> Result<()> {
Expand Down Expand Up @@ -81,8 +83,8 @@ impl AccessServer {
&self.addr
}

pub fn tcp_receiver_ref(&mut self) -> &mut Receiver<Option<TcpStream>> {
self.tcp_receiver.as_mut().unwrap()
pub async fn recv(&mut self) -> Option<TcpStream> {
self.tcp_receiver.as_mut().unwrap().recv().await?
}

pub fn take_tcp_receiver(&mut self) -> Receiver<Option<TcpStream>> {
Expand Down
51 changes: 28 additions & 23 deletions src/bin/rstunc.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use anyhow::{bail, Context, Result};
use clap::Parser;
use log::error;
use rs_utilities::log_and_bail;
use rstun::*;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

fn main() {
let args = RstuncArgs::parse();
rs_utilities::LogHelper::init_logger("rstunc", args.loglevel.as_ref());
if let Some(config) = parse_command_line_args(args) {
if let Ok(config) = parse_command_line_args(args) {
let mut client = Client::new(config);
// client.set_enable_on_info_report(true);
// client.set_on_info_listener(|s| {
Expand All @@ -15,18 +18,27 @@ fn main() {
}
}

fn parse_command_line_args(args: RstuncArgs) -> Option<ClientConfig> {
fn parse_command_line_args(args: RstuncArgs) -> Result<ClientConfig> {
let mut config = ClientConfig::default();
let addrs: Vec<&str> = args.addr_mapping.split('^').collect();
if addrs.len() != 2 {
error!("invalid address mapping: {}", args.addr_mapping);
return None;
let addr_mapping: Vec<&str> = args.addr_mapping.split('^').collect();
if addr_mapping.len() != 2 {
log_and_bail!("invalid address mapping: {}", args.addr_mapping);
}
let mut addrs: Vec<String> = addrs.iter().map(|s| s.to_string()).collect();

for addr in &mut addrs {
if !addr.contains(':') {
*addr = format!("127.0.0.1:{}", addr);
let mut addr_mapping: Vec<String> = addr_mapping.iter().map(|addr| addr.to_string()).collect();
let mut sock_addr_mapping: Vec<SocketAddr> = Vec::with_capacity(addr_mapping.len());

for addr in &mut addr_mapping {
if addr == "ANY" {
sock_addr_mapping.push(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0));
} else {
if !addr.contains(':') {
*addr = format!("127.0.0.1:{}", addr);
}
sock_addr_mapping.push(
addr.parse::<SocketAddr>()
.context(format!("invalid address mapping:[{}]", args.addr_mapping))?,
);
}
}

Expand All @@ -48,29 +60,21 @@ fn parse_command_line_args(args: RstuncArgs) -> Option<ClientConfig> {
TUNNEL_MODE_OUT
};

let local_access_server_addr;
config.login_msg = if args.mode == TUNNEL_MODE_IN {
local_access_server_addr = addrs[1].to_string();
config.local_access_server_addr = Some(sock_addr_mapping[1]);
Some(TunnelMessage::ReqInLogin(LoginInfo {
password: args.password,
access_server_addr: addrs[0].to_string(),
access_server_addr: sock_addr_mapping[0],
}))
} else {
local_access_server_addr = addrs[0].to_string();
config.local_access_server_addr = Some(sock_addr_mapping[0]);
Some(TunnelMessage::ReqOutLogin(LoginInfo {
password: args.password,
access_server_addr: addrs[1].to_string(),
access_server_addr: sock_addr_mapping[1],
}))
};

config.local_access_server_addr = Some(local_access_server_addr.parse().unwrap_or_else(|e| {
panic!(
"invalid local_access_server_addr: {}, {}",
local_access_server_addr, e
)
}));

Some(config)
Ok(config)
}

#[derive(Parser, Debug)]
Expand All @@ -89,6 +93,7 @@ struct RstuncArgs {
password: String,

/// LOCAL and REMOTE mapping in [ip:]port^[ip:]port format, e.g. 8080^0.0.0.0:9090
/// ANY^ANY means
#[clap(short = 'a', long, display_order = 4)]
addr_mapping: String,

Expand Down
14 changes: 10 additions & 4 deletions src/bin/rstund.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ fn main() {
}

async fn run(mut args: RstundArgs) -> Result<()> {
if args.addr.is_empty() {
args.addr = "0.0.0.0:0".to_string();
}

if !args.addr.contains(':') {
args.addr = format!("127.0.0.1:{}", args.addr);
}
Expand Down Expand Up @@ -64,16 +68,18 @@ async fn run(mut args: RstundArgs) -> Result<()> {
config.downstreams = downstreams;
config.max_idle_timeout_ms = args.max_idle_timeout_ms;

let server = Server::new(config);
let mut server = Server::new(config);
server.start().await?;
server.serve().await?;
Ok(())
}

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct RstundArgs {
/// Address ([ip:]port pair) to listen on
#[clap(short = 'l', long, display_order = 1)]
/// Address ([ip:]port pair) to listen on, a random port will be chosen
/// and binding to all network interfaces (0.0.0.0) if empty
#[clap(short = 'a', long, default_value = "", display_order = 1)]
addr: String,

/// Exposed downstreams as the receiving end of the tunnel, e.g. -d [ip:]port,
Expand Down Expand Up @@ -102,6 +108,6 @@ struct RstundArgs {
#[clap(short = 'w', long, default_value = "40000", display_order = 7)]
max_idle_timeout_ms: u64,

#[clap(short = 'L', long, possible_values = &["T", "D", "I", "W", "E"], default_value = "I", display_order = 8)]
#[clap(short = 'l', long, possible_values = &["T", "D", "I", "W", "E"], default_value = "I", display_order = 8)]
loglevel: String,
}
53 changes: 32 additions & 21 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use log::{debug, error, info, warn};
use quinn::{congestion, TransportConfig};
use quinn::{RecvStream, SendStream};
use quinn_proto::{IdleTimeout, VarInt};
use rs_utilities::{dns, log_and_bail, unwrap_or_continue};
use rs_utilities::{dns, log_and_bail};
use rustls::{client::ServerCertVerified, Certificate, RootCertStore, ServerName};
use rustls_platform_verifier::{self, Verifier};
use serde::Serialize;
Expand All @@ -21,7 +21,6 @@ use std::{
use tokio::net::TcpStream;
#[cfg(not(target_os = "windows"))]
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::mpsc::Receiver;
use tokio::time::Duration;
use x509_parser::prelude::{FromDer, X509Certificate};

Expand Down Expand Up @@ -57,6 +56,7 @@ impl Display for ClientState {

pub struct Client {
pub config: ClientConfig,
access_server: Option<AccessServer>,
remote_conn: Option<Arc<RwLock<quinn::Connection>>>,
ctrl_stream: Option<ControlStream>,
is_terminated: Arc<Mutex<bool>>,
Expand All @@ -71,6 +71,7 @@ impl Client {
pub fn new(config: ClientConfig) -> Self {
Client {
config,
access_server: None,
remote_conn: None,
ctrl_stream: None,
is_terminated: Arc::new(Mutex::new(false)),
Expand All @@ -89,39 +90,52 @@ impl Client {
.build()
.unwrap()
.block_on(async {
self.start_access_server()
.await
.map_err(|e| error!("failed to start access server: {}", e))
.unwrap();

self.connect_and_serve()
.await
.unwrap_or_else(|e| error!("connect failed: {}", e));
.unwrap_or_else(|e| error!("failed to connect: {}", e));
});
}

pub async fn connect_and_serve(&mut self) -> Result<()> {
info!(
"connecting, idle_timeout:{}, retry_timeout:{}, threads:{}",
self.config.max_idle_timeout_ms, self.config.wait_before_retry_ms, self.config.threads
);

pub async fn start_access_server(&mut self) -> Result<SocketAddr> {
self.post_tunnel_log("preparing...");
self.set_and_post_tunnel_state(ClientState::Preparing);

// create a local access server for 'out' tunnel
let mut access_server = None;
if self.config.mode == TUNNEL_MODE_OUT {
self.post_tunnel_log(
format!(
"starting access server for [Out] tunneling: {:?}",
"starting access server for [TunnelOut] tunneling: {:?}",
self.config.local_access_server_addr.unwrap()
)
.as_str(),
);

let mut tmp_access_server =
AccessServer::new(self.config.local_access_server_addr.unwrap());
tmp_access_server.bind().await?;
let bound_addr = tmp_access_server.bind().await?;
tmp_access_server.start().await?;
access_server = Some(tmp_access_server);
self.access_server = Some(tmp_access_server);

info!("==========================================================");
info!("[TunnelOut] access server bound to: {}", bound_addr);
info!("==========================================================");
return Ok(bound_addr);
}

bail!("call start_access_server() for TunnelOut mode only")
}

pub async fn connect_and_serve(&mut self) -> Result<()> {
info!(
"connecting, idle_timeout:{}, retry_timeout:{}, threads:{}",
self.config.max_idle_timeout_ms, self.config.wait_before_retry_ms, self.config.threads
);

let mut connect_retry_count = 0;
let connect_max_retry = self.config.connect_max_retry;
let wait_before_retry_ms = self.config.wait_before_retry_ms;
Expand All @@ -132,8 +146,7 @@ impl Client {
connect_retry_count = 0;

if self.config.mode == TUNNEL_MODE_OUT {
self.serve_outgoing(access_server.as_mut().unwrap().tcp_receiver_ref())
.await;
self.serve_outgoing().await;
} else {
self.serve_incoming().await.ok();
}
Expand Down Expand Up @@ -242,7 +255,7 @@ impl Client {
Ok(())
}

async fn serve_outgoing(&mut self, local_conn_receiver: &mut Receiver<Option<TcpStream>>) {
async fn serve_outgoing(&mut self) {
self.post_tunnel_log("start serving in [TunnelOut] mode...");

self.report_traffic_data_in_background();
Expand All @@ -251,9 +264,7 @@ impl Client {
let ref conn = remote_conn.read().unwrap();

// accept local connections and build a tunnel to remote
while let Some(tcp_stream) = local_conn_receiver.recv().await {
let tcp_stream = unwrap_or_continue!(tcp_stream);

while let Some(tcp_stream) = self.access_server.as_mut().unwrap().recv().await {
match conn.open_bi().await {
Ok(quic_stream) => {
debug!(
Expand Down Expand Up @@ -613,8 +624,8 @@ impl rustls::client::ServerCertVerifier for InsecureCertVerifier {
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
warn!("======================================= WARNING ======================================");
warn!("= Connecting to a server without verifying its certificate is DANGEROUS!!! =");
warn!("= Provide the self-signed certificate for verification or connect with a domain name =");
warn!("Connecting to a server without verifying its certificate is DANGEROUS!!!");
warn!("Provide the self-signed certificate for verification or connect with a domain name");
warn!("======================= Be cautious, this is for TEST only!!! ========================");
Ok(ServerCertVerified::assertion())
}
Expand Down
Loading

0 comments on commit f3b862a

Please sign in to comment.