diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go new file mode 100644 index 000000000000..f07c4c402e76 --- /dev/null +++ b/credentials/sts/sts.go @@ -0,0 +1,395 @@ +// +build go1.13 + +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package sts implements call credentials using STS (Security Token Service) as +// defined in https://tools.ietf.org/html/rfc8693. +// +// Experimental +// +// Notice: All APIs in this package are experimental and may be changed or +// removed in a later release. +package sts + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "sync" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" +) + +const ( + // HTTP request timeout set on the http.Client used to make STS requests. + stsRequestTimeout = 5 * time.Second + // If lifetime left in a cached token is lesser than this value, we fetch a + // new one instead of returning the current one. + minCachedTokenLifetime = 300 * time.Second + + tokenExchangeGrantType = "urn:ietf:params:oauth:grant-type:token-exchange" + defaultCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" +) + +// For overriding in tests. +var ( + loadSystemCertPool = x509.SystemCertPool + makeHTTPDoer = makeHTTPClient + readSubjectTokenFrom = ioutil.ReadFile + readActorTokenFrom = ioutil.ReadFile +) + +// Options configures the parameters used for an STS based token exchange. +type Options struct { + // TokenExchangeServiceURI is the address of the server which implements STS + // token exchange functionality. + TokenExchangeServiceURI string // Required. + + // Resource is a URI that indicates the target service or resource where the + // client intends to use the requested security token. + Resource string // Optional. + + // Audience is the logical name of the target service where the client + // intends to use the requested security token + Audience string // Optional. + + // Scope is a list of space-delimited, case-sensitive strings, that allow + // the client to specify the desired scope of the requested security token + // in the context of the service or resource where the token will be used. + // If this field is left unspecified, a default value of + // https://www.googleapis.com/auth/cloud-platform will be used. + Scope string // Optional. + + // RequestedTokenType is an identifier, as described in + // https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of + // the requested security token. + RequestedTokenType string // Optional. + + // SubjectTokenPath is a filesystem path which contains the security token + // that represents the identity of the party on behalf of whom the request + // is being made. + SubjectTokenPath string // Required. + + // SubjectTokenType is an identifier, as described in + // https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of + // the security token in the "subject_token_path" parameter. + SubjectTokenType string // Required. + + // ActorTokenPath is a security token that represents the identity of the + // acting party. + ActorTokenPath string // Optional. + + // ActorTokenType is an identifier, as described in + // https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of + // the the security token in the "actor_token_path" parameter. + ActorTokenType string // Optional. +} + +// NewCredentials returns a new PerRPCCredentials implementation, configured +// using opts, which performs token exchange using STS. +func NewCredentials(opts Options) (credentials.PerRPCCredentials, error) { + if err := validateOptions(opts); err != nil { + return nil, err + } + + // Load the system roots to validate the certificate presented by the STS + // endpoint during the TLS handshake. + roots, err := loadSystemCertPool() + if err != nil { + return nil, err + } + + return &callCreds{ + opts: opts, + client: makeHTTPDoer(roots), + }, nil +} + +// callCreds provides the implementation of call credentials based on an STS +// token exchange. +type callCreds struct { + opts Options + client httpDoer + + // Cached accessToken to avoid an STS token exchange for every call to + // GetRequestMetadata. + mu sync.Mutex + tokenMetadata map[string]string + tokenExpiry time.Time +} + +// GetRequestMetadata returns the cached accessToken, if available and valid, or +// fetches a new one by performing an STS token exchange. +func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err) + } + + // Holding the lock for the whole duration of the STS request and response + // processing ensures that concurrent RPCs don't end up in multiple + // requests being made. + c.mu.Lock() + defer c.mu.Unlock() + + if md := c.cachedMetadata(); md != nil { + return md, nil + } + req, err := constructRequest(ctx, c.opts) + if err != nil { + return nil, err + } + respBody, err := sendRequest(c.client, req) + if err != nil { + return nil, err + } + ti, err := tokenInfoFromResponse(respBody) + if err != nil { + return nil, err + } + c.tokenMetadata = map[string]string{"Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token)} + c.tokenExpiry = ti.expiryTime + return c.tokenMetadata, nil +} + +// RequireTransportSecurity indicates whether the credentials requires +// transport security. +func (c *callCreds) RequireTransportSecurity() bool { + return true +} + +// httpDoer wraps the single method on the http.Client type that we use. This +// helps with overriding in unittests. +type httpDoer interface { + Do(req *http.Request) (*http.Response, error) +} + +func makeHTTPClient(roots *x509.CertPool) httpDoer { + return &http.Client{ + Timeout: stsRequestTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: roots, + }, + }, + } +} + +// validateOptions performs the following validation checks on opts: +// - tokenExchangeServiceURI is not empty +// - tokenExchangeServiceURI is a valid URI with a http(s) scheme +// - subjectTokenPath and subjectTokenType are not empty. +func validateOptions(opts Options) error { + if opts.TokenExchangeServiceURI == "" { + return errors.New("empty token_exchange_service_uri in options") + } + u, err := url.Parse(opts.TokenExchangeServiceURI) + if err != nil { + return err + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("scheme is not supported: %s. Only http(s) is supported", u.Scheme) + } + + if opts.SubjectTokenPath == "" { + return errors.New("required field SubjectTokenPath is not specified") + } + if opts.SubjectTokenType == "" { + return errors.New("required field SubjectTokenType is not specified") + } + return nil +} + +// cachedMetadata returns the cached metadata provided it is not going to +// expire anytime soon. +// +// Caller must hold c.mu. +func (c *callCreds) cachedMetadata() map[string]string { + now := time.Now() + // If the cached token has not expired and the lifetime remaining on that + // token is greater than the minimum value we are willing to accept, go + // ahead and use it. + if c.tokenExpiry.After(now) && c.tokenExpiry.Sub(now) > minCachedTokenLifetime { + return c.tokenMetadata + } + return nil +} + +// constructRequest creates the STS request body in JSON based on the provided +// options. +// - Contents of the subjectToken are read from the file specified in +// options. If we encounter an error here, we bail out. +// - Contents of the actorToken are read from the file specified in options. +// If we encounter an error here, we ignore this field because this is +// optional. +// - Most of the other fields in the request come directly from options. +// +// A new HTTP request is created by calling http.NewRequestWithContext() and +// passing the provided context, thereby enforcing any timeouts specified in +// the latter. +func constructRequest(ctx context.Context, opts Options) (*http.Request, error) { + subToken, err := readSubjectTokenFrom(opts.SubjectTokenPath) + if err != nil { + return nil, err + } + reqScope := opts.Scope + if reqScope == "" { + reqScope = defaultCloudPlatformScope + } + reqParams := &requestParameters{ + GrantType: tokenExchangeGrantType, + Resource: opts.Resource, + Audience: opts.Audience, + Scope: reqScope, + RequestedTokenType: opts.RequestedTokenType, + SubjectToken: string(subToken), + SubjectTokenType: opts.SubjectTokenType, + } + if opts.ActorTokenPath != "" { + actorToken, err := readActorTokenFrom(opts.ActorTokenPath) + if err != nil { + return nil, err + } + reqParams.ActorToken = string(actorToken) + reqParams.ActorTokenType = opts.ActorTokenType + } + jsonBody, err := json.Marshal(reqParams) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, "POST", opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create http request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + return req, nil +} + +func sendRequest(client httpDoer, req *http.Request) ([]byte, error) { + // http.Client returns a non-nil error only if it encounters an error + // caused by client policy (such as CheckRedirect), or failure to speak + // HTTP (such as a network connectivity problem). A non-2xx status code + // doesn't cause an error. + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + // When the http.Client returns a non-nil error, it is the + // responsibility of the caller to read the response body till an EOF is + // encountered and to close it. + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusOK { + return body, nil + } + grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body)) +} + +func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { + respData := &responseParameters{} + if err := json.Unmarshal(respBody, respData); err != nil { + return nil, fmt.Errorf("json.Unmarshal(%v): %v", respBody, err) + } + if respData.AccessToken == "" { + return nil, fmt.Errorf("empty accessToken in response (%v)", string(respBody)) + } + return &tokenInfo{ + tokenType: respData.TokenType, + token: respData.AccessToken, + expiryTime: time.Now().Add(time.Duration(respData.ExpiresIn) * time.Second), + }, nil +} + +// requestParameters stores all STS request attributes defined in +// https://tools.ietf.org/html/rfc8693#section-2.1. +type requestParameters struct { + // REQUIRED. The value "urn:ietf:params:oauth:grant-type:token-exchange" + // indicates that a token exchange is being performed. + GrantType string `json:"grant_type"` + // OPTIONAL. Indicates the location of the target service or resource where + // the client intends to use the requested security token. + Resource string `json:"resource,omitempty"` + // OPTIONAL. The logical name of the target service where the client intends + // to use the requested security token. + Audience string `json:"audience,omitempty"` + // OPTIONAL. A list of space-delimited, case-sensitive strings, that allow + // the client to specify the desired scope of the requested security token + // in the context of the service or Resource where the token will be used. + Scope string `json:"scope,omitempty"` + // OPTIONAL. An identifier, for the type of the requested security token. + RequestedTokenType string `json:"requested_token_type,omitempty"` + // REQUIRED. A security token that represents the identity of the party on + // behalf of whom the request is being made. + SubjectToken string `json:"subject_token"` + // REQUIRED. An identifier, that indicates the type of the security token in + // the "subject_token" parameter. + SubjectTokenType string `json:"subject_token_type"` + // OPTIONAL. A security token that represents the identity of the acting + // party. + ActorToken string `json:"actor_token,omitempty"` + // An identifier, that indicates the type of the security token in the + // "actor_token" parameter. + ActorTokenType string `json:"actor_token_type,omitempty"` +} + +// nesponseParameters stores all attributes sent as JSON in a successful STS +// response. These attributes are defined in +// https://tools.ietf.org/html/rfc8693#section-2.2.1. +type responseParameters struct { + // REQUIRED. The security token issued by the authorization server + // in response to the token exchange request. + AccessToken string `json:"access_token"` + // REQUIRED. An identifier, representation of the issued security token. + IssuedTokenType string `json:"issued_token_type"` + // REQUIRED. A case-insensitive value specifying the method of using the access + // token issued. It provides the client with information about how to utilize the + // access token to access protected resources. + TokenType string `json:"token_type"` + // RECOMMENDED. The validity lifetime, in seconds, of the token issued by the + // authorization server. + ExpiresIn int64 `json:"expires_in"` + // OPTIONAL, if the Scope of the issued security token is identical to the + // Scope requested by the client; otherwise, REQUIRED. + Scope string `json:"scope"` + // OPTIONAL. A refresh token will typically not be issued when the exchange is + // of one temporary credential (the subject_token) for a different temporary + // credential (the issued token) for use in some other context. + RefreshToken string `json:"refresh_token"` +} + +// tokenInfo wraps the information received in a successful STS response. +type tokenInfo struct { + tokenType string + token string + expiryTime time.Time +} diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go new file mode 100644 index 000000000000..641bad5820bb --- /dev/null +++ b/credentials/sts/sts_test.go @@ -0,0 +1,764 @@ +// +build go1.13 + +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package sts + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" +) + +const ( + requestedTokenType = "urn:ietf:params:oauth:token-type:access-token" + actorTokenPath = "/var/run/secrets/token.jwt" + actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token" + actorTokenContents = "actorToken.jwt.contents" + accessTokenContents = "access_token" + subjectTokenPath = "/var/run/secrets/token.jwt" + subjectTokenType = "urn:ietf:params:oauth:token-type:id_token" + subjectTokenContents = "subjectToken.jwt.contents" + serviceURI = "http://localhost" + exampleResource = "https://backend.example.com/api" + exampleAudience = "example-backend-service" + testScope = "https://www.googleapis.com/auth/monitoring" +) + +var ( + goodOptions = Options{ + TokenExchangeServiceURI: serviceURI, + Audience: exampleAudience, + RequestedTokenType: requestedTokenType, + SubjectTokenPath: subjectTokenPath, + SubjectTokenType: subjectTokenType, + } + goodRequestParams = &requestParameters{ + GrantType: tokenExchangeGrantType, + Audience: exampleAudience, + Scope: defaultCloudPlatformScope, + RequestedTokenType: requestedTokenType, + SubjectToken: subjectTokenContents, + SubjectTokenType: subjectTokenType, + } + goodMetadata = map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents), + } +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// A struct that implements AuthInfo interface and added to the context passed +// to GetRequestMetadata from tests. +type testAuthInfo struct { + credentials.CommonAuthInfo +} + +func (ta testAuthInfo) AuthType() string { + return "testAuthInfo" +} + +func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context { + auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}} + ri := credentials.RequestInfo{ + Method: "testInfo", + AuthInfo: auth, + } + return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) +} + +// errReader implements the io.Reader interface and returns an error from the +// Read method. +type errReader struct{} + +func (r errReader) Read(b []byte) (n int, err error) { + return 0, errors.New("read error") +} + +// We need a function to construct the response instead of simply declaring it +// as a variable since the the response body will be consumed by the +// credentials, and therefore we will need a new one everytime. +func makeGoodResponse() *http.Response { + respJSON, _ := json.Marshal(responseParameters{ + AccessToken: accessTokenContents, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) + return &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: respBody, + } +} + +// fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials +// code under test. It makes the http.Request made by the credentials available +// through a channel, and makes it possible to inject various responses. +type fakeHTTPDoer struct { + reqCh *testutils.Channel + respCh *testutils.Channel + err error +} + +func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) { + fc.reqCh.Send(req) + val, err := fc.respCh.Receive() + if err != nil { + return nil, err + } + return val.(*http.Response), fc.err +} + +// Overrides the http.Client with a fakeClient which sends a good response. +func overrideHTTPClientGood() (*fakeHTTPDoer, func()) { + fc := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), + } + fc.respCh.Send(makeGoodResponse()) + + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + return fc, func() { makeHTTPDoer = origMakeHTTPDoer } +} + +// Overrides the http.Client with the provided fakeClient. +func overrideHTTPClient(fc *fakeHTTPDoer) func() { + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + return func() { makeHTTPDoer = origMakeHTTPDoer } +} + +// Overrides the subject token read to return a const which we can compare in +// our tests. +func overrideSubjectTokenGood() func() { + origReadSubjectTokenFrom := readSubjectTokenFrom + readSubjectTokenFrom = func(path string) ([]byte, error) { + return []byte(subjectTokenContents), nil + } + return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } +} + +// Overrides the subject token read to always return an error. +func overrideSubjectTokenError() func() { + origReadSubjectTokenFrom := readSubjectTokenFrom + readSubjectTokenFrom = func(path string) ([]byte, error) { + return nil, errors.New("error reading subject token") + } + return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } +} + +// Overrides the actor token read to return a const which we can compare in +// our tests. +func overrideActorTokenGood() func() { + origReadActorTokenFrom := readActorTokenFrom + readActorTokenFrom = func(path string) ([]byte, error) { + return []byte(actorTokenContents), nil + } + return func() { readActorTokenFrom = origReadActorTokenFrom } +} + +// Overrides the actor token read to always return an error. +func overrideActorTokenError() func() { + origReadActorTokenFrom := readActorTokenFrom + readActorTokenFrom = func(path string) ([]byte, error) { + return nil, errors.New("error reading actor token") + } + return func() { readActorTokenFrom = origReadActorTokenFrom } +} + +// compareRequest compares the http.Request received in the test with the +// expected requestParameters specified in wantReqParams. +func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error { + jsonBody, err := json.Marshal(wantReqParams) + if err != nil { + return err + } + wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody)) + if err != nil { + return fmt.Errorf("failed to create http request: %v", err) + } + wantReq.Header.Set("Content-Type", "application/json") + + wantR, err := httputil.DumpRequestOut(wantReq, true) + if err != nil { + return err + } + gotR, err := httputil.DumpRequestOut(gotRequest, true) + if err != nil { + return err + } + if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" { + return fmt.Errorf("sts request diff (-want +got):\n%s", diff) + } + return nil +} + +// receiveAndCompareRequest waits for a request to be sent out by the +// credentials implementation using the fakeHTTPClient and compares it to an +// expected goodRequest. This is expected to be called in a separate goroutine +// by the tests. So, any errors encountered are pushed to an error channel +// which is monitored by the test. +func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) { + val, err := reqCh.Receive() + if err != nil { + errCh <- err + return + } + req := val.(*http.Request) + if err := compareRequest(req, goodRequestParams); err != nil { + errCh <- err + return + } + errCh <- nil +} + +// TestGetRequestMetadataSuccess verifies the successful case of sending an +// token exchange request and processing the response. +func (s) TestGetRequestMetadataSuccess(t *testing.T) { + defer overrideSubjectTokenGood()() + fc, cancel := overrideHTTPClientGood() + defer cancel() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go receiveAndCompareRequest(fc.reqCh, errCh) + + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") + if err != nil { + t.Fatalf("creds.GetRequestMetadata() = %v", err) + } + if !cmp.Equal(gotMetadata, goodMetadata) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + + // Make another call to get request metadata and this should return contents + // from the cache. This will fail if the credentials tries to send a fresh + // request here since we have not configured our fakeClient to return any + // response on retries. + gotMetadata, err = creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") + if err != nil { + t.Fatalf("creds.GetRequestMetadata() = %v", err) + } + if !cmp.Equal(gotMetadata, goodMetadata) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) + } +} + +// TestGetRequestMetadataBadSecurityLevel verifies the case where the +// securityLevel specified in the context passed to GetRequestMetadata is not +// sufficient. +func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) { + defer overrideSubjectTokenGood()() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "") + if err == nil { + t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata) + } +} + +// TestGetRequestMetadataCacheExpiry verifies the case where the cached access +// token has expired, and the credentials implementation will have to send a +// fresh token exchange request. +func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { + const expiresInSecs = 1 + defer overrideSubjectTokenGood()() + fc := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), + } + defer overrideHTTPClient(fc)() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + // The fakeClient is configured to return an access_token with a one second + // expiry. So, in the second iteration, the credentials will find the cache + // entry, but that would have expired, and therefore we expect it to send + // out a fresh request. + for i := 0; i < 2; i++ { + errCh := make(chan error, 1) + go receiveAndCompareRequest(fc.reqCh, errCh) + + respJSON, _ := json.Marshal(responseParameters{ + AccessToken: accessTokenContents, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: expiresInSecs, + }) + respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) + resp := &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: respBody, + } + fc.respCh.Send(resp) + + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") + if err != nil { + t.Fatalf("creds.GetRequestMetadata() = %v", err) + } + if !cmp.Equal(gotMetadata, goodMetadata) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + time.Sleep(expiresInSecs * time.Second) + } +} + +// TestGetRequestMetadataBadResponses verifies the scenario where the token +// exchange server returns bad responses. +func (s) TestGetRequestMetadataBadResponses(t *testing.T) { + tests := []struct { + name string + response *http.Response + }{ + { + name: "bad JSON", + response: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader("not JSON")), + }, + }, + { + name: "no access token", + response: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader("{}")), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer overrideSubjectTokenGood()() + + fc := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), + } + defer overrideHTTPClient(fc)() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go receiveAndCompareRequest(fc.reqCh, errCh) + + fc.respCh.Send(test.response) + if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil { + t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + }) + } +} + +// TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the +// attempt to read the subjectToken fails. +func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { + defer overrideSubjectTokenError()() + fc, cancel := overrideHTTPClientGood() + defer cancel() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go func() { + if _, err := fc.reqCh.Receive(); err != testutils.ErrRecvTimeout { + errCh <- err + return + } + errCh <- nil + }() + + if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil { + t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") + } + if err := <-errCh; err != nil { + t.Fatal(err) + } +} + +func (s) TestNewCredentials(t *testing.T) { + tests := []struct { + name string + opts Options + errSystemRoots bool + wantErr bool + }{ + { + name: "invalid options - empty subjectTokenPath", + opts: Options{ + TokenExchangeServiceURI: serviceURI, + }, + wantErr: true, + }, + { + name: "invalid system root certs", + opts: goodOptions, + errSystemRoots: true, + wantErr: true, + }, + { + name: "good case", + opts: goodOptions, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.errSystemRoots { + oldSystemRoots := loadSystemCertPool + loadSystemCertPool = func() (*x509.CertPool, error) { + return nil, errors.New("failed to load system cert pool") + } + defer func() { + loadSystemCertPool = oldSystemRoots + }() + } + + creds, err := NewCredentials(test.opts) + if (err != nil) != test.wantErr { + t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr) + } + if err == nil { + if !creds.RequireTransportSecurity() { + t.Errorf("creds.RequireTransportSecurity() returned false") + } + } + }) + } +} + +func (s) TestValidateOptions(t *testing.T) { + tests := []struct { + name string + opts Options + wantErrPrefix string + }{ + { + name: "empty token exchange service URI", + opts: Options{}, + wantErrPrefix: "empty token_exchange_service_uri in options", + }, + { + name: "invalid URI", + opts: Options{ + TokenExchangeServiceURI: "\tI'm a bad URI\n", + }, + wantErrPrefix: "invalid control character in URL", + }, + { + name: "unsupported scheme", + opts: Options{ + TokenExchangeServiceURI: "unix:///path/to/socket", + }, + wantErrPrefix: "scheme is not supported", + }, + { + name: "empty subjectTokenPath", + opts: Options{ + TokenExchangeServiceURI: serviceURI, + }, + wantErrPrefix: "required field SubjectTokenPath is not specified", + }, + { + name: "empty subjectTokenType", + opts: Options{ + TokenExchangeServiceURI: serviceURI, + SubjectTokenPath: subjectTokenPath, + }, + wantErrPrefix: "required field SubjectTokenType is not specified", + }, + { + name: "good options", + opts: goodOptions, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := validateOptions(test.opts) + if (err != nil) != (test.wantErrPrefix != "") { + t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) + } + if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) { + t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) + } + }) + } +} + +func (s) TestConstructRequest(t *testing.T) { + tests := []struct { + name string + opts Options + subjectTokenReadErr bool + actorTokenReadErr bool + wantReqParams *requestParameters + wantErr bool + }{ + { + name: "subject token read failure", + subjectTokenReadErr: true, + opts: goodOptions, + wantErr: true, + }, + { + name: "actor token read failure", + actorTokenReadErr: true, + opts: Options{ + TokenExchangeServiceURI: serviceURI, + Audience: exampleAudience, + RequestedTokenType: requestedTokenType, + SubjectTokenPath: subjectTokenPath, + SubjectTokenType: subjectTokenType, + ActorTokenPath: actorTokenPath, + ActorTokenType: actorTokenType, + }, + wantErr: true, + }, + { + name: "default cloud platform scope", + opts: goodOptions, + wantReqParams: goodRequestParams, + }, + { + name: "all good", + opts: Options{ + TokenExchangeServiceURI: serviceURI, + Resource: exampleResource, + Audience: exampleAudience, + Scope: testScope, + RequestedTokenType: requestedTokenType, + SubjectTokenPath: subjectTokenPath, + SubjectTokenType: subjectTokenType, + ActorTokenPath: actorTokenPath, + ActorTokenType: actorTokenType, + }, + wantReqParams: &requestParameters{ + GrantType: tokenExchangeGrantType, + Resource: exampleResource, + Audience: exampleAudience, + Scope: testScope, + RequestedTokenType: requestedTokenType, + SubjectToken: subjectTokenContents, + SubjectTokenType: subjectTokenType, + ActorToken: actorTokenContents, + ActorTokenType: actorTokenType, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.subjectTokenReadErr { + defer overrideSubjectTokenError()() + } else { + defer overrideSubjectTokenGood()() + } + + if test.actorTokenReadErr { + defer overrideActorTokenError()() + } else { + defer overrideActorTokenGood()() + } + + gotRequest, err := constructRequest(context.Background(), test.opts) + if (err != nil) != test.wantErr { + t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr) + } + if test.wantErr { + return + } + if err := compareRequest(gotRequest, test.wantReqParams); err != nil { + t.Fatal(err) + } + }) + } +} + +func (s) TestSendRequest(t *testing.T) { + defer overrideSubjectTokenGood()() + req, err := constructRequest(context.Background(), goodOptions) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + resp *http.Response + respErr error + wantErr bool + }{ + { + name: "client error", + respErr: errors.New("http.Client.Do failed"), + wantErr: true, + }, + { + name: "bad response body", + resp: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(errReader{}), + }, + wantErr: true, + }, + { + name: "nonOK status code", + resp: &http.Response{ + Status: "400 BadRequest", + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + wantErr: true, + }, + { + name: "good case", + resp: makeGoodResponse(), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), + err: test.respErr, + } + client.respCh.Send(test.resp) + _, err := sendRequest(client, req) + if (err != nil) != test.wantErr { + t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr) + } + }) + } +} + +func (s) TestTokenInfoFromResponse(t *testing.T) { + noAccessToken, _ := json.Marshal(responseParameters{ + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + goodResponse, _ := json.Marshal(responseParameters{ + IssuedTokenType: requestedTokenType, + AccessToken: accessTokenContents, + TokenType: "Bearer", + ExpiresIn: 3600, + }) + + tests := []struct { + name string + respBody []byte + wantTokenInfo *tokenInfo + wantErr bool + }{ + { + name: "bad JSON", + respBody: []byte("not JSON"), + wantErr: true, + }, + { + name: "empty response", + respBody: []byte(""), + wantErr: true, + }, + { + name: "non-empty response with no access token", + respBody: noAccessToken, + wantErr: true, + }, + { + name: "good response", + respBody: goodResponse, + wantTokenInfo: &tokenInfo{ + tokenType: "Bearer", + token: accessTokenContents, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotTokenInfo, err := tokenInfoFromResponse(test.respBody) + if (err != nil) != test.wantErr { + t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr) + } + if test.wantErr { + return + } + // Can't do a cmp.Equal on the whole struct since the expiryField + // is populated based on time.Now(). + if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token { + t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo) + } + }) + } +}