Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OAuth library's AuthStyle configurable #271

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ OIDC Provider:
--providers.oidc.client-id= Client ID [$PROVIDERS_OIDC_CLIENT_ID]
--providers.oidc.client-secret= Client Secret [$PROVIDERS_OIDC_CLIENT_SECRET]
--providers.oidc.resource= Optional resource indicator [$PROVIDERS_OIDC_RESOURCE]
--providers.oidc.auth-type= Optionally choose the authentication type of the OAuth2 library
[auto-detect, header, params] (default: auto-detect) [$PROVIDERS_OIDC_AUTH_TYPE]

Generic OAuth2 Provider:
--providers.generic-oauth.auth-url= Auth/Login URL [$PROVIDERS_GENERIC_OAUTH_AUTH_URL]
Expand All @@ -188,6 +190,8 @@ Generic OAuth2 Provider:
--providers.generic-oauth.token-style=[header|query] How token is presented when querying the User URL (default: header)
[$PROVIDERS_GENERIC_OAUTH_TOKEN_STYLE]
--providers.generic-oauth.resource= Optional resource indicator [$PROVIDERS_GENERIC_OAUTH_RESOURCE]
--providers.generic-oauth.auth-type= Optionally choose the authentication type of the OAuth2 library
[auto-detect, header, params] (default: auto-detect) [$PROVIDERS_GENERIC_OAUTH_AUTH_TYPE]

Help Options:
-h, --help Show this help message
Expand Down
2 changes: 2 additions & 0 deletions internal/provider/generic_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type GenericOAuth struct {
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
Scopes []string `long:"scope" env:"SCOPE" env-delim:"," default:"profile" default:"email" description:"Scopes"`
AuthStyle string `long:"auth-style" env:"AUTH_STYLE" default:"auto-detect" choice:"auto-detect" choice:"header" choice:"params" description:"Authentication style to be used by the OAuth library"`
TokenStyle string `long:"token-style" env:"TOKEN_STYLE" default:"header" choice:"header" choice:"query" description:"How token is presented when querying the User URL"`

OAuthProvider
Expand All @@ -40,6 +41,7 @@ func (o *GenericOAuth) Setup() error {
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthStyle: parseAuthStyle(o.AuthStyle),
AuthURL: o.AuthURL,
TokenURL: o.TokenURL,
},
Expand Down
78 changes: 69 additions & 9 deletions internal/provider/generic_oauth_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package provider

import (
"golang.org/x/oauth2"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)

// Tests
Expand Down Expand Up @@ -37,6 +37,68 @@ func TestGenericOAuthSetup(t *testing.T) {
assert.Nil(err)
}

func TestGenericOAuthGetAuthStyleDefault(t *testing.T) {
assert := assert.New(t)
p := GenericOAuth{
AuthURL: "https://provider.com/oauth2/auth",
TokenURL: "https://provider.com/oauth2/token",
UserURL: "https://provider.com/oauth2/user",
ClientID: "idtest",
ClientSecret: "secret",
Scopes: []string{"scopetest"},
}

err := p.Setup()
if err != nil {
t.Fatal(err)
}

authStyle := parseAuthStyle(p.AuthStyle)
assert.Equal(oauth2.AuthStyleAutoDetect, authStyle)
}

func TestGenericOAuthGetAuthStyleHeader(t *testing.T) {
assert := assert.New(t)
p := GenericOAuth{
AuthStyle: "header",
AuthURL: "https://provider.com/oauth2/auth",
TokenURL: "https://provider.com/oauth2/token",
UserURL: "https://provider.com/oauth2/user",
ClientID: "idtest",
ClientSecret: "secret",
Scopes: []string{"scopetest"},
}

err := p.Setup()
if err != nil {
t.Fatal(err)
}

authStyle := parseAuthStyle(p.AuthStyle)
assert.Equal(oauth2.AuthStyleInHeader, authStyle)
}

func TestGenericOAuthGetAuthStyleParams(t *testing.T) {
assert := assert.New(t)
p := GenericOAuth{
AuthStyle: "params",
AuthURL: "https://provider.com/oauth2/auth",
TokenURL: "https://provider.com/oauth2/token",
UserURL: "https://provider.com/oauth2/user",
ClientID: "idtest",
ClientSecret: "secret",
Scopes: []string{"scopetest"},
}

err := p.Setup()
if err != nil {
t.Fatal(err)
}

authStyle := parseAuthStyle(p.AuthStyle)
assert.Equal(oauth2.AuthStyleInParams, authStyle)
}

func TestGenericOAuthGetLoginURL(t *testing.T) {
assert := assert.New(t)
p := GenericOAuth{
Expand Down Expand Up @@ -89,6 +151,9 @@ func TestGenericOAuthExchangeCode(t *testing.T) {

// Setup provider
p := GenericOAuth{
// We force AuthStyleInParams to prevent the test failure when the
// AuthStyleInHeader is attempted
AuthStyle: "params",
AuthURL: "https://provider.com/oauth2/auth",
TokenURL: serverURL.String() + "/token",
UserURL: "https://provider.com/oauth2/user",
Expand All @@ -100,10 +165,6 @@ func TestGenericOAuthExchangeCode(t *testing.T) {
t.Fatal(err)
}

// We force AuthStyleInParams to prevent the test failure when the
// AuthStyleInHeader is attempted
p.Config.Endpoint.AuthStyle = oauth2.AuthStyleInParams

token, err := p.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("123456789", token)
Expand All @@ -118,6 +179,9 @@ func TestGenericOAuthGetUser(t *testing.T) {

// Setup provider
p := GenericOAuth{
// We force AuthStyleInParams to prevent the test failure when the
// AuthStyleInHeader is attempted
AuthStyle: "params",
AuthURL: "https://provider.com/oauth2/auth",
TokenURL: "https://provider.com/oauth2/token",
UserURL: serverURL.String() + "/userinfo",
Expand All @@ -129,10 +193,6 @@ func TestGenericOAuthGetUser(t *testing.T) {
t.Fatal(err)
}

// We force AuthStyleInParams to prevent the test failure when the
// AuthStyleInHeader is attempted
p.Config.Endpoint.AuthStyle = oauth2.AuthStyleInParams

user, err := p.GetUser("123456789")
assert.Nil(err)

Expand Down
7 changes: 6 additions & 1 deletion internal/provider/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type OIDC struct {
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
AuthStyle string `long:"auth-style" env:"AUTH_STYLE" default:"auto-detect" choice:"auto-detect" choice:"header" choice:"params" description:"Authentication style to be used by the OAuth library"`

OAuthProvider

Expand Down Expand Up @@ -45,7 +46,11 @@ func (o *OIDC) Setup() error {
o.Config = &oauth2.Config{
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
Endpoint: o.provider.Endpoint(),
Endpoint: oauth2.Endpoint{
AuthStyle: parseAuthStyle(o.AuthStyle),
AuthURL: o.provider.Endpoint().AuthURL,
TokenURL: o.provider.Endpoint().TokenURL,
},

// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
Expand Down
36 changes: 32 additions & 4 deletions internal/provider/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/rand"
"crypto/rsa"
"fmt"
"golang.org/x/oauth2"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand All @@ -18,6 +19,8 @@ import (

// Tests

var defaultAuthStyle = "auto-detect"

func TestOIDCName(t *testing.T) {
p := OIDC{}
assert.Equal(t, "oidc", p.Name())
Expand All @@ -33,10 +36,34 @@ func TestOIDCSetup(t *testing.T) {
}
}

func TestOIDCGetAuthStyleAutoDetect(t *testing.T) {
assert := assert.New(t)
provider, _, _, _ := setupOIDCTest(t, nil, defaultAuthStyle)

authStyle := parseAuthStyle(provider.AuthStyle)
assert.Equal(oauth2.AuthStyleAutoDetect, authStyle)
}

func TestOIDCGetAuthStyleHeader(t *testing.T) {
assert := assert.New(t)
provider, _, _, _ := setupOIDCTest(t, nil, "header")

authStyle := parseAuthStyle(provider.AuthStyle)
assert.Equal(oauth2.AuthStyleInHeader, authStyle)
}

func TestOIDCGetAuthStyleParams(t *testing.T) {
assert := assert.New(t)
provider, _, _, _ := setupOIDCTest(t, nil, "params")

authStyle := parseAuthStyle(provider.AuthStyle)
assert.Equal(oauth2.AuthStyleInParams, authStyle)
}

func TestOIDCGetLoginURL(t *testing.T) {
assert := assert.New(t)

provider, server, serverURL, _ := setupOIDCTest(t, nil)
provider, server, serverURL, _ := setupOIDCTest(t, nil, defaultAuthStyle)
defer server.Close()

// Check url
Expand Down Expand Up @@ -97,7 +124,7 @@ func TestOIDCExchangeCode(t *testing.T) {
"grant_type": "authorization_code",
"redirect_uri": "http://example.com/_oauth",
},
})
}, defaultAuthStyle)
defer server.Close()

token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
Expand All @@ -108,7 +135,7 @@ func TestOIDCExchangeCode(t *testing.T) {
func TestOIDCGetUser(t *testing.T) {
assert := assert.New(t)

provider, server, serverURL, key := setupOIDCTest(t, nil)
provider, server, serverURL, key := setupOIDCTest(t, nil, defaultAuthStyle)
defer server.Close()

// Generate JWT
Expand All @@ -130,7 +157,7 @@ func TestOIDCGetUser(t *testing.T) {
// Utils

// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string, authStyle string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
// Generate key
key, err := newRSAKey()
if err != nil {
Expand All @@ -154,6 +181,7 @@ func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC

// Setup provider
p := OIDC{
AuthStyle: authStyle,
ClientID: "idtest",
ClientSecret: "sectest",
IssuerURL: serverURL.String(),
Expand Down
11 changes: 11 additions & 0 deletions internal/provider/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ type OAuthProvider struct {
ctx context.Context
}

func parseAuthStyle(authStyle string) oauth2.AuthStyle {
switch authStyle {
case "header":
return oauth2.AuthStyleInHeader
case "params":
return oauth2.AuthStyleInParams
default:
return oauth2.AuthStyleAutoDetect
}
}

// ConfigCopy returns a copy of the oauth2 config with the given redirectURI
// which ensures the underlying config is not modified
func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
Expand Down