Skip to content

Commit

Permalink
refactor(db): return user on created and updated (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
ecrupper committed Aug 30, 2023
1 parent d2c5db6 commit 3547e2a
Show file tree
Hide file tree
Showing 25 changed files with 82 additions and 83 deletions.
4 changes: 2 additions & 2 deletions api/admin/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func UpdateUser(c *gin.Context) {
}

// send API call to update the user
err = database.FromContext(c).UpdateUser(input)
u, err := database.FromContext(c).UpdateUser(input)
if err != nil {
retErr := fmt.Errorf("unable to update user %d: %w", input.GetID(), err)

Expand All @@ -75,5 +75,5 @@ func UpdateUser(c *gin.Context) {
return
}

c.JSON(http.StatusOK, input)
c.JSON(http.StatusOK, u)
}
4 changes: 2 additions & 2 deletions api/auth/get_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func GetAuthToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to create the user in the database
err = database.FromContext(c).CreateUser(u)
_, err = database.FromContext(c).CreateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to create user %s: %w", u.GetName(), err)

Expand Down Expand Up @@ -154,7 +154,7 @@ func GetAuthToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to update the user in the database
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
2 changes: 1 addition & 1 deletion api/auth/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func Logout(c *gin.Context) {
u.SetRefreshToken("")

// send API call to update the user in the database
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
5 changes: 1 addition & 4 deletions api/user/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func CreateUser(c *gin.Context) {
}).Infof("creating new user %s", input.GetName())

// send API call to create the user
err = database.FromContext(c).CreateUser(input)
user, err := database.FromContext(c).CreateUser(input)
if err != nil {
retErr := fmt.Errorf("unable to create user: %w", err)

Expand All @@ -81,8 +81,5 @@ func CreateUser(c *gin.Context) {
return
}

// send API call to capture the created user
user, _ := database.FromContext(c).GetUserForName(input.GetName())

c.JSON(http.StatusCreated, user)
}
2 changes: 1 addition & 1 deletion api/user/create_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func CreateToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
2 changes: 1 addition & 1 deletion api/user/delete_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func DeleteToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
5 changes: 1 addition & 4 deletions api/user/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func UpdateUser(c *gin.Context) {
}

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
u, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", user, err)

Expand All @@ -117,8 +117,5 @@ func UpdateUser(c *gin.Context) {
return
}

// send API call to capture the updated user
u, _ = database.FromContext(c).GetUserForName(user)

c.JSON(http.StatusOK, u)
}
12 changes: 1 addition & 11 deletions api/user/update_current.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func UpdateCurrentUser(c *gin.Context) {
}

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
u, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand All @@ -91,15 +91,5 @@ func UpdateCurrentUser(c *gin.Context) {
return
}

// send API call to capture the updated user
u, err = database.FromContext(c).GetUserForName(u.GetName())
if err != nil {
retErr := fmt.Errorf("unable to get updated user %s: %w", u.GetName(), err)

util.HandleError(c, http.StatusNotFound, retErr)

return
}

c.JSON(http.StatusOK, u)
}
9 changes: 2 additions & 7 deletions database/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) {

// create the users
for _, user := range resources.Users {
err := db.CreateUser(user)
_, err := db.CreateUser(user)
if err != nil {
t.Errorf("unable to create user %d: %v", user.GetID(), err)
}
Expand Down Expand Up @@ -1711,16 +1711,11 @@ func testUsers(t *testing.T, db Interface, resources *Resources) {
// update the users
for _, user := range resources.Users {
user.SetActive(false)
err = db.UpdateUser(user)
got, err := db.UpdateUser(user)
if err != nil {
t.Errorf("unable to update user %d: %v", user.GetID(), err)
}

// lookup the user by ID
got, err := db.GetUser(user.GetID())
if err != nil {
t.Errorf("unable to get user %d by ID: %v", user.GetID(), err)
}
if !reflect.DeepEqual(got, user) {
t.Errorf("GetUser() is %v, want %v", got, user)
}
Expand Down
4 changes: 2 additions & 2 deletions database/user/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ func TestUser_Engine_CountUsers(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_userOne)
_, err := _sqlite.CreateUser(_userOne)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}

err = _sqlite.CreateUser(_userTwo)
_, err = _sqlite.CreateUser(_userTwo)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
19 changes: 12 additions & 7 deletions database/user/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

// CreateUser creates a new user in the database.
func (e *engine) CreateUser(u *library.User) error {
func (e *engine) CreateUser(u *library.User) (*library.User, error) {
e.logger.WithFields(logrus.Fields{
"user": u.GetName(),
}).Tracef("creating user %s in the database", u.GetName())
Expand All @@ -30,20 +30,25 @@ func (e *engine) CreateUser(u *library.User) error {
// https://pkg.go.dev/github.com/go-vela/types/database#User.Validate
err := user.Validate()
if err != nil {
return err
return nil, err
}

// encrypt the fields for the user
//
// https://pkg.go.dev/github.com/go-vela/types/database#User.Encrypt
err = user.Encrypt(e.config.EncryptionKey)
if err != nil {
return fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
return nil, fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
}

// send query to the database
return e.client.
Table(constants.TableUser).
Create(user).
Error
result := e.client.Table(constants.TableUser).Create(user)

// decrypt fields to return user
err = user.Decrypt(e.config.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt user %s: %w", u.GetName(), err)
}

return user.ToLibrary(), result.Error
}
7 changes: 6 additions & 1 deletion database/user/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package user

import (
"reflect"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand Down Expand Up @@ -55,7 +56,7 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING "id"`).
// run tests
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := test.database.CreateUser(_user)
got, err := test.database.CreateUser(_user)

if test.failure {
if err == nil {
Expand All @@ -68,6 +69,10 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING "id"`).
if err != nil {
t.Errorf("CreateUser for %s returned err: %v", test.name, err)
}

if !reflect.DeepEqual(got, _user) {
t.Errorf("CreateUser for %s returned %s, want %s", test.name, got, _user)
}
})
}
}
2 changes: 1 addition & 1 deletion database/user/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestUser_Engine_DeleteUser(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion database/user/get_name_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestUser_Engine_GetUserForName(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion database/user/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestUser_Engine_GetUser(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions database/user/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type UserInterface interface {
// CountUsers defines a function that gets the count of all users.
CountUsers() (int64, error)
// CreateUser defines a function that creates a new user.
CreateUser(*library.User) error
CreateUser(*library.User) (*library.User, error)
// DeleteUser defines a function that deletes an existing user.
DeleteUser(*library.User) error
// GetUser defines a function that gets a user by ID.
Expand All @@ -41,5 +41,5 @@ type UserInterface interface {
// ListLiteUsers defines a function that gets a lite list of users.
ListLiteUsers(int, int) ([]*library.User, int64, error)
// UpdateUser defines a function that updates an existing user.
UpdateUser(*library.User) error
UpdateUser(*library.User) (*library.User, error)
}
4 changes: 2 additions & 2 deletions database/user/list_lite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func TestUser_Engine_ListLiteUsers(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_userOne)
_, err := _sqlite.CreateUser(_userOne)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}

err = _sqlite.CreateUser(_userTwo)
_, err = _sqlite.CreateUser(_userTwo)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions database/user/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func TestUser_Engine_ListUsers(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_userOne)
_, err := _sqlite.CreateUser(_userOne)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}

err = _sqlite.CreateUser(_userTwo)
_, err = _sqlite.CreateUser(_userTwo)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
19 changes: 12 additions & 7 deletions database/user/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

// UpdateUser updates an existing user in the database.
func (e *engine) UpdateUser(u *library.User) error {
func (e *engine) UpdateUser(u *library.User) (*library.User, error) {
e.logger.WithFields(logrus.Fields{
"user": u.GetName(),
}).Tracef("updating user %s in the database", u.GetName())
Expand All @@ -30,20 +30,25 @@ func (e *engine) UpdateUser(u *library.User) error {
// https://pkg.go.dev/github.com/go-vela/types/database#User.Validate
err := user.Validate()
if err != nil {
return err
return nil, err
}

// encrypt the fields for the user
//
// https://pkg.go.dev/github.com/go-vela/types/database#User.Encrypt
err = user.Encrypt(e.config.EncryptionKey)
if err != nil {
return fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
return nil, fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
}

// send query to the database
return e.client.
Table(constants.TableUser).
Save(user).
Error
result := e.client.Table(constants.TableUser).Save(user)

// decrypt fields to return user
err = user.Decrypt(e.config.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt user %s: %w", u.GetName(), err)
}

return user.ToLibrary(), result.Error
}
9 changes: 7 additions & 2 deletions database/user/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package user

import (
"reflect"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand All @@ -31,7 +32,7 @@ WHERE "id" = $8`).
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand All @@ -57,7 +58,7 @@ WHERE "id" = $8`).
// run tests
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err = test.database.UpdateUser(_user)
got, err := test.database.UpdateUser(_user)

if test.failure {
if err == nil {
Expand All @@ -70,6 +71,10 @@ WHERE "id" = $8`).
if err != nil {
t.Errorf("UpdateUser for %s returned err: %v", test.name, err)
}

if !reflect.DeepEqual(got, _user) {
t.Errorf("UpdateUser for %s returned %s, want %s", test.name, got, _user)
}
})
}
}
Loading

0 comments on commit 3547e2a

Please sign in to comment.