From f829194e17d6fa5cd0d59caf634a731d9e411f25 Mon Sep 17 00:00:00 2001 From: neevek Date: Fri, 24 May 2024 18:13:37 +0800 Subject: [PATCH] Revert "refactor bytes transfering and log output" This reverts commit 3b4ff9478ebc91e8b19472dbc5f23dfdf0abcb18. This commit causes unexpected connection reset, which needs to be resolved to be merged back. --- src/access_server.rs | 2 +- src/bin/rstund.rs | 6 +- src/client.rs | 47 +++++++++--- src/lib.rs | 26 +++++-- src/server.rs | 104 ++++++++++++++++--------- src/tunnel.rs | 175 ++++++++++++------------------------------- 6 files changed, 177 insertions(+), 183 deletions(-) diff --git a/src/access_server.rs b/src/access_server.rs index d220d5b..c0f6cc9 100644 --- a/src/access_server.rs +++ b/src/access_server.rs @@ -72,7 +72,7 @@ impl AccessServer { continue; } - debug!("received conn : {addr}"); + debug!("received new local connection, addr: {addr}"); match tcp_sender .send_timeout( Some(ChannelMessage::Request(socket)), diff --git a/src/bin/rstund.rs b/src/bin/rstund.rs index 166ddf2..73db53b 100644 --- a/src/bin/rstund.rs +++ b/src/bin/rstund.rs @@ -30,7 +30,7 @@ fn main() { run(args) .await .map_err(|e| { - error!("{e}"); + error!("{}", e); }) .ok(); }) @@ -53,13 +53,13 @@ async fn run(mut args: RstundArgs) -> Result<()> { } if !d.contains(':') { - *d = format!("127.0.0.1:{d}"); + *d = format!("127.0.0.1:{}", d); } if let Ok(addr) = d.parse() { upstreams.push(addr); } else { - log_and_bail!("invalid upstreams address: {d}"); + log_and_bail!("invalid upstreams address: {}", d); } } diff --git a/src/client.rs b/src/client.rs index ef4da6c..6b5e0a7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -192,16 +192,22 @@ impl Client { } Err(e) => { - error!("connect failed, err: {e}"); + error!("connect failed, err: {}", e); if connect_max_retry > 0 { connect_retry_count += 1; if connect_retry_count >= connect_max_retry { - info!("quit after having retried for {connect_retry_count} times"); + info!( + "quit after having retried for {} times", + connect_retry_count + ); break; } } - debug!("will wait for {wait_before_retry_ms}ms before retrying..."); + debug!( + "will wait for {}ms before retrying...", + wait_before_retry_ms + ); tokio::time::sleep(Duration::from_millis(wait_before_retry_ms)).await; } } @@ -262,7 +268,7 @@ impl Client { let connection = endpoint.connect(remote_addr, domain.as_str())?.await?; self.set_and_post_tunnel_state(ClientState::Connected); - self.post_tunnel_log(format!("connected to server: {remote_addr:?}").as_str()); + self.post_tunnel_log(format!("connected to server: {:?}", remote_addr).as_str()); let (mut quic_send, mut quic_recv) = connection .open_bi() @@ -301,7 +307,16 @@ impl Client { // accept local connections and build a tunnel to remote while let Some(ChannelMessage::Request(tcp_stream)) = access_server.recv().await { match remote_conn.open_bi().await { - Ok(quic_stream) => Tunnel::new().start(true, tcp_stream, quic_stream), + Ok(quic_stream) => { + debug!( + "[TunnelOut] open stream for conn, {} -> {}", + quic_stream.0.id().index(), + remote_conn.remote_address(), + ); + + let tcp_stream = tcp_stream.into_split(); + Tunnel::new().start(tcp_stream, quic_stream).await; + } Err(e) => { error!("failed to open_bi on remote connection: {e}"); self.post_tunnel_log( @@ -338,7 +353,16 @@ impl Client { let remote_conn = remote_conn.read().await; while let Ok(quic_stream) = remote_conn.accept_bi().await { match TcpStream::connect(self.config.local_access_server_addr.unwrap()).await { - Ok(tcp_stream) => Tunnel::new().start(false, tcp_stream, quic_stream), + Ok(tcp_stream) => { + debug!( + "[TunnelIn] open stream for conn, {} <- {}", + quic_stream.0.id().index(), + remote_conn.remote_address(), + ); + + let tcp_stream = tcp_stream.into_split(); + Tunnel::new().start(tcp_stream, quic_stream).await; + } Err(e) => { error!( "failed to connect to access server: {e}, {}", @@ -455,7 +479,7 @@ impl Client { let cert = certs.first().context("certificate is not in PEM format")?; let mut roots = RootCertStore::empty(); - roots.add(cert).context(format!( + roots.add(&cert).context(format!( "failed to add certificate: {}", self.config.cert_path ))?; @@ -560,7 +584,7 @@ impl Client { return Ok(SocketAddr::new(ip, port)); } - bail!("failed to resolve domain: {domain}"); + bail!("failed to resolve domain: {}", domain); } async fn lookup_server_ip( @@ -583,7 +607,7 @@ impl Client { }; let ip = resolver.await.lookup_first(domain).await?; - info!("resolved {domain} to {ip}"); + info!("resolved {} to {}", domain, ip); Ok(ip) } @@ -592,8 +616,9 @@ impl Client { self.post_tunnel_info(TunnelInfo::new( TunnelInfoType::TunnelLog, Box::new(format!( - "{} {log}", - chrono::Local::now().format(TIME_FORMAT) + "{} {}", + chrono::Local::now().format(TIME_FORMAT), + log )), )); } diff --git a/src/lib.rs b/src/lib.rs index edf0cdf..34959ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,6 +140,11 @@ pub struct ServerConfig { pub dashboard_server_credential: String, } +pub(crate) enum ReadResult { + Succeeded, + Eof, +} + impl ClientConfig { pub fn create( mode: &str, @@ -156,7 +161,7 @@ impl ClientConfig { let addr_mapping_str = addr_mapping; let addr_mapping: Vec<&str> = addr_mapping_str.split('^').collect(); if addr_mapping.len() != 2 { - log_and_bail!("invalid address mapping: {addr_mapping_str}"); + log_and_bail!("invalid address mapping: {}", addr_mapping_str); } let mut addr_mapping: Vec = @@ -168,12 +173,12 @@ impl ClientConfig { sock_addr_mapping.push(None); } else { if !addr.contains(':') { - *addr = format!("127.0.0.1:{addr}"); + *addr = format!("127.0.0.1:{}", addr); } - sock_addr_mapping - .push(Some(addr.parse::().context(format!( - "invalid address mapping:[{addr_mapping_str}]" - ))?)); + sock_addr_mapping.push(Some( + addr.parse::() + .context(format!("invalid address mapping:[{}]", addr_mapping_str))?, + )); } } @@ -201,7 +206,7 @@ impl ClientConfig { access_server_addr: sock_addr_mapping[0], })) } else { - if sock_addr_mapping[0].is_none() { + if sock_addr_mapping[0] == None { log_and_bail!("'ANY' is not allowed as local access server for OUT tunneling"); } config.local_access_server_addr = sock_addr_mapping[0]; @@ -215,6 +220,13 @@ impl ClientConfig { } } +impl ReadResult { + #![allow(dead_code)] + pub fn is_eof(&self) -> bool { + matches!(self, Self::Eof) + } +} + pub fn socket_addr_with_unspecified_ip_port(ipv6: bool) -> SocketAddr { if ipv6 { SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) diff --git a/src/server.rs b/src/server.rs index f3f4402..1e35735 100644 --- a/src/server.rs +++ b/src/server.rs @@ -39,7 +39,10 @@ impl Server { let quinn_server_cfg = Self::load_quinn_server_config(&self.config)?; let endpoint = quinn::Endpoint::server(quinn_server_cfg, addr).map_err(|e| { - error!("failed to bind tunnel server on address: {addr}, err: {e}"); + error!( + "failed to bind tunnel server on address: {}, error: {}", + addr, e + ); e })?; @@ -116,13 +119,14 @@ impl Server { match tun_type { TunnelType::Out((client_conn, addr)) => { info!( - "start tunnel streaming in TunnelOut mode, {} ↔ {addr}", + "start tunnel streaming in OUT mode, {} -> {}", client_conn.remote_address(), + addr ); this.process_out_connection(client_conn, addr) .await - .map_err(|e| error!("process_out_connection failed: {e}")) + .map_err(|e| error!("process_out_connection failed: {}", e)) .ok(); } @@ -135,7 +139,7 @@ impl Server { this.process_in_connection(client_conn, access_server, ctrl_stream) .await - .map_err(|e| error!("process_in_connection failed: {e}")) + .map_err(|e| error!("process_in_connection failed: {}", e)) .ok(); } } @@ -155,17 +159,21 @@ impl Server { ) -> Result { let remote_addr = &client_conn.remote_address(); - info!("received connection, authenticating... addr:{remote_addr}"); - let (mut quic_send, mut quic_recv) = client_conn - .accept_bi() - .await - .context(format!("login request not received in time: {remote_addr}"))?; + info!( + "received connection, authenticating... addr:{}", + remote_addr + ); - info!("received bi_stream request: {remote_addr}"); + let (mut quic_send, mut quic_recv) = client_conn.accept_bi().await.context(format!( + "login request not received in time, addr: {}", + remote_addr + ))?; + + info!("received bi_stream request, addr: {}", remote_addr); let tunnel_type; match TunnelMessage::recv(&mut quic_recv).await? { TunnelMessage::ReqOutLogin(login_info) => { - info!("received OutLogin request: {remote_addr}"); + info!("received OutLogin request, addr: {}", remote_addr); Self::check_password(self.config.password.as_str(), login_info.password.as_str())?; @@ -184,7 +192,10 @@ impl Server { } }; if !is_local { - log_and_bail!("only local IPs are allowed for upstream: {access_server_addr}"); + log_and_bail!( + "only local IPs are allowed for upstream, addr: {}", + access_server_addr + ); } let upstream_addr = if access_server_addr.port() == 0 { @@ -193,24 +204,25 @@ impl Server { } let addr = upstreams.first().unwrap(); info!( - "will bind incoming TunnelIn request({}) to default address({addr})", - client_conn.remote_address() + "will bind incoming TunnelIn request({}) to default address({})", + client_conn.remote_address(), + addr ); addr } else { if !upstreams.is_empty() && !upstreams.contains(&access_server_addr) { - log_and_bail!("upstream address not set: {access_server_addr}"); + log_and_bail!("upstream address not set: {}", access_server_addr); } &access_server_addr }; TunnelMessage::send(&mut quic_send, &TunnelMessage::RespSuccess).await?; tunnel_type = TunnelType::Out((client_conn, *upstream_addr)); - info!("sent response for OutLogin request: {remote_addr}"); + info!("sent response for OutLogin request, addr: {}", remote_addr); } TunnelMessage::ReqInLogin(login_info) => { - info!("received InLogin request: {remote_addr}"); + info!("received InLogin request, addr: {}", remote_addr); Self::check_password(self.config.password.as_str(), login_info.password.as_str())?; let access_server_addr = match login_info.access_server_addr { @@ -219,13 +231,15 @@ impl Server { }; if access_server_addr.port() == 0 { log_and_bail!( - "explicit access_server_addr for TunnelIn mode tunelling is required: {access_server_addr:?}"); + "explicit access_server_addr for TunnelIn mode tunelling is required, addr: {:?}", access_server_addr + ); } if !access_server_addr.ip().is_unspecified() && !access_server_addr.ip().is_loopback() { log_and_bail!( - "only loopback or unspecified IP is allowed for TunnelIn mode tunelling: {access_server_addr:?}"); + "only loopback or unspecified IP is allowed for TunnelIn mode tunelling, addr: {:?}", access_server_addr + ); } let upstream_addr = access_server_addr; @@ -270,7 +284,7 @@ impl Server { guarded_access_server_ports.push(upstream_addr.port()); - info!("sent response for InLogin request: {remote_addr}"); + info!("sent response for InLogin request, addr: {}", remote_addr); } _ => { @@ -278,7 +292,7 @@ impl Server { } } - info!("connection authenticated! addr: {remote_addr}"); + info!("connection authenticated! addr: {}", remote_addr); Ok(tunnel_type) } @@ -293,20 +307,36 @@ impl Server { loop { match client_conn.accept_bi().await { Err(quinn::ConnectionError::TimedOut { .. }) => { - info!("connection timeout: {remote_addr}"); + info!("connection timeout, addr: {}", remote_addr); return Ok(()); } Err(quinn::ConnectionError::ApplicationClosed { .. }) => { - debug!("connection closed: {remote_addr}"); + debug!("connection closed, addr: {}", remote_addr); return Ok(()); } Err(e) => { - log_and_bail!("failed to open accpet_bi: {remote_addr}, err: {e}"); + log_and_bail!( + "failed to open bi_streams, addr: {}, err: {}", + remote_addr, + e + ); } Ok(quic_stream) => tokio::spawn(async move { match TcpStream::connect(&upstream_addr).await { - Ok(tcp_stream) => Tunnel::new().start(true, tcp_stream, quic_stream), - Err(e) => error!("failed to connect to {upstream_addr}, err: {e}"), + Ok(tcp_stream) => { + debug!( + "[Out] open stream for conn, {} -> {}", + quic_stream.0.id().index(), + upstream_addr, + ); + + let tcp_stream = tcp_stream.into_split(); + Tunnel::new().start(tcp_stream, quic_stream).await; + } + + Err(e) => { + error!("failed to connect to {}, err: {}", upstream_addr, e); + } } }), }; @@ -321,17 +351,23 @@ impl Server { ) -> Result<()> { let tcp_sender = access_server.clone_tcp_sender(); tokio::spawn(async move { - TunnelMessage::recv(&mut ctrl_stream.quic_recv).await.ok(); - // send None to signify exit - tcp_sender.send(None).await.ok(); - Ok::<(), anyhow::Error>(()) + match TunnelMessage::recv(&mut ctrl_stream.quic_recv).await { + _ => { + // send None to signify exit + tcp_sender.send(None).await.ok(); + Ok::<(), anyhow::Error>(()) + } + } }); access_server.set_drop_conn(false); let mut tcp_receiver = access_server.take_tcp_receiver(); while let Some(Some(ChannelMessage::Request(tcp_stream))) = tcp_receiver.recv().await { match client_conn.open_bi().await { - Ok(quic_stream) => Tunnel::new().start(false, tcp_stream, quic_stream), + Ok(quic_stream) => { + let tcp_stream = tcp_stream.into_split(); + Tunnel::new().start(tcp_stream, quic_stream).await; + } _ => { log_and_bail!("failed to open bi_streams to client, quit"); } @@ -349,7 +385,7 @@ impl Server { access_server.shutdown(tcp_receiver).await.ok(); - info!("access server quit: {addr}"); + info!("access server quit: {}", addr); Ok(()) } @@ -370,9 +406,9 @@ impl Server { (vec![Certificate(cert)], PrivateKey(key)) } else { let certs = pem_util::load_certificates_from_pem(cert_path) - .context(format!("failed to read cert file: {cert_path}"))?; + .context(format!("failed to read cert file: {}", cert_path))?; let key = pem_util::load_private_key_from_pem(key_path) - .context(format!("failed to read key file: {key_path}"))?; + .context(format!("failed to read key file: {}", key_path))?; (certs, key) }; diff --git a/src/tunnel.rs b/src/tunnel.rs index 5944c7c..267a8d3 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -1,156 +1,77 @@ -use std::time::Duration; - +use crate::ReadResult; use crate::BUFFER_POOL; use anyhow::Result; -use log::debug; +use log::info; use quinn::{RecvStream, SendStream}; -use tokio::io::AsyncRead; -use tokio::io::AsyncWrite; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio::time::error::Elapsed; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; pub struct Tunnel {} -#[derive(Debug, PartialEq, Eq)] -pub enum TransferError { - Timeout, - InternalError, -} - impl Tunnel { pub fn new() -> Self { Tunnel {} } - pub fn start( + pub async fn start( &self, - tunnel_out: bool, - tcp_stream: TcpStream, + tcp_stream: (OwnedReadHalf, OwnedWriteHalf), quic_stream: (SendStream, RecvStream), ) { - tokio::spawn(async move { - Self::run(tunnel_out, tcp_stream, quic_stream).await.ok(); - }); - } - - async fn run( - tunnel_out: bool, - mut tcp_stream: TcpStream, - quic_stream: (SendStream, RecvStream), - ) -> Result<(), TransferError> { - let (mut tcp_read, mut tcp_write) = tcp_stream.split(); + let (mut tcp_read, mut tcp_write) = tcp_stream; let (mut quic_send, mut quic_recv) = quic_stream; - let tag = if tunnel_out { "OUT" } else { "IN" }; - let index = quic_send.id().index(); - let in_addr = tcp_read - .peer_addr() - .map_err(|_| TransferError::InternalError)?; - - debug!("[{tag}] tunnel start : {index:<3} ↔ {in_addr:<20}"); - - const BUFFER_SIZE: usize = 8192; - let mut inbound_buffer = BUFFER_POOL.alloc_and_fill(BUFFER_SIZE); - let mut outbound_buffer = BUFFER_POOL.alloc_and_fill(BUFFER_SIZE); + info!( + "built tunnel for local conn, {} <=> {}", + quic_send.id().index(), + tcp_read.peer_addr().unwrap(), + ); - let mut tx_bytes = 0u64; - let mut rx_bytes = 0u64; - let mut tcp_stream_eos = false; - let mut quic_stream_eos = false; - let mut loop_count = 0; - - loop { - loop_count += 1; - let result = if !tcp_stream_eos && !quic_stream_eos { - tokio::select! { - result = Self::transfer_data_with_timeout( - &mut tcp_read, - &mut quic_send, - &mut inbound_buffer, - &mut tx_bytes, - &mut tcp_stream_eos) => result, - result = Self::transfer_data_with_timeout( - &mut quic_recv, - &mut tcp_write, - &mut outbound_buffer, - &mut rx_bytes, - &mut quic_stream_eos) => result, + tokio::spawn(async move { + loop { + let result = Self::tcp_to_quic(&mut tcp_read, &mut quic_send).await; + if let Ok(ReadResult::Eof) | Err(_) = result { + break; } - } else if !quic_stream_eos { - Self::transfer_data_with_timeout( - &mut quic_recv, - &mut tcp_write, - &mut outbound_buffer, - &mut rx_bytes, - &mut quic_stream_eos, - ) - .await - } else { - Self::transfer_data_with_timeout( - &mut tcp_read, - &mut quic_send, - &mut inbound_buffer, - &mut tx_bytes, - &mut tcp_stream_eos, - ) - .await - }; + } + }); - match result { - Ok(0) => { - if tcp_stream_eos && quic_stream_eos { - break; - } - } - Err(TransferError::Timeout) => { - debug!("[{tag}] tunnel timeout: {index:<3} ↔ {in_addr:<20} | ⟳ {loop_count:<8}| ↑ {tx_bytes:<10} ↓ {rx_bytes:<10}"); + tokio::spawn(async move { + loop { + let result = Self::quic_to_tcp(&mut tcp_write, &mut quic_recv).await; + if let Ok(ReadResult::Eof) | Err(_) = result { break; } - Err(_) => break, - Ok(_) => {} } - } - - debug!("[{tag}] tunnel end : {index:<3} ↔ {in_addr:<20} | ⟳ {loop_count:<8}| ↑ {tx_bytes:<10} ↓ {rx_bytes:<10}"); + }); + } - Ok(()) + async fn tcp_to_quic( + tcp_read: &mut OwnedReadHalf, + quic_send: &mut SendStream, //local_read: &mut OwnedReadHalf, + ) -> Result { + let mut buffer = BUFFER_POOL.alloc_and_fill(8192); + let len_read = tcp_read.read(&mut buffer[..]).await?; + if len_read > 0 { + quic_send.write_all(&buffer[..len_read]).await?; + Ok(ReadResult::Succeeded) + } else { + quic_send.finish().await?; + Ok(ReadResult::Eof) + } } - async fn transfer_data_with_timeout( - reader: &mut R, - writer: &mut W, - buffer: &mut [u8], - out_bytes: &mut u64, - eos_flag: &mut bool, - ) -> Result - where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, - { - match tokio::time::timeout(Duration::from_secs(15), reader.read(buffer)) - .await - .map_err(|_: Elapsed| TransferError::Timeout)? - { - Ok(0) => { - if !*eos_flag { - *eos_flag = true; - writer - .shutdown() - .await - .map_err(|_| TransferError::InternalError)?; - } - Ok(0) - } - Ok(n) => { - *out_bytes += n as u64; - writer - .write_all(&buffer[..n]) - .await - .map_err(|_| TransferError::InternalError)?; - Ok(n) - } - Err(_) => Err(TransferError::InternalError), // Connection mostly reset by peer + async fn quic_to_tcp( + tcp_write: &mut OwnedWriteHalf, + quic_recv: &mut RecvStream, + ) -> Result { + let mut buffer = BUFFER_POOL.alloc_and_fill(8192); + let result = quic_recv.read(&mut buffer[..]).await?; + if let Some(len_read) = result { + tcp_write.write_all(&buffer[..len_read]).await?; + Ok(ReadResult::Succeeded) + } else { + Ok(ReadResult::Eof) } } }