Skip to content

Commit

Permalink
Merge pull request #3 from madser123/add-alter-command
Browse files Browse the repository at this point in the history
Add alter command
  • Loading branch information
frederikhs authored Aug 29, 2024
2 parents de6e3ec + dfc4da3 commit 08d06b0
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
95 changes: 95 additions & 0 deletions cmd/alter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package cmd

import (
"encoding/json"
"fmt"
"os"
"strings"

"github.com/spf13/cobra"
)

var alterCmd = &cobra.Command{
Use: "alter [username]",
Short: "Alter a database users roles",
Long: "Alter a database users roles in a specific database",
Run: func(cmd *cobra.Command, args []string) {
username, conn := givenUserModification(cmd, 1, args, true)

add_roles, err := cmd.Flags().GetStringSlice("add")
if err != nil {
cmd.Println(fmt.Errorf("could not alter user: %v", err))
os.Exit(1)
}

remove_roles, err := cmd.Flags().GetStringSlice("remove")
if err != nil {
cmd.Println(fmt.Errorf("could not alter user: %v", err))
os.Exit(1)
}

if len(add_roles) == 0 && len(remove_roles) == 0 {
cmd.Println("at least one role must be added or removed")
os.Exit(1)
}

all_roles := append(add_roles, remove_roles...)

for _, role := range all_roles {
exists, err := conn.RoleExist(role)

if err != nil {
cmd.Println(fmt.Errorf("could not alter user: %v", err))
os.Exit(1)
}

if !exists {
cmd.Println(fmt.Sprintf("role %s does not exist", role))
os.Exit(1)
}
}

tx := conn.BeginTransaction()

conn.AddRole(tx, username, add_roles)

Check failure on line 54 in cmd/alter.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `conn.AddRole` is not checked (errcheck)
conn.RemoveRole(tx, username, remove_roles)

Check failure on line 55 in cmd/alter.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `conn.RemoveRole` is not checked (errcheck)

tx.Commit()

Check failure on line 57 in cmd/alter.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `tx.Commit` is not checked (errcheck)

output := getOutputType(cmd)
if output == OutputTypeJson {
outputAlterJson(cmd, add_roles, remove_roles)
} else if output == OutputTypeTable {
outputAlterTable(cmd, add_roles, remove_roles, username, conn.Config.Host)
}
},
}

func outputAlterTable(cmd *cobra.Command, add_roles, remove_roles []string, username, host string) {
cmd.Println(fmt.Sprintf("successfully altered user %s in %s", username, host))
cmd.Println("Added roles:", add_roles)
cmd.Println("Removed roles:", remove_roles)
}

func outputAlterJson(cmd *cobra.Command, add_roles, remove_roles []string) {
b, err := json.MarshalIndent(struct {
Added []string `json:"added"`
Removed []string `json:"removed"`
}{
Added: add_roles,
Removed: remove_roles,
}, "", strings.Repeat(" ", 4))
if err != nil {
panic(err)
}

cmd.Println(string(b))
}

func init() {
rootCmd.AddCommand(alterCmd)
addRequiredHostFlag(alterCmd)

alterCmd.Flags().StringSlice("add", []string{}, "--add=roleA,roleB (optional)")
alterCmd.Flags().StringSlice("remove", []string{}, "--remove=roleA,roleB (optional)")
}
31 changes: 30 additions & 1 deletion database/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package database
import (
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"time"

"github.com/jmoiron/sqlx"
)

type User struct {
Expand Down Expand Up @@ -180,6 +181,10 @@ func (conn *DBConn) GetUser(username string) (*User, error) {
return &user, err
}

func (conn *DBConn) BeginTransaction() *sqlx.Tx {
return conn.db.MustBegin()
}

func (conn *DBConn) AddRole(tx *sqlx.Tx, username string, roles []string) error {
for _, role := range roles {

Expand All @@ -204,6 +209,30 @@ func (conn *DBConn) AddRole(tx *sqlx.Tx, username string, roles []string) error
return nil
}

func (conn *DBConn) RemoveRole(tx *sqlx.Tx, username string, roles []string) error {
for _, role := range roles {

roleExists, err := conn.RoleExist(role)
if err != nil {
err2 := tx.Rollback()
if err2 != nil {
return err2
}

return err
}

if !roleExists {
return fmt.Errorf("role does not exist: %s", role)
}

sql := fmt.Sprintf("REVOKE \"%s\" from \"%s\"", role, username)
tx.MustExec(sql)
}

return nil
}

func (conn *DBConn) RoleExist(role string) (bool, error) {
var exists bool

Expand Down

0 comments on commit 08d06b0

Please sign in to comment.