From c5d78b4ded239073aed380d3fa0ac63a136b85e6 Mon Sep 17 00:00:00 2001 From: Alva Zhang <140438385+Alva8756@users.noreply.github.com> Date: Tue, 16 Jan 2024 09:55:45 -0800 Subject: [PATCH] support multi oauth (#12) --- cmd/serve.go | 42 +++++++--------- internal/httpsrv/server.go | 7 +-- internal/httpsrv/server_test.go | 16 ++++--- pkg/api/v1/router.go | 85 ++++++++++++++++----------------- pkg/api/v1/router_int_test.go | 14 +++--- sample_oidc_config.yaml | 15 ++++++ 6 files changed, 94 insertions(+), 85 deletions(-) create mode 100644 sample_oidc_config.yaml diff --git a/cmd/serve.go b/cmd/serve.go index a235b0c..ba9e893 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -47,17 +47,8 @@ func init() { // OIDC Flags serveCmd.Flags().Bool("oidc", true, "use oidc auth") - viperx.MustBindFlag(viper.GetViper(), "oidc.enabled", serveCmd.Flags().Lookup("oidc")) - serveCmd.Flags().String("oidc-aud", "", "expected audience on OIDC JWT") - viperx.MustBindFlag(viper.GetViper(), "oidc.audience", serveCmd.Flags().Lookup("oidc-aud")) - serveCmd.Flags().String("oidc-issuer", "", "expected issuer of OIDC JWT") - viperx.MustBindFlag(viper.GetViper(), "oidc.issuer", serveCmd.Flags().Lookup("oidc-issuer")) - serveCmd.Flags().String("oidc-jwksuri", "", "URI for JWKS listing for JWTs") - viperx.MustBindFlag(viper.GetViper(), "oidc.jwksuri", serveCmd.Flags().Lookup("oidc-jwksuri")) - serveCmd.Flags().String("oidc-roles-claim", "claim", "field containing the permissions of an OIDC JWT") - viperx.MustBindFlag(viper.GetViper(), "oidc.claims.roles", serveCmd.Flags().Lookup("oidc-roles-claim")) - serveCmd.Flags().String("oidc-username-claim", "", "additional fields to output in logs from the JWT token, ex (email)") - viperx.MustBindFlag(viper.GetViper(), "oidc.claims.username", serveCmd.Flags().Lookup("oidc-username-claim")) + ginjwt.BindFlagFromViperInst(viper.GetViper(), "oidc.enabled", serveCmd.Flags().Lookup("oidc")) + // DB Flags serveCmd.Flags().String("db-encryption-driver", "", "encryption driver uri; 32 byte base64 encoded string, (example: base64key://your-encoded-secret-key)") viperx.MustBindFlag(viper.GetViper(), "db.encryption_driver", serveCmd.Flags().Lookup("db-encryption-driver")) @@ -111,21 +102,22 @@ func serve(ctx context.Context) { "address", viper.GetString("listen"), ) + logger.Infow("oidc enabled", "oidc", viper.GetString("oidc")) + + var authCfgs []ginjwt.AuthConfig + if viper.GetViper().GetBool("oidc.enabled") { + authCfgs, err = ginjwt.GetAuthConfigsFromFlags(viper.GetViper()) + if err != nil { + logger.Fatal(err) + } + } + hs := &httpsrv.Server{ - Logger: logger.Desugar(), - Listen: viper.GetString("listen"), - Debug: config.AppConfig.Logging.Debug, - DB: db, - SecretsKeeper: keeper, - AuthConfig: ginjwt.AuthConfig{ - Enabled: viper.GetBool("oidc.enabled"), - Audience: viper.GetString("oidc.audience"), - Issuer: viper.GetString("oidc.issuer"), - JWKSURI: viper.GetString("oidc.jwksuri"), - LogFields: viper.GetStringSlice("oidc.log"), - RolesClaim: viper.GetString("oidc.claims.roles"), - UsernameClaim: viper.GetString("oidc.claims.username"), - }, + Logger: logger.Desugar(), + Listen: viper.GetString("listen"), + Debug: config.AppConfig.Logging.Debug, + DB: db, + AuthConfigs: authCfgs, } // init event stream - for now, only when nats.url is specified diff --git a/internal/httpsrv/server.go b/internal/httpsrv/server.go index 8a605a6..e1155a8 100644 --- a/internal/httpsrv/server.go +++ b/internal/httpsrv/server.go @@ -11,6 +11,7 @@ import ( "github.com/jmoiron/sqlx" ginprometheus "github.com/zsais/go-gin-prometheus" "go.hollow.sh/toolbox/events" + "go.hollow.sh/toolbox/ginauth" "go.hollow.sh/toolbox/ginjwt" "go.infratographer.com/x/versionx" "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" @@ -27,7 +28,7 @@ type Server struct { Listen string Debug bool DB *sqlx.DB - AuthConfig ginjwt.AuthConfig + AuthConfigs []ginjwt.AuthConfig SecretsKeeper *secrets.Keeper EventStream events.Stream } @@ -40,11 +41,11 @@ var ( func (s *Server) setup() *gin.Engine { var ( - authMW *ginjwt.Middleware + authMW *ginauth.MultiTokenMiddleware err error ) - authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig) + authMW, err = ginjwt.NewMultiTokenMiddlewareFromConfigs(s.AuthConfigs...) if err != nil { s.Logger.Sugar().Fatal("failed to initialize auth middleware: ", "error", err) } diff --git a/internal/httpsrv/server_test.go b/internal/httpsrv/server_test.go index 2f3e443..394909e 100644 --- a/internal/httpsrv/server_test.go +++ b/internal/httpsrv/server_test.go @@ -15,12 +15,14 @@ import ( "go.hollow.sh/serverservice/internal/httpsrv" ) -var serverAuthConfig = ginjwt.AuthConfig{ - Enabled: false, +var serverAuthConfig = []ginjwt.AuthConfig{ + { + Enabled: false, + }, } func TestUnknownRoute(t *testing.T) { - hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig} + hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig} s := hs.NewServer() router := s.Handler @@ -33,7 +35,7 @@ func TestUnknownRoute(t *testing.T) { } func TestHealthzRoute(t *testing.T) { - hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig} + hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig} s := hs.NewServer() router := s.Handler @@ -46,7 +48,7 @@ func TestHealthzRoute(t *testing.T) { } func TestLivenessRoute(t *testing.T) { - hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig} + hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig} s := hs.NewServer() router := s.Handler @@ -61,7 +63,7 @@ func TestLivenessRoute(t *testing.T) { func TestReadinessRouteDown(t *testing.T) { db, _ := sqlx.Open("postgres", "localhost:12341") - hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig, DB: db} + hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig, DB: db} s := hs.NewServer() router := s.Handler @@ -76,7 +78,7 @@ func TestReadinessRouteDown(t *testing.T) { func TestReadinessRouteUp(t *testing.T) { db := dbtools.DatabaseTest(t) - hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig, DB: db} + hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig, DB: db} s := hs.NewServer() router := s.Handler diff --git a/pkg/api/v1/router.go b/pkg/api/v1/router.go index d058782..d26ba90 100644 --- a/pkg/api/v1/router.go +++ b/pkg/api/v1/router.go @@ -12,7 +12,7 @@ import ( "github.com/pkg/errors" "github.com/volatiletech/sqlboiler/v4/boil" "go.hollow.sh/toolbox/events" - "go.hollow.sh/toolbox/ginjwt" + "go.hollow.sh/toolbox/ginauth" "go.uber.org/zap" "gocloud.dev/secrets" @@ -21,7 +21,7 @@ import ( // Router provides a router for the v1 API type Router struct { - AuthMW *ginjwt.Middleware + AuthMW *ginauth.MultiTokenMiddleware DB *sqlx.DB SecretsKeeper *secrets.Keeper Logger *zap.Logger @@ -32,57 +32,54 @@ type Router struct { func (r *Router) Routes(rg *gin.RouterGroup) { amw := r.AuthMW - // require all calls to have auth - rg.Use(amw.AuthRequired()) - // /servers srvs := rg.Group("/servers") { - srvs.GET("", amw.RequiredScopes(readScopes("server")), r.serverList) - srvs.POST("", amw.RequiredScopes(createScopes("server")), r.serverCreate) + srvs.GET("", amw.AuthRequired(readScopes("server")), r.serverList) + srvs.POST("", amw.AuthRequired(createScopes("server")), r.serverCreate) - srvs.GET("/components", amw.RequiredScopes(readScopes("server:component")), r.serverComponentList) + srvs.GET("/components", amw.AuthRequired(readScopes("server:component")), r.serverComponentList) // /servers/:uuid srv := srvs.Group("/:uuid") { - srv.GET("", amw.RequiredScopes(readScopes("server")), r.serverGet) - srv.PUT("", amw.RequiredScopes(updateScopes("server")), r.serverUpdate) - srv.DELETE("", amw.RequiredScopes(deleteScopes("server")), r.serverDelete) + srv.GET("", amw.AuthRequired(readScopes("server")), r.serverGet) + srv.PUT("", amw.AuthRequired(updateScopes("server")), r.serverUpdate) + srv.DELETE("", amw.AuthRequired(deleteScopes("server")), r.serverDelete) // /servers/:uuid/attributes srvAttrs := srv.Group("/attributes") { - srvAttrs.GET("", amw.RequiredScopes(readScopes("server", "server:attributes")), r.serverAttributesList) - srvAttrs.POST("", amw.RequiredScopes(createScopes("server", "server:attributes")), r.serverAttributesCreate) - srvAttrs.GET("/:namespace", amw.RequiredScopes(readScopes("server", "server:attributes")), r.serverAttributesGet) - srvAttrs.PUT("/:namespace", amw.RequiredScopes(updateScopes("server", "server:attributes")), r.serverAttributesUpdate) - srvAttrs.DELETE("/:namespace", amw.RequiredScopes(deleteScopes("server", "server:attributes")), r.serverAttributesDelete) + srvAttrs.GET("", amw.AuthRequired(readScopes("server", "server:attributes")), r.serverAttributesList) + srvAttrs.POST("", amw.AuthRequired(createScopes("server", "server:attributes")), r.serverAttributesCreate) + srvAttrs.GET("/:namespace", amw.AuthRequired(readScopes("server", "server:attributes")), r.serverAttributesGet) + srvAttrs.PUT("/:namespace", amw.AuthRequired(updateScopes("server", "server:attributes")), r.serverAttributesUpdate) + srvAttrs.DELETE("/:namespace", amw.AuthRequired(deleteScopes("server", "server:attributes")), r.serverAttributesDelete) } // /servers/:uuid/components srvComponents := srv.Group("/components") { - srvComponents.POST("", amw.RequiredScopes(createScopes("server", "server:component")), r.serverComponentsCreate) - srvComponents.GET("", amw.RequiredScopes(readScopes("server", "server:component")), r.serverComponentGet) - srvComponents.PUT("", amw.RequiredScopes(updateScopes("server", "server:component")), r.serverComponentUpdate) - srvComponents.DELETE("", amw.RequiredScopes(deleteScopes("server", "server:component")), r.serverComponentDelete) + srvComponents.POST("", amw.AuthRequired(createScopes("server", "server:component")), r.serverComponentsCreate) + srvComponents.GET("", amw.AuthRequired(readScopes("server", "server:component")), r.serverComponentGet) + srvComponents.PUT("", amw.AuthRequired(updateScopes("server", "server:component")), r.serverComponentUpdate) + srvComponents.DELETE("", amw.AuthRequired(deleteScopes("server", "server:component")), r.serverComponentDelete) } // /servers/:uuid/credentials/:slug svrCreds := srv.Group("credentials/:slug") { - svrCreds.GET("", amw.RequiredScopes([]string{"read:server:credentials"}), r.serverCredentialGet) - svrCreds.PUT("", amw.RequiredScopes([]string{"write:server:credentials"}), r.serverCredentialUpsert) - svrCreds.DELETE("", amw.RequiredScopes([]string{"write:server:credentials"}), r.serverCredentialDelete) + svrCreds.GET("", amw.AuthRequired([]string{"read:server:credentials"}), r.serverCredentialGet) + svrCreds.PUT("", amw.AuthRequired([]string{"write:server:credentials"}), r.serverCredentialUpsert) + svrCreds.DELETE("", amw.AuthRequired([]string{"write:server:credentials"}), r.serverCredentialDelete) } // /servers/:uuid/versioned-attributes srvVerAttrs := srv.Group("/versioned-attributes") { - srvVerAttrs.GET("", amw.RequiredScopes(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesList) - srvVerAttrs.POST("", amw.RequiredScopes(createScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesCreate) - srvVerAttrs.GET("/:namespace", amw.RequiredScopes(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesGet) + srvVerAttrs.GET("", amw.AuthRequired(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesList) + srvVerAttrs.POST("", amw.AuthRequired(createScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesCreate) + srvVerAttrs.GET("/:namespace", amw.AuthRequired(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesGet) } } } @@ -90,36 +87,36 @@ func (r *Router) Routes(rg *gin.RouterGroup) { // /server-component-types srvCmpntType := rg.Group("/server-component-types") { - srvCmpntType.GET("", amw.RequiredScopes(readScopes("server-component-types")), r.serverComponentTypeList) - srvCmpntType.POST("", amw.RequiredScopes(updateScopes("server-component-types")), r.serverComponentTypeCreate) + srvCmpntType.GET("", amw.AuthRequired(readScopes("server-component-types")), r.serverComponentTypeList) + srvCmpntType.POST("", amw.AuthRequired(updateScopes("server-component-types")), r.serverComponentTypeCreate) } // /server-component-firmwares srvCmpntFw := rg.Group("/server-component-firmwares") { - srvCmpntFw.GET("", amw.RequiredScopes(readScopes("server-component-firmwares")), r.serverComponentFirmwareList) - srvCmpntFw.POST("", amw.RequiredScopes(createScopes("server-component-firmwares")), r.serverComponentFirmwareCreate) - srvCmpntFw.GET("/:uuid", amw.RequiredScopes(readScopes("server-component-firmwares")), r.serverComponentFirmwareGet) - srvCmpntFw.PUT("/:uuid", amw.RequiredScopes(updateScopes("server-component-firmwares")), r.serverComponentFirmwareUpdate) - srvCmpntFw.DELETE("/:uuid", amw.RequiredScopes(deleteScopes("server-component-firmwares")), r.serverComponentFirmwareDelete) + srvCmpntFw.GET("", amw.AuthRequired(readScopes("server-component-firmwares")), r.serverComponentFirmwareList) + srvCmpntFw.POST("", amw.AuthRequired(createScopes("server-component-firmwares")), r.serverComponentFirmwareCreate) + srvCmpntFw.GET("/:uuid", amw.AuthRequired(readScopes("server-component-firmwares")), r.serverComponentFirmwareGet) + srvCmpntFw.PUT("/:uuid", amw.AuthRequired(updateScopes("server-component-firmwares")), r.serverComponentFirmwareUpdate) + srvCmpntFw.DELETE("/:uuid", amw.AuthRequired(deleteScopes("server-component-firmwares")), r.serverComponentFirmwareDelete) } // /server-credential-types srvCredentialTypes := rg.Group("/server-credential-types") { - srvCredentialTypes.GET("", amw.RequiredScopes(readScopes("server-credential-types")), r.serverCredentialTypesList) - srvCredentialTypes.POST("", amw.RequiredScopes(createScopes("server-credential-types")), r.serverCredentialTypesCreate) + srvCredentialTypes.GET("", amw.AuthRequired(readScopes("server-credential-types")), r.serverCredentialTypesList) + srvCredentialTypes.POST("", amw.AuthRequired(createScopes("server-credential-types")), r.serverCredentialTypesCreate) } // /server-component-firmware-sets srvCmpntFwSets := rg.Group("/server-component-firmware-sets") { - srvCmpntFwSets.GET("", amw.RequiredScopes(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetList) - srvCmpntFwSets.POST("", amw.RequiredScopes(createScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetCreate) - srvCmpntFwSets.GET("/:uuid", amw.RequiredScopes(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetGet) - srvCmpntFwSets.PUT("/:uuid", amw.RequiredScopes(updateScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetUpdate) - srvCmpntFwSets.DELETE("/:uuid", amw.RequiredScopes(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetDelete) - srvCmpntFwSets.POST("/:uuid/remove-firmware", amw.RequiredScopes(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetRemoveFirmware) + srvCmpntFwSets.GET("", amw.AuthRequired(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetList) + srvCmpntFwSets.POST("", amw.AuthRequired(createScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetCreate) + srvCmpntFwSets.GET("/:uuid", amw.AuthRequired(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetGet) + srvCmpntFwSets.PUT("/:uuid", amw.AuthRequired(updateScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetUpdate) + srvCmpntFwSets.DELETE("/:uuid", amw.AuthRequired(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetDelete) + srvCmpntFwSets.POST("/:uuid/remove-firmware", amw.AuthRequired(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetRemoveFirmware) } // /bill-of-materials @@ -128,19 +125,19 @@ func (r *Router) Routes(rg *gin.RouterGroup) { // /bill-of-materials/batch-boms-upload uploadFile := srvBoms.Group("/batch-upload") { - uploadFile.POST("", amw.RequiredScopes(createScopes("batch-upload")), r.bomsUpload) + uploadFile.POST("", amw.AuthRequired(createScopes("batch-upload")), r.bomsUpload) } // /bill-of-materials/aoc-mac-address srvBomByAocMacAddress := srvBoms.Group("/aoc-mac-address") { - srvBomByAocMacAddress.GET("/:aoc_mac_address", amw.RequiredScopes(readScopes("aoc-mac-address")), r.getBomFromAocMacAddress) + srvBomByAocMacAddress.GET("/:aoc_mac_address", amw.AuthRequired(readScopes("aoc-mac-address")), r.getBomFromAocMacAddress) } // /bill-of-materials/bmc-mac-address srvBomByBmcMacAddress := srvBoms.Group("/bmc-mac-address") { - srvBomByBmcMacAddress.GET("/:bmc_mac_address", amw.RequiredScopes(readScopes("bmc-mac-address")), r.getBomFromBmcMacAddress) + srvBomByBmcMacAddress.GET("/:bmc_mac_address", amw.AuthRequired(readScopes("bmc-mac-address")), r.getBomFromBmcMacAddress) } } } diff --git a/pkg/api/v1/router_int_test.go b/pkg/api/v1/router_int_test.go index 26c0055..b374e92 100644 --- a/pkg/api/v1/router_int_test.go +++ b/pkg/api/v1/router_int_test.go @@ -33,12 +33,14 @@ func serverTest(t *testing.T) *integrationServer { hs := httpsrv.Server{ Logger: l, DB: db, - AuthConfig: ginjwt.AuthConfig{ - Enabled: true, - Audience: "hollow.test", - Issuer: "hollow.test.issuer", - JWKSURI: jwksURI, - RolesClaim: "userPerms", + AuthConfigs: []ginjwt.AuthConfig{ + { + Enabled: true, + Audience: "hollow.test", + Issuer: "hollow.test.issuer", + JWKSURI: jwksURI, + RolesClaim: "userPerms", + }, }, SecretsKeeper: dbtools.TestSecretKeeper(t), } diff --git a/sample_oidc_config.yaml b/sample_oidc_config.yaml new file mode 100644 index 0000000..d6306c1 --- /dev/null +++ b/sample_oidc_config.yaml @@ -0,0 +1,15 @@ +oidc: + - audience: "https://audience1" + issuer: "https://issuer1" + jwksuri: "https://jwkuri1" + enabled: true + claims: + roles: "role1" + username: "user1" + - audience: "https://audience2" + issuer: "https://issuer2" + jwksuri: "https://jwkuri2" + enabled: true + claims: + roles: "role2" + username: "user2"