From 3ef9c7c98c07d9905578222127358344276913d5 Mon Sep 17 00:00:00 2001 From: benthecarman Date: Thu, 19 Dec 2024 13:05:05 -0600 Subject: [PATCH] Better ban handling --- src/auth.rs | 42 ++++++++++++++++++++++++++++++++++++++++ src/main.rs | 55 +---------------------------------------------------- 2 files changed, 43 insertions(+), 54 deletions(-) diff --git a/src/auth.rs b/src/auth.rs index b1bbbca..a7e2a98 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -66,6 +66,44 @@ impl IntoResponse for AuthError { } } +fn banned_domains() -> Vec { + let mut domains = vec![]; + let file = std::fs::read_to_string("faucet_config/banned_domains.txt"); + if let Ok(file) = file { + for line in file.lines() { + let line = line.trim(); + if !line.is_empty() { + domains.push(line.to_string()); + } + } + } + domains +} + +fn get_banned_users() -> Vec { + let mut banned_users = vec![]; + let file = std::fs::read_to_string("faucet_config/banned_users.txt"); + if let Ok(file) = file { + for line in file.lines() { + let line = line.trim(); + if !line.is_empty() { + banned_users.push(line.to_string()); + } + } + } + banned_users +} + +fn is_banned(email: &String) -> bool { + let domains = banned_domains(); + let user_host = email.split('@').last().unwrap_or(""); + if domains.contains(&user_host.to_lowercase()) { + return true; + } + let banned_users = get_banned_users(); + banned_users.contains(email) +} + // Middleware extractor for authenticated users #[derive(Debug, Clone)] pub struct AuthUser { @@ -101,6 +139,10 @@ pub async fn auth_middleware( return Err(AuthError::TokenExpired); } + if is_banned(&token_data.claims.sub) { + return Err(AuthError::TokenExpired); + } + // Add AuthUser to request extensions request.extensions_mut().insert(AuthUser { username: token_data.claims.sub, diff --git a/src/main.rs b/src/main.rs index f742848..01d0f06 100644 --- a/src/main.rs +++ b/src/main.rs @@ -165,44 +165,6 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -fn banned_domains() -> Vec { - let mut domains = vec![]; - let file = std::fs::read_to_string("faucet_config/banned_domains.txt"); - if let Ok(file) = file { - for line in file.lines() { - let line = line.trim(); - if !line.is_empty() { - domains.push(line.to_string()); - } - } - } - domains -} - -fn get_banned_users() -> Vec { - let mut banned_users = vec![]; - let file = std::fs::read_to_string("faucet_config/banned_users.txt"); - if let Ok(file) = file { - for line in file.lines() { - let line = line.trim(); - if !line.is_empty() { - banned_users.push(line.to_string()); - } - } - } - banned_users -} - -fn is_banned(user: &AuthUser) -> bool { - let domains = banned_domains(); - let user_host = user.username.split('@').last().unwrap_or(""); - if domains.contains(&user_host.to_lowercase()) { - return true; - } - let banned_users = get_banned_users(); - banned_users.contains(&user.username) -} - #[axum::debug_handler] async fn github_auth(Extension(state): Extension) -> Result { let redirect_url = format!( @@ -292,11 +254,8 @@ async fn github_callback( #[axum::debug_handler] async fn auth_check( Extension(_state): Extension, - Extension(user): Extension, + Extension(_user): Extension, ) -> Result, AppError> { - if is_banned(&user) { - return Err(AppError::new("You are banned")); - } Ok(Json(json!({"status": "OK"}))) } @@ -307,10 +266,6 @@ async fn onchain_handler( headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { - if is_banned(&user) { - return Err(AppError::new("You are banned")); - } - // Extract the X-Forwarded-For header let x_forwarded_for = headers .get("x-forwarded-for") @@ -341,10 +296,6 @@ async fn lightning_handler( headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { - if is_banned(&user) { - return Err(AppError::new("You are banned")); - } - // Extract the X-Forwarded-For header let x_forwarded_for = headers .get("x-forwarded-for") @@ -427,10 +378,6 @@ async fn channel_handler( headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { - if is_banned(&user) { - return Err(AppError::new("You are banned")); - } - // Extract the X-Forwarded-For header let x_forwarded_for = headers .get("x-forwarded-for")