From 7ba5e72e6a5bf92dd17050f5eaecc5eedd1eecba Mon Sep 17 00:00:00 2001 From: Meng Zhang <meng@tabbyml.com> Date: Thu, 22 Feb 2024 11:30:42 -0800 Subject: [PATCH] refactor: cleanup service level error as we're handling input validation at graphql level. (#1513) * refactor out TokenAuthError * refactor out RefreshTokenError * refactor out RegisterError * refactor out PasswordResetError * fix test * fix typos --- ee/tabby-webserver/src/schema/auth.rs | 217 +++++++++++------------ ee/tabby-webserver/src/schema/mod.rs | 37 ++-- ee/tabby-webserver/src/service/auth.rs | 227 ++++--------------------- 3 files changed, 159 insertions(+), 322 deletions(-) diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 09445f216a59..dd39d9e9e7fb 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -1,12 +1,10 @@ -use std::fmt::Debug; +use std::{borrow::Cow, fmt::Debug}; use anyhow::Result; use async_trait::async_trait; use chrono::{DateTime, Utc}; use jsonwebtoken as jwt; -use juniper::{ - FieldError, GraphQLEnum, GraphQLInputObject, GraphQLObject, IntoFieldError, ScalarValue, ID, -}; +use juniper::{GraphQLEnum, GraphQLInputObject, GraphQLObject, ID}; use juniper_axum::relay; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; @@ -15,9 +13,8 @@ use thiserror::Error; use tokio::task::JoinHandle; use tracing::{error, warn}; use uuid::Uuid; -use validator::{Validate, ValidationErrors}; +use validator::Validate; -use super::from_validation_errors; use crate::schema::Context; lazy_static! { @@ -80,33 +77,6 @@ impl RegisterResponse { } } -#[derive(Error, Debug)] -pub enum RegisterError { - #[error("Invalid input parameters")] - InvalidInput(#[from] ValidationErrors), - - #[error("Invitation code is not valid")] - InvalidInvitationCode, - - #[error("Email is already registered")] - DuplicateEmail, - - #[error(transparent)] - Other(#[from] anyhow::Error), - - #[error("Unknown error")] - Unknown, -} - -impl<S: ScalarValue> IntoFieldError<S> for RegisterError { - fn into_field_error(self) -> FieldError<S> { - match self { - Self::InvalidInput(errors) => from_validation_errors(errors), - _ => self.into(), - } - } -} - #[derive(Debug, GraphQLObject)] pub struct TokenAuthResponse { access_token: String, @@ -122,46 +92,67 @@ impl TokenAuthResponse { } } -#[derive(Error, Debug)] -pub enum TokenAuthError { - #[error("Invalid input parameters")] - InvalidInput(#[from] ValidationErrors), - - #[error("User not found")] - UserNotFound, - - #[error("Password is not valid")] - InvalidPassword, - - #[error("User is disabled")] - UserDisabled, - - #[error(transparent)] - Other(#[from] anyhow::Error), - - #[error("Unknown error")] - Unknown, -} - -#[derive(Error, Debug)] -pub enum PasswordResetError { - #[error("Invalid code")] - InvalidCode, - #[error("Invalid password")] - InvalidInput(#[from] ValidationErrors), - #[error(transparent)] - Other(#[from] anyhow::Error), - #[error("Unknown error")] - Unknown, +/// Input parameters for token_auth mutation +/// See `RegisterInput` for `validate` attribute usage +#[derive(Validate)] +pub struct TokenAuthInput { + #[validate(email(code = "email", message = "Email is invalid"))] + #[validate(length( + max = 128, + code = "email", + message = "Email must be at most 128 characters" + ))] + pub email: String, + #[validate(length( + min = 8, + code = "password", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "password", + message = "Password must be at most 20 characters" + ))] + pub password: String, } -impl<S: ScalarValue> IntoFieldError<S> for PasswordResetError { - fn into_field_error(self) -> FieldError<S> { - match self { - Self::InvalidInput(errors) => from_validation_errors(errors), - _ => self.into(), - } - } +/// Input parameters for register mutation +/// `validate` attribute is used to validate the input parameters +/// - `code` argument specifies which parameter causes the failure +/// - `message` argument provides client friendly error message +/// +#[derive(Validate)] +pub struct RegisterInput { + #[validate(email(code = "email", message = "Email is invalid"))] + #[validate(length( + max = 128, + code = "email", + message = "Email must be at most 128 characters" + ))] + pub email: String, + #[validate(length( + min = 8, + code = "password1", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "password1", + message = "Password must be at most 20 characters" + ))] + #[validate(custom = "validate_password")] + pub password1: String, + #[validate(must_match( + code = "password2", + message = "Passwords do not match", + other = "password1" + ))] + #[validate(length( + max = 20, + code = "password2", + message = "Password must be at most 20 characters" + ))] + pub password2: String, } #[derive(Default, Serialize)] @@ -191,42 +182,6 @@ pub enum OAuthError { Unknown, } -impl<S: ScalarValue> IntoFieldError<S> for TokenAuthError { - fn into_field_error(self) -> FieldError<S> { - match self { - Self::InvalidInput(errors) => from_validation_errors(errors), - _ => self.into(), - } - } -} - -#[derive(Error, Debug)] -pub enum RefreshTokenError { - #[error("Invalid refresh token")] - InvalidRefreshToken, - - #[error("Expired refresh token")] - ExpiredRefreshToken, - - #[error("User not found")] - UserNotFound, - - #[error("User is disabled")] - UserDisabled, - - #[error(transparent)] - Other(#[from] anyhow::Error), - - #[error("Unknown error")] - Unknown, -} - -impl<S: ScalarValue> IntoFieldError<S> for RefreshTokenError { - fn into_field_error(self) -> FieldError<S> { - self.into() - } -} - #[derive(Debug, GraphQLObject)] pub struct RefreshTokenResponse { pub access_token: String, @@ -412,21 +367,13 @@ pub trait AuthenticationService: Send + Sync { &self, email: String, password1: String, - password2: String, invitation_code: Option<String>, - ) -> std::result::Result<RegisterResponse, RegisterError>; + ) -> Result<RegisterResponse>; async fn allow_self_signup(&self) -> Result<bool>; - async fn token_auth( - &self, - email: String, - password: String, - ) -> std::result::Result<TokenAuthResponse, TokenAuthError>; + async fn token_auth(&self, email: String, password: String) -> Result<TokenAuthResponse>; - async fn refresh_token( - &self, - refresh_token: String, - ) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>; + async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>; async fn delete_expired_token(&self) -> Result<()>; async fn delete_expired_password_resets(&self) -> Result<()>; async fn verify_access_token(&self, access_token: &str) -> Result<JWTPayload>; @@ -438,7 +385,7 @@ pub trait AuthenticationService: Send + Sync { async fn delete_invitation(&self, id: &ID) -> Result<ID>; async fn reset_user_auth_token(&self, email: &str) -> Result<()>; - async fn password_reset(&self, code: &str, password: &str) -> Result<(), PasswordResetError>; + async fn password_reset(&self, code: &str, password: &str) -> Result<()>; async fn request_password_reset_email(&self, email: String) -> Result<Option<JoinHandle<()>>>; async fn list_users( @@ -477,6 +424,38 @@ pub trait AuthenticationService: Send + Sync { async fn update_user_role(&self, id: &ID, is_admin: bool) -> Result<()>; } +fn validate_password(value: &str) -> Result<(), validator::ValidationError> { + let make_validation_error = |message: &'static str| { + let mut err = validator::ValidationError::new("password1"); + err.message = Some(Cow::Borrowed(message)); + Err(err) + }; + + let contains_lowercase = value.chars().any(|x| x.is_ascii_lowercase()); + if !contains_lowercase { + return make_validation_error("Password should contain at least one lowercase character"); + } + + let contains_uppercase = value.chars().any(|x| x.is_ascii_uppercase()); + if !contains_uppercase { + return make_validation_error("Password should contain at least one uppercase character"); + } + + let contains_digit = value.chars().any(|x| x.is_ascii_digit()); + if !contains_digit { + return make_validation_error("Password should contain at least one numeric character"); + } + + let contains_special_char = value.chars().any(|x| x.is_ascii_punctuation()); + if !contains_special_char { + return make_validation_error( + "Password should contain at least one special character, e.g @#$%^&{}", + ); + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 5e09b52ca0f4..f05b32f73be9 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -9,8 +9,8 @@ pub mod worker; use std::sync::Arc; use auth::{ - validate_jwt, AuthenticationService, Invitation, RefreshTokenError, RefreshTokenResponse, - RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, User, + validate_jwt, AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, + TokenAuthResponse, User, }; use job::{JobRun, JobService}; use juniper::{ @@ -383,19 +383,33 @@ impl Mutation { password1: String, password2: String, invitation_code: Option<String>, - ) -> Result<RegisterResponse, RegisterError> { - ctx.locator + ) -> Result<RegisterResponse> { + let input = auth::RegisterInput { + email, + password1, + password2, + }; + input.validate()?; + + Ok(ctx + .locator .auth() - .register(email, password1, password2, invitation_code) - .await + .register(input.email, input.password1, invitation_code) + .await?) } async fn token_auth( ctx: &Context, email: String, password: String, - ) -> Result<TokenAuthResponse, TokenAuthError> { - ctx.locator.auth().token_auth(email, password).await + ) -> Result<TokenAuthResponse> { + let input = auth::TokenAuthInput { email, password }; + input.validate()?; + Ok(ctx + .locator + .auth() + .token_auth(input.email, input.password) + .await?) } async fn verify_token(ctx: &Context, token: String) -> Result<bool> { @@ -403,11 +417,8 @@ impl Mutation { Ok(true) } - async fn refresh_token( - ctx: &Context, - refresh_token: String, - ) -> Result<RefreshTokenResponse, RefreshTokenError> { - ctx.locator.auth().refresh_token(refresh_token).await + async fn refresh_token(ctx: &Context, refresh_token: String) -> Result<RefreshTokenResponse> { + Ok(ctx.locator.auth().refresh_token(refresh_token).await?) } async fn create_invitation(ctx: &Context, email: String) -> Result<ID> { diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 45efd725d714..ed22d0450381 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, sync::Arc}; +use std::sync::Arc; use anyhow::{anyhow, Context, Result}; use argon2::{ @@ -12,7 +12,6 @@ use juniper::ID; use tabby_db::{DbConn, InvitationDAO}; use tokio::task::JoinHandle; use tracing::warn; -use validator::{Validate, ValidationError}; use super::{graphql_pagination_to_filter, AsID, AsRowid}; use crate::{ @@ -21,8 +20,7 @@ use crate::{ auth::{ generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Invitation, JWTPayload, OAuthCredential, OAuthError, OAuthProvider, OAuthResponse, - PasswordResetError, RefreshTokenError, RefreshTokenResponse, RegisterError, - RegisterResponse, RequestInvitationInput, TokenAuthError, TokenAuthResponse, + RefreshTokenResponse, RegisterResponse, RequestInvitationInput, TokenAuthResponse, UpdateOAuthCredentialInput, User, }, email::{EmailService, SendEmailError}, @@ -30,45 +28,6 @@ use crate::{ }, }; -/// Input parameters for register mutation -/// `validate` attribute is used to validate the input parameters -/// - `code` argument specifies which parameter causes the failure -/// - `message` argument provides client friendly error message -/// -#[derive(Validate)] -struct RegisterInput { - #[validate(email(code = "email", message = "Email is invalid"))] - #[validate(length( - max = 128, - code = "email", - message = "Email must be at most 128 characters" - ))] - email: String, - #[validate(length( - min = 8, - code = "password1", - message = "Password must be at least 8 characters" - ))] - #[validate(length( - max = 20, - code = "password1", - message = "Password must be at most 20 characters" - ))] - #[validate(custom = "validate_password")] - password1: String, - #[validate(must_match( - code = "password2", - message = "Passwords do not match", - other = "password1" - ))] - #[validate(length( - max = 20, - code = "password2", - message = "Password must be at most 20 characters" - ))] - password2: String, -} - #[derive(Clone)] struct AuthenticationServiceImpl { db: DbConn, @@ -82,119 +41,31 @@ pub fn new_authentication_service( AuthenticationServiceImpl { db, mail } } -impl std::fmt::Debug for RegisterInput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RegisterInput") - .field("email", &self.email) - .field("password1", &"********") - .field("password2", &"********") - .finish() - } -} - -fn validate_password(value: &str) -> Result<(), ValidationError> { - let make_validation_error = |message: &'static str| { - let mut err = ValidationError::new("password1"); - err.message = Some(Cow::Borrowed(message)); - Err(err) - }; - - let contains_lowercase = value.chars().any(|x| x.is_ascii_lowercase()); - if !contains_lowercase { - return make_validation_error("Password should contains at least one lowercase character"); - } - - let contains_uppercase = value.chars().any(|x| x.is_ascii_uppercase()); - if !contains_uppercase { - return make_validation_error("Password should contains at least one uppercase character"); - } - - let contains_digit = value.chars().any(|x| x.is_ascii_digit()); - if !contains_digit { - return make_validation_error("Password should contains at least one numeric character"); - } - - let contains_special_char = value.chars().any(|x| x.is_ascii_punctuation()); - if !contains_special_char { - return make_validation_error( - "Password should contains at least one special character, e.g @#$%^&{}", - ); - } - - Ok(()) -} - -/// Input parameters for token_auth mutation -/// See `RegisterInput` for `validate` attribute usage -#[derive(Validate)] -struct TokenAuthInput { - #[validate(email(code = "email", message = "Email is invalid"))] - #[validate(length( - max = 128, - code = "email", - message = "Email must be at most 128 characters" - ))] - email: String, - #[validate(length( - min = 8, - code = "password", - message = "Password must be at least 8 characters" - ))] - #[validate(length( - max = 20, - code = "password", - message = "Password must be at most 20 characters" - ))] - password: String, -} - -impl std::fmt::Debug for TokenAuthInput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TokenAuthInput") - .field("email", &self.email) - .field("password", &"********") - .finish() - } -} - #[async_trait] impl AuthenticationService for AuthenticationServiceImpl { async fn register( &self, email: String, - password1: String, - password2: String, + password: String, invitation_code: Option<String>, - ) -> std::result::Result<RegisterResponse, RegisterError> { - let input = RegisterInput { - email, - password1, - password2, - }; - input.validate()?; - + ) -> Result<RegisterResponse> { let is_admin_initialized = self.is_admin_initialized().await?; - let invitation = check_invitation( - &self.db, - is_admin_initialized, - invitation_code, - &input.email, - ) - .await?; + let invitation = + check_invitation(&self.db, is_admin_initialized, invitation_code, &email).await?; // check if email exists - if self.db.get_user_by_email(&input.email).await?.is_some() { - return Err(RegisterError::DuplicateEmail); + if self.db.get_user_by_email(&email).await?.is_some() { + return Err(anyhow!("Email is already registered")); } - let Ok(pwd_hash) = password_hash(&input.password1) else { - return Err(RegisterError::Unknown); + let Ok(pwd_hash) = password_hash(&password) else { + return Err(anyhow!("Unknown error")); }; let id = if let Some(invitation) = invitation { self.db .create_user_with_invitation( - input.email.clone(), + email.clone(), pwd_hash, !is_admin_initialized, invitation.id, @@ -202,7 +73,7 @@ impl AuthenticationService for AuthenticationServiceImpl { .await? } else { self.db - .create_user(input.email.clone(), pwd_hash, !is_admin_initialized) + .create_user(email.clone(), pwd_hash, !is_admin_initialized) .await? }; @@ -213,7 +84,7 @@ impl AuthenticationService for AuthenticationServiceImpl { let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) else { - return Err(RegisterError::Unknown); + return Err(anyhow!("Unknown error")); }; let resp = RegisterResponse::new(access_token, refresh_token); @@ -254,9 +125,8 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(Some(handle)) } - async fn password_reset(&self, code: &str, password: &str) -> Result<(), PasswordResetError> { - let password_encrypted = - password_hash(password).map_err(|_| PasswordResetError::Unknown)?; + async fn password_reset(&self, code: &str, password: &str) -> Result<()> { + let password_encrypted = password_hash(password).map_err(|_| anyhow!("Unknown error"))?; let user_id = self.db.verify_password_reset(code).await?; self.db.delete_password_reset_by_user_id(user_id).await?; @@ -266,24 +136,17 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(()) } - async fn token_auth( - &self, - email: String, - password: String, - ) -> std::result::Result<TokenAuthResponse, TokenAuthError> { - let input = TokenAuthInput { email, password }; - input.validate()?; - - let Some(user) = self.db.get_user_by_email(&input.email).await? else { - return Err(TokenAuthError::UserNotFound); + async fn token_auth(&self, email: String, password: String) -> Result<TokenAuthResponse> { + let Some(user) = self.db.get_user_by_email(&email).await? else { + return Err(anyhow!("User not found")); }; if !user.active { - return Err(TokenAuthError::UserDisabled); + return Err(anyhow!("User is disabled")); } - if !password_verify(&input.password, &user.password_encrypted) { - return Err(TokenAuthError::InvalidPassword); + if !password_verify(&password, &user.password_encrypted) { + return Err(anyhow!("Password is not valid")); } let refresh_token = generate_refresh_token(); @@ -293,29 +156,26 @@ impl AuthenticationService for AuthenticationServiceImpl { let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) else { - return Err(TokenAuthError::Unknown); + return Err(anyhow!("Unknown error")); }; let resp = TokenAuthResponse::new(access_token, refresh_token); Ok(resp) } - async fn refresh_token( - &self, - token: String, - ) -> std::result::Result<RefreshTokenResponse, RefreshTokenError> { + async fn refresh_token(&self, token: String) -> Result<RefreshTokenResponse> { let Some(refresh_token) = self.db.get_refresh_token(&token).await? else { - return Err(RefreshTokenError::InvalidRefreshToken); + return Err(anyhow!("Invalid refresh token")); }; if refresh_token.is_expired() { - return Err(RefreshTokenError::ExpiredRefreshToken); + return Err(anyhow!("Expired refresh token")); } let Some(user) = self.db.get_user(refresh_token.user_id).await? else { - return Err(RefreshTokenError::UserNotFound); + return Err(anyhow!("User not found")); }; if !user.active { - return Err(RefreshTokenError::UserDisabled); + return Err(anyhow!("User is disabled")); } let new_token = generate_refresh_token(); @@ -324,7 +184,7 @@ impl AuthenticationService for AuthenticationServiceImpl { // refresh token update is done, generate new access token based on user info let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) else { - return Err(RefreshTokenError::Unknown); + return Err(anyhow!("Unknown error")); }; let resp = RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at); @@ -561,13 +421,13 @@ async fn check_invitation( is_admin_initialized: bool, invitation_code: Option<String>, email: &str, -) -> Result<Option<InvitationDAO>, RegisterError> { +) -> Result<Option<InvitationDAO>> { if !is_admin_initialized { // Creating the admin user, no invitation required return Ok(None); } - let err = Err(RegisterError::InvalidInvitationCode); + let err = Err(anyhow!("Invitation code is not valid")); let Some(invitation_code) = invitation_code else { return err; }; @@ -652,12 +512,7 @@ mod tests { async fn register_admin_user(service: &AuthenticationServiceImpl) -> RegisterResponse { service - .register( - ADMIN_EMAIL.to_owned(), - ADMIN_PASSWORD.to_owned(), - ADMIN_PASSWORD.to_owned(), - None, - ) + .register(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned(), None) .await .unwrap() } @@ -669,7 +524,7 @@ mod tests { service .token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned()) .await, - Err(TokenAuthError::UserNotFound) + Err(_) ); register_admin_user(&service).await; @@ -678,7 +533,7 @@ mod tests { service .token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned()) .await, - Err(TokenAuthError::InvalidPassword) + Err(_) ); let resp1 = service @@ -712,14 +567,9 @@ mod tests { // Admin initialized, registeration requires a invitation code; assert_matches!( service - .register( - email.to_owned(), - password.to_owned(), - password.to_owned(), - None - ) + .register(email.to_owned(), password.to_owned(), None) .await, - Err(RegisterError::InvalidInvitationCode) + Err(_) ); // Invalid invitation code won't work. @@ -728,11 +578,10 @@ mod tests { .register( email.to_owned(), password.to_owned(), - password.to_owned(), Some("abc".to_owned()) ) .await, - Err(RegisterError::InvalidInvitationCode) + Err(_) ); // Register success. @@ -740,7 +589,6 @@ mod tests { .register( email.to_owned(), password.to_owned(), - password.to_owned(), Some(invitation.code.clone()), ) .await @@ -752,11 +600,10 @@ mod tests { .register( email.to_owned(), password.to_owned(), - password.to_owned(), Some(invitation.code.clone()) ) .await, - Err(RegisterError::InvalidInvitationCode) + Err(_) ); // Used invitation should have been deleted, following delete attempt should fail.