From 28617b09b9b395e95eb3ca42b2db380ff9c45b4f Mon Sep 17 00:00:00 2001 From: dreth Date: Fri, 16 Aug 2024 02:07:01 +0200 Subject: [PATCH] fix modify user reauth bug --- backend/auth/auth.go | 61 +++++++++++++++++++++++++++++++++----------- backend/auth/user.go | 8 +++--- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/backend/auth/auth.go b/backend/auth/auth.go index 752ace4..8db0238 100644 --- a/backend/auth/auth.go +++ b/backend/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "crypto/rand" + "database/sql" "encoding/hex" "fmt" "hbd/encryption" @@ -145,8 +146,14 @@ func Register(c *gin.Context) { // As the user was successfully created, send a telegram message through the bot and ID to confirm the registration telegram.SendTelegramMessage(req.TelegramBotAPIKey, req.TelegramUserID, fmt.Sprintf("🎂 Your user has been successfully registered, through this bot and user ID you'll receive your birthday reminders (if there's any) at %s (Timezone: %s).\n\nIf you encounter any issues using the app or want to give any feedback to us. Please open an issue here: https://github.com/dreth/hbd/issues, thanks and we hope you find the application useful!", req.ReminderTime, req.Timezone)) - // Return the token and the user's details - token, err := GenerateJWT(req.Email, 720) + // Get the JWT duration from the header or use the default + jwtDuration, err := GetJWTDurationFromHeader(c, 720) + if err != nil { + jwtDuration = 720 + } + + // Generate JWT token + token, err := GenerateJWT(req.Email, jwtDuration) if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) { return } else { @@ -213,16 +220,23 @@ func Login(c *gin.Context) { emailHash := encryption.HashStringWithSHA256(req.Email) passwordHash := encryption.HashStringWithSHA256(req.Password) - // Fetch the user with the given email hash and password hash from the database _, err := models.Users( qm.Where("email_hash = ?", emailHash), qm.Where("password_hash = ?", passwordHash), ).One(c.Request.Context(), boil.GetContextDB()) - if err != nil { + + // If no user is found, return a 401 Unauthorized + if err == sql.ErrNoRows { c.JSON(http.StatusUnauthorized, structs.Error{Error: "invalid email or password"}) return } + // Handle other errors separately + if err != nil { + c.JSON(http.StatusInternalServerError, structs.Error{Error: "an unexpected error occurred"}) + return + } + // Set the user email in the context c.Set("Email", req.Email) @@ -238,7 +252,6 @@ func Login(c *gin.Context) { } // Generate JWT token - println(jwtDuration) token, err := GenerateJWT(req.Email, jwtDuration) if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) { return @@ -287,7 +300,7 @@ func Me(c *gin.Context) { // @x-order 4 func ModifyUser(c *gin.Context) { // Retrieve the user from the database - user, err := GetUserByEmail(c) + user, originalEmail, err := GetUserByEmail(c) if helper.HE(c, err, http.StatusUnauthorized, "invalid email", false) { return } @@ -388,12 +401,6 @@ func ModifyUser(c *gin.Context) { return } - // After committing the transaction, emit another JWT token with the new email - token, err := GenerateJWT(req.NewEmail, 720) - if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) { - return - } - // Get user data post-changes userData, err := GetUserData(c) if helper.HE(c, err, http.StatusInternalServerError, "invalid email or password", true) { @@ -401,8 +408,32 @@ func ModifyUser(c *gin.Context) { } // Update the email in the context - if req.NewEmail != "" { - c.Set("Email", req.NewEmail) + if (req.NewEmail != "") || (req.NewPassword != "" && req.NewEmail == "") { + // Possible scenarios + // 1. New email is empty, but password is not + // 2. New email is not empty (regardless of password) + // Case 1: New email is empty, but password is not + if req.NewPassword != "" && req.NewEmail == "" { + req.NewEmail = originalEmail + } + + // Case 2: New email is not empty (regardless of password) + if req.NewEmail != "" { + c.Set("Email", req.NewEmail) + } + + // After committing the transaction, emit another JWT token with the new email + // Get the JWT duration from the header or use the default + jwtDuration, err := GetJWTDurationFromHeader(c, 720) + if err != nil { + jwtDuration = 720 + } + + // Generate JWT token + token, err := GenerateJWT(req.NewEmail, jwtDuration) + if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) { + return + } // Return the new token with the new user data c.JSON(http.StatusOK, structs.LoginSuccess{ @@ -433,7 +464,7 @@ func ModifyUser(c *gin.Context) { // @x-order 5 func DeleteUser(c *gin.Context) { // Retrieve the user from the database - user, err := GetUserByEmail(c) + user, _, err := GetUserByEmail(c) if helper.HE(c, err, http.StatusUnauthorized, "invalid email", false) { return } diff --git a/backend/auth/user.go b/backend/auth/user.go index c2f31ad..04eb3b9 100644 --- a/backend/auth/user.go +++ b/backend/auth/user.go @@ -28,7 +28,7 @@ import ( // 2. Hashes the email using SHA-256. // 3. Queries the database for a user with the given email hash. // 4. Returns the user object or an error if the user is not found or an error occurs. -func GetUserByEmail(c *gin.Context) (*models.User, error) { +func GetUserByEmail(c *gin.Context) (*models.User, string, error) { // Get the email from the context email := c.GetString("Email") @@ -40,10 +40,10 @@ func GetUserByEmail(c *gin.Context) (*models.User, error) { qm.Where("email_hash = ?", emailHash), ).One(c.Request.Context(), boil.GetContextDB()) if err != nil { - return nil, errors.New("invalid email") + return nil, email, errors.New("invalid email") } - return user, nil + return user, email, nil } // GetUserData fetches and returns user data including decrypted Telegram bot API key and user ID, @@ -69,7 +69,7 @@ func GetUserByEmail(c *gin.Context) (*models.User, error) { // - (*structs.UserData, error): A pointer to the UserData struct containing user details and birthdays, or an error. func GetUserData(c *gin.Context) (*structs.UserData, error) { // Get the user by its email - user, err := GetUserByEmail(c) + user, _, err := GetUserByEmail(c) if err != nil { return nil, errors.New("invalid email") }