Skip to content

Commit

Permalink
fix: initiate reply, add session link, provide port through server de…
Browse files Browse the repository at this point in the history
…tails
  • Loading branch information
jacobtread committed Sep 4, 2024
1 parent 6de69bd commit 0930682
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct RuntimeConfig {
pub dashboard: DashboardConfig,
pub tunnel: TunnelConfig,
pub api: APIConfig,
pub tunnel_port: u16,
}

/// Environment variable key to load the config from
Expand Down
9 changes: 5 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,29 @@ async fn main() {
qos: config.qos,
tunnel: config.tunnel,
api: config.api,
tunnel_port: config.tunnel_port,
};

debug!("QoS server: {:?}", &runtime_config.qos);

// This step may take longer than expected so its spawned instead of joined
tokio::spawn(logging::log_connection_urls(config.port));

let (db, retriever, signing_key, tunnel_service_v2) = join!(
let (db, retriever, signing_key) = join!(
database::init(&runtime_config),
Retriever::start(config.retriever),
SigningKey::global(),
create_tunnel_service(tunnel_addr),
);

let sessions = Arc::new(Sessions::new(signing_key));
let config = Arc::new(runtime_config);
let tunnel_service = Arc::new(TunnelService::default());
let tunnel_service_v2 = create_tunnel_service(sessions.clone(), tunnel_addr).await;

let game_manager = Arc::new(GameManager::new(
tunnel_service.clone(),
tunnel_service_v2.clone(),
config.clone(),
));
let sessions = Arc::new(Sessions::new(signing_key));
let retriever = Arc::new(retriever);

// Initialize session router
Expand Down
8 changes: 7 additions & 1 deletion src/routes/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,25 @@ pub struct ServerDetails {
version: &'static str,
/// Random association token for the client to use
association: String,
/// Port the tunnel server is running on
tunnel_port: u16,
}

/// GET /api/server
///
/// Handles providing the server details. The Pocket Relay client tool
/// uses this endpoint to validate that the provided host is a valid
/// Pocket Relay server.
pub async fn server_details(Extension(sessions): Extension<Arc<Sessions>>) -> Json<ServerDetails> {
pub async fn server_details(
Extension(sessions): Extension<Arc<Sessions>>,
Extension(config): Extension<Arc<RuntimeConfig>>,
) -> Json<ServerDetails> {
let association = sessions.create_assoc_token();
Json(ServerDetails {
ident: "POCKET_RELAY_SERVER",
version: VERSION,
association,
tunnel_port: config.tunnel_port,
})
}

Expand Down
47 changes: 38 additions & 9 deletions src/services/udp_tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
};

use codec::{MessageHeader, MessageReader, MessageWriter, TunnelMessage};
use log::error;
use log::{debug, error};
use parking_lot::RwLock;
use tokio::{
net::UdpSocket,
Expand All @@ -19,7 +19,7 @@ use uuid::Uuid;

use crate::utils::{hashing::IntHashMap, types::GameID};

use super::sessions::AssociationId;
use super::sessions::{AssociationId, Sessions};

/// The port bound on clients representing the host player within the socket pool
pub const _TUNNEL_HOST_LOCAL_PORT: u16 = 42132;
Expand All @@ -31,9 +31,14 @@ type PoolIndex = u8;
/// ID of a pool
type PoolId = GameID;

pub async fn create_tunnel_service(tunnel_addr: SocketAddr) -> Arc<TunnelServiceV2> {
pub async fn create_tunnel_service(
sessions: Arc<Sessions>,
tunnel_addr: SocketAddr,
) -> Arc<TunnelServiceV2> {
let socket = UdpSocket::bind(tunnel_addr).await.unwrap();
let service = Arc::new(TunnelService::new(socket));
let service = Arc::new(TunnelService::new(socket, sessions));

debug!("started tunneling server {tunnel_addr}");

// Spawn the task to handle accepting messages
tokio::spawn(accept_messages(service.clone()));
Expand Down Expand Up @@ -64,15 +69,21 @@ pub async fn accept_messages(service: Arc<TunnelService>) {
let header = match MessageHeader::read(&mut reader) {
Ok(value) => value,
Err(_err) => {
error!("invalid message header");
continue;
}
};

let message = match TunnelMessage::read(&mut reader) {
Ok(value) => value,
Err(_err) => continue,
Err(_err) => {
error!("invalid message");
continue;
}
};

debug!("got message: {:?}", message);

let tunnel_id = header.tunnel_id;

// Handle the message through a background task
Expand All @@ -83,7 +94,7 @@ pub async fn accept_messages(service: Arc<TunnelService>) {
}
}

const KEEP_ALIVE_DELAY: Duration = Duration::from_secs(10);
const KEEP_ALIVE_DELAY: Duration = Duration::from_secs(5);
const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(60);

pub async fn keep_alive(service: Arc<TunnelService>) {
Expand Down Expand Up @@ -153,6 +164,7 @@ pub struct TunnelService {
socket: UdpSocket,
next_tunnel_id: AtomicU32,
mappings: RwLock<TunnelMappings>,
sessions: Arc<Sessions>,
}

pub struct TunnelData {
Expand Down Expand Up @@ -288,11 +300,12 @@ impl TunnelMappings {
}

impl TunnelService {
pub fn new(socket: UdpSocket) -> Self {
pub fn new(socket: UdpSocket, sessions: Arc<Sessions>) -> Self {
Self {
socket,
next_tunnel_id: AtomicU32::new(0),
mappings: RwLock::new(TunnelMappings::default()),
sessions,
}
}

Expand Down Expand Up @@ -332,7 +345,7 @@ impl TunnelService {
async fn handle_message(&self, tunnel_id: u32, msg: TunnelMessage, addr: SocketAddr) {
match msg {
TunnelMessage::Initiate { association_token } => {
let association = match Uuid::parse_str(&association_token) {
let association = match self.sessions.verify_assoc_token(&association_token) {
Ok(value) => value,
Err(_err) => {
return;
Expand All @@ -355,6 +368,20 @@ impl TunnelService {
last_alive: Instant::now(),
},
);

let mut buffer = MessageWriter::default();

let header = MessageHeader {
tunnel_id,
version: 0,
};
let message = TunnelMessage::Initiated { tunnel_id: id };

// Write header and message
header.write(&mut buffer);
message.write(&mut buffer);

self.socket.send_to(&buffer.buffer, addr).await.unwrap();
}
TunnelMessage::Initiated { .. } => {
// Server shouldn't be receiving this message... ignore it
Expand Down Expand Up @@ -534,6 +561,7 @@ mod codec {
}
}

#[derive(Debug)]
pub enum TunnelMessage {
/// Client is requesting to initiate a connection
Initiate {
Expand Down Expand Up @@ -605,7 +633,7 @@ mod codec {
debug_assert!(association_token.len() < u16::MAX as usize);
buf.write_u8(MessageType::Initiate as u8);

buf.write_u32(association_token.len() as u32);
buf.write_u16(association_token.len() as u16);
buf.write_bytes(association_token.as_bytes());
}
TunnelMessage::Initiated { tunnel_id } => {
Expand All @@ -618,6 +646,7 @@ mod codec {

buf.write_u8(*index);
buf.write_u16(message.len() as u16);
buf.write_bytes(message);
}
TunnelMessage::KeepAlive => {
buf.write_u8(MessageType::KeepAlive as u8);
Expand Down

0 comments on commit 0930682

Please sign in to comment.