Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

meshca: credentials/sts: PerRPCCreds Implementation #3696

Merged
merged 7 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 89 additions & 119 deletions credentials/sts/sts.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// +build go1.13

/*
*
* Copyright 2020 gRPC authors.
Expand Down Expand Up @@ -39,10 +41,8 @@ import (
"sync"
"time"

grpcbackoff "google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff"
)

const (
Expand All @@ -59,7 +59,6 @@ const (
// For overriding in tests.
var (
loadSystemCertPool = x509.SystemCertPool
makeBackoffStrategy = defaultBackoffStrategy
makeHTTPDoer = makeHTTPClient
readSubjectTokenFrom = ioutil.ReadFile
readActorTokenFrom = ioutil.ReadFile
Expand Down Expand Up @@ -128,17 +127,55 @@ func NewCredentials(opts Options) (credentials.PerRPCCredentials, error) {
return &callCreds{
opts: opts,
client: makeHTTPDoer(roots),
bs: makeBackoffStrategy(),
}, nil
}

// defaultBackoffStrategy returns an exponential backoff strategy based on the
// default exponential backoff config, but with a smaller BaseDelay (because the
// former is meant more for connection establishment).
func defaultBackoffStrategy() backoff.Strategy {
config := grpcbackoff.DefaultConfig
config.BaseDelay = 100 * time.Millisecond
return backoff.Exponential{Config: config}
// callCreds provides the implementation of call credentials based on an STS
// token exchange.
type callCreds struct {
opts Options
client httpDoer

// Cached accessToken to avoid an STS token exchange for every call to
// GetRequestMetadata.
mu sync.Mutex
cachedToken *tokenInfo
}

// GetRequestMetadata returns the cached accessToken, if available and valid, or
// fetches a new one by performing an STS token exchange.
func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err)
}
if md := c.metadataFromCachedToken(); md != nil {
easwars marked this conversation as resolved.
Show resolved Hide resolved
return md, nil
}
req, err := constructRequest(ctx, c.opts)
dfawley marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}
respBody, err := sendRequest(c.client, req)
if err != nil {
return nil, err
}
ti, err := tokenInfoFromResponse(respBody)
if err != nil {
return nil, err
}
c.mu.Lock()
c.cachedToken = ti
c.mu.Unlock()

return map[string]string{
"Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token),
}, nil
}

// RequireTransportSecurity indicates whether the credentials requires
// transport security.
func (c *callCreds) RequireTransportSecurity() bool {
return true
}

// httpDoer wraps the single method on the http.Client type that we use. This
Expand Down Expand Up @@ -183,51 +220,6 @@ func validateOptions(opts Options) error {
return nil
}

// callCreds provides the implementation of call credentials based on an STS
// token exchange.
type callCreds struct {
opts Options
client httpDoer
bs backoff.Strategy

// Cached accessToken to avoid an STS token exchange for every call to
// GetRequestMetadata.
mu sync.Mutex
cachedToken *tokenInfo
}

// GetRequestMetadata returns the cached accessToken, if available and valid, or
// fetches a new one by performing an STS token exchange.
func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
if md := c.metadataFromCachedToken(); md != nil {
return md, nil
}
req, err := c.constructRequest()
if err != nil {
return nil, err
}

// Send the request with exponential backoff and retry. Even though not
// retrying here is OK, as the connection attempt will be retried by the
// subConn, it is not very hard to perform some basic retries here.
respBody, err := c.sendRequestWithRetry(ctx, req)
if err != nil {
return nil, err
}

ti, err := tokenInfoFromResponse(respBody)
if err != nil {
return nil, err
}
c.mu.Lock()
c.cachedToken = ti
c.mu.Unlock()

return map[string]string{
"Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token),
}, nil
}

// metadataFromCachedToken returns the cached accessToken as request metadata,
// provided a cached accessToken exists and is not going to expire anytime soon.
func (c *callCreds) metadataFromCachedToken() map[string]string {
Expand All @@ -250,95 +242,79 @@ func (c *callCreds) metadataFromCachedToken() map[string]string {
return nil
}

// constructRequest creates the STS request body in JSON based on the options
// received when the credentials type was created.
// constructRequest creates the STS request body in JSON based on the provided
// options.
// - Contents of the subjectToken are read from the file specified in
// options. If we encounter an error here, we bail out.
// - Contents of the actorToken are read from the file specified in options.
// If we encounter an error here, we ignore this field because this is
// optional.
// - Most of the other fields in the request come directly from options.
func (c *callCreds) constructRequest() (*http.Request, error) {
subToken, err := readSubjectTokenFrom(c.opts.SubjectTokenPath)
//
// A new HTTP request is created by calling http.NewRequestWithContext() and
// passing the provided context, thereby enforcing any timeouts specified in
// the latter.
func constructRequest(ctx context.Context, opts Options) (*http.Request, error) {
subToken, err := readSubjectTokenFrom(opts.SubjectTokenPath)
if err != nil {
return nil, err
}
reqScope := c.opts.Scope
reqScope := opts.Scope
if reqScope == "" {
reqScope = defaultCloudPlatformScope
}
reqParams := &RequestParameters{
GrantType: tokenExchangeGrantType,
Resource: c.opts.Resource,
Audience: c.opts.Audience,
Resource: opts.Resource,
Audience: opts.Audience,
Scope: reqScope,
RequestedTokenType: c.opts.RequestedTokenType,
RequestedTokenType: opts.RequestedTokenType,
SubjectToken: string(subToken),
SubjectTokenType: c.opts.SubjectTokenType,
SubjectTokenType: opts.SubjectTokenType,
}
actorToken, err := readActorTokenFrom(c.opts.ActorTokenPath)
if err == nil {
reqParams.ActorToken = string(actorToken)
reqParams.ActorTokenType = c.opts.ActorTokenType
if opts.ActorTokenPath != "" {
actorToken, err := readActorTokenFrom(opts.ActorTokenPath)
if err == nil {
dfawley marked this conversation as resolved.
Show resolved Hide resolved
reqParams.ActorToken = string(actorToken)
reqParams.ActorTokenType = opts.ActorTokenType
}
}
jsonBody, err := json.Marshal(reqParams)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", c.opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create http request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
return req, nil
}

// sendRequestWithRetry sends the provided http.Request and retries with
// exponential backoff for certain types of errors. It takes care of closing the
// response body on success (returns the contents of the body) and retries.
func (c *callCreds) sendRequestWithRetry(ctx context.Context, req *http.Request) ([]byte, error) {
retries := 0
for {
// Wait for one of exponential backoff or context deadline to expire.
if retries > 0 {
timer := time.NewTimer(c.bs.Backoff(retries))
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}

// http.Client returns a non-nil error only if it encounters an error
// caused by client policy (such as CheckRedirect), or failure to speak
// HTTP (such as a network connectivity problem). A non-2xx status code
// doesn't cause an error.
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
func sendRequest(client httpDoer, req *http.Request) ([]byte, error) {
// http.Client returns a non-nil error only if it encounters an error
// caused by client policy (such as CheckRedirect), or failure to speak
// HTTP (such as a network connectivity problem). A non-2xx status code
// doesn't cause an error.
resp, err := client.Do(req)
if err != nil {
return nil, err
}

// When the http.Client returns a non-nil error, it is the
// responsibility of the caller to read the response body till an EOF is
// encountered and to close it.
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
retries++
continue
}
// When the http.Client returns a non-nil error, it is the
// responsibility of the caller to read the response body till an EOF is
// encountered and to close it.
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, err
}

if resp.StatusCode == http.StatusOK {
return body, nil
}
if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError {
// For 4xx errors, which are client errors, we do not retry.
return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body))
}
retries++
grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusOK {
return body, nil
}
grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body))
}

func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) {
Expand All @@ -356,12 +332,6 @@ func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) {
}, nil
}

// RequireTransportSecurity indicates whether the credentials requires
// transport security.
func (c *callCreds) RequireTransportSecurity() bool {
return true
}

// RequestParameters stores all STS request attributes defined in
// https://tools.ietf.org/html/rfc8693#section-2.1.
type RequestParameters struct {
Expand Down
Loading