Skip to content

Commit

Permalink
feat: separate tunnel socket from service, separate udp tunnel config…
Browse files Browse the repository at this point in the history
…, handle tunnel startup failure, handle disabled tunnel
  • Loading branch information
jacobtread committed Sep 10, 2024
1 parent 241dfab commit 2dd6015
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 37 deletions.
34 changes: 28 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ pub struct RuntimeConfig {
pub menu_message: String,
pub dashboard: DashboardConfig,
pub tunnel: TunnelConfig,
pub udp_tunnel: UdpTunnelConfig,
pub api: APIConfig,
pub tunnel_port: u16,
pub external_tunnel_port: Option<u16>,
}

/// Environment variable key to load the config from
Expand Down Expand Up @@ -73,8 +72,6 @@ pub fn load_config() -> Option<Config> {
pub struct Config {
pub host: IpAddr,
pub port: Port,
pub tunnel_port: Port,
pub external_tunnel_port: Option<Port>,
pub qos: QosServerConfig,
pub reverse_proxy: bool,
pub dashboard: DashboardConfig,
Expand All @@ -83,6 +80,7 @@ pub struct Config {
pub logging: LevelFilter,
pub retriever: RetrieverConfig,
pub tunnel: TunnelConfig,
pub udp_tunnel: UdpTunnelConfig,
pub api: APIConfig,
}

Expand All @@ -91,8 +89,6 @@ impl Default for Config {
Self {
host: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
port: 80,
tunnel_port: 9032,
external_tunnel_port: None,
qos: QosServerConfig::default(),
reverse_proxy: false,
dashboard: Default::default(),
Expand All @@ -101,12 +97,15 @@ impl Default for Config {
logging: LevelFilter::Info,
retriever: Default::default(),
tunnel: Default::default(),
udp_tunnel: Default::default(),
api: Default::default()
}
}
}

/// Configuration for how the server should use tunneling
///
/// This option applies to both the HTTP and UDP tunnels
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TunnelConfig {
Expand All @@ -122,6 +121,29 @@ pub enum TunnelConfig {
Disabled,
}

#[derive(Debug, Deserialize)]
pub struct UdpTunnelConfig {
/// Port to bind the UDP tunnel socket to, the socket is bound
/// using the same host as the server
pub port: Port,

/// External facing port, only needed when the port visible to users
/// is different to [UdpTunnelConfig::port]
///
/// For cases such as different exposed port in docker or usage behind
/// a reverse proxy such as NGINX
pub external_port: Option<Port>,
}

impl Default for UdpTunnelConfig {
fn default() -> Self {
Self {
port: 9032,
external_port: None,
}
}
}

/// Configuration for the server QoS setup
#[derive(Debug, Default, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
Expand Down
28 changes: 19 additions & 9 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use crate::{
utils::signing::SigningKey,
};
use axum::{self, Extension};
use config::load_config;
use config::{load_config, TunnelConfig};
use log::{debug, error, info, LevelFilter};
use services::udp_tunnel::create_tunnel_service;
use services::udp_tunnel::{start_udp_tunnel, UdpTunnelService};
use std::{net::SocketAddr, sync::Arc};
use tokio::{join, net::TcpListener, signal};
use utils::logging;
Expand Down Expand Up @@ -39,7 +39,10 @@ async fn main() {
let addr: SocketAddr = SocketAddr::new(config.host, config.port);

// Create tunnel server socket address
let tunnel_addr: SocketAddr = SocketAddr::new(config.host, config.tunnel_port);
let tunnel_addr: SocketAddr = SocketAddr::new(config.host, config.udp_tunnel.port);

// Check if the tunnel is enabled
let tunnel_enabled: bool = !matches!(config.tunnel, TunnelConfig::Disabled);

// Config data persisted to runtime
let runtime_config = RuntimeConfig {
Expand All @@ -50,8 +53,7 @@ async fn main() {
qos: config.qos,
tunnel: config.tunnel,
api: config.api,
tunnel_port: config.tunnel_port,
external_tunnel_port: config.external_tunnel_port,
udp_tunnel: config.udp_tunnel,
};

debug!("QoS server: {:?}", &runtime_config.qos);
Expand All @@ -67,15 +69,23 @@ async fn main() {
let sessions = Arc::new(Sessions::new(signing_key));
let config = Arc::new(runtime_config);
let tunnel_service = Arc::new(TunnelService::default());
let tunnel_service_v2 = create_tunnel_service(sessions.clone(), tunnel_addr).await;
let udp_tunnel_service = Arc::new(UdpTunnelService::new(sessions.clone()));

let game_manager = Arc::new(GameManager::new(
tunnel_service.clone(),
tunnel_service_v2.clone(),
udp_tunnel_service.clone(),
config.clone(),
));
let retriever = Arc::new(retriever);

// Start the tunnel server
if tunnel_enabled {
// Start the tunnel service server
if let Err(err) = start_udp_tunnel(tunnel_addr, udp_tunnel_service.clone()).await {
error!("failed to start udp tunnel server: {}", err);
}
}

// Initialize session router
let mut router = session::routes::router();

Expand All @@ -84,7 +94,7 @@ async fn main() {
router.add_extension(retriever);
router.add_extension(game_manager.clone());
router.add_extension(sessions.clone());
router.add_extension(tunnel_service_v2.clone());
router.add_extension(udp_tunnel_service.clone());

let router = router.build();

Expand All @@ -97,7 +107,7 @@ async fn main() {
.layer(Extension(game_manager))
.layer(Extension(sessions))
.layer(Extension(tunnel_service))
.layer(Extension(tunnel_service_v2))
.layer(Extension(udp_tunnel_service))
.into_make_service_with_connect_info::<SocketAddr>();

info!("Starting server on {} (v{})", addr, VERSION);
Expand Down
5 changes: 4 additions & 1 deletion src/routes/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ pub async fn server_details(
ident: "POCKET_RELAY_SERVER",
version: VERSION,
association,
tunnel_port: config.external_tunnel_port.unwrap_or(config.tunnel_port),
tunnel_port: config
.udp_tunnel
.external_port
.unwrap_or(config.udp_tunnel.port),
})
}

Expand Down
51 changes: 30 additions & 21 deletions src/services/udp_tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,32 @@ type PoolIndex = u8;
/// ID of a pool
type PoolId = GameID;

pub async fn create_tunnel_service(
sessions: Arc<Sessions>,
pub async fn start_udp_tunnel(
tunnel_addr: SocketAddr,
) -> Arc<UdpTunnelService> {
let socket = UdpSocket::bind(tunnel_addr).await.unwrap();
let service = Arc::new(UdpTunnelService::new(socket, sessions));
service: Arc<UdpTunnelService>,
) -> std::io::Result<()> {
let socket = UdpSocket::bind(tunnel_addr).await?;
let socket = Arc::new(socket);

debug!("started tunneling server {tunnel_addr}");

// Spawn the task to handle accepting messages
tokio::spawn(accept_messages(service.clone()));
tokio::spawn(accept_messages(service.clone(), socket.clone()));

// Spawn task to keep connections alive
tokio::spawn(keep_alive(service.clone()));
tokio::spawn(keep_alive(service, socket));

service
Ok(())
}

/// Reads inbound messages from the tunnel service
pub async fn accept_messages(service: Arc<UdpTunnelService>) {
pub async fn accept_messages(service: Arc<UdpTunnelService>, socket: Arc<UdpSocket>) {
// Buffer to recv messages
let mut buffer = [0; u16::MAX as usize];

loop {
// Receive the message bytes
let (size, addr) = match service.socket.recv_from(&mut buffer).await {
let (size, addr) = match socket.recv_from(&mut buffer).await {
Ok(value) => value,
Err(err) => {
if let Some(error_code) = err.raw_os_error() {
Expand Down Expand Up @@ -83,13 +83,13 @@ pub async fn accept_messages(service: Arc<UdpTunnelService>) {

let tunnel_id = packet.header.tunnel_id;

// Handle the message through a background task
let service = service.clone();
let socket = socket.clone();

// Handle the message in its own task
tokio::spawn(async move {
service
.handle_message(tunnel_id, packet.message, addr)
.handle_message(socket, tunnel_id, packet.message, addr)
.await;
});
}
Expand All @@ -104,7 +104,7 @@ const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(KEEP_ALIVE_DELAY.as_sec

/// Background task that sends out keep alive messages to all the sockets connected
/// to the tunnel system. Removes inactive and dead connections
pub async fn keep_alive(service: Arc<UdpTunnelService>) {
pub async fn keep_alive(service: Arc<UdpTunnelService>, socket: Arc<UdpSocket>) {
// Task set for keep alive tasks
let mut send_task_set = JoinSet::new();

Expand Down Expand Up @@ -150,9 +150,9 @@ pub async fn keep_alive(service: Arc<UdpTunnelService>) {

// Spawn the task to send the keep-alive message
send_task_set.spawn({
let service = service.clone();
let socket = socket.clone();

async move { service.socket.send_to(&buffer, addr).await }
async move { socket.send_to(&buffer, addr).await }
});
}

Expand All @@ -170,10 +170,14 @@ pub async fn keep_alive(service: Arc<UdpTunnelService>) {
}
}

/// UDP tunneling service
pub struct UdpTunnelService {
socket: UdpSocket,
/// Next available tunnel ID
next_tunnel_id: AtomicU32,
/// Tunneling mapping data
mappings: RwLock<TunnelMappings>,
/// Access to the session service for exchanging
/// association tokens
sessions: Arc<Sessions>,
}

Expand Down Expand Up @@ -322,9 +326,8 @@ impl TunnelMappings {
}

impl UdpTunnelService {
pub fn new(socket: UdpSocket, sessions: Arc<Sessions>) -> Self {
pub fn new(sessions: Arc<Sessions>) -> Self {
Self {
socket,
next_tunnel_id: AtomicU32::new(0),
mappings: RwLock::new(TunnelMappings::default()),
sessions,
Expand Down Expand Up @@ -364,7 +367,13 @@ impl UdpTunnelService {
}

/// Handles processing a message received through the tunnel
async fn handle_message(&self, tunnel_id: u32, msg: TunnelMessage, addr: SocketAddr) {
async fn handle_message(
&self,
socket: Arc<UdpSocket>,
tunnel_id: u32,
msg: TunnelMessage,
addr: SocketAddr,
) {
// Only process tunnels with known IDs
if tunnel_id != u32::MAX {
// Store the updated tunnel address
Expand Down Expand Up @@ -426,7 +435,7 @@ impl UdpTunnelService {

let buffer = serialize_message(tunnel_id, &TunnelMessage::Initiated { tunnel_id });

self.socket.send_to(&buffer, addr).await.unwrap();
_ = socket.send_to(&buffer, addr).await;
}
TunnelMessage::Initiated { .. } => {
// Server shouldn't be receiving this message... ignore it
Expand All @@ -444,7 +453,7 @@ impl UdpTunnelService {
let buffer =
serialize_message(tunnel_id, &TunnelMessage::Forward { index, message });

self.socket.send_to(&buffer, target_addr).await.unwrap();
_ = socket.send_to(&buffer, target_addr).await;
}
TunnelMessage::KeepAlive => {
// Ack keep alive
Expand Down

0 comments on commit 2dd6015

Please sign in to comment.