Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

take multi-token auth config instead of individual variables #12

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,8 @@ func init() {

// OIDC Flags
serveCmd.Flags().Bool("oidc", true, "use oidc auth")
viperx.MustBindFlag(viper.GetViper(), "oidc.enabled", serveCmd.Flags().Lookup("oidc"))
serveCmd.Flags().String("oidc-aud", "", "expected audience on OIDC JWT")
viperx.MustBindFlag(viper.GetViper(), "oidc.audience", serveCmd.Flags().Lookup("oidc-aud"))
serveCmd.Flags().String("oidc-issuer", "", "expected issuer of OIDC JWT")
viperx.MustBindFlag(viper.GetViper(), "oidc.issuer", serveCmd.Flags().Lookup("oidc-issuer"))
serveCmd.Flags().String("oidc-jwksuri", "", "URI for JWKS listing for JWTs")
viperx.MustBindFlag(viper.GetViper(), "oidc.jwksuri", serveCmd.Flags().Lookup("oidc-jwksuri"))
serveCmd.Flags().String("oidc-roles-claim", "claim", "field containing the permissions of an OIDC JWT")
viperx.MustBindFlag(viper.GetViper(), "oidc.claims.roles", serveCmd.Flags().Lookup("oidc-roles-claim"))
serveCmd.Flags().String("oidc-username-claim", "", "additional fields to output in logs from the JWT token, ex (email)")
viperx.MustBindFlag(viper.GetViper(), "oidc.claims.username", serveCmd.Flags().Lookup("oidc-username-claim"))
ginjwt.BindFlagFromViperInst(viper.GetViper(), "oidc.enabled", serveCmd.Flags().Lookup("oidc"))

// DB Flags
serveCmd.Flags().String("db-encryption-driver", "", "encryption driver uri; 32 byte base64 encoded string, (example: base64key://your-encoded-secret-key)")
viperx.MustBindFlag(viper.GetViper(), "db.encryption_driver", serveCmd.Flags().Lookup("db-encryption-driver"))
Expand Down Expand Up @@ -111,21 +102,22 @@ func serve(ctx context.Context) {
"address", viper.GetString("listen"),
)

logger.Infow("oidc enabled", "oidc", viper.GetString("oidc"))

var authCfgs []ginjwt.AuthConfig
if viper.GetViper().GetBool("oidc.enabled") {
authCfgs, err = ginjwt.GetAuthConfigsFromFlags(viper.GetViper())
if err != nil {
logger.Fatal(err)
}
}

hs := &httpsrv.Server{
Logger: logger.Desugar(),
Listen: viper.GetString("listen"),
Debug: config.AppConfig.Logging.Debug,
DB: db,
SecretsKeeper: keeper,
AuthConfig: ginjwt.AuthConfig{
Enabled: viper.GetBool("oidc.enabled"),
Audience: viper.GetString("oidc.audience"),
Issuer: viper.GetString("oidc.issuer"),
JWKSURI: viper.GetString("oidc.jwksuri"),
LogFields: viper.GetStringSlice("oidc.log"),
RolesClaim: viper.GetString("oidc.claims.roles"),
UsernameClaim: viper.GetString("oidc.claims.username"),
},
Logger: logger.Desugar(),
Listen: viper.GetString("listen"),
Debug: config.AppConfig.Logging.Debug,
DB: db,
AuthConfigs: authCfgs,
}

// init event stream - for now, only when nats.url is specified
Expand Down
7 changes: 4 additions & 3 deletions internal/httpsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/jmoiron/sqlx"
ginprometheus "github.com/zsais/go-gin-prometheus"
"go.hollow.sh/toolbox/events"
"go.hollow.sh/toolbox/ginauth"
"go.hollow.sh/toolbox/ginjwt"
"go.infratographer.com/x/versionx"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
Expand All @@ -27,7 +28,7 @@ type Server struct {
Listen string
Debug bool
DB *sqlx.DB
AuthConfig ginjwt.AuthConfig
AuthConfigs []ginjwt.AuthConfig
SecretsKeeper *secrets.Keeper
EventStream events.Stream
}
Expand All @@ -40,11 +41,11 @@ var (

func (s *Server) setup() *gin.Engine {
var (
authMW *ginjwt.Middleware
authMW *ginauth.MultiTokenMiddleware
err error
)

authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig)
authMW, err = ginjwt.NewMultiTokenMiddlewareFromConfigs(s.AuthConfigs...)
if err != nil {
s.Logger.Sugar().Fatal("failed to initialize auth middleware: ", "error", err)
}
Expand Down
16 changes: 9 additions & 7 deletions internal/httpsrv/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"go.hollow.sh/serverservice/internal/httpsrv"
)

var serverAuthConfig = ginjwt.AuthConfig{
Enabled: false,
var serverAuthConfig = []ginjwt.AuthConfig{
{
Enabled: false,
},
}

func TestUnknownRoute(t *testing.T) {
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig}
s := hs.NewServer()
router := s.Handler

Expand All @@ -33,7 +35,7 @@ func TestUnknownRoute(t *testing.T) {
}

func TestHealthzRoute(t *testing.T) {
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig}
s := hs.NewServer()
router := s.Handler

Expand All @@ -46,7 +48,7 @@ func TestHealthzRoute(t *testing.T) {
}

func TestLivenessRoute(t *testing.T) {
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig}
s := hs.NewServer()
router := s.Handler

Expand All @@ -61,7 +63,7 @@ func TestLivenessRoute(t *testing.T) {
func TestReadinessRouteDown(t *testing.T) {
db, _ := sqlx.Open("postgres", "localhost:12341")

hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig, DB: db}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig, DB: db}
s := hs.NewServer()
router := s.Handler

Expand All @@ -76,7 +78,7 @@ func TestReadinessRouteDown(t *testing.T) {
func TestReadinessRouteUp(t *testing.T) {
db := dbtools.DatabaseTest(t)

hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig, DB: db}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig, DB: db}
s := hs.NewServer()
router := s.Handler

Expand Down
85 changes: 41 additions & 44 deletions pkg/api/v1/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/pkg/errors"
"github.com/volatiletech/sqlboiler/v4/boil"
"go.hollow.sh/toolbox/events"
"go.hollow.sh/toolbox/ginjwt"
"go.hollow.sh/toolbox/ginauth"
"go.uber.org/zap"
"gocloud.dev/secrets"

Expand All @@ -21,7 +21,7 @@ import (

// Router provides a router for the v1 API
type Router struct {
AuthMW *ginjwt.Middleware
AuthMW *ginauth.MultiTokenMiddleware
DB *sqlx.DB
SecretsKeeper *secrets.Keeper
Logger *zap.Logger
Expand All @@ -32,94 +32,91 @@ type Router struct {
func (r *Router) Routes(rg *gin.RouterGroup) {
amw := r.AuthMW

// require all calls to have auth
rg.Use(amw.AuthRequired())

// /servers
srvs := rg.Group("/servers")
{
srvs.GET("", amw.RequiredScopes(readScopes("server")), r.serverList)
srvs.POST("", amw.RequiredScopes(createScopes("server")), r.serverCreate)
srvs.GET("", amw.AuthRequired(readScopes("server")), r.serverList)
srvs.POST("", amw.AuthRequired(createScopes("server")), r.serverCreate)

srvs.GET("/components", amw.RequiredScopes(readScopes("server:component")), r.serverComponentList)
srvs.GET("/components", amw.AuthRequired(readScopes("server:component")), r.serverComponentList)

// /servers/:uuid
srv := srvs.Group("/:uuid")
{
srv.GET("", amw.RequiredScopes(readScopes("server")), r.serverGet)
srv.PUT("", amw.RequiredScopes(updateScopes("server")), r.serverUpdate)
srv.DELETE("", amw.RequiredScopes(deleteScopes("server")), r.serverDelete)
srv.GET("", amw.AuthRequired(readScopes("server")), r.serverGet)
srv.PUT("", amw.AuthRequired(updateScopes("server")), r.serverUpdate)
srv.DELETE("", amw.AuthRequired(deleteScopes("server")), r.serverDelete)

// /servers/:uuid/attributes
srvAttrs := srv.Group("/attributes")
{
srvAttrs.GET("", amw.RequiredScopes(readScopes("server", "server:attributes")), r.serverAttributesList)
srvAttrs.POST("", amw.RequiredScopes(createScopes("server", "server:attributes")), r.serverAttributesCreate)
srvAttrs.GET("/:namespace", amw.RequiredScopes(readScopes("server", "server:attributes")), r.serverAttributesGet)
srvAttrs.PUT("/:namespace", amw.RequiredScopes(updateScopes("server", "server:attributes")), r.serverAttributesUpdate)
srvAttrs.DELETE("/:namespace", amw.RequiredScopes(deleteScopes("server", "server:attributes")), r.serverAttributesDelete)
srvAttrs.GET("", amw.AuthRequired(readScopes("server", "server:attributes")), r.serverAttributesList)
srvAttrs.POST("", amw.AuthRequired(createScopes("server", "server:attributes")), r.serverAttributesCreate)
srvAttrs.GET("/:namespace", amw.AuthRequired(readScopes("server", "server:attributes")), r.serverAttributesGet)
srvAttrs.PUT("/:namespace", amw.AuthRequired(updateScopes("server", "server:attributes")), r.serverAttributesUpdate)
srvAttrs.DELETE("/:namespace", amw.AuthRequired(deleteScopes("server", "server:attributes")), r.serverAttributesDelete)
}

// /servers/:uuid/components
srvComponents := srv.Group("/components")
{
srvComponents.POST("", amw.RequiredScopes(createScopes("server", "server:component")), r.serverComponentsCreate)
srvComponents.GET("", amw.RequiredScopes(readScopes("server", "server:component")), r.serverComponentGet)
srvComponents.PUT("", amw.RequiredScopes(updateScopes("server", "server:component")), r.serverComponentUpdate)
srvComponents.DELETE("", amw.RequiredScopes(deleteScopes("server", "server:component")), r.serverComponentDelete)
srvComponents.POST("", amw.AuthRequired(createScopes("server", "server:component")), r.serverComponentsCreate)
srvComponents.GET("", amw.AuthRequired(readScopes("server", "server:component")), r.serverComponentGet)
srvComponents.PUT("", amw.AuthRequired(updateScopes("server", "server:component")), r.serverComponentUpdate)
srvComponents.DELETE("", amw.AuthRequired(deleteScopes("server", "server:component")), r.serverComponentDelete)
}

// /servers/:uuid/credentials/:slug
svrCreds := srv.Group("credentials/:slug")
{
svrCreds.GET("", amw.RequiredScopes([]string{"read:server:credentials"}), r.serverCredentialGet)
svrCreds.PUT("", amw.RequiredScopes([]string{"write:server:credentials"}), r.serverCredentialUpsert)
svrCreds.DELETE("", amw.RequiredScopes([]string{"write:server:credentials"}), r.serverCredentialDelete)
svrCreds.GET("", amw.AuthRequired([]string{"read:server:credentials"}), r.serverCredentialGet)
svrCreds.PUT("", amw.AuthRequired([]string{"write:server:credentials"}), r.serverCredentialUpsert)
svrCreds.DELETE("", amw.AuthRequired([]string{"write:server:credentials"}), r.serverCredentialDelete)
}

// /servers/:uuid/versioned-attributes
srvVerAttrs := srv.Group("/versioned-attributes")
{
srvVerAttrs.GET("", amw.RequiredScopes(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesList)
srvVerAttrs.POST("", amw.RequiredScopes(createScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesCreate)
srvVerAttrs.GET("/:namespace", amw.RequiredScopes(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesGet)
srvVerAttrs.GET("", amw.AuthRequired(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesList)
srvVerAttrs.POST("", amw.AuthRequired(createScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesCreate)
srvVerAttrs.GET("/:namespace", amw.AuthRequired(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesGet)
}
}
}

// /server-component-types
srvCmpntType := rg.Group("/server-component-types")
{
srvCmpntType.GET("", amw.RequiredScopes(readScopes("server-component-types")), r.serverComponentTypeList)
srvCmpntType.POST("", amw.RequiredScopes(updateScopes("server-component-types")), r.serverComponentTypeCreate)
srvCmpntType.GET("", amw.AuthRequired(readScopes("server-component-types")), r.serverComponentTypeList)
srvCmpntType.POST("", amw.AuthRequired(updateScopes("server-component-types")), r.serverComponentTypeCreate)
}

// /server-component-firmwares
srvCmpntFw := rg.Group("/server-component-firmwares")
{
srvCmpntFw.GET("", amw.RequiredScopes(readScopes("server-component-firmwares")), r.serverComponentFirmwareList)
srvCmpntFw.POST("", amw.RequiredScopes(createScopes("server-component-firmwares")), r.serverComponentFirmwareCreate)
srvCmpntFw.GET("/:uuid", amw.RequiredScopes(readScopes("server-component-firmwares")), r.serverComponentFirmwareGet)
srvCmpntFw.PUT("/:uuid", amw.RequiredScopes(updateScopes("server-component-firmwares")), r.serverComponentFirmwareUpdate)
srvCmpntFw.DELETE("/:uuid", amw.RequiredScopes(deleteScopes("server-component-firmwares")), r.serverComponentFirmwareDelete)
srvCmpntFw.GET("", amw.AuthRequired(readScopes("server-component-firmwares")), r.serverComponentFirmwareList)
srvCmpntFw.POST("", amw.AuthRequired(createScopes("server-component-firmwares")), r.serverComponentFirmwareCreate)
srvCmpntFw.GET("/:uuid", amw.AuthRequired(readScopes("server-component-firmwares")), r.serverComponentFirmwareGet)
srvCmpntFw.PUT("/:uuid", amw.AuthRequired(updateScopes("server-component-firmwares")), r.serverComponentFirmwareUpdate)
srvCmpntFw.DELETE("/:uuid", amw.AuthRequired(deleteScopes("server-component-firmwares")), r.serverComponentFirmwareDelete)
}

// /server-credential-types
srvCredentialTypes := rg.Group("/server-credential-types")
{
srvCredentialTypes.GET("", amw.RequiredScopes(readScopes("server-credential-types")), r.serverCredentialTypesList)
srvCredentialTypes.POST("", amw.RequiredScopes(createScopes("server-credential-types")), r.serverCredentialTypesCreate)
srvCredentialTypes.GET("", amw.AuthRequired(readScopes("server-credential-types")), r.serverCredentialTypesList)
srvCredentialTypes.POST("", amw.AuthRequired(createScopes("server-credential-types")), r.serverCredentialTypesCreate)
}

// /server-component-firmware-sets
srvCmpntFwSets := rg.Group("/server-component-firmware-sets")
{
srvCmpntFwSets.GET("", amw.RequiredScopes(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetList)
srvCmpntFwSets.POST("", amw.RequiredScopes(createScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetCreate)
srvCmpntFwSets.GET("/:uuid", amw.RequiredScopes(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetGet)
srvCmpntFwSets.PUT("/:uuid", amw.RequiredScopes(updateScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetUpdate)
srvCmpntFwSets.DELETE("/:uuid", amw.RequiredScopes(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetDelete)
srvCmpntFwSets.POST("/:uuid/remove-firmware", amw.RequiredScopes(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetRemoveFirmware)
srvCmpntFwSets.GET("", amw.AuthRequired(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetList)
srvCmpntFwSets.POST("", amw.AuthRequired(createScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetCreate)
srvCmpntFwSets.GET("/:uuid", amw.AuthRequired(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetGet)
srvCmpntFwSets.PUT("/:uuid", amw.AuthRequired(updateScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetUpdate)
srvCmpntFwSets.DELETE("/:uuid", amw.AuthRequired(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetDelete)
srvCmpntFwSets.POST("/:uuid/remove-firmware", amw.AuthRequired(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetRemoveFirmware)
}

// /bill-of-materials
Expand All @@ -128,19 +125,19 @@ func (r *Router) Routes(rg *gin.RouterGroup) {
// /bill-of-materials/batch-boms-upload
uploadFile := srvBoms.Group("/batch-upload")
{
uploadFile.POST("", amw.RequiredScopes(createScopes("batch-upload")), r.bomsUpload)
uploadFile.POST("", amw.AuthRequired(createScopes("batch-upload")), r.bomsUpload)
}

// /bill-of-materials/aoc-mac-address
srvBomByAocMacAddress := srvBoms.Group("/aoc-mac-address")
{
srvBomByAocMacAddress.GET("/:aoc_mac_address", amw.RequiredScopes(readScopes("aoc-mac-address")), r.getBomFromAocMacAddress)
srvBomByAocMacAddress.GET("/:aoc_mac_address", amw.AuthRequired(readScopes("aoc-mac-address")), r.getBomFromAocMacAddress)
}

// /bill-of-materials/bmc-mac-address
srvBomByBmcMacAddress := srvBoms.Group("/bmc-mac-address")
{
srvBomByBmcMacAddress.GET("/:bmc_mac_address", amw.RequiredScopes(readScopes("bmc-mac-address")), r.getBomFromBmcMacAddress)
srvBomByBmcMacAddress.GET("/:bmc_mac_address", amw.AuthRequired(readScopes("bmc-mac-address")), r.getBomFromBmcMacAddress)
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/api/v1/router_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ func serverTest(t *testing.T) *integrationServer {
hs := httpsrv.Server{
Logger: l,
DB: db,
AuthConfig: ginjwt.AuthConfig{
Enabled: true,
Audience: "hollow.test",
Issuer: "hollow.test.issuer",
JWKSURI: jwksURI,
RolesClaim: "userPerms",
AuthConfigs: []ginjwt.AuthConfig{
{
Enabled: true,
Audience: "hollow.test",
Issuer: "hollow.test.issuer",
JWKSURI: jwksURI,
RolesClaim: "userPerms",
},
},
SecretsKeeper: dbtools.TestSecretKeeper(t),
}
Expand Down
15 changes: 15 additions & 0 deletions sample_oidc_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
oidc:
- audience: "https://audience1"
issuer: "https://issuer1"
jwksuri: "https://jwkuri1"
enabled: true
claims:
roles: "role1"
username: "user1"
- audience: "https://audience2"
issuer: "https://issuer2"
jwksuri: "https://jwkuri2"
enabled: true
claims:
roles: "role2"
username: "user2"
Loading