Skip to content

Commit

Permalink
refactor tcp server
Browse files Browse the repository at this point in the history
  • Loading branch information
neevek committed Sep 15, 2024
1 parent ed4e99e commit 212921f
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 154 deletions.
58 changes: 27 additions & 31 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
access_server::ChannelMessage,
pem_util, socket_addr_with_unspecified_ip_port,
tcp::tcp_server::ChannelMessage,
tunnel_info_bridge::{TunnelInfo, TunnelInfoBridge, TunnelInfoType, TunnelTraffic},
AccessServer, ClientConfig, ControlStream, SelectedCipherSuite, Tunnel, TunnelMessage,
ClientConfig, ControlStream, SelectedCipherSuite, TcpServer, Tunnel, TunnelMessage,
TUNNEL_MODE_OUT,
};
use anyhow::{bail, Context, Result};
Expand Down Expand Up @@ -66,7 +66,7 @@ impl Display for ClientState {
struct ThreadSafeState {
remote_conn: Option<Arc<tokio::sync::RwLock<quinn::Connection>>>,
ctrl_stream: Option<ControlStream>,
access_server: Option<AccessServer>,
tcp_server: Option<TcpServer>,
client_state: ClientState,
channel_message_sender: Option<Sender<Option<ChannelMessage>>>,
total_traffic_data: TunnelTraffic,
Expand All @@ -79,7 +79,7 @@ impl ThreadSafeState {
fn new() -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(Self {
ctrl_stream: None,
access_server: None,
tcp_server: None,
remote_conn: None,
client_state: ClientState::Idle,
channel_message_sender: None,
Expand Down Expand Up @@ -125,32 +125,28 @@ impl Client {
.block_on(async { self.connect_and_serve().await });
}

pub async fn start_access_server(self: &Arc<Self>) -> Result<SocketAddr> {
pub async fn start_tcp_server(self: &Arc<Self>) -> Result<SocketAddr> {
self.post_tunnel_log("preparing...");
self.set_and_post_tunnel_state(ClientState::Preparing);

// create a local access server for 'out' tunnel
// create a local tcp server for 'out' tunnel
if self.config.mode == TUNNEL_MODE_OUT {
let mut access_server =
AccessServer::new(self.config.local_access_server_addr.unwrap());
let bound_addr = access_server.bind().await?;

access_server.start().await?;
let tcp_server =
TcpServer::bind_and_start(self.config.local_tcp_server_addr.unwrap()).await?;
let addr = tcp_server.addr();

info!("==========================================================");
info!("[TunnelOut] access server bound to: {bound_addr}");
info!("[TunnelOut] tcp server bound to: {addr}");
info!("==========================================================");

self.post_tunnel_log(
format!("Tunnel access server for [TunnelOut] bound to: {bound_addr}").as_str(),
);
self.post_tunnel_log(format!("tcp server for [TunnelOut] bound to: {addr}").as_str());

inner_state!(self, channel_message_sender) = Some(access_server.clone_tcp_sender());
inner_state!(self, access_server) = Some(access_server);
return Ok(bound_addr);
inner_state!(self, channel_message_sender) = Some(tcp_server.clone_tcp_sender());
inner_state!(self, tcp_server) = Some(tcp_server);
return Ok(addr);
}

bail!("call start_access_server() for TunnelOut mode only")
bail!("call start_tcp_server() for TunnelOut mode only")
}

pub fn get_config(self: &Arc<Self>) -> ClientConfig {
Expand All @@ -165,7 +161,7 @@ impl Client {
});
}
None => {
log_and_bail!("access server not started");
log_and_bail!("tcp server not started");
}
};

Expand Down Expand Up @@ -298,16 +294,16 @@ impl Client {
async fn serve_outgoing(self: &Arc<Self>) -> Result<()> {
self.post_tunnel_log("start serving in [TunnelOut] mode...");
self.report_traffic_data_in_background().await;
if inner_state!(self, access_server).is_none() {
self.start_access_server().await?;
if inner_state!(self, tcp_server).is_none() {
self.start_tcp_server().await?;
}
let mut access_server = inner_state!(self, access_server).take().unwrap();
let mut tcp_server = inner_state!(self, tcp_server).take().unwrap();
let remote_conn = inner_state!(self, remote_conn).clone().unwrap();
let remote_conn = remote_conn.read().await;
access_server.set_drop_conn(false);
tcp_server.set_active(true);

// accept local connections and build a tunnel to remote
while let Some(ChannelMessage::Request(tcp_stream)) = access_server.recv().await {
while let Some(ChannelMessage::Request(tcp_stream)) = tcp_server.recv().await {
match remote_conn.open_bi().await {
Ok(quic_stream) => Tunnel::new().start(true, tcp_stream, quic_stream),
Err(e) => {
Expand All @@ -320,9 +316,9 @@ impl Client {
}
}

// the access server will be reused when tunnel reconnects
access_server.set_drop_conn(true);
inner_state!(self, access_server) = Some(access_server);
// the tcp server will be reused when tunnel reconnects
tcp_server.set_active(false);
inner_state!(self, tcp_server) = Some(tcp_server);

let stats = remote_conn.stats();
let data = &mut inner_state!(self, total_traffic_data);
Expand All @@ -345,12 +341,12 @@ impl Client {
let remote_conn = inner_state!(self, remote_conn).clone().unwrap();
let remote_conn = remote_conn.read().await;
while let Ok(quic_stream) = remote_conn.accept_bi().await {
match TcpStream::connect(self.config.local_access_server_addr.unwrap()).await {
match TcpStream::connect(self.config.local_tcp_server_addr.unwrap()).await {
Ok(tcp_stream) => Tunnel::new().start(false, tcp_stream, quic_stream),
Err(e) => {
error!(
"failed to connect to access server: {e}, {}",
self.config.local_access_server_addr.unwrap(),
"failed to connect to tcp server: {e}, {}",
self.config.local_tcp_server_addr.unwrap(),
);
}
}
Expand Down
16 changes: 8 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
mod access_server;
mod client;
mod pem_util;
mod server;
mod tcp;
mod tunnel;
mod tunnel_info_bridge;
mod tunnel_message;

pub use access_server::AccessServer;
use anyhow::{bail, Context, Result};
use byte_pool::BytePool;
pub use client::Client;
Expand All @@ -21,6 +20,7 @@ use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::{net::SocketAddr, ops::Deref};
pub use tcp::tcp_server::TcpServer;
pub use tunnel::Tunnel;
pub use tunnel_message::{LoginInfo, TunnelMessage};

Expand Down Expand Up @@ -100,7 +100,7 @@ impl Deref for SelectedCipherSuite {
#[derive(Debug)]
pub enum TunnelType {
Out((quinn::Connection, SocketAddr)),
In((quinn::Connection, AccessServer, ControlStream)),
In((quinn::Connection, TcpServer, ControlStream)),
}

#[derive(Debug)]
Expand All @@ -111,7 +111,7 @@ pub struct ControlStream {

#[derive(Debug, Default, Clone)]
pub struct ClientConfig {
pub local_access_server_addr: Option<SocketAddr>,
pub local_tcp_server_addr: Option<SocketAddr>,
pub cert_path: String,
pub cipher: String,
pub server_addr: String,
Expand Down Expand Up @@ -198,19 +198,19 @@ impl ClientConfig {
};

config.login_msg = if mode == TUNNEL_MODE_IN {
config.local_access_server_addr = sock_addr_mapping[1];
config.local_tcp_server_addr = sock_addr_mapping[1];
Some(TunnelMessage::ReqInLogin(LoginInfo {
password: password.to_string(),
access_server_addr: sock_addr_mapping[0],
tcp_server_addr: sock_addr_mapping[0],
}))
} else {
if sock_addr_mapping[0].is_none() {
log_and_bail!("'ANY' is not allowed as local access server for OUT tunneling");
}
config.local_access_server_addr = sock_addr_mapping[0];
config.local_tcp_server_addr = sock_addr_mapping[0];
Some(TunnelMessage::ReqOutLogin(LoginInfo {
password: password.to_string(),
access_server_addr: sock_addr_mapping[1],
tcp_server_addr: sock_addr_mapping[1],
}))
};

Expand Down
Loading

0 comments on commit 212921f

Please sign in to comment.