diff --git a/Cargo.lock b/Cargo.lock index 9ff8869..f70dc8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,9 +232,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.84" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f8e7c90afad890484a21653d08b6e209ae34770fb5ee298f9c699fcc1e5c856" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "libc", ] @@ -434,9 +434,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "hashbrown" @@ -708,7 +708,7 @@ dependencies = [ [[package]] name = "quincy" -version = "0.4.2" +version = "0.4.3" dependencies = [ "anyhow", "argon2", @@ -940,9 +940,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustls" -version = "0.21.8" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", "ring 0.17.5", @@ -1037,18 +1037,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", @@ -1353,9 +1353,9 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.4" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f751112709b4e791d8ce53e32c4ed2d353565a795ce84da2285393f41557bdf2" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ "log", "once_cell", @@ -1364,9 +1364,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "matchers", "nu-ansi-term", diff --git a/Cargo.toml b/Cargo.toml index 5f11f67..e106775 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "quincy" -version = "0.4.2" +version = "0.4.3" authors = ["Jakub KubĂ­k "] license = "MIT" description = "QUIC-based VPN" diff --git a/src/auth/server.rs b/src/auth/server.rs index 027f3a3..7171367 100644 --- a/src/auth/server.rs +++ b/src/auth/server.rs @@ -1,21 +1,15 @@ use std::{net::IpAddr, sync::Arc, time::Duration}; +use crate::constants::{AUTH_FAILED_MESSAGE, AUTH_MESSAGE_BUFFER_SIZE}; use anyhow::{anyhow, Context, Result}; use bytes::BytesMut; use ipnet::IpNet; use quinn::{Connection, RecvStream, SendStream, VarInt}; use serde::{Deserialize, Serialize}; -use tokio::{io::AsyncReadExt, sync::RwLock, time::timeout}; +use tokio::{io::AsyncReadExt, time::timeout}; use super::{client::AuthClientMessage, user::UserDatabase}; -/// Represents the internal authentication state for a session. -#[derive(Clone, Debug, PartialEq)] -pub enum AuthState { - Unauthenticated, - Authenticated(String), -} - /// Represents an authentication message sent by the server. #[derive(Serialize, Deserialize)] pub enum AuthServerMessage { @@ -27,64 +21,53 @@ pub enum AuthServerMessage { /// Represents an authentication server handling initial authentication and session management. pub struct AuthServer { user_database: Arc, - auth_state: RwLock, + username: Option, client_address: IpNet, connection: Arc, - send_stream: SendStream, - recv_stream: RecvStream, auth_timeout: Duration, } impl AuthServer { - pub async fn new( + pub fn new( user_database: Arc, connection: Arc, client_address: IpNet, auth_timeout: Duration, - ) -> Result { - let (send_stream, recv_stream) = connection.accept_bi().await?; - - Ok(Self { + ) -> Self { + Self { user_database, - auth_state: RwLock::new(AuthState::Unauthenticated), + username: None, client_address, connection, - send_stream, - recv_stream, auth_timeout, - }) + } } /// Handles authentication for a client. pub async fn handle_authentication(&mut self) -> Result<()> { - let message: Option = timeout(self.auth_timeout, self.recv_message()) - .await? - .ok() - .flatten(); - - let state = self.get_state().await; - - match (state, message) { - ( - AuthState::Unauthenticated, - Some(AuthClientMessage::Authentication(username, password)), - ) => self.authenticate_user(username, password).await, - _ => self.handle_failure().await, + let (send_stream, mut recv_stream) = self.connection.accept_bi().await?; + + let message = timeout(self.auth_timeout, Self::recv_message(&mut recv_stream)).await?; + + if let Ok(AuthClientMessage::Authentication(username, password)) = message { + self.authenticate_user(send_stream, username, password) + .await + } else { + self.handle_failure(send_stream).await } } /// Authenticates a user with the given username and password. - async fn authenticate_user(&mut self, username: String, password: String) -> Result<()> { - if self - .user_database - .authenticate(&username, password) - .await - .is_err() - { - self.close_connection("Invalid username or password") - .await?; - - return Err(anyhow!("Invalid username or password")); + async fn authenticate_user( + &mut self, + mut send_stream: SendStream, + username: String, + password: String, + ) -> Result<()> { + let auth_result = self.user_database.authenticate(&username, password).await; + + if auth_result.is_err() { + return self.handle_failure(send_stream).await; } let response = AuthServerMessage::Authenticated( @@ -92,53 +75,44 @@ impl AuthServer { self.client_address.netmask(), ); - self.send_message(response).await?; - self.set_state(AuthState::Authenticated(username)).await; + Self::send_message(&mut send_stream, response).await?; + self.username.replace(username); Ok(()) } /// Handles a failure during authentication. - async fn handle_failure(&mut self) -> Result<()> { - self.close_connection("Authentication failed").await?; + async fn handle_failure(&self, send_stream: SendStream) -> Result<()> { + self.close_connection(send_stream, AUTH_FAILED_MESSAGE) + .await?; - Err(anyhow!("Authentication failed")) + Err(anyhow!(AUTH_FAILED_MESSAGE)) } /// Closes the connection with the given reason. - async fn close_connection(&mut self, reason: &str) -> Result<()> { - self.send_message(AuthServerMessage::Failed).await?; - self.send_stream.finish().await?; + async fn close_connection(&self, mut send_stream: SendStream, reason: &str) -> Result<()> { + Self::send_message(&mut send_stream, AuthServerMessage::Failed).await?; + send_stream.finish().await?; self.connection .close(VarInt::from_u32(0x01), reason.as_bytes()); - self.set_state(AuthState::Unauthenticated).await; - Ok(()) } #[inline] - async fn send_message(&mut self, message: AuthServerMessage) -> Result<()> { - self.send_stream + async fn send_message(send_stream: &mut SendStream, message: AuthServerMessage) -> Result<()> { + send_stream .write_all(&serde_json::to_vec(&message)?) .await .context("Failed to send AuthServerMessage") } #[inline] - async fn recv_message(&mut self) -> Result> { - let mut buf = BytesMut::with_capacity(1024); - self.recv_stream.read_buf(&mut buf).await?; + async fn recv_message(recv_stream: &mut RecvStream) -> Result { + let mut buf = BytesMut::with_capacity(AUTH_MESSAGE_BUFFER_SIZE); + recv_stream.read_buf(&mut buf).await?; serde_json::from_slice(&buf).context("Failed to parse AuthClientMessage") } - - pub async fn get_state(&self) -> AuthState { - self.auth_state.read().await.clone() - } - - async fn set_state(&self, state: AuthState) { - *self.auth_state.write().await = state; - } } diff --git a/src/constants.rs b/src/constants.rs index c6c5310..f0bde64 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -20,6 +20,12 @@ pub const QUIC_MTU_OVERHEAD: u16 = 42; /// Represents the interval used by various cleanup tasks. pub const CLEANUP_INTERVAL: Duration = Duration::from_secs(1); +/// Error message when authentication fails. +pub const AUTH_FAILED_MESSAGE: &str = "Authentication failed"; + +/// Buffer size for authentication messages. +pub const AUTH_MESSAGE_BUFFER_SIZE: usize = 1024; + /// Represents the size of the packet info header on UNIX systems. #[cfg(target_os = "macos")] pub const DARWIN_PI_HEADER_LENGTH: usize = 4; diff --git a/src/server/connection.rs b/src/server/connection.rs index 82fc3cf..8707d29 100644 --- a/src/server/connection.rs +++ b/src/server/connection.rs @@ -1,4 +1,4 @@ -use crate::auth::server::{AuthServer, AuthState}; +use crate::auth::server::AuthServer; use crate::auth::user::UserDatabase; use crate::config::ConnectionConfig; use crate::utils::tasks::join_or_abort_task; @@ -12,14 +12,13 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::RwLock; use tokio::task::JoinHandle; use tracing::{debug, error}; /// Represents a Quincy connection encapsulating authentication and IO. pub struct QuincyConnection { connection: Arc, - auth_server: Arc>, + auth_server: AuthServer, tun_queue: Arc>, tasks: Vec>>, } @@ -33,28 +32,27 @@ impl QuincyConnection { /// - `user_database` - the user database /// - `auth_timeout` - the authentication timeout /// - `client_address` - the assigned client address - pub async fn new( + pub fn new( connection: Connection, connection_config: &ConnectionConfig, tun_queue: Arc>, user_database: Arc, client_address: IpNet, - ) -> Result { + ) -> Self { let connection = Arc::new(connection); let auth_server = AuthServer::new( user_database, connection.clone(), client_address, connection_config.timeout, - ) - .await?; + ); - Ok(Self { + Self { connection, - auth_server: Arc::new(RwLock::new(auth_server)), + auth_server, tun_queue, tasks: Vec::new(), - }) + } } /// Starts the tasks for this instance of Quincy connection. @@ -65,10 +63,12 @@ impl QuincyConnection { )); } + self.auth_server.handle_authentication().await?; + + // Additional authentication checks are not needed if initial authentication succeeded self.tasks.push(tokio::spawn(Self::process_incoming_data( self.connection.clone(), self.tun_queue.clone(), - self.auth_server.clone(), ))); Ok(()) @@ -99,17 +99,8 @@ impl QuincyConnection { /// /// ### Arguments /// - `data` - the data to be sent + #[inline] pub async fn send_datagram(&self, data: Bytes) -> Result<()> { - match self.auth_server.read().await.get_state().await { - AuthState::Authenticated(_) => (), - _ => { - return Err(anyhow!( - "Attempted to send datagram to unauthenticated client {:?}", - self.connection.remote_address(), - )) - } - } - self.connection.send_datagram(data)?; Ok(()) @@ -127,25 +118,11 @@ impl QuincyConnection { /// ### Arguments /// - `connection` - a reference to the underlying QUIC connection /// - `tun_queue` - a sender of an unbounded queue used by the tunnel worker to receive data - /// - `auth_server` - a reference to the authentication server async fn process_incoming_data( connection: Arc, tun_queue: Arc>, - auth_server: Arc>, ) -> Result<()> { - Self::handle_authentication(&auth_server).await?; - loop { - match auth_server.read().await.get_state().await { - AuthState::Authenticated(_) => (), - _ => { - return Err(anyhow!( - "Connection {:?} not authenticated, dropping incoming data", - connection.remote_address(), - )) - } - } - let data = connection.read_datagram().await?; debug!( "Received {} bytes from {:?}", @@ -156,9 +133,4 @@ impl QuincyConnection { tun_queue.send(data)?; } } - - async fn handle_authentication(auth_server: &Arc>) -> Result<()> { - let mut auth_server = auth_server.write().await; - auth_server.handle_authentication().await - } } diff --git a/src/server/tunnel.rs b/src/server/tunnel.rs index 0b76cee..aa6cf41 100644 --- a/src/server/tunnel.rs +++ b/src/server/tunnel.rs @@ -214,8 +214,7 @@ impl QuincyTunnel { write_queue_sender.clone(), user_database.clone(), client_tun_ip, - ) - .await?; + ); connection.start().await?; info!(