Skip to content

Commit

Permalink
support multi oauth (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alva8756 authored Jan 16, 2024
1 parent f905c9c commit c5d78b4
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 85 deletions.
42 changes: 17 additions & 25 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,8 @@ func init() {

// OIDC Flags
serveCmd.Flags().Bool("oidc", true, "use oidc auth")
viperx.MustBindFlag(viper.GetViper(), "oidc.enabled", serveCmd.Flags().Lookup("oidc"))
serveCmd.Flags().String("oidc-aud", "", "expected audience on OIDC JWT")
viperx.MustBindFlag(viper.GetViper(), "oidc.audience", serveCmd.Flags().Lookup("oidc-aud"))
serveCmd.Flags().String("oidc-issuer", "", "expected issuer of OIDC JWT")
viperx.MustBindFlag(viper.GetViper(), "oidc.issuer", serveCmd.Flags().Lookup("oidc-issuer"))
serveCmd.Flags().String("oidc-jwksuri", "", "URI for JWKS listing for JWTs")
viperx.MustBindFlag(viper.GetViper(), "oidc.jwksuri", serveCmd.Flags().Lookup("oidc-jwksuri"))
serveCmd.Flags().String("oidc-roles-claim", "claim", "field containing the permissions of an OIDC JWT")
viperx.MustBindFlag(viper.GetViper(), "oidc.claims.roles", serveCmd.Flags().Lookup("oidc-roles-claim"))
serveCmd.Flags().String("oidc-username-claim", "", "additional fields to output in logs from the JWT token, ex (email)")
viperx.MustBindFlag(viper.GetViper(), "oidc.claims.username", serveCmd.Flags().Lookup("oidc-username-claim"))
ginjwt.BindFlagFromViperInst(viper.GetViper(), "oidc.enabled", serveCmd.Flags().Lookup("oidc"))

// DB Flags
serveCmd.Flags().String("db-encryption-driver", "", "encryption driver uri; 32 byte base64 encoded string, (example: base64key://your-encoded-secret-key)")
viperx.MustBindFlag(viper.GetViper(), "db.encryption_driver", serveCmd.Flags().Lookup("db-encryption-driver"))
Expand Down Expand Up @@ -111,21 +102,22 @@ func serve(ctx context.Context) {
"address", viper.GetString("listen"),
)

logger.Infow("oidc enabled", "oidc", viper.GetString("oidc"))

var authCfgs []ginjwt.AuthConfig
if viper.GetViper().GetBool("oidc.enabled") {
authCfgs, err = ginjwt.GetAuthConfigsFromFlags(viper.GetViper())
if err != nil {
logger.Fatal(err)
}
}

hs := &httpsrv.Server{
Logger: logger.Desugar(),
Listen: viper.GetString("listen"),
Debug: config.AppConfig.Logging.Debug,
DB: db,
SecretsKeeper: keeper,
AuthConfig: ginjwt.AuthConfig{
Enabled: viper.GetBool("oidc.enabled"),
Audience: viper.GetString("oidc.audience"),
Issuer: viper.GetString("oidc.issuer"),
JWKSURI: viper.GetString("oidc.jwksuri"),
LogFields: viper.GetStringSlice("oidc.log"),
RolesClaim: viper.GetString("oidc.claims.roles"),
UsernameClaim: viper.GetString("oidc.claims.username"),
},
Logger: logger.Desugar(),
Listen: viper.GetString("listen"),
Debug: config.AppConfig.Logging.Debug,
DB: db,
AuthConfigs: authCfgs,
}

// init event stream - for now, only when nats.url is specified
Expand Down
7 changes: 4 additions & 3 deletions internal/httpsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/jmoiron/sqlx"
ginprometheus "github.com/zsais/go-gin-prometheus"
"go.hollow.sh/toolbox/events"
"go.hollow.sh/toolbox/ginauth"
"go.hollow.sh/toolbox/ginjwt"
"go.infratographer.com/x/versionx"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
Expand All @@ -27,7 +28,7 @@ type Server struct {
Listen string
Debug bool
DB *sqlx.DB
AuthConfig ginjwt.AuthConfig
AuthConfigs []ginjwt.AuthConfig
SecretsKeeper *secrets.Keeper
EventStream events.Stream
}
Expand All @@ -40,11 +41,11 @@ var (

func (s *Server) setup() *gin.Engine {
var (
authMW *ginjwt.Middleware
authMW *ginauth.MultiTokenMiddleware
err error
)

authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig)
authMW, err = ginjwt.NewMultiTokenMiddlewareFromConfigs(s.AuthConfigs...)
if err != nil {
s.Logger.Sugar().Fatal("failed to initialize auth middleware: ", "error", err)
}
Expand Down
16 changes: 9 additions & 7 deletions internal/httpsrv/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"go.hollow.sh/serverservice/internal/httpsrv"
)

var serverAuthConfig = ginjwt.AuthConfig{
Enabled: false,
var serverAuthConfig = []ginjwt.AuthConfig{
{
Enabled: false,
},
}

func TestUnknownRoute(t *testing.T) {
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig}
s := hs.NewServer()
router := s.Handler

Expand All @@ -33,7 +35,7 @@ func TestUnknownRoute(t *testing.T) {
}

func TestHealthzRoute(t *testing.T) {
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig}
s := hs.NewServer()
router := s.Handler

Expand All @@ -46,7 +48,7 @@ func TestHealthzRoute(t *testing.T) {
}

func TestLivenessRoute(t *testing.T) {
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig}
s := hs.NewServer()
router := s.Handler

Expand All @@ -61,7 +63,7 @@ func TestLivenessRoute(t *testing.T) {
func TestReadinessRouteDown(t *testing.T) {
db, _ := sqlx.Open("postgres", "localhost:12341")

hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig, DB: db}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig, DB: db}
s := hs.NewServer()
router := s.Handler

Expand All @@ -76,7 +78,7 @@ func TestReadinessRouteDown(t *testing.T) {
func TestReadinessRouteUp(t *testing.T) {
db := dbtools.DatabaseTest(t)

hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfig: serverAuthConfig, DB: db}
hs := httpsrv.Server{Logger: zap.NewNop(), AuthConfigs: serverAuthConfig, DB: db}
s := hs.NewServer()
router := s.Handler

Expand Down
85 changes: 41 additions & 44 deletions pkg/api/v1/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/pkg/errors"
"github.com/volatiletech/sqlboiler/v4/boil"
"go.hollow.sh/toolbox/events"
"go.hollow.sh/toolbox/ginjwt"
"go.hollow.sh/toolbox/ginauth"
"go.uber.org/zap"
"gocloud.dev/secrets"

Expand All @@ -21,7 +21,7 @@ import (

// Router provides a router for the v1 API
type Router struct {
AuthMW *ginjwt.Middleware
AuthMW *ginauth.MultiTokenMiddleware
DB *sqlx.DB
SecretsKeeper *secrets.Keeper
Logger *zap.Logger
Expand All @@ -32,94 +32,91 @@ type Router struct {
func (r *Router) Routes(rg *gin.RouterGroup) {
amw := r.AuthMW

// require all calls to have auth
rg.Use(amw.AuthRequired())

// /servers
srvs := rg.Group("/servers")
{
srvs.GET("", amw.RequiredScopes(readScopes("server")), r.serverList)
srvs.POST("", amw.RequiredScopes(createScopes("server")), r.serverCreate)
srvs.GET("", amw.AuthRequired(readScopes("server")), r.serverList)
srvs.POST("", amw.AuthRequired(createScopes("server")), r.serverCreate)

srvs.GET("/components", amw.RequiredScopes(readScopes("server:component")), r.serverComponentList)
srvs.GET("/components", amw.AuthRequired(readScopes("server:component")), r.serverComponentList)

// /servers/:uuid
srv := srvs.Group("/:uuid")
{
srv.GET("", amw.RequiredScopes(readScopes("server")), r.serverGet)
srv.PUT("", amw.RequiredScopes(updateScopes("server")), r.serverUpdate)
srv.DELETE("", amw.RequiredScopes(deleteScopes("server")), r.serverDelete)
srv.GET("", amw.AuthRequired(readScopes("server")), r.serverGet)
srv.PUT("", amw.AuthRequired(updateScopes("server")), r.serverUpdate)
srv.DELETE("", amw.AuthRequired(deleteScopes("server")), r.serverDelete)

// /servers/:uuid/attributes
srvAttrs := srv.Group("/attributes")
{
srvAttrs.GET("", amw.RequiredScopes(readScopes("server", "server:attributes")), r.serverAttributesList)
srvAttrs.POST("", amw.RequiredScopes(createScopes("server", "server:attributes")), r.serverAttributesCreate)
srvAttrs.GET("/:namespace", amw.RequiredScopes(readScopes("server", "server:attributes")), r.serverAttributesGet)
srvAttrs.PUT("/:namespace", amw.RequiredScopes(updateScopes("server", "server:attributes")), r.serverAttributesUpdate)
srvAttrs.DELETE("/:namespace", amw.RequiredScopes(deleteScopes("server", "server:attributes")), r.serverAttributesDelete)
srvAttrs.GET("", amw.AuthRequired(readScopes("server", "server:attributes")), r.serverAttributesList)
srvAttrs.POST("", amw.AuthRequired(createScopes("server", "server:attributes")), r.serverAttributesCreate)
srvAttrs.GET("/:namespace", amw.AuthRequired(readScopes("server", "server:attributes")), r.serverAttributesGet)
srvAttrs.PUT("/:namespace", amw.AuthRequired(updateScopes("server", "server:attributes")), r.serverAttributesUpdate)
srvAttrs.DELETE("/:namespace", amw.AuthRequired(deleteScopes("server", "server:attributes")), r.serverAttributesDelete)
}

// /servers/:uuid/components
srvComponents := srv.Group("/components")
{
srvComponents.POST("", amw.RequiredScopes(createScopes("server", "server:component")), r.serverComponentsCreate)
srvComponents.GET("", amw.RequiredScopes(readScopes("server", "server:component")), r.serverComponentGet)
srvComponents.PUT("", amw.RequiredScopes(updateScopes("server", "server:component")), r.serverComponentUpdate)
srvComponents.DELETE("", amw.RequiredScopes(deleteScopes("server", "server:component")), r.serverComponentDelete)
srvComponents.POST("", amw.AuthRequired(createScopes("server", "server:component")), r.serverComponentsCreate)
srvComponents.GET("", amw.AuthRequired(readScopes("server", "server:component")), r.serverComponentGet)
srvComponents.PUT("", amw.AuthRequired(updateScopes("server", "server:component")), r.serverComponentUpdate)
srvComponents.DELETE("", amw.AuthRequired(deleteScopes("server", "server:component")), r.serverComponentDelete)
}

// /servers/:uuid/credentials/:slug
svrCreds := srv.Group("credentials/:slug")
{
svrCreds.GET("", amw.RequiredScopes([]string{"read:server:credentials"}), r.serverCredentialGet)
svrCreds.PUT("", amw.RequiredScopes([]string{"write:server:credentials"}), r.serverCredentialUpsert)
svrCreds.DELETE("", amw.RequiredScopes([]string{"write:server:credentials"}), r.serverCredentialDelete)
svrCreds.GET("", amw.AuthRequired([]string{"read:server:credentials"}), r.serverCredentialGet)
svrCreds.PUT("", amw.AuthRequired([]string{"write:server:credentials"}), r.serverCredentialUpsert)
svrCreds.DELETE("", amw.AuthRequired([]string{"write:server:credentials"}), r.serverCredentialDelete)
}

// /servers/:uuid/versioned-attributes
srvVerAttrs := srv.Group("/versioned-attributes")
{
srvVerAttrs.GET("", amw.RequiredScopes(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesList)
srvVerAttrs.POST("", amw.RequiredScopes(createScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesCreate)
srvVerAttrs.GET("/:namespace", amw.RequiredScopes(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesGet)
srvVerAttrs.GET("", amw.AuthRequired(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesList)
srvVerAttrs.POST("", amw.AuthRequired(createScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesCreate)
srvVerAttrs.GET("/:namespace", amw.AuthRequired(readScopes("server", "server:versioned-attributes")), r.serverVersionedAttributesGet)
}
}
}

// /server-component-types
srvCmpntType := rg.Group("/server-component-types")
{
srvCmpntType.GET("", amw.RequiredScopes(readScopes("server-component-types")), r.serverComponentTypeList)
srvCmpntType.POST("", amw.RequiredScopes(updateScopes("server-component-types")), r.serverComponentTypeCreate)
srvCmpntType.GET("", amw.AuthRequired(readScopes("server-component-types")), r.serverComponentTypeList)
srvCmpntType.POST("", amw.AuthRequired(updateScopes("server-component-types")), r.serverComponentTypeCreate)
}

// /server-component-firmwares
srvCmpntFw := rg.Group("/server-component-firmwares")
{
srvCmpntFw.GET("", amw.RequiredScopes(readScopes("server-component-firmwares")), r.serverComponentFirmwareList)
srvCmpntFw.POST("", amw.RequiredScopes(createScopes("server-component-firmwares")), r.serverComponentFirmwareCreate)
srvCmpntFw.GET("/:uuid", amw.RequiredScopes(readScopes("server-component-firmwares")), r.serverComponentFirmwareGet)
srvCmpntFw.PUT("/:uuid", amw.RequiredScopes(updateScopes("server-component-firmwares")), r.serverComponentFirmwareUpdate)
srvCmpntFw.DELETE("/:uuid", amw.RequiredScopes(deleteScopes("server-component-firmwares")), r.serverComponentFirmwareDelete)
srvCmpntFw.GET("", amw.AuthRequired(readScopes("server-component-firmwares")), r.serverComponentFirmwareList)
srvCmpntFw.POST("", amw.AuthRequired(createScopes("server-component-firmwares")), r.serverComponentFirmwareCreate)
srvCmpntFw.GET("/:uuid", amw.AuthRequired(readScopes("server-component-firmwares")), r.serverComponentFirmwareGet)
srvCmpntFw.PUT("/:uuid", amw.AuthRequired(updateScopes("server-component-firmwares")), r.serverComponentFirmwareUpdate)
srvCmpntFw.DELETE("/:uuid", amw.AuthRequired(deleteScopes("server-component-firmwares")), r.serverComponentFirmwareDelete)
}

// /server-credential-types
srvCredentialTypes := rg.Group("/server-credential-types")
{
srvCredentialTypes.GET("", amw.RequiredScopes(readScopes("server-credential-types")), r.serverCredentialTypesList)
srvCredentialTypes.POST("", amw.RequiredScopes(createScopes("server-credential-types")), r.serverCredentialTypesCreate)
srvCredentialTypes.GET("", amw.AuthRequired(readScopes("server-credential-types")), r.serverCredentialTypesList)
srvCredentialTypes.POST("", amw.AuthRequired(createScopes("server-credential-types")), r.serverCredentialTypesCreate)
}

// /server-component-firmware-sets
srvCmpntFwSets := rg.Group("/server-component-firmware-sets")
{
srvCmpntFwSets.GET("", amw.RequiredScopes(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetList)
srvCmpntFwSets.POST("", amw.RequiredScopes(createScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetCreate)
srvCmpntFwSets.GET("/:uuid", amw.RequiredScopes(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetGet)
srvCmpntFwSets.PUT("/:uuid", amw.RequiredScopes(updateScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetUpdate)
srvCmpntFwSets.DELETE("/:uuid", amw.RequiredScopes(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetDelete)
srvCmpntFwSets.POST("/:uuid/remove-firmware", amw.RequiredScopes(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetRemoveFirmware)
srvCmpntFwSets.GET("", amw.AuthRequired(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetList)
srvCmpntFwSets.POST("", amw.AuthRequired(createScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetCreate)
srvCmpntFwSets.GET("/:uuid", amw.AuthRequired(readScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetGet)
srvCmpntFwSets.PUT("/:uuid", amw.AuthRequired(updateScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetUpdate)
srvCmpntFwSets.DELETE("/:uuid", amw.AuthRequired(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetDelete)
srvCmpntFwSets.POST("/:uuid/remove-firmware", amw.AuthRequired(deleteScopes("server-component-firmware-sets")), r.serverComponentFirmwareSetRemoveFirmware)
}

// /bill-of-materials
Expand All @@ -128,19 +125,19 @@ func (r *Router) Routes(rg *gin.RouterGroup) {
// /bill-of-materials/batch-boms-upload
uploadFile := srvBoms.Group("/batch-upload")
{
uploadFile.POST("", amw.RequiredScopes(createScopes("batch-upload")), r.bomsUpload)
uploadFile.POST("", amw.AuthRequired(createScopes("batch-upload")), r.bomsUpload)
}

// /bill-of-materials/aoc-mac-address
srvBomByAocMacAddress := srvBoms.Group("/aoc-mac-address")
{
srvBomByAocMacAddress.GET("/:aoc_mac_address", amw.RequiredScopes(readScopes("aoc-mac-address")), r.getBomFromAocMacAddress)
srvBomByAocMacAddress.GET("/:aoc_mac_address", amw.AuthRequired(readScopes("aoc-mac-address")), r.getBomFromAocMacAddress)
}

// /bill-of-materials/bmc-mac-address
srvBomByBmcMacAddress := srvBoms.Group("/bmc-mac-address")
{
srvBomByBmcMacAddress.GET("/:bmc_mac_address", amw.RequiredScopes(readScopes("bmc-mac-address")), r.getBomFromBmcMacAddress)
srvBomByBmcMacAddress.GET("/:bmc_mac_address", amw.AuthRequired(readScopes("bmc-mac-address")), r.getBomFromBmcMacAddress)
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/api/v1/router_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ func serverTest(t *testing.T) *integrationServer {
hs := httpsrv.Server{
Logger: l,
DB: db,
AuthConfig: ginjwt.AuthConfig{
Enabled: true,
Audience: "hollow.test",
Issuer: "hollow.test.issuer",
JWKSURI: jwksURI,
RolesClaim: "userPerms",
AuthConfigs: []ginjwt.AuthConfig{
{
Enabled: true,
Audience: "hollow.test",
Issuer: "hollow.test.issuer",
JWKSURI: jwksURI,
RolesClaim: "userPerms",
},
},
SecretsKeeper: dbtools.TestSecretKeeper(t),
}
Expand Down
15 changes: 15 additions & 0 deletions sample_oidc_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
oidc:
- audience: "https://audience1"
issuer: "https://issuer1"
jwksuri: "https://jwkuri1"
enabled: true
claims:
roles: "role1"
username: "user1"
- audience: "https://audience2"
issuer: "https://issuer2"
jwksuri: "https://jwkuri2"
enabled: true
claims:
roles: "role2"
username: "user2"

0 comments on commit c5d78b4

Please sign in to comment.