From 4cd8816555d82371961796e5ea1995e3ff172da9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederik=20H=C3=B8rgreen?= Date: Wed, 2 Nov 2022 19:44:09 +0100 Subject: [PATCH] main: add role info - add list roles command - add roles associated with users in list users --- README.md | 3 ++- cmd/list.go | 5 +++-- cmd/roles.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++ database/query.go | 37 ++++++++++++++++++++++++++++++--- 4 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 cmd/roles.go diff --git a/README.md b/README.md index 0a753a7..eba77d0 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,9 @@ Features - Deleting existing users - Extending the validity of existing users - Resetting passwords of existing users -- Listing of existing users +- Listing existing users and their associated roles - Listing all configured hosts +- Listing existing roles ## Installation diff --git a/cmd/list.go b/cmd/list.go index 9b87f41..61a7982 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -5,6 +5,7 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "os" + "strings" ) var listCmd = &cobra.Command{ @@ -36,11 +37,11 @@ func listUsersForConnection(conn *database.DBConn) { } table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"username", "valid until"}) + table.SetHeader([]string{"username", "valid until", "roles"}) for _, u := range users { table.Append([]string{ - u.Username, *u.ValidUntil, + u.Username, *u.ValidUntil, strings.Join(u.Roles, ", "), }) } diff --git a/cmd/roles.go b/cmd/roles.go new file mode 100644 index 0000000..f6bc3ba --- /dev/null +++ b/cmd/roles.go @@ -0,0 +1,52 @@ +package cmd + +import ( + "github.com/hiperdk/pg_user/database" + "github.com/olekukonko/tablewriter" + "github.com/spf13/cobra" + "os" +) + +var rolesCmd = &cobra.Command{ + Use: "roles [host]", + Short: "List database roles", + Long: `List database roles for a specific database`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + cmd.Println(cmd.UsageString()) + os.Exit(1) + } + + host := args[0] + + conn := database.GetDatabaseEntryConnection(host) + if conn == nil { + cmd.Println("no hosts found by that name") + os.Exit(1) + } + + listRolesForConnection(conn) + }, +} + +func listRolesForConnection(conn *database.DBConn) { + roles, err := conn.GetAllRoles() + if err != nil { + panic(err) + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"rolename"}) + + for _, r := range roles { + table.Append([]string{ + r, + }) + } + + table.Render() +} + +func init() { + rootCmd.AddCommand(rolesCmd) +} diff --git a/database/query.go b/database/query.go index dfb7f88..2776588 100644 --- a/database/query.go +++ b/database/query.go @@ -1,14 +1,17 @@ package database import ( + "encoding/json" "fmt" "github.com/jmoiron/sqlx" "time" ) type User struct { - Username string `db:"usename"` - ValidUntil *string `db:"valuntil"` + Username string `db:"usename"` + ValidUntil *string `db:"valuntil"` + RolesJson string `db:"roles"` + Roles []string `db:"-"` } func (u *User) ParseValidUntil() (*string, error) { @@ -30,7 +33,16 @@ func (u *User) ParseValidUntil() (*string, error) { func (conn *DBConn) GetAllUsers() ([]User, error) { var users []User - err := conn.db.Select(&users, "SELECT usename, valuntil FROM pg_catalog.pg_user WHERE usesuper = false ORDER BY valuntil, usename") + err := conn.db.Select(&users, ` + SELECT a.rolname AS usename, + a.rolvaliduntil AS valuntil, + json_agg(c.rolname) AS roles + FROM pg_roles a + INNER JOIN pg_auth_members b ON a.oid = b.member + INNER JOIN pg_roles c ON b.roleid = c.oid + GROUP BY 1, 2 + ORDER BY 2, 1 + `) if err != nil { return nil, err } @@ -42,11 +54,30 @@ func (conn *DBConn) GetAllUsers() ([]User, error) { } users[i].ValidUntil = validUntil + + // convert postgres json aggregation to string slice + var roles []string + err = json.Unmarshal([]byte(users[i].RolesJson), &roles) + if err != nil { + return nil, err + } + + users[i].Roles = roles } return users, nil } +func (conn *DBConn) GetAllRoles() ([]string, error) { + var roles []string + err := conn.db.Select(&roles, "SELECT rolname FROM pg_roles ORDER BY rolname") + if err != nil { + return nil, err + } + + return roles, nil +} + func (conn *DBConn) CreateUser(username string, validDuration time.Duration, roles []string) (string, time.Time, error) { tx := conn.db.MustBegin()