From 53527ee97fd3f1a7d8ab88855139883e6de52ecb Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 4 Jun 2020 14:42:49 -0700 Subject: [PATCH 1/7] WIP: STS Call Creds Implementation. --- credentials/sts/sts.go | 427 +++++++++++++++++++++++++ credentials/sts/sts_test.go | 617 ++++++++++++++++++++++++++++++++++++ 2 files changed, 1044 insertions(+) create mode 100644 credentials/sts/sts.go create mode 100644 credentials/sts/sts_test.go diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go new file mode 100644 index 000000000000..cd2f937402ec --- /dev/null +++ b/credentials/sts/sts.go @@ -0,0 +1,427 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package sts implements call credentials using STS (Security Token Service) as +// defined in https://tools.ietf.org/html/rfc8693. +// +// Experimental +// +// Notice: All APIs in this package are experimental and may be changed or +// removed in a later release. +package sts + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "sync" + "time" + + grpcbackoff "google.golang.org/grpc/backoff" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/backoff" +) + +const ( + // HTTP request timeout set on the http.Client used to make STS requests. + stsRequestTimeout = 5 * time.Second + // If lifetime left in a cached token is lesser than this value, we fetch a + // new one instead of returning the current one. + minCachedTokenLifetime = 300 * time.Second + + tokenExchangeGrantType = "urn:ietf:params:oauth:grant-type:token-exchange" + defaultCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" +) + +// For overriding in tests. +var ( + loadSystemCertPool = x509.SystemCertPool + makeBackoffStrategy = defaultBackoffStrategy + makeHTTPDoer = makeHTTPClient + readSubjectTokenFrom = ioutil.ReadFile + readActorTokenFrom = ioutil.ReadFile +) + +// Options configures the parameters used for an STS based token exchange. +type Options struct { + // TokenExchangeServiceURI is the address of the server which implements STS + // token exchange functionality. + TokenExchangeServiceURI string // Required. + + // Resource is a URI that indicates the target service or resource where the + // client intends to use the requested security token. + Resource string // Optional. + + // Audience is the logical name of the target service where the client + // intends to use the requested security token + Audience string // Optional. + + // Scope is a list of space-delimited, case-sensitive strings, that allow + // the client to specify the desired scope of the requested security token + // in the context of the service or resource where the token will be used. + // If this field is left unspecified, a default value of + // https://www.googleapis.com/auth/cloud-platform will be used. + Scope string // Optional. + + // RequestedTokenType is an identifier, as described in + // https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of + // the requested security token. + RequestedTokenType string // Optional. + + // SubjectTokenPath is a filesystem path which contains the security token + // that represents the identity of the party on behalf of whom the request + // is being made. + SubjectTokenPath string // Required. + + // SubjectTokenType is an identifier, as described in + // https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of + // the security token in the "subject_token_path" parameter. + SubjectTokenType string // Required. + + // ActorTokenPath is a security token that represents the identity of the + // acting party. + ActorTokenPath string // Optional. + + // ActorTokenType is an identifier, as described in + // https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of + // the the security token in the "actor_token_path" parameter. + ActorTokenType string // Optional. +} + +// NewCredentials returns a new PerRPCCredentials implementation, configured +// using opts, which performs token exchange using STS. +func NewCredentials(opts Options) (credentials.PerRPCCredentials, error) { + if err := validateOptions(opts); err != nil { + return nil, err + } + + // Load the system roots to validate the certificate presented by the STS + // endpoint during the TLS handshake. + roots, err := loadSystemCertPool() + if err != nil { + return nil, err + } + + return &callCreds{ + opts: opts, + client: makeHTTPDoer(roots), + bs: makeBackoffStrategy(), + }, nil +} + +// defaultBackoffStrategy returns an exponential backoff strategy based on the +// default exponential backoff config, but with a smaller BaseDelay (because the +// former is meant more for connection establishment). +func defaultBackoffStrategy() backoff.Strategy { + config := grpcbackoff.DefaultConfig + config.BaseDelay = 100 * time.Millisecond + return backoff.Exponential{Config: config} +} + +// httpDoer wraps the single method on the http.Client type that we use. This +// helps with overriding in unittests. +type httpDoer interface { + Do(req *http.Request) (*http.Response, error) +} + +func makeHTTPClient(roots *x509.CertPool) httpDoer { + return &http.Client{ + Timeout: stsRequestTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: roots, + }, + }, + } +} + +// validateOptions performs the following validation checks on opts: +// - tokenExchangeServiceURI is not empty +// - tokenExchangeServiceURI is a valid URI with a http(s) scheme +// - subjectTokenPath and subjectTokenType are not empty. +func validateOptions(opts Options) error { + if opts.TokenExchangeServiceURI == "" { + return errors.New("empty token_exchange_service_uri in options") + } + u, err := url.Parse(opts.TokenExchangeServiceURI) + if err != nil { + return err + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("scheme is not supported: %s. Only http(s) is supported", u.Scheme) + } + + if opts.SubjectTokenPath == "" { + return errors.New("required field SubjectTokenPath is not specified") + } + if opts.SubjectTokenType == "" { + return errors.New("required field SubjectTokenType is not specified") + } + return nil +} + +// callCreds provides the implementation of call credentials based on an STS +// token exchange. +type callCreds struct { + opts Options + client httpDoer + bs backoff.Strategy + + // Cached accessToken to avoid an STS token exchange for every call to + // GetRequestMetadata. + mu sync.Mutex + cachedToken *tokenInfo +} + +// GetRequestMetadata returns the cached accessToken, if available and valid, or +// fetches a new one by performing an STS token exchange. +func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + if md := c.metadataFromCachedToken(); md != nil { + return md, nil + } + req, err := c.constructRequest() + if err != nil { + return nil, err + } + + // Send the request with exponential backoff and retry. Even though not + // retrying here is OK, as the connection attempt will be retried by the + // subConn, it is not very hard to perform some basic retries here. + respBody, err := c.sendRequestWithRetry(ctx, req) + if err != nil { + return nil, err + } + + ti, err := tokenInfoFromResponse(respBody) + if err != nil { + return nil, err + } + c.mu.Lock() + c.cachedToken = ti + c.mu.Unlock() + + return map[string]string{ + "Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token), + }, nil +} + +// metadataFromCachedToken returns the cached accessToken as request metadata, +// provided a cached accessToken exists and is not going to expire anytime soon. +func (c *callCreds) metadataFromCachedToken() map[string]string { + c.mu.Lock() + defer c.mu.Unlock() + + if c.cachedToken == nil { + return nil + } + + now := time.Now() + // If the cached token has not expired and the lifetime remaining on that + // token is greater than the minimum value we are willing to accept, go + // ahead and use it. + if c.cachedToken.expiryTime.After(now) && c.cachedToken.expiryTime.Sub(now) > minCachedTokenLifetime { + return map[string]string{ + "Authorization": fmt.Sprintf("%s %s", c.cachedToken.tokenType, c.cachedToken.token), + } + } + return nil +} + +// constructRequest creates the STS request body in JSON based on the options +// received when the credentials type was created. +// - Contents of the subjectToken are read from the file specified in +// options. If we encounter an error here, we bail out. +// - Contents of the actorToken are read from the file specified in options. +// If we encounter an error here, we ignore this field because this is +// optional. +// - Most of the other fields in the request come directly from options. +func (c *callCreds) constructRequest() (*http.Request, error) { + subToken, err := readSubjectTokenFrom(c.opts.SubjectTokenPath) + if err != nil { + return nil, err + } + reqScope := c.opts.Scope + if reqScope == "" { + reqScope = defaultCloudPlatformScope + } + reqParams := &RequestParameters{ + GrantType: tokenExchangeGrantType, + Resource: c.opts.Resource, + Audience: c.opts.Audience, + Scope: reqScope, + RequestedTokenType: c.opts.RequestedTokenType, + SubjectToken: string(subToken), + SubjectTokenType: c.opts.SubjectTokenType, + } + actorToken, err := readActorTokenFrom(c.opts.ActorTokenPath) + if err == nil { + reqParams.ActorToken = string(actorToken) + reqParams.ActorTokenType = c.opts.ActorTokenType + } + jsonBody, err := json.Marshal(reqParams) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", c.opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create http request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + return req, nil +} + +// sendRequestWithRetry sends the provided http.Request and retries with +// exponential backoff for certain types of errors. It takes care of closing the +// response body on success (returns the contents of the body) and retries. +func (c *callCreds) sendRequestWithRetry(ctx context.Context, req *http.Request) ([]byte, error) { + retries := 0 + for { + // Wait for one of exponential backoff or context deadline to expire. + if retries > 0 { + timer := time.NewTimer(c.bs.Backoff(retries)) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + + // http.Client returns a non-nil error only if it encounters an error + // caused by client policy (such as CheckRedirect), or failure to speak + // HTTP (such as a network connectivity problem). A non-2xx status code + // doesn't cause an error. + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + + // When the http.Client returns a non-nil error, it is the + // responsibility of the caller to read the response body till an EOF is + // encountered and to close it. + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + retries++ + continue + } + + if resp.StatusCode == http.StatusOK { + return body, nil + } + if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError { + // For 4xx errors, which are client errors, we do not retry. + return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body)) + } + retries++ + grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body)) + } +} + +func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { + respData := &ResponseParameters{} + if err := json.Unmarshal(respBody, respData); err != nil { + return nil, fmt.Errorf("json.Unmarshal(%v): %v", respBody, err) + } + if respData.AccessToken == "" { + return nil, fmt.Errorf("empty accessToken in response (%v)", string(respBody)) + } + return &tokenInfo{ + tokenType: respData.TokenType, + token: respData.AccessToken, + expiryTime: time.Now().Add(time.Duration(respData.ExpiresIn) * time.Second), + }, nil +} + +// RequireTransportSecurity indicates whether the credentials requires +// transport security. +func (c *callCreds) RequireTransportSecurity() bool { + return true +} + +// RequestParameters stores all STS request attributes defined in +// https://tools.ietf.org/html/rfc8693#section-2.1. +type RequestParameters struct { + // REQUIRED. The value "urn:ietf:params:oauth:grant-type:token-exchange" + // indicates that a token exchange is being performed. + GrantType string `json:"grant_type"` + // OPTIONAL. Indicates the location of the target service or resource where + // the client intends to use the requested security token. + Resource string `json:"resource"` + // OPTIONAL. The logical name of the target service where the client intends + // to use the requested security token. + Audience string `json:"audience"` + // OPTIONAL. A list of space-delimited, case-sensitive strings, that allow + // the client to specify the desired scope of the requested security token + // in the context of the service or Resource where the token will be used. + Scope string `json:"scope"` + // OPTIONAL. An identifier, for the type of the requested security token. + RequestedTokenType string `json:"requested_token_type"` + // REQUIRED. A security token that represents the identity of the party on + // behalf of whom the request is being made. + SubjectToken string `json:"subject_token"` + // REQUIRED. An identifier, that indicates the type of the security token in + // the "subject_token" parameter. + SubjectTokenType string `json:"subject_token_type"` + // OPTIONAL. A security token that represents the identity of the acting + // party. + ActorToken string `json:"actor_token"` + // An identifier, that indicates the type of the security token in the + // "actor_token" parameter. + ActorTokenType string `json:"actor_token_type"` +} + +// ResponseParameters stores all attributes sent as JSON in a successful STS +// response. These attributes are defined in +// https://tools.ietf.org/html/rfc8693#section-2.2.1. +type ResponseParameters struct { + // REQUIRED. The security token issued by the authorization server + // in response to the token exchange request. + AccessToken string `json:"access_token"` + // REQUIRED. An identifier, representation of the issued security token. + IssuedTokenType string `json:"issued_token_type"` + // REQUIRED. A case-insensitive value specifying the method of using the access + // token issued. It provides the client with information about how to utilize the + // access token to access protected resources. + TokenType string `json:"token_type"` + // RECOMMENDED. The validity lifetime, in seconds, of the token issued by the + // authorization server. + ExpiresIn int64 `json:"expires_in"` + // OPTIONAL, if the Scope of the issued security token is identical to the + // Scope requested by the client; otherwise, REQUIRED. + Scope string `json:"scope"` + // OPTIONAL. A refresh token will typically not be issued when the exchange is + // of one temporary credential (the subject_token) for a different temporary + // credential (the issued token) for use in some other context. + RefreshToken string `json:"refresh_token"` +} + +// tokenInfo wraps the information received in a successful STS response. +type tokenInfo struct { + tokenType string + token string + expiryTime time.Time +} diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go new file mode 100644 index 000000000000..34861dc4f03d --- /dev/null +++ b/credentials/sts/sts_test.go @@ -0,0 +1,617 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package sts + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/internal/testutils" +) + +const ( + subjectTokenContents = "subjectToken.jwt.contents" + actorTokenContents = "actorToken.jwt.contents" + accessTokenContents = "access_token" +) + +var ( + goodOptions = Options{ + TokenExchangeServiceURI: "http://localhost", + SubjectTokenPath: "/var/run/secrets/token.jwt", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + } + goodMetadata = map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents), + } +) + +// fakeHTTPClient helps mock out the HTTP calls made by the credentials code +// under test. It makes the http.Request made by the credentials available +// through a channel, and makes it possible to inject various responses. +type fakeHTTPClient struct { + reqCh *testutils.Channel + // When no retry is involve, only these two fields need to be populated. + firstResp *http.Response + firstErr error + // To test retry scenarios with a different response upon retry. + subsequentResp *http.Response + subsequentErr error + + numCalls int +} + +func (fc *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + fc.numCalls++ + fc.reqCh.Send(req) + if fc.numCalls > 1 { + return fc.subsequentResp, fc.subsequentErr + } + return fc.firstResp, fc.firstErr +} + +// fakeBackoff implements backoff.Strategy and pushes on a channel to indicate +// that a backoff was attempted. +type fakeBackoff struct { + boCh *testutils.Channel +} + +func (fb *fakeBackoff) Backoff(retries int) time.Duration { + fb.boCh.Send(retries) + return 0 +} + +// errReader implements the io.Reader interface and returns an error from the +// Read method. +type errReader struct{} + +func (r errReader) Read(b []byte) (n int, err error) { + return 0, errors.New("read error") +} + +// We need a function to construct the response instead of simply declaring it +// as a variable since the the response body will be consumed by the +// credentials, and therefore we will need a new one everytime. +func makeGoodResponse() *http.Response { + respJSON, _ := json.Marshal(ResponseParameters{ + AccessToken: accessTokenContents, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) + return &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: respBody, + } +} + +// Overrides the http.Client with a fakeClient which sends a good response. +func overrideHTTPClientGood() (*fakeHTTPClient, func()) { + fc := &fakeHTTPClient{ + reqCh: testutils.NewChannel(), + firstResp: makeGoodResponse(), + } + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + return fc, func() { makeHTTPDoer = origMakeHTTPDoer } +} + +// Overrides the subject token read to return a const which we can compare in +// our tests. +func overrideSubjectTokenGood() func() { + origReadSubjectTokenFrom := readSubjectTokenFrom + readSubjectTokenFrom = func(path string) ([]byte, error) { + return []byte(subjectTokenContents), nil + } + return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } +} + +// compareRequestWithRetry is run in a separate goroutine by tests to perform +// the following: +// - wait for a http request to be made by the credentials type and compare it +// with an expected one. +// - if the credentials is expected to retry, verify that a backoff was done +// before the retry. +// If any of the above steps fail, an error is pushed on the errCh. +func compareRequestWithRetry(errCh chan error, wantRetry bool, reqCh, boCh *testutils.Channel) { + val, err := reqCh.Receive() + if err != nil { + errCh <- err + return + } + req := val.(*http.Request) + if err := compareRequest(goodOptions, req); err != nil { + errCh <- err + return + } + + if wantRetry { + _, err := boCh.Receive() + if err != nil { + errCh <- err + return + } + } + errCh <- nil +} + +func compareRequest(opts Options, gotRequest *http.Request) error { + reqScope := opts.Scope + if reqScope == "" { + reqScope = defaultCloudPlatformScope + } + reqParams := &RequestParameters{ + GrantType: tokenExchangeGrantType, + Resource: opts.Resource, + Audience: opts.Audience, + Scope: reqScope, + RequestedTokenType: opts.RequestedTokenType, + SubjectToken: subjectTokenContents, + SubjectTokenType: opts.SubjectTokenType, + } + if opts.ActorTokenPath != "" { + reqParams.ActorToken = actorTokenContents + reqParams.ActorTokenType = opts.ActorTokenType + } + jsonBody, err := json.Marshal(reqParams) + if err != nil { + return err + } + wantReq, err := http.NewRequest("POST", opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody)) + if err != nil { + return fmt.Errorf("failed to create http request: %v", err) + } + wantReq.Header.Set("Content-Type", "application/json") + + wantR, err := httputil.DumpRequestOut(wantReq, true) + if err != nil { + return err + } + gotR, err := httputil.DumpRequestOut(gotRequest, true) + if err != nil { + return err + } + if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" { + return fmt.Errorf("sts request diff (-want +got):\n%s", diff) + } + return nil +} + +// TestGetRequestMetadataSuccess verifies the successful case of sending an +// token exchange request and processing the response. +func TestGetRequestMetadataSuccess(t *testing.T) { + defer overrideSubjectTokenGood()() + fc, cancel := overrideHTTPClientGood() + defer cancel() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + + gotMetadata, err := creds.GetRequestMetadata(context.Background(), "") + if err != nil { + t.Fatalf("creds.GetRequestMetadata() = %v", err) + } + if !cmp.Equal(gotMetadata, goodMetadata) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + + // Make another call to get request metadata and this should return contents + // from the cache. This will fail if the credentials tries to send a fresh + // request here since we have not configured our fakeClient to return any + // response on retries. + gotMetadata, err = creds.GetRequestMetadata(context.Background(), "") + if err != nil { + t.Fatalf("creds.GetRequestMetadata() = %v", err) + } + if !cmp.Equal(gotMetadata, goodMetadata) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) + } +} + +// TestGetRequestMetadataCacheExpiry verifies the case where the cached access +// token has expired, and the credentials implementation will have to send a +// fresh token exchange request. +func TestGetRequestMetadataCacheExpiry(t *testing.T) { + const expiresInSecs = 1 + defer overrideSubjectTokenGood()() + respJSON, _ := json.Marshal(ResponseParameters{ + AccessToken: accessTokenContents, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: expiresInSecs, + }) + respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) + resp := &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: respBody, + } + fc := &fakeHTTPClient{ + reqCh: testutils.NewChannel(), + firstResp: resp, + subsequentResp: makeGoodResponse(), + } + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + defer func() { makeHTTPDoer = origMakeHTTPDoer }() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + // The fakeClient is configured to return an access_token with a one second + // expiry. So, in the second iteration, the credentials will find the cache + // entry, but that would have expired, and therefore we expect it to send + // out a fresh request. + for i := 0; i < 2; i++ { + errCh := make(chan error, 1) + go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + + gotMetadata, err := creds.GetRequestMetadata(context.Background(), "") + if err != nil { + t.Fatalf("creds.GetRequestMetadata() = %v", err) + } + if !cmp.Equal(gotMetadata, goodMetadata) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + time.Sleep(expiresInSecs * time.Second) + } +} + +// TestGetRequestMetadataBadResponses verifies the scenario where the token +// exchange server returns bad responses. +func TestGetRequestMetadataBadResponses(t *testing.T) { + tests := []struct { + name string + response *http.Response + }{ + { + name: "bad JSON", + response: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader("not JSON")), + }, + }, + { + name: "no access token", + response: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader("{}")), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer overrideSubjectTokenGood()() + + fc := &fakeHTTPClient{ + reqCh: testutils.NewChannel(), + firstResp: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader("not JSON")), + }, + } + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + defer func() { makeHTTPDoer = origMakeHTTPDoer }() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + if _, err := creds.GetRequestMetadata(context.Background(), ""); err == nil { + t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + }) + } +} + +// TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the +// attempt to read the subjectToken fails. +func TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { + origReadSubjectTokenFrom := readSubjectTokenFrom + readSubjectTokenFrom = func(path string) ([]byte, error) { + return nil, errors.New("failed to read subject token") + } + defer func() { readSubjectTokenFrom = origReadSubjectTokenFrom }() + + fc, cancel := overrideHTTPClientGood() + defer cancel() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go func() { + if _, err := fc.reqCh.Receive(); err != testutils.ErrRecvTimeout { + errCh <- err + return + } + errCh <- nil + }() + + if _, err = creds.GetRequestMetadata(context.Background(), ""); err == nil { + t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") + } + if err := <-errCh; err != nil { + t.Fatal(err) + } +} + +// TestGetRequestMetadataRetry verifies various retry scenarios. +func TestGetRequestMetadataRetry(t *testing.T) { + tests := []struct { + name string + firstResp *http.Response + firstErr error + subsequentResp *http.Response + subsequentErr error + wantRetry bool + wantErr bool + wantMetadata map[string]string + }{ + { + name: "httpClient.Do error", + firstErr: errors.New("httpClient.Do() failed"), + wantErr: true, + }, + { + name: "bad response body first time", + firstResp: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(errReader{}), + }, + subsequentResp: makeGoodResponse(), + wantRetry: true, + wantMetadata: goodMetadata, + }, + { + name: "http client error status code", + firstResp: &http.Response{ + Status: "400 BadRequest", + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(&bytes.Reader{}), + }, + wantErr: true, + }, + { + name: "server error first time", + firstResp: &http.Response{ + Status: "400 BadRequest", + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(&bytes.Reader{}), + }, + subsequentResp: makeGoodResponse(), + wantRetry: true, + wantMetadata: goodMetadata, + }, + } + + // The test body performs the following steps: + // 1. Overrides the function to read subjectToken file and returns arbitrary + // data and nil error. + // 2. Overrides the function to return a http.Client and returns a fake + // client which is configured with response/error values to be returned. + // 3. Overrides the function to create the backoff strategy and returns a + // fake implementation which notifies the test through a channel when + // backoff is attempted. + // 4. Creates a new credentials type and invokes the GetRequestMetadata + // method on it. + // 5. Spawn a goroutine which verifies that the credentials sent out the + // expected http.Request, and performed a backoff when it encountered + // certain errors. + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer overrideSubjectTokenGood()() + + fc := &fakeHTTPClient{ + reqCh: testutils.NewChannel(), + firstResp: test.firstResp, + firstErr: test.firstErr, + subsequentResp: test.subsequentResp, + subsequentErr: test.subsequentErr, + } + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + + origBackoff := makeBackoffStrategy + fb := &fakeBackoff{boCh: testutils.NewChannel()} + makeBackoffStrategy = func() backoff.Strategy { return fb } + + defer func() { + makeHTTPDoer = origMakeHTTPDoer + makeBackoffStrategy = origBackoff + }() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go compareRequestWithRetry(errCh, test.wantRetry, fc.reqCh, fb.boCh) + + gotMetadata, err := creds.GetRequestMetadata(context.Background(), "") + if (err != nil) != test.wantErr { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", err, test.wantErr) + } + if !cmp.Equal(gotMetadata, test.wantMetadata, cmpopts.EquateEmpty()) { + t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, test.wantMetadata) + } + if err := <-errCh; err != nil { + t.Fatal(err) + } + }) + } +} + +func TestNewCredentials(t *testing.T) { + tests := []struct { + name string + opts Options + errSystemRoots bool + wantErr bool + }{ + { + name: "invalid options - empty subjectTokenPath", + opts: Options{ + TokenExchangeServiceURI: "http://localhost", + }, + wantErr: true, + }, + { + name: "invalid system root certs", + opts: goodOptions, + errSystemRoots: true, + wantErr: true, + }, + { + name: "good case", + opts: goodOptions, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.errSystemRoots { + oldSystemRoots := loadSystemCertPool + loadSystemCertPool = func() (*x509.CertPool, error) { + return nil, errors.New("failed to load system cert pool") + } + defer func() { + loadSystemCertPool = oldSystemRoots + }() + } + + creds, err := NewCredentials(test.opts) + if (err != nil) != test.wantErr { + t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr) + } + if err == nil { + if !creds.RequireTransportSecurity() { + t.Errorf("creds.RequireTransportSecurity() returned false") + } + } + }) + } +} + +func TestValidateOptions(t *testing.T) { + tests := []struct { + name string + opts Options + wantErrPrefix string + }{ + { + name: "empty token exchange service URI", + opts: Options{}, + wantErrPrefix: "empty token_exchange_service_uri in options", + }, + { + name: "invalid URI", + opts: Options{ + TokenExchangeServiceURI: "\tI'm a bad URI\n", + }, + wantErrPrefix: "invalid control character in URL", + }, + { + name: "unsupported scheme", + opts: Options{ + TokenExchangeServiceURI: "unix:///path/to/socket", + }, + wantErrPrefix: "scheme is not supported", + }, + { + name: "empty subjectTokenPath", + opts: Options{ + TokenExchangeServiceURI: "http://localhost", + }, + wantErrPrefix: "required field SubjectTokenPath is not specified", + }, + { + name: "empty subjectTokenType", + opts: Options{ + TokenExchangeServiceURI: "http://localhost", + SubjectTokenPath: "/var/run/secrets/token.jwt", + }, + wantErrPrefix: "required field SubjectTokenType is not specified", + }, + { + name: "good options", + opts: Options{ + TokenExchangeServiceURI: "http://localhost", + SubjectTokenPath: "/var/run/secrets/token.jwt", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := validateOptions(test.opts) + if (err != nil) != (test.wantErrPrefix != "") { + t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) + } + if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) { + t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) + } + }) + } +} From 558fca568fcdaf1acab272080812bff6c7ac9d06 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Fri, 19 Jun 2020 10:36:12 -0700 Subject: [PATCH 2/7] Add omitempty JSON tags to fields which can be empty. --- credentials/sts/sts.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index cd2f937402ec..0dfb6c7cc7fd 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -370,16 +370,16 @@ type RequestParameters struct { GrantType string `json:"grant_type"` // OPTIONAL. Indicates the location of the target service or resource where // the client intends to use the requested security token. - Resource string `json:"resource"` + Resource string `json:"resource,omitempty"` // OPTIONAL. The logical name of the target service where the client intends // to use the requested security token. - Audience string `json:"audience"` + Audience string `json:"audience,omitempty"` // OPTIONAL. A list of space-delimited, case-sensitive strings, that allow // the client to specify the desired scope of the requested security token // in the context of the service or Resource where the token will be used. - Scope string `json:"scope"` + Scope string `json:"scope,omitempty"` // OPTIONAL. An identifier, for the type of the requested security token. - RequestedTokenType string `json:"requested_token_type"` + RequestedTokenType string `json:"requested_token_type,omitempty"` // REQUIRED. A security token that represents the identity of the party on // behalf of whom the request is being made. SubjectToken string `json:"subject_token"` @@ -388,10 +388,10 @@ type RequestParameters struct { SubjectTokenType string `json:"subject_token_type"` // OPTIONAL. A security token that represents the identity of the acting // party. - ActorToken string `json:"actor_token"` + ActorToken string `json:"actor_token,omitempty"` // An identifier, that indicates the type of the security token in the // "actor_token" parameter. - ActorTokenType string `json:"actor_token_type"` + ActorTokenType string `json:"actor_token_type,omitempty"` } // ResponseParameters stores all attributes sent as JSON in a successful STS From 1053ce51f8591719144c333b8e2ffa6b5cfd5978 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Tue, 7 Jul 2020 16:00:19 -0700 Subject: [PATCH 3/7] Use credentials.CheckSecurityLevel(). --- credentials/sts/sts.go | 3 ++ credentials/sts/sts_test.go | 78 ++++++++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index 0dfb6c7cc7fd..f7c6b82982e9 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -199,6 +199,9 @@ type callCreds struct { // GetRequestMetadata returns the cached accessToken, if available and valid, or // fetches a new one by performing an STS token exchange. func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err) + } if md := c.metadataFromCachedToken(); md != nil { return md, nil } diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go index 34861dc4f03d..75cc6a58d936 100644 --- a/credentials/sts/sts_test.go +++ b/credentials/sts/sts_test.go @@ -35,7 +35,10 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" ) @@ -56,6 +59,33 @@ var ( } ) +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// A struct that implements AuthInfo interface and implements CommonAuthInfo() +// method. +type testAuthInfo struct { + credentials.CommonAuthInfo +} + +func (ta testAuthInfo) AuthType() string { + return "testAuthInfo" +} + +func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context { + auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}} + ri := credentials.RequestInfo{ + Method: "testInfo", + AuthInfo: auth, + } + return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) +} + // fakeHTTPClient helps mock out the HTTP calls made by the credentials code // under test. It makes the http.Request made by the credentials available // through a channel, and makes it possible to inject various responses. @@ -211,7 +241,7 @@ func compareRequest(opts Options, gotRequest *http.Request) error { // TestGetRequestMetadataSuccess verifies the successful case of sending an // token exchange request and processing the response. -func TestGetRequestMetadataSuccess(t *testing.T) { +func (s) TestGetRequestMetadataSuccess(t *testing.T) { defer overrideSubjectTokenGood()() fc, cancel := overrideHTTPClientGood() defer cancel() @@ -224,7 +254,7 @@ func TestGetRequestMetadataSuccess(t *testing.T) { errCh := make(chan error, 1) go compareRequestWithRetry(errCh, false, fc.reqCh, nil) - gotMetadata, err := creds.GetRequestMetadata(context.Background(), "") + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") if err != nil { t.Fatalf("creds.GetRequestMetadata() = %v", err) } @@ -239,7 +269,7 @@ func TestGetRequestMetadataSuccess(t *testing.T) { // from the cache. This will fail if the credentials tries to send a fresh // request here since we have not configured our fakeClient to return any // response on retries. - gotMetadata, err = creds.GetRequestMetadata(context.Background(), "") + gotMetadata, err = creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") if err != nil { t.Fatalf("creds.GetRequestMetadata() = %v", err) } @@ -248,10 +278,32 @@ func TestGetRequestMetadataSuccess(t *testing.T) { } } +// TestGetRequestMetadataBadSecurityLevel verifies the case where the +// securityLevel specified in the context passed to GetRequestMetadata is not +// sufficient. +func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) { + defer overrideSubjectTokenGood()() + fc, cancel := overrideHTTPClientGood() + defer cancel() + + creds, err := NewCredentials(goodOptions) + if err != nil { + t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) + } + + errCh := make(chan error, 1) + go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "") + if err == nil { + t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata) + } +} + // TestGetRequestMetadataCacheExpiry verifies the case where the cached access // token has expired, and the credentials implementation will have to send a // fresh token exchange request. -func TestGetRequestMetadataCacheExpiry(t *testing.T) { +func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { const expiresInSecs = 1 defer overrideSubjectTokenGood()() respJSON, _ := json.Marshal(ResponseParameters{ @@ -288,7 +340,7 @@ func TestGetRequestMetadataCacheExpiry(t *testing.T) { errCh := make(chan error, 1) go compareRequestWithRetry(errCh, false, fc.reqCh, nil) - gotMetadata, err := creds.GetRequestMetadata(context.Background(), "") + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") if err != nil { t.Fatalf("creds.GetRequestMetadata() = %v", err) } @@ -304,7 +356,7 @@ func TestGetRequestMetadataCacheExpiry(t *testing.T) { // TestGetRequestMetadataBadResponses verifies the scenario where the token // exchange server returns bad responses. -func TestGetRequestMetadataBadResponses(t *testing.T) { +func (s) TestGetRequestMetadataBadResponses(t *testing.T) { tests := []struct { name string response *http.Response @@ -350,7 +402,7 @@ func TestGetRequestMetadataBadResponses(t *testing.T) { errCh := make(chan error, 1) go compareRequestWithRetry(errCh, false, fc.reqCh, nil) - if _, err := creds.GetRequestMetadata(context.Background(), ""); err == nil { + if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil { t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") } if err := <-errCh; err != nil { @@ -362,7 +414,7 @@ func TestGetRequestMetadataBadResponses(t *testing.T) { // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the // attempt to read the subjectToken fails. -func TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { +func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { origReadSubjectTokenFrom := readSubjectTokenFrom readSubjectTokenFrom = func(path string) ([]byte, error) { return nil, errors.New("failed to read subject token") @@ -386,7 +438,7 @@ func TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { errCh <- nil }() - if _, err = creds.GetRequestMetadata(context.Background(), ""); err == nil { + if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil { t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") } if err := <-errCh; err != nil { @@ -395,7 +447,7 @@ func TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { } // TestGetRequestMetadataRetry verifies various retry scenarios. -func TestGetRequestMetadataRetry(t *testing.T) { +func (s) TestGetRequestMetadataRetry(t *testing.T) { tests := []struct { name string firstResp *http.Response @@ -488,7 +540,7 @@ func TestGetRequestMetadataRetry(t *testing.T) { errCh := make(chan error, 1) go compareRequestWithRetry(errCh, test.wantRetry, fc.reqCh, fb.boCh) - gotMetadata, err := creds.GetRequestMetadata(context.Background(), "") + gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") if (err != nil) != test.wantErr { t.Fatalf("creds.GetRequestMetadata() = %v, want %v", err, test.wantErr) } @@ -502,7 +554,7 @@ func TestGetRequestMetadataRetry(t *testing.T) { } } -func TestNewCredentials(t *testing.T) { +func (s) TestNewCredentials(t *testing.T) { tests := []struct { name string opts Options @@ -553,7 +605,7 @@ func TestNewCredentials(t *testing.T) { } } -func TestValidateOptions(t *testing.T) { +func (s) TestValidateOptions(t *testing.T) { tests := []struct { name string opts Options From 9e607f58af40fc8e2f7cc3387347fd6a8f42434f Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Wed, 8 Jul 2020 15:05:07 -0700 Subject: [PATCH 4/7] Remove code to retry with exponential backoff. Also, refactor the code a bit for easier and better testing. --- credentials/sts/sts.go | 209 ++++++------- credentials/sts/sts_test.go | 593 +++++++++++++++++++++--------------- 2 files changed, 430 insertions(+), 372 deletions(-) diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index f7c6b82982e9..60ef89778d32 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -39,10 +39,8 @@ import ( "sync" "time" - grpcbackoff "google.golang.org/grpc/backoff" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/backoff" ) const ( @@ -59,7 +57,6 @@ const ( // For overriding in tests. var ( loadSystemCertPool = x509.SystemCertPool - makeBackoffStrategy = defaultBackoffStrategy makeHTTPDoer = makeHTTPClient readSubjectTokenFrom = ioutil.ReadFile readActorTokenFrom = ioutil.ReadFile @@ -128,17 +125,55 @@ func NewCredentials(opts Options) (credentials.PerRPCCredentials, error) { return &callCreds{ opts: opts, client: makeHTTPDoer(roots), - bs: makeBackoffStrategy(), }, nil } -// defaultBackoffStrategy returns an exponential backoff strategy based on the -// default exponential backoff config, but with a smaller BaseDelay (because the -// former is meant more for connection establishment). -func defaultBackoffStrategy() backoff.Strategy { - config := grpcbackoff.DefaultConfig - config.BaseDelay = 100 * time.Millisecond - return backoff.Exponential{Config: config} +// callCreds provides the implementation of call credentials based on an STS +// token exchange. +type callCreds struct { + opts Options + client httpDoer + + // Cached accessToken to avoid an STS token exchange for every call to + // GetRequestMetadata. + mu sync.Mutex + cachedToken *tokenInfo +} + +// GetRequestMetadata returns the cached accessToken, if available and valid, or +// fetches a new one by performing an STS token exchange. +func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err) + } + if md := c.metadataFromCachedToken(); md != nil { + return md, nil + } + req, err := constructRequest(ctx, c.opts) + if err != nil { + return nil, err + } + respBody, err := sendRequest(c.client, req) + if err != nil { + return nil, err + } + ti, err := tokenInfoFromResponse(respBody) + if err != nil { + return nil, err + } + c.mu.Lock() + c.cachedToken = ti + c.mu.Unlock() + + return map[string]string{ + "Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token), + }, nil +} + +// RequireTransportSecurity indicates whether the credentials requires +// transport security. +func (c *callCreds) RequireTransportSecurity() bool { + return true } // httpDoer wraps the single method on the http.Client type that we use. This @@ -183,54 +218,6 @@ func validateOptions(opts Options) error { return nil } -// callCreds provides the implementation of call credentials based on an STS -// token exchange. -type callCreds struct { - opts Options - client httpDoer - bs backoff.Strategy - - // Cached accessToken to avoid an STS token exchange for every call to - // GetRequestMetadata. - mu sync.Mutex - cachedToken *tokenInfo -} - -// GetRequestMetadata returns the cached accessToken, if available and valid, or -// fetches a new one by performing an STS token exchange. -func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { - if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { - return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err) - } - if md := c.metadataFromCachedToken(); md != nil { - return md, nil - } - req, err := c.constructRequest() - if err != nil { - return nil, err - } - - // Send the request with exponential backoff and retry. Even though not - // retrying here is OK, as the connection attempt will be retried by the - // subConn, it is not very hard to perform some basic retries here. - respBody, err := c.sendRequestWithRetry(ctx, req) - if err != nil { - return nil, err - } - - ti, err := tokenInfoFromResponse(respBody) - if err != nil { - return nil, err - } - c.mu.Lock() - c.cachedToken = ti - c.mu.Unlock() - - return map[string]string{ - "Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token), - }, nil -} - // metadataFromCachedToken returns the cached accessToken as request metadata, // provided a cached accessToken exists and is not going to expire anytime soon. func (c *callCreds) metadataFromCachedToken() map[string]string { @@ -253,42 +240,48 @@ func (c *callCreds) metadataFromCachedToken() map[string]string { return nil } -// constructRequest creates the STS request body in JSON based on the options -// received when the credentials type was created. +// constructRequest creates the STS request body in JSON based on the provided +// options. // - Contents of the subjectToken are read from the file specified in // options. If we encounter an error here, we bail out. // - Contents of the actorToken are read from the file specified in options. // If we encounter an error here, we ignore this field because this is // optional. // - Most of the other fields in the request come directly from options. -func (c *callCreds) constructRequest() (*http.Request, error) { - subToken, err := readSubjectTokenFrom(c.opts.SubjectTokenPath) +// +// A new HTTP request is created by calling http.NewRequestWithContext() and +// passing the provided context, thereby enforcing any timeouts specified in +// the latter. +func constructRequest(ctx context.Context, opts Options) (*http.Request, error) { + subToken, err := readSubjectTokenFrom(opts.SubjectTokenPath) if err != nil { return nil, err } - reqScope := c.opts.Scope + reqScope := opts.Scope if reqScope == "" { reqScope = defaultCloudPlatformScope } reqParams := &RequestParameters{ GrantType: tokenExchangeGrantType, - Resource: c.opts.Resource, - Audience: c.opts.Audience, + Resource: opts.Resource, + Audience: opts.Audience, Scope: reqScope, - RequestedTokenType: c.opts.RequestedTokenType, + RequestedTokenType: opts.RequestedTokenType, SubjectToken: string(subToken), - SubjectTokenType: c.opts.SubjectTokenType, + SubjectTokenType: opts.SubjectTokenType, } - actorToken, err := readActorTokenFrom(c.opts.ActorTokenPath) - if err == nil { - reqParams.ActorToken = string(actorToken) - reqParams.ActorTokenType = c.opts.ActorTokenType + if opts.ActorTokenPath != "" { + actorToken, err := readActorTokenFrom(opts.ActorTokenPath) + if err == nil { + reqParams.ActorToken = string(actorToken) + reqParams.ActorTokenType = opts.ActorTokenType + } } jsonBody, err := json.Marshal(reqParams) if err != nil { return nil, err } - req, err := http.NewRequest("POST", c.opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create http request: %v", err) } @@ -296,52 +289,30 @@ func (c *callCreds) constructRequest() (*http.Request, error) { return req, nil } -// sendRequestWithRetry sends the provided http.Request and retries with -// exponential backoff for certain types of errors. It takes care of closing the -// response body on success (returns the contents of the body) and retries. -func (c *callCreds) sendRequestWithRetry(ctx context.Context, req *http.Request) ([]byte, error) { - retries := 0 - for { - // Wait for one of exponential backoff or context deadline to expire. - if retries > 0 { - timer := time.NewTimer(c.bs.Backoff(retries)) - select { - case <-ctx.Done(): - timer.Stop() - return nil, ctx.Err() - case <-timer.C: - } - } - - // http.Client returns a non-nil error only if it encounters an error - // caused by client policy (such as CheckRedirect), or failure to speak - // HTTP (such as a network connectivity problem). A non-2xx status code - // doesn't cause an error. - resp, err := c.client.Do(req) - if err != nil { - return nil, err - } +func sendRequest(client httpDoer, req *http.Request) ([]byte, error) { + // http.Client returns a non-nil error only if it encounters an error + // caused by client policy (such as CheckRedirect), or failure to speak + // HTTP (such as a network connectivity problem). A non-2xx status code + // doesn't cause an error. + resp, err := client.Do(req) + if err != nil { + return nil, err + } - // When the http.Client returns a non-nil error, it is the - // responsibility of the caller to read the response body till an EOF is - // encountered and to close it. - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - retries++ - continue - } + // When the http.Client returns a non-nil error, it is the + // responsibility of the caller to read the response body till an EOF is + // encountered and to close it. + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } - if resp.StatusCode == http.StatusOK { - return body, nil - } - if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError { - // For 4xx errors, which are client errors, we do not retry. - return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body)) - } - retries++ - grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body)) + if resp.StatusCode == http.StatusOK { + return body, nil } + grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body)) } func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { @@ -359,12 +330,6 @@ func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { }, nil } -// RequireTransportSecurity indicates whether the credentials requires -// transport security. -func (c *callCreds) RequireTransportSecurity() bool { - return true -} - // RequestParameters stores all STS request attributes defined in // https://tools.ietf.org/html/rfc8693#section-2.1. type RequestParameters struct { diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go index 75cc6a58d936..418e2668cc83 100644 --- a/credentials/sts/sts_test.go +++ b/credentials/sts/sts_test.go @@ -33,26 +33,43 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal" - "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" ) const ( - subjectTokenContents = "subjectToken.jwt.contents" + requestedTokenType = "urn:ietf:params:oauth:token-type:access-token" + actorTokenPath = "/var/run/secrets/token.jwt" + actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token" actorTokenContents = "actorToken.jwt.contents" accessTokenContents = "access_token" + subjectTokenPath = "/var/run/secrets/token.jwt" + subjectTokenType = "urn:ietf:params:oauth:token-type:id_token" + subjectTokenContents = "subjectToken.jwt.contents" + serviceURI = "http://localhost" + exampleResource = "https://backend.example.com/api" + exampleAudience = "example-backend-service" + testScope = "https://www.googleapis.com/auth/monitoring" ) var ( goodOptions = Options{ - TokenExchangeServiceURI: "http://localhost", - SubjectTokenPath: "/var/run/secrets/token.jwt", - SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + TokenExchangeServiceURI: serviceURI, + Audience: exampleAudience, + RequestedTokenType: requestedTokenType, + SubjectTokenPath: subjectTokenPath, + SubjectTokenType: subjectTokenType, + } + goodRequestParams = &RequestParameters{ + GrantType: tokenExchangeGrantType, + Audience: exampleAudience, + Scope: defaultCloudPlatformScope, + RequestedTokenType: requestedTokenType, + SubjectToken: subjectTokenContents, + SubjectTokenType: subjectTokenType, } goodMetadata = map[string]string{ "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents), @@ -67,8 +84,8 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } -// A struct that implements AuthInfo interface and implements CommonAuthInfo() -// method. +// A struct that implements AuthInfo interface and added to the context passed +// to GetRequestMetadata from tests. type testAuthInfo struct { credentials.CommonAuthInfo } @@ -86,41 +103,6 @@ func createTestContext(ctx context.Context, s credentials.SecurityLevel) context return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) } -// fakeHTTPClient helps mock out the HTTP calls made by the credentials code -// under test. It makes the http.Request made by the credentials available -// through a channel, and makes it possible to inject various responses. -type fakeHTTPClient struct { - reqCh *testutils.Channel - // When no retry is involve, only these two fields need to be populated. - firstResp *http.Response - firstErr error - // To test retry scenarios with a different response upon retry. - subsequentResp *http.Response - subsequentErr error - - numCalls int -} - -func (fc *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { - fc.numCalls++ - fc.reqCh.Send(req) - if fc.numCalls > 1 { - return fc.subsequentResp, fc.subsequentErr - } - return fc.firstResp, fc.firstErr -} - -// fakeBackoff implements backoff.Strategy and pushes on a channel to indicate -// that a backoff was attempted. -type fakeBackoff struct { - boCh *testutils.Channel -} - -func (fb *fakeBackoff) Backoff(retries int) time.Duration { - fb.boCh.Send(retries) - return 0 -} - // errReader implements the io.Reader interface and returns an error from the // Read method. type errReader struct{} @@ -147,17 +129,44 @@ func makeGoodResponse() *http.Response { } } +// fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials +// code under test. It makes the http.Request made by the credentials available +// through a channel, and makes it possible to inject various responses. +type fakeHTTPDoer struct { + reqCh *testutils.Channel + respCh *testutils.Channel + err error +} + +func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) { + fc.reqCh.Send(req) + val, err := fc.respCh.Receive() + if err != nil { + return nil, err + } + return val.(*http.Response), fc.err +} + // Overrides the http.Client with a fakeClient which sends a good response. -func overrideHTTPClientGood() (*fakeHTTPClient, func()) { - fc := &fakeHTTPClient{ - reqCh: testutils.NewChannel(), - firstResp: makeGoodResponse(), +func overrideHTTPClientGood() (*fakeHTTPDoer, func()) { + fc := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), } + fc.respCh.Send(makeGoodResponse()) + origMakeHTTPDoer := makeHTTPDoer makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } return fc, func() { makeHTTPDoer = origMakeHTTPDoer } } +// Overrides the http.Client with the provided fakeClient. +func overrideHTTPClient(fc *fakeHTTPDoer) func() { + origMakeHTTPDoer := makeHTTPDoer + makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } + return func() { makeHTTPDoer = origMakeHTTPDoer } +} + // Overrides the subject token read to return a const which we can compare in // our tests. func overrideSubjectTokenGood() func() { @@ -168,58 +177,42 @@ func overrideSubjectTokenGood() func() { return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } } -// compareRequestWithRetry is run in a separate goroutine by tests to perform -// the following: -// - wait for a http request to be made by the credentials type and compare it -// with an expected one. -// - if the credentials is expected to retry, verify that a backoff was done -// before the retry. -// If any of the above steps fail, an error is pushed on the errCh. -func compareRequestWithRetry(errCh chan error, wantRetry bool, reqCh, boCh *testutils.Channel) { - val, err := reqCh.Receive() - if err != nil { - errCh <- err - return - } - req := val.(*http.Request) - if err := compareRequest(goodOptions, req); err != nil { - errCh <- err - return +// Overrides the subject token read to always return an error. +func overrideSubjectTokenError() func() { + origReadSubjectTokenFrom := readSubjectTokenFrom + readSubjectTokenFrom = func(path string) ([]byte, error) { + return nil, errors.New("error reading subject token") } + return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } +} - if wantRetry { - _, err := boCh.Receive() - if err != nil { - errCh <- err - return - } +// Overrides the actor token read to return a const which we can compare in +// our tests. +func overrideActorTokenGood() func() { + origReadActorTokenFrom := readActorTokenFrom + readActorTokenFrom = func(path string) ([]byte, error) { + return []byte(actorTokenContents), nil } - errCh <- nil + return func() { readActorTokenFrom = origReadActorTokenFrom } } -func compareRequest(opts Options, gotRequest *http.Request) error { - reqScope := opts.Scope - if reqScope == "" { - reqScope = defaultCloudPlatformScope +// Overrides the actor token read to always return an error. +func overrideActorTokenError() func() { + origReadActorTokenFrom := readActorTokenFrom + readActorTokenFrom = func(path string) ([]byte, error) { + return nil, errors.New("error reading actor token") } - reqParams := &RequestParameters{ - GrantType: tokenExchangeGrantType, - Resource: opts.Resource, - Audience: opts.Audience, - Scope: reqScope, - RequestedTokenType: opts.RequestedTokenType, - SubjectToken: subjectTokenContents, - SubjectTokenType: opts.SubjectTokenType, - } - if opts.ActorTokenPath != "" { - reqParams.ActorToken = actorTokenContents - reqParams.ActorTokenType = opts.ActorTokenType - } - jsonBody, err := json.Marshal(reqParams) + return func() { readActorTokenFrom = origReadActorTokenFrom } +} + +// compareRequest compares the http.Request received in the test with the +// expected RequestParameters specified in wantReqParams. +func compareRequest(gotRequest *http.Request, wantReqParams *RequestParameters) error { + jsonBody, err := json.Marshal(wantReqParams) if err != nil { return err } - wantReq, err := http.NewRequest("POST", opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody)) + wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody)) if err != nil { return fmt.Errorf("failed to create http request: %v", err) } @@ -239,6 +232,25 @@ func compareRequest(opts Options, gotRequest *http.Request) error { return nil } +// receiveAndCompareRequest waits for a request to be sent out by the +// credentials implementation using the fakeHTTPClient and compares it to an +// expected goodRequest. This is expected to be called in a separate goroutine +// by the tests. So, any errors encountered are pushed to an error channel +// which is monitored by the test. +func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) { + val, err := reqCh.Receive() + if err != nil { + errCh <- err + return + } + req := val.(*http.Request) + if err := compareRequest(req, goodRequestParams); err != nil { + errCh <- err + return + } + errCh <- nil +} + // TestGetRequestMetadataSuccess verifies the successful case of sending an // token exchange request and processing the response. func (s) TestGetRequestMetadataSuccess(t *testing.T) { @@ -252,7 +264,7 @@ func (s) TestGetRequestMetadataSuccess(t *testing.T) { } errCh := make(chan error, 1) - go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + go receiveAndCompareRequest(fc.reqCh, errCh) gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") if err != nil { @@ -283,17 +295,12 @@ func (s) TestGetRequestMetadataSuccess(t *testing.T) { // sufficient. func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) { defer overrideSubjectTokenGood()() - fc, cancel := overrideHTTPClientGood() - defer cancel() creds, err := NewCredentials(goodOptions) if err != nil { t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) } - errCh := make(chan error, 1) - go compareRequestWithRetry(errCh, false, fc.reqCh, nil) - gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "") if err == nil { t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata) @@ -306,26 +313,11 @@ func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) { func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { const expiresInSecs = 1 defer overrideSubjectTokenGood()() - respJSON, _ := json.Marshal(ResponseParameters{ - AccessToken: accessTokenContents, - IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", - TokenType: "Bearer", - ExpiresIn: expiresInSecs, - }) - respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) - resp := &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: respBody, + fc := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), } - fc := &fakeHTTPClient{ - reqCh: testutils.NewChannel(), - firstResp: resp, - subsequentResp: makeGoodResponse(), - } - origMakeHTTPDoer := makeHTTPDoer - makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } - defer func() { makeHTTPDoer = origMakeHTTPDoer }() + defer overrideHTTPClient(fc)() creds, err := NewCredentials(goodOptions) if err != nil { @@ -338,7 +330,21 @@ func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { // out a fresh request. for i := 0; i < 2; i++ { errCh := make(chan error, 1) - go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + go receiveAndCompareRequest(fc.reqCh, errCh) + + respJSON, _ := json.Marshal(ResponseParameters{ + AccessToken: accessTokenContents, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: expiresInSecs, + }) + respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) + resp := &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: respBody, + } + fc.respCh.Send(resp) gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") if err != nil { @@ -383,17 +389,11 @@ func (s) TestGetRequestMetadataBadResponses(t *testing.T) { t.Run(test.name, func(t *testing.T) { defer overrideSubjectTokenGood()() - fc := &fakeHTTPClient{ - reqCh: testutils.NewChannel(), - firstResp: &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader("not JSON")), - }, + fc := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), } - origMakeHTTPDoer := makeHTTPDoer - makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } - defer func() { makeHTTPDoer = origMakeHTTPDoer }() + defer overrideHTTPClient(fc)() creds, err := NewCredentials(goodOptions) if err != nil { @@ -401,7 +401,9 @@ func (s) TestGetRequestMetadataBadResponses(t *testing.T) { } errCh := make(chan error, 1) - go compareRequestWithRetry(errCh, false, fc.reqCh, nil) + go receiveAndCompareRequest(fc.reqCh, errCh) + + fc.respCh.Send(test.response) if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil { t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") } @@ -415,12 +417,7 @@ func (s) TestGetRequestMetadataBadResponses(t *testing.T) { // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the // attempt to read the subjectToken fails. func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { - origReadSubjectTokenFrom := readSubjectTokenFrom - readSubjectTokenFrom = func(path string) ([]byte, error) { - return nil, errors.New("failed to read subject token") - } - defer func() { readSubjectTokenFrom = origReadSubjectTokenFrom }() - + defer overrideSubjectTokenError()() fc, cancel := overrideHTTPClientGood() defer cancel() @@ -446,114 +443,6 @@ func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { } } -// TestGetRequestMetadataRetry verifies various retry scenarios. -func (s) TestGetRequestMetadataRetry(t *testing.T) { - tests := []struct { - name string - firstResp *http.Response - firstErr error - subsequentResp *http.Response - subsequentErr error - wantRetry bool - wantErr bool - wantMetadata map[string]string - }{ - { - name: "httpClient.Do error", - firstErr: errors.New("httpClient.Do() failed"), - wantErr: true, - }, - { - name: "bad response body first time", - firstResp: &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(errReader{}), - }, - subsequentResp: makeGoodResponse(), - wantRetry: true, - wantMetadata: goodMetadata, - }, - { - name: "http client error status code", - firstResp: &http.Response{ - Status: "400 BadRequest", - StatusCode: http.StatusBadRequest, - Body: ioutil.NopCloser(&bytes.Reader{}), - }, - wantErr: true, - }, - { - name: "server error first time", - firstResp: &http.Response{ - Status: "400 BadRequest", - StatusCode: http.StatusInternalServerError, - Body: ioutil.NopCloser(&bytes.Reader{}), - }, - subsequentResp: makeGoodResponse(), - wantRetry: true, - wantMetadata: goodMetadata, - }, - } - - // The test body performs the following steps: - // 1. Overrides the function to read subjectToken file and returns arbitrary - // data and nil error. - // 2. Overrides the function to return a http.Client and returns a fake - // client which is configured with response/error values to be returned. - // 3. Overrides the function to create the backoff strategy and returns a - // fake implementation which notifies the test through a channel when - // backoff is attempted. - // 4. Creates a new credentials type and invokes the GetRequestMetadata - // method on it. - // 5. Spawn a goroutine which verifies that the credentials sent out the - // expected http.Request, and performed a backoff when it encountered - // certain errors. - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - defer overrideSubjectTokenGood()() - - fc := &fakeHTTPClient{ - reqCh: testutils.NewChannel(), - firstResp: test.firstResp, - firstErr: test.firstErr, - subsequentResp: test.subsequentResp, - subsequentErr: test.subsequentErr, - } - origMakeHTTPDoer := makeHTTPDoer - makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } - - origBackoff := makeBackoffStrategy - fb := &fakeBackoff{boCh: testutils.NewChannel()} - makeBackoffStrategy = func() backoff.Strategy { return fb } - - defer func() { - makeHTTPDoer = origMakeHTTPDoer - makeBackoffStrategy = origBackoff - }() - - creds, err := NewCredentials(goodOptions) - if err != nil { - t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) - } - - errCh := make(chan error, 1) - go compareRequestWithRetry(errCh, test.wantRetry, fc.reqCh, fb.boCh) - - gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "") - if (err != nil) != test.wantErr { - t.Fatalf("creds.GetRequestMetadata() = %v, want %v", err, test.wantErr) - } - if !cmp.Equal(gotMetadata, test.wantMetadata, cmpopts.EquateEmpty()) { - t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, test.wantMetadata) - } - if err := <-errCh; err != nil { - t.Fatal(err) - } - }) - } -} - func (s) TestNewCredentials(t *testing.T) { tests := []struct { name string @@ -564,7 +453,7 @@ func (s) TestNewCredentials(t *testing.T) { { name: "invalid options - empty subjectTokenPath", opts: Options{ - TokenExchangeServiceURI: "http://localhost", + TokenExchangeServiceURI: serviceURI, }, wantErr: true, }, @@ -633,25 +522,21 @@ func (s) TestValidateOptions(t *testing.T) { { name: "empty subjectTokenPath", opts: Options{ - TokenExchangeServiceURI: "http://localhost", + TokenExchangeServiceURI: serviceURI, }, wantErrPrefix: "required field SubjectTokenPath is not specified", }, { name: "empty subjectTokenType", opts: Options{ - TokenExchangeServiceURI: "http://localhost", - SubjectTokenPath: "/var/run/secrets/token.jwt", + TokenExchangeServiceURI: serviceURI, + SubjectTokenPath: subjectTokenPath, }, wantErrPrefix: "required field SubjectTokenType is not specified", }, { name: "good options", - opts: Options{ - TokenExchangeServiceURI: "http://localhost", - SubjectTokenPath: "/var/run/secrets/token.jwt", - SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", - }, + opts: goodOptions, }, } @@ -667,3 +552,211 @@ func (s) TestValidateOptions(t *testing.T) { }) } } + +func (s) TestConstructRequest(t *testing.T) { + tests := []struct { + name string + opts Options + subjectTokenReadErr bool + actorTokenReadErr bool + wantReqParams *RequestParameters + wantErr bool + }{ + { + name: "subject token read failure", + subjectTokenReadErr: true, + opts: goodOptions, + wantErr: true, + }, + { + name: "default cloud platform scope", + opts: goodOptions, + wantReqParams: goodRequestParams, + }, + { + name: "actor token read failure", + actorTokenReadErr: true, + opts: Options{ + TokenExchangeServiceURI: serviceURI, + Audience: exampleAudience, + RequestedTokenType: requestedTokenType, + SubjectTokenPath: subjectTokenPath, + SubjectTokenType: subjectTokenType, + ActorTokenPath: actorTokenPath, + ActorTokenType: actorTokenType, + }, + wantReqParams: goodRequestParams, + }, + { + name: "all good", + opts: Options{ + TokenExchangeServiceURI: serviceURI, + Resource: exampleResource, + Audience: exampleAudience, + Scope: testScope, + RequestedTokenType: requestedTokenType, + SubjectTokenPath: subjectTokenPath, + SubjectTokenType: subjectTokenType, + ActorTokenPath: actorTokenPath, + ActorTokenType: actorTokenType, + }, + wantReqParams: &RequestParameters{ + GrantType: tokenExchangeGrantType, + Resource: exampleResource, + Audience: exampleAudience, + Scope: testScope, + RequestedTokenType: requestedTokenType, + SubjectToken: subjectTokenContents, + SubjectTokenType: subjectTokenType, + ActorToken: actorTokenContents, + ActorTokenType: actorTokenType, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.subjectTokenReadErr { + defer overrideSubjectTokenError()() + } else { + defer overrideSubjectTokenGood()() + } + + if test.actorTokenReadErr { + defer overrideActorTokenError()() + } else { + defer overrideActorTokenGood()() + } + + gotRequest, err := constructRequest(context.Background(), test.opts) + if (err != nil) != test.wantErr { + t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr) + } + if test.wantErr { + return + } + if err := compareRequest(gotRequest, test.wantReqParams); err != nil { + t.Fatal(err) + } + }) + } +} + +func (s) TestSendRequest(t *testing.T) { + defer overrideSubjectTokenGood()() + req, err := constructRequest(context.Background(), goodOptions) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + resp *http.Response + respErr error + wantErr bool + }{ + { + name: "client error", + respErr: errors.New("http.Client.Do failed"), + wantErr: true, + }, + { + name: "bad response body", + resp: &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(errReader{}), + }, + wantErr: true, + }, + { + name: "nonOK status code", + resp: &http.Response{ + Status: "400 BadRequest", + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + wantErr: true, + }, + { + name: "good case", + resp: makeGoodResponse(), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client := &fakeHTTPDoer{ + reqCh: testutils.NewChannel(), + respCh: testutils.NewChannel(), + err: test.respErr, + } + client.respCh.Send(test.resp) + _, err := sendRequest(client, req) + if (err != nil) != test.wantErr { + t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr) + } + }) + } +} + +func (s) TestTokenInfoFromResponse(t *testing.T) { + noAccessToken, _ := json.Marshal(ResponseParameters{ + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + goodResponse, _ := json.Marshal(ResponseParameters{ + IssuedTokenType: requestedTokenType, + AccessToken: accessTokenContents, + TokenType: "Bearer", + ExpiresIn: 3600, + }) + + tests := []struct { + name string + respBody []byte + wantTokenInfo *tokenInfo + wantErr bool + }{ + { + name: "bad JSON", + respBody: []byte("not JSON"), + wantErr: true, + }, + { + name: "empty response", + respBody: []byte(""), + wantErr: true, + }, + { + name: "non-empty response with no access token", + respBody: noAccessToken, + wantErr: true, + }, + { + name: "good response", + respBody: goodResponse, + wantTokenInfo: &tokenInfo{ + tokenType: "Bearer", + token: accessTokenContents, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotTokenInfo, err := tokenInfoFromResponse(test.respBody) + if (err != nil) != test.wantErr { + t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr) + } + if test.wantErr { + return + } + // Can't do a cmp.Equal on the whole struct since the expiryField + // is populated based on time.Now(). + if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token { + t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo) + } + }) + } +} From e02e28ea9de501cd537e816d361b9bbe4e59f5d3 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Wed, 8 Jul 2020 15:34:32 -0700 Subject: [PATCH 5/7] http.NewRequestWithContext was only added in go1.13 --- credentials/sts/sts.go | 2 ++ credentials/sts/sts_test.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index 60ef89778d32..4c361d46e296 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -1,3 +1,5 @@ +// +build go1.13 + /* * * Copyright 2020 gRPC authors. diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go index 418e2668cc83..c961c6563be2 100644 --- a/credentials/sts/sts_test.go +++ b/credentials/sts/sts_test.go @@ -1,3 +1,5 @@ +// +build go1.13 + /* * * Copyright 2020 gRPC authors. From 2482c17d4481e99cda7a5bcd8025cb883ee3d0ba Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 9 Jul 2020 08:22:37 -0700 Subject: [PATCH 6/7] Review comments. - Unexport [Request/Response]Parameters structs. - Hold lock for entire duration of sts request. - Don't ignore errors on attempts to read actor token. --- credentials/sts/sts.go | 47 ++++++++++++++++++++----------------- credentials/sts/sts_test.go | 28 +++++++++++----------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index 4c361d46e296..c63eddb2228c 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -148,6 +148,12 @@ func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[st if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err) } + + // Holding the lock for the whole duration of the STS request and response + // processing ensures that concurrent RPCs don't end up in multiple + // requests being made. + c.mu.Lock() + defer c.mu.Unlock() if md := c.metadataFromCachedToken(); md != nil { return md, nil } @@ -163,13 +169,8 @@ func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[st if err != nil { return nil, err } - c.mu.Lock() c.cachedToken = ti - c.mu.Unlock() - - return map[string]string{ - "Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token), - }, nil + return tokenInfoToMetadata(ti), nil } // RequireTransportSecurity indicates whether the credentials requires @@ -222,10 +223,9 @@ func validateOptions(opts Options) error { // metadataFromCachedToken returns the cached accessToken as request metadata, // provided a cached accessToken exists and is not going to expire anytime soon. +// +// Caller must hold c.mu. func (c *callCreds) metadataFromCachedToken() map[string]string { - c.mu.Lock() - defer c.mu.Unlock() - if c.cachedToken == nil { return nil } @@ -235,9 +235,7 @@ func (c *callCreds) metadataFromCachedToken() map[string]string { // token is greater than the minimum value we are willing to accept, go // ahead and use it. if c.cachedToken.expiryTime.After(now) && c.cachedToken.expiryTime.Sub(now) > minCachedTokenLifetime { - return map[string]string{ - "Authorization": fmt.Sprintf("%s %s", c.cachedToken.tokenType, c.cachedToken.token), - } + return tokenInfoToMetadata(c.cachedToken) } return nil } @@ -263,7 +261,7 @@ func constructRequest(ctx context.Context, opts Options) (*http.Request, error) if reqScope == "" { reqScope = defaultCloudPlatformScope } - reqParams := &RequestParameters{ + reqParams := &requestParameters{ GrantType: tokenExchangeGrantType, Resource: opts.Resource, Audience: opts.Audience, @@ -274,10 +272,11 @@ func constructRequest(ctx context.Context, opts Options) (*http.Request, error) } if opts.ActorTokenPath != "" { actorToken, err := readActorTokenFrom(opts.ActorTokenPath) - if err == nil { - reqParams.ActorToken = string(actorToken) - reqParams.ActorTokenType = opts.ActorTokenType + if err != nil { + return nil, err } + reqParams.ActorToken = string(actorToken) + reqParams.ActorTokenType = opts.ActorTokenType } jsonBody, err := json.Marshal(reqParams) if err != nil { @@ -318,7 +317,7 @@ func sendRequest(client httpDoer, req *http.Request) ([]byte, error) { } func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { - respData := &ResponseParameters{} + respData := &responseParameters{} if err := json.Unmarshal(respBody, respData); err != nil { return nil, fmt.Errorf("json.Unmarshal(%v): %v", respBody, err) } @@ -332,9 +331,15 @@ func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { }, nil } -// RequestParameters stores all STS request attributes defined in +func tokenInfoToMetadata(ti *tokenInfo) map[string]string { + return map[string]string{ + "Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token), + } +} + +// requestParameters stores all STS request attributes defined in // https://tools.ietf.org/html/rfc8693#section-2.1. -type RequestParameters struct { +type requestParameters struct { // REQUIRED. The value "urn:ietf:params:oauth:grant-type:token-exchange" // indicates that a token exchange is being performed. GrantType string `json:"grant_type"` @@ -364,10 +369,10 @@ type RequestParameters struct { ActorTokenType string `json:"actor_token_type,omitempty"` } -// ResponseParameters stores all attributes sent as JSON in a successful STS +// nesponseParameters stores all attributes sent as JSON in a successful STS // response. These attributes are defined in // https://tools.ietf.org/html/rfc8693#section-2.2.1. -type ResponseParameters struct { +type responseParameters struct { // REQUIRED. The security token issued by the authorization server // in response to the token exchange request. AccessToken string `json:"access_token"` diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go index c961c6563be2..641bad5820bb 100644 --- a/credentials/sts/sts_test.go +++ b/credentials/sts/sts_test.go @@ -65,7 +65,7 @@ var ( SubjectTokenPath: subjectTokenPath, SubjectTokenType: subjectTokenType, } - goodRequestParams = &RequestParameters{ + goodRequestParams = &requestParameters{ GrantType: tokenExchangeGrantType, Audience: exampleAudience, Scope: defaultCloudPlatformScope, @@ -117,7 +117,7 @@ func (r errReader) Read(b []byte) (n int, err error) { // as a variable since the the response body will be consumed by the // credentials, and therefore we will need a new one everytime. func makeGoodResponse() *http.Response { - respJSON, _ := json.Marshal(ResponseParameters{ + respJSON, _ := json.Marshal(responseParameters{ AccessToken: accessTokenContents, IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", TokenType: "Bearer", @@ -208,8 +208,8 @@ func overrideActorTokenError() func() { } // compareRequest compares the http.Request received in the test with the -// expected RequestParameters specified in wantReqParams. -func compareRequest(gotRequest *http.Request, wantReqParams *RequestParameters) error { +// expected requestParameters specified in wantReqParams. +func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error { jsonBody, err := json.Marshal(wantReqParams) if err != nil { return err @@ -334,7 +334,7 @@ func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { errCh := make(chan error, 1) go receiveAndCompareRequest(fc.reqCh, errCh) - respJSON, _ := json.Marshal(ResponseParameters{ + respJSON, _ := json.Marshal(responseParameters{ AccessToken: accessTokenContents, IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", TokenType: "Bearer", @@ -561,7 +561,7 @@ func (s) TestConstructRequest(t *testing.T) { opts Options subjectTokenReadErr bool actorTokenReadErr bool - wantReqParams *RequestParameters + wantReqParams *requestParameters wantErr bool }{ { @@ -570,11 +570,6 @@ func (s) TestConstructRequest(t *testing.T) { opts: goodOptions, wantErr: true, }, - { - name: "default cloud platform scope", - opts: goodOptions, - wantReqParams: goodRequestParams, - }, { name: "actor token read failure", actorTokenReadErr: true, @@ -587,6 +582,11 @@ func (s) TestConstructRequest(t *testing.T) { ActorTokenPath: actorTokenPath, ActorTokenType: actorTokenType, }, + wantErr: true, + }, + { + name: "default cloud platform scope", + opts: goodOptions, wantReqParams: goodRequestParams, }, { @@ -602,7 +602,7 @@ func (s) TestConstructRequest(t *testing.T) { ActorTokenPath: actorTokenPath, ActorTokenType: actorTokenType, }, - wantReqParams: &RequestParameters{ + wantReqParams: &requestParameters{ GrantType: tokenExchangeGrantType, Resource: exampleResource, Audience: exampleAudience, @@ -702,12 +702,12 @@ func (s) TestSendRequest(t *testing.T) { } func (s) TestTokenInfoFromResponse(t *testing.T) { - noAccessToken, _ := json.Marshal(ResponseParameters{ + noAccessToken, _ := json.Marshal(responseParameters{ IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", TokenType: "Bearer", ExpiresIn: 3600, }) - goodResponse, _ := json.Marshal(ResponseParameters{ + goodResponse, _ := json.Marshal(responseParameters{ IssuedTokenType: requestedTokenType, AccessToken: accessTokenContents, TokenType: "Bearer", From 43780431177fe821c4bd8672867577c23f0f67c8 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 9 Jul 2020 11:00:37 -0700 Subject: [PATCH 7/7] Cache computed metadata and expiry time instead of raw token. --- credentials/sts/sts.go | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index c63eddb2228c..f07c4c402e76 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -138,8 +138,9 @@ type callCreds struct { // Cached accessToken to avoid an STS token exchange for every call to // GetRequestMetadata. - mu sync.Mutex - cachedToken *tokenInfo + mu sync.Mutex + tokenMetadata map[string]string + tokenExpiry time.Time } // GetRequestMetadata returns the cached accessToken, if available and valid, or @@ -154,7 +155,8 @@ func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[st // requests being made. c.mu.Lock() defer c.mu.Unlock() - if md := c.metadataFromCachedToken(); md != nil { + + if md := c.cachedMetadata(); md != nil { return md, nil } req, err := constructRequest(ctx, c.opts) @@ -169,8 +171,9 @@ func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[st if err != nil { return nil, err } - c.cachedToken = ti - return tokenInfoToMetadata(ti), nil + c.tokenMetadata = map[string]string{"Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token)} + c.tokenExpiry = ti.expiryTime + return c.tokenMetadata, nil } // RequireTransportSecurity indicates whether the credentials requires @@ -221,21 +224,17 @@ func validateOptions(opts Options) error { return nil } -// metadataFromCachedToken returns the cached accessToken as request metadata, -// provided a cached accessToken exists and is not going to expire anytime soon. +// cachedMetadata returns the cached metadata provided it is not going to +// expire anytime soon. // // Caller must hold c.mu. -func (c *callCreds) metadataFromCachedToken() map[string]string { - if c.cachedToken == nil { - return nil - } - +func (c *callCreds) cachedMetadata() map[string]string { now := time.Now() // If the cached token has not expired and the lifetime remaining on that // token is greater than the minimum value we are willing to accept, go // ahead and use it. - if c.cachedToken.expiryTime.After(now) && c.cachedToken.expiryTime.Sub(now) > minCachedTokenLifetime { - return tokenInfoToMetadata(c.cachedToken) + if c.tokenExpiry.After(now) && c.tokenExpiry.Sub(now) > minCachedTokenLifetime { + return c.tokenMetadata } return nil } @@ -331,12 +330,6 @@ func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) { }, nil } -func tokenInfoToMetadata(ti *tokenInfo) map[string]string { - return map[string]string{ - "Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token), - } -} - // requestParameters stores all STS request attributes defined in // https://tools.ietf.org/html/rfc8693#section-2.1. type requestParameters struct {