diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 2459d069f..6da071516 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -24,6 +24,10 @@ import ( "golang.org/x/oauth2/internal" ) +const ( + ClientJWTAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +) + // Config describes a 2-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. type Config struct { @@ -33,6 +37,9 @@ type Config struct { // ClientSecret is the application's secret. ClientSecret string + // ClientAssertionFn is a function to generate a client assertion value. + ClientAssertionFn func(ctx context.Context) (string, error) + // TokenURL is the resource server's token endpoint // URL. This is a constant specific to each server. TokenURL string @@ -107,6 +114,16 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { v[k] = p } + if c.conf.ClientAssertionFn != nil { + clientAssertion, err := c.conf.ClientAssertionFn(c.ctx) + if err != nil { + return nil, err + } + + v.Set("client_assertion", clientAssertion) + v.Set("client_assertion_type", ClientJWTAssertionType) + } + tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get()) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 078e75ec7..e89c2325f 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -12,16 +12,26 @@ import ( "net/http/httptest" "net/url" "testing" + + "golang.org/x/oauth2" ) -func newConf(serverURL string) *Config { - return &Config{ +func newConf(serverURL string, assertion bool) *Config { + conf := &Config{ ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", Scopes: []string{"scope1", "scope2"}, TokenURL: serverURL + "/token", EndpointParams: url.Values{"audience": {"audience1"}}, + AuthStyle: oauth2.AuthStyleInParams, + } + if assertion { + conf.ClientAssertionFn = func(ctx context.Context) (string, error) { + return "CLIENT_ASSERTION", nil + } + } else { + conf.ClientSecret = "CLIENT_SECRET" } + return conf } type mockTransport struct { @@ -69,45 +79,70 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) { } } +func assert(t *testing.T, want, got string) { + t.Helper() + if got != want { + t.Errorf("got %q; want %q", got, want) + } +} + func TestTokenRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() != "/token" { t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token") } - headerAuth := r.Header.Get("Authorization") - if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) - } + if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { t.Errorf("Content-Type header = %q; want %q", got, want) } - body, err := ioutil.ReadAll(r.Body) - if err != nil { - r.Body.Close() - } - if err != nil { - t.Errorf("failed reading request body: %s.", err) - } - if string(body) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" { - t.Errorf("payload = %q; want %q", string(body), "grant_type=client_credentials&scope=scope1+scope2") + + assert(t, "audience1", r.FormValue("audience")) + assert(t, "CLIENT_ID", r.FormValue("client_id")) + assert(t, "client_credentials", r.FormValue("grant_type")) + assert(t, "scope1 scope2", r.FormValue("scope")) + if r.FormValue("client_secret") != "" { + assert(t, "CLIENT_SECRET", r.FormValue("client_secret")) + } else { + assert(t, "CLIENT_ASSERTION", r.FormValue("client_assertion")) + assert(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", r.FormValue("client_assertion_type")) } w.Header().Set("Content-Type", "application/x-www-form-urlencoded") w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) })) defer ts.Close() - conf := newConf(ts.URL) - tok, err := conf.Token(context.Background()) - if err != nil { - t.Error(err) - } - if !tok.Valid() { - t.Fatalf("token invalid. got: %#v", tok) + + type testCase struct { + name string + conf *Config } - if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { - t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c") + + tests := []testCase{ + { + name: "client id and client_secret", + conf: newConf(ts.URL, false), + }, + { + name: "client id and client_assertion", + conf: newConf(ts.URL, true), + }, } - if tok.TokenType != "bearer" { - t.Errorf("token type = %q; want %q", tok.TokenType, "bearer") + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tok, err := tc.conf.Token(context.Background()) + if err != nil { + t.Error(err) + } + if !tok.Valid() { + t.Fatalf("token invalid. got: %#v", tok) + } + if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { + t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c") + } + if tok.TokenType != "bearer" { + t.Errorf("token type = %q; want %q", tok.TokenType, "bearer") + } + }) } } @@ -132,7 +167,7 @@ func TestTokenRefreshRequest(t *testing.T) { io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newConf(ts.URL, false) c := conf.Client(context.Background()) c.Get(ts.URL + "/somethingelse") }