Skip to content

Commit

Permalink
refactor(auth/server): simplify server-side auth (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
M0dEx authored Nov 24, 2023
1 parent 38e7da2 commit f9205c6
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 125 deletions.
30 changes: 15 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "quincy"
version = "0.4.2"
version = "0.4.3"
authors = ["Jakub Kubík <[email protected]>"]
license = "MIT"
description = "QUIC-based VPN"
Expand Down
108 changes: 41 additions & 67 deletions src/auth/server.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
use std::{net::IpAddr, sync::Arc, time::Duration};

use crate::constants::{AUTH_FAILED_MESSAGE, AUTH_MESSAGE_BUFFER_SIZE};
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use ipnet::IpNet;
use quinn::{Connection, RecvStream, SendStream, VarInt};
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncReadExt, sync::RwLock, time::timeout};
use tokio::{io::AsyncReadExt, time::timeout};

use super::{client::AuthClientMessage, user::UserDatabase};

/// Represents the internal authentication state for a session.
#[derive(Clone, Debug, PartialEq)]
pub enum AuthState {
Unauthenticated,
Authenticated(String),
}

/// Represents an authentication message sent by the server.
#[derive(Serialize, Deserialize)]
pub enum AuthServerMessage {
Expand All @@ -27,118 +21,98 @@ pub enum AuthServerMessage {
/// Represents an authentication server handling initial authentication and session management.
pub struct AuthServer {
user_database: Arc<UserDatabase>,
auth_state: RwLock<AuthState>,
username: Option<String>,
client_address: IpNet,
connection: Arc<Connection>,
send_stream: SendStream,
recv_stream: RecvStream,
auth_timeout: Duration,
}

impl AuthServer {
pub async fn new(
pub fn new(
user_database: Arc<UserDatabase>,
connection: Arc<Connection>,
client_address: IpNet,
auth_timeout: Duration,
) -> Result<Self> {
let (send_stream, recv_stream) = connection.accept_bi().await?;

Ok(Self {
) -> Self {
Self {
user_database,
auth_state: RwLock::new(AuthState::Unauthenticated),
username: None,
client_address,
connection,
send_stream,
recv_stream,
auth_timeout,
})
}
}

/// Handles authentication for a client.
pub async fn handle_authentication(&mut self) -> Result<()> {
let message: Option<AuthClientMessage> = timeout(self.auth_timeout, self.recv_message())
.await?
.ok()
.flatten();

let state = self.get_state().await;

match (state, message) {
(
AuthState::Unauthenticated,
Some(AuthClientMessage::Authentication(username, password)),
) => self.authenticate_user(username, password).await,
_ => self.handle_failure().await,
let (send_stream, mut recv_stream) = self.connection.accept_bi().await?;

let message = timeout(self.auth_timeout, Self::recv_message(&mut recv_stream)).await?;

if let Ok(AuthClientMessage::Authentication(username, password)) = message {
self.authenticate_user(send_stream, username, password)
.await
} else {
self.handle_failure(send_stream).await
}
}

/// Authenticates a user with the given username and password.
async fn authenticate_user(&mut self, username: String, password: String) -> Result<()> {
if self
.user_database
.authenticate(&username, password)
.await
.is_err()
{
self.close_connection("Invalid username or password")
.await?;

return Err(anyhow!("Invalid username or password"));
async fn authenticate_user(
&mut self,
mut send_stream: SendStream,
username: String,
password: String,
) -> Result<()> {
let auth_result = self.user_database.authenticate(&username, password).await;

if auth_result.is_err() {
return self.handle_failure(send_stream).await;
}

let response = AuthServerMessage::Authenticated(
self.client_address.addr(),
self.client_address.netmask(),
);

self.send_message(response).await?;
self.set_state(AuthState::Authenticated(username)).await;
Self::send_message(&mut send_stream, response).await?;
self.username.replace(username);

Ok(())
}

/// Handles a failure during authentication.
async fn handle_failure(&mut self) -> Result<()> {
self.close_connection("Authentication failed").await?;
async fn handle_failure(&self, send_stream: SendStream) -> Result<()> {
self.close_connection(send_stream, AUTH_FAILED_MESSAGE)
.await?;

Err(anyhow!("Authentication failed"))
Err(anyhow!(AUTH_FAILED_MESSAGE))
}

/// Closes the connection with the given reason.
async fn close_connection(&mut self, reason: &str) -> Result<()> {
self.send_message(AuthServerMessage::Failed).await?;
self.send_stream.finish().await?;
async fn close_connection(&self, mut send_stream: SendStream, reason: &str) -> Result<()> {
Self::send_message(&mut send_stream, AuthServerMessage::Failed).await?;
send_stream.finish().await?;

self.connection
.close(VarInt::from_u32(0x01), reason.as_bytes());

self.set_state(AuthState::Unauthenticated).await;

Ok(())
}

#[inline]
async fn send_message(&mut self, message: AuthServerMessage) -> Result<()> {
self.send_stream
async fn send_message(send_stream: &mut SendStream, message: AuthServerMessage) -> Result<()> {
send_stream
.write_all(&serde_json::to_vec(&message)?)
.await
.context("Failed to send AuthServerMessage")
}

#[inline]
async fn recv_message(&mut self) -> Result<Option<AuthClientMessage>> {
let mut buf = BytesMut::with_capacity(1024);
self.recv_stream.read_buf(&mut buf).await?;
async fn recv_message(recv_stream: &mut RecvStream) -> Result<AuthClientMessage> {
let mut buf = BytesMut::with_capacity(AUTH_MESSAGE_BUFFER_SIZE);
recv_stream.read_buf(&mut buf).await?;

serde_json::from_slice(&buf).context("Failed to parse AuthClientMessage")
}

pub async fn get_state(&self) -> AuthState {
self.auth_state.read().await.clone()
}

async fn set_state(&self, state: AuthState) {
*self.auth_state.write().await = state;
}
}
6 changes: 6 additions & 0 deletions src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ pub const QUIC_MTU_OVERHEAD: u16 = 42;
/// Represents the interval used by various cleanup tasks.
pub const CLEANUP_INTERVAL: Duration = Duration::from_secs(1);

/// Error message when authentication fails.
pub const AUTH_FAILED_MESSAGE: &str = "Authentication failed";

/// Buffer size for authentication messages.
pub const AUTH_MESSAGE_BUFFER_SIZE: usize = 1024;

/// Represents the size of the packet info header on UNIX systems.
#[cfg(target_os = "macos")]
pub const DARWIN_PI_HEADER_LENGTH: usize = 4;
Expand Down
Loading

0 comments on commit f9205c6

Please sign in to comment.