diff --git a/Cargo.lock b/Cargo.lock index 08b1463..74abbc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -941,9 +941,9 @@ dependencies = [ [[package]] name = "rs-utilities" -version = "0.3.2" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b64a094d164941b2bc82d5434b3d539d11e099b54d4a3193af476d8795476c2f" +checksum = "c5b75e86f691f63827c3de486f91858efc4f68330d13e6a2d8bded7d3d1519d2" dependencies = [ "android_logger", "anyhow", diff --git a/src/access_server.rs b/src/access_server.rs index fc7271a..3f06ee6 100644 --- a/src/access_server.rs +++ b/src/access_server.rs @@ -26,12 +26,14 @@ impl AccessServer { } } - pub async fn bind(&mut self) -> Result<()> { + pub async fn bind(&mut self) -> Result { info!("starting access server, addr: {}", self.addr); - self.tcp_listener = Some(Arc::new(TcpListener::bind(self.addr).await?)); + let tcp_listener = TcpListener::bind(self.addr).await?; + let bound_addr = tcp_listener.local_addr().unwrap(); + self.tcp_listener = Some(Arc::new(tcp_listener)); info!("started access server, addr: {}", self.addr); - Ok(()) + Ok(bound_addr) } pub async fn start(&mut self) -> Result<()> { @@ -81,8 +83,8 @@ impl AccessServer { &self.addr } - pub fn tcp_receiver_ref(&mut self) -> &mut Receiver> { - self.tcp_receiver.as_mut().unwrap() + pub async fn recv(&mut self) -> Option { + self.tcp_receiver.as_mut().unwrap().recv().await? } pub fn take_tcp_receiver(&mut self) -> Receiver> { diff --git a/src/bin/rstunc.rs b/src/bin/rstunc.rs index 30af2f8..890444c 100644 --- a/src/bin/rstunc.rs +++ b/src/bin/rstunc.rs @@ -1,11 +1,14 @@ +use anyhow::{bail, Context, Result}; use clap::Parser; use log::error; +use rs_utilities::log_and_bail; use rstun::*; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; fn main() { let args = RstuncArgs::parse(); rs_utilities::LogHelper::init_logger("rstunc", args.loglevel.as_ref()); - if let Some(config) = parse_command_line_args(args) { + if let Ok(config) = parse_command_line_args(args) { let mut client = Client::new(config); // client.set_enable_on_info_report(true); // client.set_on_info_listener(|s| { @@ -15,18 +18,27 @@ fn main() { } } -fn parse_command_line_args(args: RstuncArgs) -> Option { +fn parse_command_line_args(args: RstuncArgs) -> Result { let mut config = ClientConfig::default(); - let addrs: Vec<&str> = args.addr_mapping.split('^').collect(); - if addrs.len() != 2 { - error!("invalid address mapping: {}", args.addr_mapping); - return None; + let addr_mapping: Vec<&str> = args.addr_mapping.split('^').collect(); + if addr_mapping.len() != 2 { + log_and_bail!("invalid address mapping: {}", args.addr_mapping); } - let mut addrs: Vec = addrs.iter().map(|s| s.to_string()).collect(); - for addr in &mut addrs { - if !addr.contains(':') { - *addr = format!("127.0.0.1:{}", addr); + let mut addr_mapping: Vec = addr_mapping.iter().map(|addr| addr.to_string()).collect(); + let mut sock_addr_mapping: Vec = Vec::with_capacity(addr_mapping.len()); + + for addr in &mut addr_mapping { + if addr == "ANY" { + sock_addr_mapping.push(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)); + } else { + if !addr.contains(':') { + *addr = format!("127.0.0.1:{}", addr); + } + sock_addr_mapping.push( + addr.parse::() + .context(format!("invalid address mapping:[{}]", args.addr_mapping))?, + ); } } @@ -48,29 +60,21 @@ fn parse_command_line_args(args: RstuncArgs) -> Option { TUNNEL_MODE_OUT }; - let local_access_server_addr; config.login_msg = if args.mode == TUNNEL_MODE_IN { - local_access_server_addr = addrs[1].to_string(); + config.local_access_server_addr = Some(sock_addr_mapping[1]); Some(TunnelMessage::ReqInLogin(LoginInfo { password: args.password, - access_server_addr: addrs[0].to_string(), + access_server_addr: sock_addr_mapping[0], })) } else { - local_access_server_addr = addrs[0].to_string(); + config.local_access_server_addr = Some(sock_addr_mapping[0]); Some(TunnelMessage::ReqOutLogin(LoginInfo { password: args.password, - access_server_addr: addrs[1].to_string(), + access_server_addr: sock_addr_mapping[1], })) }; - config.local_access_server_addr = Some(local_access_server_addr.parse().unwrap_or_else(|e| { - panic!( - "invalid local_access_server_addr: {}, {}", - local_access_server_addr, e - ) - })); - - Some(config) + Ok(config) } #[derive(Parser, Debug)] @@ -89,6 +93,7 @@ struct RstuncArgs { password: String, /// LOCAL and REMOTE mapping in [ip:]port^[ip:]port format, e.g. 8080^0.0.0.0:9090 + /// ANY^ANY means #[clap(short = 'a', long, display_order = 4)] addr_mapping: String, diff --git a/src/bin/rstund.rs b/src/bin/rstund.rs index 1dbbbad..7ce8d26 100644 --- a/src/bin/rstund.rs +++ b/src/bin/rstund.rs @@ -34,6 +34,10 @@ fn main() { } async fn run(mut args: RstundArgs) -> Result<()> { + if args.addr.is_empty() { + args.addr = "0.0.0.0:0".to_string(); + } + if !args.addr.contains(':') { args.addr = format!("127.0.0.1:{}", args.addr); } @@ -64,16 +68,18 @@ async fn run(mut args: RstundArgs) -> Result<()> { config.downstreams = downstreams; config.max_idle_timeout_ms = args.max_idle_timeout_ms; - let server = Server::new(config); + let mut server = Server::new(config); server.start().await?; + server.serve().await?; Ok(()) } #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct RstundArgs { - /// Address ([ip:]port pair) to listen on - #[clap(short = 'l', long, display_order = 1)] + /// Address ([ip:]port pair) to listen on, a random port will be chosen + /// and binding to all network interfaces (0.0.0.0) if empty + #[clap(short = 'a', long, default_value = "", display_order = 1)] addr: String, /// Exposed downstreams as the receiving end of the tunnel, e.g. -d [ip:]port, @@ -102,6 +108,6 @@ struct RstundArgs { #[clap(short = 'w', long, default_value = "40000", display_order = 7)] max_idle_timeout_ms: u64, - #[clap(short = 'L', long, possible_values = &["T", "D", "I", "W", "E"], default_value = "I", display_order = 8)] + #[clap(short = 'l', long, possible_values = &["T", "D", "I", "W", "E"], default_value = "I", display_order = 8)] loglevel: String, } diff --git a/src/client.rs b/src/client.rs index 4b05a99..9b475b1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,7 +8,7 @@ use log::{debug, error, info, warn}; use quinn::{congestion, TransportConfig}; use quinn::{RecvStream, SendStream}; use quinn_proto::{IdleTimeout, VarInt}; -use rs_utilities::{dns, log_and_bail, unwrap_or_continue}; +use rs_utilities::{dns, log_and_bail}; use rustls::{client::ServerCertVerified, Certificate, RootCertStore, ServerName}; use rustls_platform_verifier::{self, Verifier}; use serde::Serialize; @@ -21,7 +21,6 @@ use std::{ use tokio::net::TcpStream; #[cfg(not(target_os = "windows"))] use tokio::signal::unix::{signal, SignalKind}; -use tokio::sync::mpsc::Receiver; use tokio::time::Duration; use x509_parser::prelude::{FromDer, X509Certificate}; @@ -57,6 +56,7 @@ impl Display for ClientState { pub struct Client { pub config: ClientConfig, + access_server: Option, remote_conn: Option>>, ctrl_stream: Option, is_terminated: Arc>, @@ -71,6 +71,7 @@ impl Client { pub fn new(config: ClientConfig) -> Self { Client { config, + access_server: None, remote_conn: None, ctrl_stream: None, is_terminated: Arc::new(Mutex::new(false)), @@ -89,27 +90,26 @@ impl Client { .build() .unwrap() .block_on(async { + self.start_access_server() + .await + .map_err(|e| error!("failed to start access server: {}", e)) + .unwrap(); + self.connect_and_serve() .await - .unwrap_or_else(|e| error!("connect failed: {}", e)); + .unwrap_or_else(|e| error!("failed to connect: {}", e)); }); } - pub async fn connect_and_serve(&mut self) -> Result<()> { - info!( - "connecting, idle_timeout:{}, retry_timeout:{}, threads:{}", - self.config.max_idle_timeout_ms, self.config.wait_before_retry_ms, self.config.threads - ); - + pub async fn start_access_server(&mut self) -> Result { self.post_tunnel_log("preparing..."); self.set_and_post_tunnel_state(ClientState::Preparing); // create a local access server for 'out' tunnel - let mut access_server = None; if self.config.mode == TUNNEL_MODE_OUT { self.post_tunnel_log( format!( - "starting access server for [Out] tunneling: {:?}", + "starting access server for [TunnelOut] tunneling: {:?}", self.config.local_access_server_addr.unwrap() ) .as_str(), @@ -117,11 +117,25 @@ impl Client { let mut tmp_access_server = AccessServer::new(self.config.local_access_server_addr.unwrap()); - tmp_access_server.bind().await?; + let bound_addr = tmp_access_server.bind().await?; tmp_access_server.start().await?; - access_server = Some(tmp_access_server); + self.access_server = Some(tmp_access_server); + + info!("=========================================================="); + info!("[TunnelOut] access server bound to: {}", bound_addr); + info!("=========================================================="); + return Ok(bound_addr); } + bail!("call start_access_server() for TunnelOut mode only") + } + + pub async fn connect_and_serve(&mut self) -> Result<()> { + info!( + "connecting, idle_timeout:{}, retry_timeout:{}, threads:{}", + self.config.max_idle_timeout_ms, self.config.wait_before_retry_ms, self.config.threads + ); + let mut connect_retry_count = 0; let connect_max_retry = self.config.connect_max_retry; let wait_before_retry_ms = self.config.wait_before_retry_ms; @@ -132,8 +146,7 @@ impl Client { connect_retry_count = 0; if self.config.mode == TUNNEL_MODE_OUT { - self.serve_outgoing(access_server.as_mut().unwrap().tcp_receiver_ref()) - .await; + self.serve_outgoing().await; } else { self.serve_incoming().await.ok(); } @@ -242,7 +255,7 @@ impl Client { Ok(()) } - async fn serve_outgoing(&mut self, local_conn_receiver: &mut Receiver>) { + async fn serve_outgoing(&mut self) { self.post_tunnel_log("start serving in [TunnelOut] mode..."); self.report_traffic_data_in_background(); @@ -251,9 +264,7 @@ impl Client { let ref conn = remote_conn.read().unwrap(); // accept local connections and build a tunnel to remote - while let Some(tcp_stream) = local_conn_receiver.recv().await { - let tcp_stream = unwrap_or_continue!(tcp_stream); - + while let Some(tcp_stream) = self.access_server.as_mut().unwrap().recv().await { match conn.open_bi().await { Ok(quic_stream) => { debug!( @@ -613,8 +624,8 @@ impl rustls::client::ServerCertVerifier for InsecureCertVerifier { _now: SystemTime, ) -> Result { warn!("======================================= WARNING ======================================"); - warn!("= Connecting to a server without verifying its certificate is DANGEROUS!!! ="); - warn!("= Provide the self-signed certificate for verification or connect with a domain name ="); + warn!("Connecting to a server without verifying its certificate is DANGEROUS!!!"); + warn!("Provide the self-signed certificate for verification or connect with a domain name"); warn!("======================= Be cautious, this is for TEST only!!! ========================"); Ok(ServerCertVerified::assertion()) } diff --git a/src/server.rs b/src/server.rs index 0ec8191..54aa307 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,11 +1,11 @@ use crate::{AccessServer, ControlStream, ServerConfig, Tunnel, TunnelMessage, TunnelType}; use anyhow::{bail, Context, Result}; use log::{debug, error, info, warn}; -use quinn::{congestion, TransportConfig}; +use quinn::{congestion, Endpoint, TransportConfig}; use quinn_proto::{IdleTimeout, VarInt}; use rs_utilities::log_and_bail; use rustls::{Certificate, PrivateKey}; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::Mutex; @@ -15,6 +15,7 @@ use tokio::time::Duration; pub struct Server { config: ServerConfig, access_server_ports: Mutex>, + endpoint: Option, } impl Server { @@ -22,10 +23,11 @@ impl Server { Arc::new(Server { config, access_server_ports: Mutex::new(Vec::new()), + endpoint: None, }) } - pub async fn start(self: &Arc) -> Result<()> { + pub async fn start(mut self: &mut Arc) -> Result { let config = &self.config; let (cert, key) = Server::read_cert_and_key(config.cert_path.as_str(), config.key_path.as_str()) @@ -66,6 +68,17 @@ impl Server { config.max_idle_timeout_ms ); + Arc::get_mut(&mut self).unwrap().endpoint = Some(endpoint); + + Ok(addr) + } + + pub async fn serve(self: &Arc) -> Result<()> { + let endpoint = self + .endpoint + .as_ref() + .context("make sure start call succeeded!")?; + while let Some(client_conn) = endpoint.accept().await { let mut this = self.clone(); tokio::spawn(async move { @@ -132,19 +145,45 @@ impl Server { info!("received OutLogin request, addr: {}", remote_addr); Self::check_password(self.config.password.as_str(), login_info.password.as_str())?; - let downstream_addr = login_info.access_server_addr.parse().context(format!( - "invalid access server address: {}", - login_info.access_server_addr - ))?; - if !self.config.downstreams.is_empty() - && !self.config.downstreams.contains(&downstream_addr) - { - log_and_bail!("invalid addr: {}", downstream_addr); + let downstreams = &self.config.downstreams; + let access_server_addr = &login_info.access_server_addr; + + let is_local = match access_server_addr.ip() { + IpAddr::V4(ipv4) => { + ipv4.is_private() || ipv4.is_loopback() || ipv4.is_unspecified() + } + IpAddr::V6(ipv6) => { + ipv6.is_loopback() || ipv6.is_loopback() || ipv6.is_unspecified() + } + }; + if !is_local { + log_and_bail!( + "only local IPs are allowed for downstream, addr: {}", + access_server_addr + ); } + let downstream_addr = if access_server_addr.port() == 0 { + if downstreams.is_empty() { + log_and_bail!("explicit downstream address must be specified because there's no default set for the server"); + } + let addr = downstreams.first().unwrap(); + info!( + "will bind incoming TunnelIn request({}) to default address({})", + client_conn.remote_address(), + addr + ); + addr + } else { + if !downstreams.is_empty() && !downstreams.contains(access_server_addr) { + log_and_bail!("downstream address not set: {}", access_server_addr); + } + access_server_addr + }; + TunnelMessage::send(&mut quic_send, &TunnelMessage::RespSuccess).await?; - tunnel_type = TunnelType::Out((client_conn, downstream_addr)); + tunnel_type = TunnelType::Out((client_conn, *downstream_addr)); info!("sent response for OutLogin request, addr: {}", remote_addr); } @@ -152,9 +191,19 @@ impl Server { info!("received InLogin request, addr: {}", remote_addr); Self::check_password(self.config.password.as_str(), login_info.password.as_str())?; - let upstream_addr: SocketAddr = login_info.access_server_addr.parse().context( - format!("invalid address: {}", login_info.access_server_addr), - )?; + if login_info.access_server_addr.port() == 0 { + log_and_bail!( + "explicit access_server_addr for TunnelIn mode tunelling is required, addr: {:?}", login_info.access_server_addr + ); + } + if !login_info.access_server_addr.ip().is_unspecified() + && !login_info.access_server_addr.ip().is_loopback() + { + log_and_bail!( + "only loopback or unspecified IP is allowed for TunnelIn mode tunelling, addr: {:?}", login_info.access_server_addr + ); + } + let upstream_addr = login_info.access_server_addr; let mut guarded_access_server_ports = self.access_server_ports.lock().await; if guarded_access_server_ports.contains(&upstream_addr.port()) { @@ -305,8 +354,8 @@ impl Server { fn read_cert_and_key(cert_path: &str, key_path: &str) -> Result<(Certificate, PrivateKey)> { let (cert, key) = if cert_path.is_empty() { warn!("============================= WARNING =============================="); - warn!("= No valid certificate path is provided, a self-signed certificate ="); - warn!("= for the domain \"localhost\" is generated. ="); + warn!("No valid certificate path is provided, a self-signed certificate"); + warn!("for the domain \"localhost\" is generated."); warn!("============== Be cautious, this is for TEST only!!! ==============="); let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?; let key = cert.serialize_private_key_der(); diff --git a/src/tunnel_message.rs b/src/tunnel_message.rs index 133e8e9..4a533fe 100644 --- a/src/tunnel_message.rs +++ b/src/tunnel_message.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use anyhow::Result; use anyhow::{bail, Context}; use enum_as_inner::EnumAsInner; @@ -19,7 +21,7 @@ pub enum TunnelMessage { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct LoginInfo { pub password: String, - pub access_server_addr: String, // ip:port tuple + pub access_server_addr: SocketAddr, } impl TunnelMessage {