diff --git a/.generator/templates/client.mustache b/.generator/templates/client.mustache index 07ce6426d..b345b3ad9 100644 --- a/.generator/templates/client.mustache +++ b/.generator/templates/client.mustache @@ -69,6 +69,254 @@ type service struct { client *APIClient } +type Authorization interface { + Authorize() error +} + +type SSWSAuth struct { + token string + req *http.Request +} + +func NewSSWSAuth(token string, req *http.Request) *SSWSAuth { + return &SSWSAuth{token: token, req: req} +} + +func (a *SSWSAuth) Authorize() error { + a.req.Header.Add("Authorization", "SSWS "+a.token) + return nil +} + +type BearerAuth struct { + token string + req *http.Request +} + +func NewBearerAuth(token string, req *http.Request) *BearerAuth { + return &BearerAuth{token: token, req: req} +} + +func (a *BearerAuth) Authorize() error { + a.req.Header.Add("Authorization", "Bearer "+a.token) + return nil +} + +type PrivateKeyAuth struct { + tokenCache *goCache.Cache + httpClient *http.Client + privateKeySigner jose.Signer + privateKey string + privateKeyId string + clientId string + orgURL string + scopes []string + maxRetries int32 + maxBackoff int64 + req *http.Request +} + +type PrivateKeyAuthConfig struct { + TokenCache *goCache.Cache + HttpClient *http.Client + PrivateKeySigner jose.Signer + PrivateKey string + PrivateKeyId string + ClientId string + OrgURL string + Scopes []string + MaxRetries int32 + MaxBackoff int64 + Req *http.Request +} + +func NewPrivateKeyAuth(config PrivateKeyAuthConfig) *PrivateKeyAuth { + return &PrivateKeyAuth{ + tokenCache: config.TokenCache, + httpClient: config.HttpClient, + privateKeySigner: config.PrivateKeySigner, + privateKey: config.PrivateKey, + privateKeyId: config.PrivateKeyId, + clientId: config.ClientId, + orgURL: config.OrgURL, + scopes: config.Scopes, + maxRetries: config.MaxRetries, + maxBackoff: config.MaxBackoff, + req: config.Req, + } +} + +func (a *PrivateKeyAuth) Authorize() error { + accessToken, hasToken := a.tokenCache.Get(AccessTokenCacheKey) + if hasToken { + a.req.Header.Add("Authorization", "Bearer "+accessToken.(string)) + } else { + if a.privateKeySigner == nil { + var err error + a.privateKeySigner, err = createKeySigner(a.privateKey, a.privateKeyId) + if err != nil { + return err + } + } + + clientAssertion, err := createClientAssertion(a.orgURL, a.clientId, a.privateKeySigner) + if err != nil { + return err + } + + accessToken, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.scopes, a.maxRetries, a.maxBackoff) + if err != nil { + return err + } + + a.req.Header.Set("Authorization", "Bearer "+accessToken.AccessToken) + + // Trim a couple of seconds off calculated expiry so cache expiry + // occures before Okta server side expiry. + expiration := accessToken.ExpiresIn - 2 + a.tokenCache.Set(AccessTokenCacheKey, accessToken.AccessToken, time.Second*time.Duration(expiration)) + } + return nil +} + +type JWTAuth struct { + tokenCache *goCache.Cache + httpClient *http.Client + orgURL string + scopes []string + clientAssertion string + maxRetries int32 + maxBackoff int64 + req *http.Request +} + +type JWTAuthConfig struct { + TokenCache *goCache.Cache + HttpClient *http.Client + OrgURL string + Scopes []string + ClientAssertion string + MaxRetries int32 + MaxBackoff int64 + Req *http.Request +} + +func NewJWTAuth(config JWTAuthConfig) *JWTAuth { + return &JWTAuth{ + tokenCache: config.TokenCache, + httpClient: config.HttpClient, + orgURL: config.OrgURL, + scopes: config.Scopes, + clientAssertion: config.ClientAssertion, + maxRetries: config.MaxRetries, + maxBackoff: config.MaxBackoff, + req: config.Req, + } +} + +func (a *JWTAuth) Authorize() error { + accessToken, hasToken := a.tokenCache.Get(AccessTokenCacheKey) + if hasToken { + a.req.Header.Add("Authorization", "Bearer "+accessToken.(string)) + } else { + accessToken, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.scopes, a.maxRetries, a.maxBackoff) + if err != nil { + return err + } + a.req.Header.Set("Authorization", "Bearer "+accessToken.AccessToken) + + // Trim a couple of seconds off calculated expiry so cache expiry + // occures before Okta server side expiry. + expiration := accessToken.ExpiresIn - 2 + a.tokenCache.Set(AccessTokenCacheKey, accessToken.AccessToken, time.Second*time.Duration(expiration)) + } + return nil +} + +func createKeySigner(privateKey, privateKeyID string) (jose.Signer, error) { + priv := []byte(strings.ReplaceAll(privateKey, `\n`, "\n")) + + privPem, _ := pem.Decode(priv) + if privPem == nil { + return nil, errors.New("invalid private key") + } + if privPem.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("RSA private key is of the wrong type") + } + + parsedKey, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) + if err != nil { + return nil, err + } + + var signerOptions *jose.SignerOptions + if privateKeyID != "" { + signerOptions = (&jose.SignerOptions{}).WithHeader("kid", privateKeyID) + } + + return jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: parsedKey}, signerOptions) +} + +func createClientAssertion(orgURL, clientID string, privateKeySinger jose.Signer) (clientAssertion string, err error) { + claims := ClientAssertionClaims{ + Subject: clientID, + IssuedAt: jwt.NewNumericDate(time.Now()), + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))), + Issuer: clientID, + Audience: orgURL + "/oauth2/v1/token", + } + jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims) + return jwtBuilder.CompactSerialize() +} + +func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion string, scopes []string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, error) { + var tokenRequestBuff io.ReadWriter + query := url.Values{} + tokenRequestURL := orgURL + "/oauth2/v1/token" + + query.Add("grant_type", "client_credentials") + query.Add("scope", strings.Join(scopes, " ")) + query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + query.Add("client_assertion", clientAssertion) + tokenRequestURL += "?" + query.Encode() + tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) + if err != nil { + return nil, err + } + + tokenRequest.Header.Add("Accept", "application/json") + tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + bOff := &oktaBackoff{ + ctx: context.TODO(), + maxRetries: maxRetries, + backoffDuration: time.Duration(maxBackoff), + } + var tokenResponse *http.Response + operation := func() error { + tokenResponse, err = httpClient.Do(tokenRequest) + bOff.retryCount++ + return err + } + err = backoff.Retry(operation, bOff) + if err != nil { + return nil, err + } + + respBody, err := io.ReadAll(tokenResponse.Body) + if err != nil { + return nil, err + } + origResp := io.NopCloser(bytes.NewBuffer(respBody)) + tokenResponse.Body = origResp + var accessToken *RequestAccessToken + + _, err = buildResponse(tokenResponse, nil, &accessToken) + if err != nil { + return nil, err + } + return accessToken, nil +} + // NewAPIClient creates a new API client. Requires a userAgent string describing your application. // optionally a custom http.Client to allow for advanced features such as caching. func NewAPIClient(cfg *Configuration) *APIClient { @@ -423,110 +671,43 @@ func (c *APIClient) prepareRequest( } // This will override the auth in context - if c.cfg.Okta.Client.AuthorizationMode == "SSWS" { - localVarRequest.Header.Set("Authorization", "SSWS "+c.cfg.Okta.Client.Token) - } - - if c.cfg.Okta.Client.AuthorizationMode == "Bearer" { - localVarRequest.Header.Set("Authorization", "Bearer "+c.cfg.Okta.Client.Token) + var auth Authorization + switch c.cfg.Okta.Client.AuthorizationMode { + case "SSWS": + auth = NewSSWSAuth(c.cfg.Okta.Client.Token, localVarRequest) + case "Bearer": + auth = NewBearerAuth(c.cfg.Okta.Client.Token, localVarRequest) + case "PrivateKey": + auth = NewPrivateKeyAuth(PrivateKeyAuthConfig{ + TokenCache: c.tokenCache, + HttpClient: c.cfg.HTTPClient, + PrivateKeySigner: c.cfg.PrivateKeySigner, + PrivateKey: c.cfg.Okta.Client.PrivateKey, + PrivateKeyId: c.cfg.Okta.Client.PrivateKeyId, + ClientId: c.cfg.Okta.Client.ClientId, + OrgURL: c.cfg.Okta.Client.OrgUrl, + Scopes: c.cfg.Okta.Client.Scopes, + MaxRetries: c.cfg.Okta.Client.RateLimit.MaxRetries, + MaxBackoff: c.cfg.Okta.Client.RateLimit.MaxBackoff, + Req: localVarRequest, + }) + case "JWT": + auth = NewJWTAuth(JWTAuthConfig{ + TokenCache: c.tokenCache, + HttpClient: c.cfg.HTTPClient, + OrgURL: c.cfg.Okta.Client.OrgUrl, + Scopes: c.cfg.Okta.Client.Scopes, + ClientAssertion: c.cfg.Okta.Client.ClientAssertion, + MaxRetries: c.cfg.Okta.Client.RateLimit.MaxRetries, + MaxBackoff: c.cfg.Okta.Client.RateLimit.MaxBackoff, + Req: localVarRequest, + }) + default: + return nil, fmt.Errorf("unknown authorization mode %v", c.cfg.Okta.Client.AuthorizationMode) } - - if c.cfg.Okta.Client.AuthorizationMode == "PrivateKey" { - cachedToken, hasToken := c.tokenCache.Get(AccessTokenCacheKey) - if hasToken { - localVarRequest.Header.Set("Authorization", "Bearer "+cachedToken.(string)) - } else { - if c.cfg.PrivateKeySigner == nil { - priv := []byte(strings.ReplaceAll(c.cfg.Okta.Client.PrivateKey, `\n`, "\n")) - - privPem, _ := pem.Decode(priv) - if privPem == nil { - return nil, errors.New("invalid private key") - } - if privPem.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("RSA private key is of the wrong type") - } - - parsedKey, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) - if err != nil { - return nil, err - } - - var signerOptions *jose.SignerOptions - if c.cfg.Okta.Client.PrivateKeyId != "" { - signerOptions = (&jose.SignerOptions{}).WithHeader("kid", c.cfg.Okta.Client.PrivateKeyId) - } - - c.cfg.PrivateKeySigner, err = jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: parsedKey}, signerOptions) - if err != nil { - return nil, err - } - } - - claims := ClientAssertionClaims{ - Subject: c.cfg.Okta.Client.ClientId, - IssuedAt: jwt.NewNumericDate(time.Now()), - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))), - Issuer: c.cfg.Okta.Client.ClientId, - Audience: c.cfg.Okta.Client.OrgUrl + "/oauth2/v1/token", - } - - jwtBuilder := jwt.Signed(c.cfg.PrivateKeySigner).Claims(claims) - clientAssertion, err := jwtBuilder.CompactSerialize() - if err != nil { - return nil, err - } - - var tokenRequestBuff io.ReadWriter - query := url.Values{} - tokenRequestURL := c.cfg.Okta.Client.OrgUrl + "/oauth2/v1/token" - - query.Add("grant_type", "client_credentials") - query.Add("scope", strings.Join(c.cfg.Okta.Client.Scopes, " ")) - query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") - query.Add("client_assertion", clientAssertion) - tokenRequestURL += "?" + query.Encode() - tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) - if err != nil { - return nil, err - } - - tokenRequest.Header.Add("Accept", "application/json") - tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - bOff := &oktaBackoff{ - ctx: context.TODO(), - maxRetries: c.cfg.Okta.Client.RateLimit.MaxRetries, - backoffDuration: time.Duration(c.cfg.Okta.Client.RateLimit.MaxBackoff), - } - var tokenResponse *http.Response - operation := func() error { - tokenResponse, err = c.cfg.HTTPClient.Do(tokenRequest) - bOff.retryCount++ - return err - } - err = backoff.Retry(operation, bOff) - if err != nil { - return nil, err - } - - respBody, err := io.ReadAll(tokenResponse.Body) - if err != nil { - return nil, err - } - origResp := io.NopCloser(bytes.NewBuffer(respBody)) - tokenResponse.Body = origResp - var accessToken *RequestAccessToken - - _, err = buildResponse(tokenResponse, nil, &accessToken) - if err != nil { - return nil, err - } - localVarRequest.Header.Set("Authorization", "Bearer "+accessToken.AccessToken) - - expiration := accessToken.ExpiresIn - 2 - c.tokenCache.Set(AccessTokenCacheKey, accessToken.AccessToken, time.Second*time.Duration(expiration)) - } + err = auth.Authorize() + if err != nil { + return nil, err } for header, value := range c.cfg.DefaultHeader { diff --git a/.generator/templates/configuration.mustache b/.generator/templates/configuration.mustache index 97beb7adc..13ec3b729 100644 --- a/.generator/templates/configuration.mustache +++ b/.generator/templates/configuration.mustache @@ -149,6 +149,7 @@ type Configuration struct { Token string `yaml:"token" envconfig:"OKTA_CLIENT_TOKEN"` AuthorizationMode string `yaml:"authorizationMode" envconfig:"OKTA_CLIENT_AUTHORIZATIONMODE"` ClientId string `yaml:"clientId" envconfig:"OKTA_CLIENT_CLIENTID"` + ClientAssertion string `yaml:"clientAssertion" envconfig:"OKTA_CLIENT_CLIENTASSERTION"` Scopes []string `yaml:"scopes" envconfig:"OKTA_CLIENT_SCOPES"` PrivateKey string `yaml:"privateKey" envconfig:"OKTA_CLIENT_PRIVATEKEY"` PrivateKeyId string `yaml:"privateKeyId" envconfig:"OKTA_CLIENT_PRIVATEKEYID"` diff --git a/.generator/templates/private_key_test.go b/.generator/templates/private_key_test.go index 29cc1e7e7..a4c11a065 100644 --- a/.generator/templates/private_key_test.go +++ b/.generator/templates/private_key_test.go @@ -21,3 +21,22 @@ func Test_Private_Key_Request_Can_Create_User(t *testing.T) { require.NoError(t, err, "Creating a new user should not error") assert.NotNil(t, user, "User should not be nil") } + +func Test_JWT_Request_Can_Create_User(t *testing.T) { + if os.Getenv("OKTA_TRAVIS_CI") != "yes" { + t.Skip("Skipping testing not in CI environment") + } + configuration := NewConfiguration(WithAuthorizationMode("JWT"), WithScopes([]string{"okta.users.manage"})) + privateKeySigner, err := createKeySigner(configuration.Okta.Client.PrivateKey, configuration.Okta.Client.PrivateKeyId) + require.NoError(t, err) + clientAssertion, err := createClientAssertion(configuration.Okta.Client.OrgUrl, configuration.Okta.Client.ClientId, privateKeySigner) + require.NoError(t, err) + configuration.Okta.Client.ClientAssertion = clientAssertion + client := NewAPIClient(configuration) + uc := testFactory.NewValidTestUserCredentialsWithPassword() + profile := testFactory.NewValidTestUserProfile() + body := CreateUserRequest{Credentials: uc, Profile: profile} + user, _, err := client.UserApi.CreateUser(apiClient.cfg.Context).Body(body).Execute() + require.NoError(t, err, "Creating a new user should not error") + assert.NotNil(t, user, "User should not be nil") +} diff --git a/README.md b/README.md index f37a3a3c3..098a65a07 100644 --- a/README.md +++ b/README.md @@ -859,8 +859,9 @@ The client is configured with a configuration setter object passed to the `NewCl | WithRequestTimeout(requestTimeout int64) | HTTP request time out in seconds | | WithRateLimitMaxRetries(maxRetries int32) | Number of request retries when http request times out | | WithRateLimitMaxBackOff(maxBackoff int64) | Max amount of time to wait on request back off | -| WithAuthorizationMode(authzMode string) | Okta API auth mode, `SSWS` (Okta based) or `PrivateKey` (OAuth app based) | +| WithAuthorizationMode(authzMode string) | Okta API auth mode, `SSWS` (Okta based), `PrivateKey` (OAuth app based) or `JWT` (OAuth app based) | | WithClientId(clientId string) | Okta App client id, used with `PrivateKey` OAuth auth mode | +| WithClientAssertion(clientAssertion string) | Okta App client assertion, used with `JWT` OAuth auth mode | | WithScopes(scopes []string) | Okta API app scopes | | WithPrivateKey(privateKey string) | Private key value | | WithPrivateKeyId(privateKeyId string) | Private key id (kid) value | @@ -1055,6 +1056,43 @@ ctx, client, err := okta.NewClient(ctx, ``` +### OAuth 2.0 With JWT Key +Okta allows you to interact with Okta APIs using scoped OAuth 2.0 access +tokens. Each access token enables the bearer to perform specific actions on +specific Okta endpoints, with that ability controlled by which scopes the +access token contains. + +Access Tokens are always cached and respect the `expires_in` value of an access +token response. + +This SDK supports this feature only for service-to-service applications. Check +out [our +guides](https://developer.okta.com/docs/guides/implement-oauth-for-okta/overview/) +to learn more about how to register a new service application using a private +and public key pair. Otherwise, follow the example steps at the end of this +topic. + +When using this approach you won't need an API Token because the SDK will +request an access token for you. In order to use OAuth 2.0, construct a client +instance by passing the following parameters: + +```go +ctx := context.TODO() +ctx, client, err := okta.NewClient(ctx, + okta.WithOrgUrl("https://{yourOktaDomain}"), + okta.WithAuthorizationMode("JWT"), + okta.WithClientAssertion("{{clientAssertion}}"), + okta.WithScopes(([]string{"okta.users.manage"})), +) +if err != nil { + fmt.Printf("Error: %v\n", err) +} + +fmt.Printf("Context: %+v\n Client: %+v\n\n",ctx, client) +``` + +This is very similar to PrivateKey Authorization Mode with a caveat, instead of providing public/privatekey pair, you can use a pre-signed JWT instead + ### OAuth 2.0 With Bearer Token Okta SDK supports authorization using a `Bearer` token. A bearer token is an diff --git a/okta/config.go b/okta/config.go index b44739b9f..9567a698b 100644 --- a/okta/config.go +++ b/okta/config.go @@ -55,6 +55,7 @@ type config struct { Scopes []string `yaml:"scopes" envconfig:"OKTA_CLIENT_SCOPES"` PrivateKey string `yaml:"privateKey" envconfig:"OKTA_CLIENT_PRIVATEKEY"` PrivateKeyId string `yaml:"privateKeyId" envconfig:"OKTA_CLIENT_PRIVATEKEYID"` + ClientAssertion string `yaml:"clientAssertion" envconfig:"OKTA_CLIENT_CLIENTASSERTION"` } `yaml:"client"` Testing struct { DisableHttpsCheck bool `yaml:"disableHttpsCheck" envconfig:"OKTA_TESTING_DISABLE_HTTPS_CHECK"` @@ -189,6 +190,12 @@ func WithClientId(clientId string) ConfigSetter { } } +func WithClientAssertion(clientAssertion string) ConfigSetter { + return func(c *config) { + c.Okta.Client.ClientAssertion = clientAssertion + } +} + func WithScopes(scopes []string) ConfigSetter { return func(c *config) { c.Okta.Client.Scopes = scopes diff --git a/okta/okta.go b/okta/okta.go index 3d7862ff8..1d260ac44 100644 --- a/okta/okta.go +++ b/okta/okta.go @@ -150,6 +150,19 @@ func (c *Client) GetConfig() *config { return c.config } +func (c *Client) SetConfig(conf ...ConfigSetter) (err error) { + config := c.config + for _, confSetter := range conf { + confSetter(config) + } + _, err = validateConfig(config) + if err != nil { + return + } + c.config = config + return +} + // GetRequestExecutor returns underlying request executor // Deprecated: please use CloneRequestExecutor() to avoid race conditions func (c *Client) GetRequestExecutor() *RequestExecutor { diff --git a/okta/requestExecutor.go b/okta/requestExecutor.go index 70568f80e..230566b0a 100644 --- a/okta/requestExecutor.go +++ b/okta/requestExecutor.go @@ -71,6 +71,254 @@ type RequestAccessToken struct { Scope string `json:"scope,omitempty"` } +type Authorization interface { + Authorize() error +} + +type SSWSAuth struct { + token string + req *http.Request +} + +func NewSSWSAuth(token string, req *http.Request) *SSWSAuth { + return &SSWSAuth{token: token, req: req} +} + +func (a *SSWSAuth) Authorize() error { + a.req.Header.Add("Authorization", "SSWS "+a.token) + return nil +} + +type BearerAuth struct { + token string + req *http.Request +} + +func NewBearerAuth(token string, req *http.Request) *BearerAuth { + return &BearerAuth{token: token, req: req} +} + +func (a *BearerAuth) Authorize() error { + a.req.Header.Add("Authorization", "Bearer "+a.token) + return nil +} + +type PrivateKeyAuth struct { + tokenCache *goCache.Cache + httpClient *http.Client + privateKeySigner jose.Signer + privateKey string + privateKeyId string + clientId string + orgURL string + scopes []string + maxRetries int32 + maxBackoff int64 + req *http.Request +} + +type PrivateKeyAuthConfig struct { + TokenCache *goCache.Cache + HttpClient *http.Client + PrivateKeySigner jose.Signer + PrivateKey string + PrivateKeyId string + ClientId string + OrgURL string + Scopes []string + MaxRetries int32 + MaxBackoff int64 + Req *http.Request +} + +func NewPrivateKeyAuth(config PrivateKeyAuthConfig) *PrivateKeyAuth { + return &PrivateKeyAuth{ + tokenCache: config.TokenCache, + httpClient: config.HttpClient, + privateKeySigner: config.PrivateKeySigner, + privateKey: config.PrivateKey, + privateKeyId: config.PrivateKeyId, + clientId: config.ClientId, + orgURL: config.OrgURL, + scopes: config.Scopes, + maxRetries: config.MaxRetries, + maxBackoff: config.MaxBackoff, + req: config.Req, + } +} + +func (a *PrivateKeyAuth) Authorize() error { + accessToken, hasToken := a.tokenCache.Get(AccessTokenCacheKey) + if hasToken { + a.req.Header.Add("Authorization", "Bearer "+accessToken.(string)) + } else { + if a.privateKeySigner == nil { + var err error + a.privateKeySigner, err = CreateKeySigner(a.privateKey, a.privateKeyId) + if err != nil { + return err + } + } + + clientAssertion, err := CreateClientAssertion(a.orgURL, a.clientId, a.privateKeySigner) + if err != nil { + return err + } + + accessToken, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.scopes, a.maxRetries, a.maxBackoff) + if err != nil { + return err + } + + a.req.Header.Add("Authorization", "Bearer "+accessToken.AccessToken) + + // Trim a couple of seconds off calculated expiry so cache expiry + // occures before Okta server side expiry. + expiration := accessToken.ExpiresIn - 2 + a.tokenCache.Set(AccessTokenCacheKey, accessToken.AccessToken, time.Second*time.Duration(expiration)) + } + return nil +} + +type JWTAuth struct { + tokenCache *goCache.Cache + httpClient *http.Client + orgURL string + scopes []string + clientAssertion string + maxRetries int32 + maxBackoff int64 + req *http.Request +} + +type JWTAuthConfig struct { + TokenCache *goCache.Cache + HttpClient *http.Client + OrgURL string + Scopes []string + ClientAssertion string + MaxRetries int32 + MaxBackoff int64 + Req *http.Request +} + +func NewJWTAuth(config JWTAuthConfig) *JWTAuth { + return &JWTAuth{ + tokenCache: config.TokenCache, + httpClient: config.HttpClient, + orgURL: config.OrgURL, + scopes: config.Scopes, + clientAssertion: config.ClientAssertion, + maxRetries: config.MaxRetries, + maxBackoff: config.MaxBackoff, + req: config.Req, + } +} + +func (a *JWTAuth) Authorize() error { + accessToken, hasToken := a.tokenCache.Get(AccessTokenCacheKey) + if hasToken { + a.req.Header.Add("Authorization", "Bearer "+accessToken.(string)) + } else { + accessToken, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.scopes, a.maxRetries, a.maxBackoff) + if err != nil { + return err + } + a.req.Header.Add("Authorization", "Bearer "+accessToken.AccessToken) + + // Trim a couple of seconds off calculated expiry so cache expiry + // occures before Okta server side expiry. + expiration := accessToken.ExpiresIn - 2 + a.tokenCache.Set(AccessTokenCacheKey, accessToken.AccessToken, time.Second*time.Duration(expiration)) + } + return nil +} + +func CreateKeySigner(privateKey, privateKeyID string) (jose.Signer, error) { + priv := []byte(strings.ReplaceAll(privateKey, `\n`, "\n")) + + privPem, _ := pem.Decode(priv) + if privPem == nil { + return nil, errors.New("invalid private key") + } + if privPem.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("RSA private key is of the wrong type") + } + + parsedKey, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) + if err != nil { + return nil, err + } + + var signerOptions *jose.SignerOptions + if privateKeyID != "" { + signerOptions = (&jose.SignerOptions{}).WithHeader("kid", privateKeyID) + } + + return jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: parsedKey}, signerOptions) +} + +func CreateClientAssertion(orgURL, clientID string, privateKeySinger jose.Signer) (clientAssertion string, err error) { + claims := ClientAssertionClaims{ + Subject: clientID, + IssuedAt: jwt.NewNumericDate(time.Now()), + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))), + Issuer: clientID, + Audience: orgURL + "/oauth2/v1/token", + } + jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims) + return jwtBuilder.CompactSerialize() +} + +func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion string, scopes []string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, error) { + var tokenRequestBuff io.ReadWriter + query := urlpkg.Values{} + tokenRequestURL := orgURL + "/oauth2/v1/token" + + query.Add("grant_type", "client_credentials") + query.Add("scope", strings.Join(scopes, " ")) + query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + query.Add("client_assertion", clientAssertion) + tokenRequestURL += "?" + query.Encode() + tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) + if err != nil { + return nil, err + } + + tokenRequest.Header.Add("Accept", "application/json") + tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + bOff := &oktaBackoff{ + ctx: context.TODO(), + maxRetries: maxRetries, + backoffDuration: time.Duration(maxBackoff), + } + var tokenResponse *http.Response + operation := func() error { + tokenResponse, err = httpClient.Do(tokenRequest) + bOff.retryCount++ + return err + } + err = backoff.Retry(operation, bOff) + if err != nil { + return nil, err + } + + respBody, err := io.ReadAll(tokenResponse.Body) + if err != nil { + return nil, err + } + origResp := io.NopCloser(bytes.NewBuffer(respBody)) + tokenResponse.Body = origResp + var accessToken *RequestAccessToken + + _, err = buildResponse(tokenResponse, nil, &accessToken) + if err != nil { + return nil, err + } + return accessToken, nil +} + func NewRequestExecutor(httpClient *http.Client, cache cache.Cache, config *config) *RequestExecutor { re := RequestExecutor{ tokenCache: goCache.New(5*time.Minute, 10*time.Minute), @@ -121,114 +369,47 @@ func (re *RequestExecutor) NewRequest(method string, url string, body interface{ return nil, err } - if re.config.Okta.Client.AuthorizationMode == "SSWS" { - req.Header.Add("Authorization", "SSWS "+re.config.Okta.Client.Token) + var auth Authorization + + switch re.config.Okta.Client.AuthorizationMode { + case "SSWS": + auth = NewSSWSAuth(re.config.Okta.Client.Token, req) + case "Bearer": + auth = NewBearerAuth(re.config.Okta.Client.Token, req) + case "PrivateKey": + auth = NewPrivateKeyAuth(PrivateKeyAuthConfig{ + TokenCache: re.tokenCache, + HttpClient: re.httpClient, + PrivateKeySigner: re.config.PrivateKeySigner, + PrivateKey: re.config.Okta.Client.PrivateKey, + PrivateKeyId: re.config.Okta.Client.PrivateKeyId, + ClientId: re.config.Okta.Client.ClientId, + OrgURL: re.config.Okta.Client.OrgUrl, + Scopes: re.config.Okta.Client.Scopes, + MaxRetries: re.config.Okta.Client.RateLimit.MaxRetries, + MaxBackoff: re.config.Okta.Client.RateLimit.MaxBackoff, + Req: req, + }) + case "JWT": + auth = NewJWTAuth(JWTAuthConfig{ + TokenCache: re.tokenCache, + HttpClient: re.httpClient, + OrgURL: re.config.Okta.Client.OrgUrl, + Scopes: re.config.Okta.Client.Scopes, + ClientAssertion: re.config.Okta.Client.ClientAssertion, + MaxRetries: re.config.Okta.Client.RateLimit.MaxRetries, + MaxBackoff: re.config.Okta.Client.RateLimit.MaxBackoff, + Req: req, + }) + default: + return nil, fmt.Errorf("unknown authorization mode %v", re.config.Okta.Client.AuthorizationMode) } - if re.config.Okta.Client.AuthorizationMode == "Bearer" { - req.Header.Add("Authorization", "Bearer "+re.config.Okta.Client.Token) + err = auth.Authorize() + if err != nil { + return nil, err } - if re.config.Okta.Client.AuthorizationMode == "PrivateKey" { - // OAuth tokens are always cached in a dedicated cache regardless of - // what SDK cache manager the request executor is initialized with - accessToken, hasToken := re.tokenCache.Get(AccessTokenCacheKey) - if hasToken { - req.Header.Add("Authorization", "Bearer "+accessToken.(string)) - } else { - if re.config.PrivateKeySigner == nil { - priv := []byte(strings.ReplaceAll(re.config.Okta.Client.PrivateKey, `\n`, "\n")) - - privPem, _ := pem.Decode(priv) - if privPem == nil { - return nil, errors.New("invalid private key") - } - if privPem.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("RSA private key is of the wrong type") - } - - parsedKey, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) - if err != nil { - return nil, err - } - - var signerOptions *jose.SignerOptions - if re.config.Okta.Client.PrivateKeyId != "" { - signerOptions = (&jose.SignerOptions{}).WithHeader("kid", re.config.Okta.Client.PrivateKeyId) - } - - re.config.PrivateKeySigner, err = jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: parsedKey}, signerOptions) - if err != nil { - return nil, err - } - } - - claims := ClientAssertionClaims{ - Subject: re.config.Okta.Client.ClientId, - IssuedAt: jwt.NewNumericDate(time.Now()), - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))), - Issuer: re.config.Okta.Client.ClientId, - Audience: re.config.Okta.Client.OrgUrl + "/oauth2/v1/token", - } - jwtBuilder := jwt.Signed(re.config.PrivateKeySigner).Claims(claims) - clientAssertion, err := jwtBuilder.CompactSerialize() - if err != nil { - return nil, err - } - - var tokenRequestBuff io.ReadWriter - query := urlpkg.Values{} - tokenRequestURL := re.config.Okta.Client.OrgUrl + "/oauth2/v1/token" - - query.Add("grant_type", "client_credentials") - query.Add("scope", strings.Join(re.config.Okta.Client.Scopes, " ")) - query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") - query.Add("client_assertion", clientAssertion) - tokenRequestURL += "?" + query.Encode() - tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) - if err != nil { - return nil, err - } - - tokenRequest.Header.Add("Accept", "application/json") - tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - bOff := &oktaBackoff{ - ctx: context.TODO(), - maxRetries: re.config.Okta.Client.RateLimit.MaxRetries, - backoffDuration: time.Duration(re.config.Okta.Client.RateLimit.MaxBackoff), - } - var tokenResponse *http.Response - operation := func() error { - tokenResponse, err = re.httpClient.Do(tokenRequest) - bOff.retryCount++ - return err - } - err = backoff.Retry(operation, bOff) - if err != nil { - return nil, err - } - - respBody, err := io.ReadAll(tokenResponse.Body) - if err != nil { - return nil, err - } - origResp := io.NopCloser(bytes.NewBuffer(respBody)) - tokenResponse.Body = origResp - var accessToken *RequestAccessToken - - _, err = buildResponse(tokenResponse, nil, &accessToken) - if err != nil { - return nil, err - } - req.Header.Add("Authorization", "Bearer "+accessToken.AccessToken) - - // Trim a couple of seconds off calculated expiry so cache expiry - // occures before Okta server side expiry. - expiration := accessToken.ExpiresIn - 2 - re.tokenCache.Set(AccessTokenCacheKey, accessToken.AccessToken, time.Second*time.Duration(expiration)) - } - } req.Header.Add("User-Agent", NewUserAgent(re.config).String()) req.Header.Add("Accept", re.headerAccept) @@ -240,7 +421,6 @@ func (re *RequestExecutor) NewRequest(method string, url string, body interface{ re.binary = false re.headerAccept = "application/json" re.headerContentType = "application/json" - return req, nil } diff --git a/okta/validator.go b/okta/validator.go index cd42ac2aa..b41cb15fb 100644 --- a/okta/validator.go +++ b/okta/validator.go @@ -67,8 +67,9 @@ func validateAPIToken(c *config) error { func validateAuthorization(c *config) error { if c.Okta.Client.AuthorizationMode != "SSWS" && c.Okta.Client.AuthorizationMode != "PrivateKey" && - c.Okta.Client.AuthorizationMode != "Bearer" { - return errors.New("the AuthorizaitonMode config option must be one of [SSWS, Bearer, PrivateKey]. You provided the SDK with " + c.Okta.Client.AuthorizationMode) + c.Okta.Client.AuthorizationMode != "Bearer" && + c.Okta.Client.AuthorizationMode != "JWT" { + return errors.New("the AuthorizaitonMode config option must be one of [SSWS, Bearer, PrivateKey, JWT]. You provided the SDK with " + c.Okta.Client.AuthorizationMode) } if c.Okta.Client.AuthorizationMode == "PrivateKey" && @@ -79,5 +80,9 @@ func validateAuthorization(c *config) error { return errors.New("when using AuthorizationMode 'PrivateKey', you must supply 'ClientId', 'Scopes', and 'PrivateKey' or 'PrivateKeySigner'") } + if c.Okta.Client.AuthorizationMode == "JWT" && (c.Okta.Client.Scopes == nil || c.Okta.Client.ClientAssertion == "") { + return errors.New("when using AuthorizationMode 'JWT', you must supply 'Scopes', 'ClientAssertion'") + } + return nil } diff --git a/openapi/generator/templates/okta.go.hbs b/openapi/generator/templates/okta.go.hbs index 4382491b5..dc4c2a7f4 100644 --- a/openapi/generator/templates/okta.go.hbs +++ b/openapi/generator/templates/okta.go.hbs @@ -84,6 +84,19 @@ func (c *Client) GetConfig() *config { return c.config } +func (c *Client) SetConfig(conf ...ConfigSetter) (err error) { + config := c.config + for _, confSetter := range conf { + confSetter(config) + } + _, err = validateConfig(config) + if err != nil { + return + } + c.config = config + return +} + // GetRequestExecutor returns underlying request executor // Deprecated: please use CloneRequestExecutor() to avoid race conditions func (c *Client) GetRequestExecutor() *RequestExecutor { diff --git a/tests/integration/request_test.go b/tests/integration/request_test.go index 7c809416e..2c52ecffb 100644 --- a/tests/integration/request_test.go +++ b/tests/integration/request_test.go @@ -40,6 +40,27 @@ func Test_private_key_request_contains_bearer_token(t *testing.T) { assert.Contains(t, request.Header.Get("Authorization"), "Bearer", "does not contain a bearer token for the request") } +func Test_jwt_request_contains_bearer_token(t *testing.T) { + var buff io.ReadWriter + + _, client, err := tests.NewClient(context.TODO()) + require.NoError(t, err) + + privateKeySigner, err := okta.CreateKeySigner(client.GetConfig().Okta.Client.PrivateKey, client.GetConfig().Okta.Client.PrivateKeyId) + require.NoError(t, err) + + clientAssertion, err := okta.CreateClientAssertion(client.GetConfig().Okta.Client.OrgUrl, client.GetConfig().Okta.Client.ClientId, privateKeySigner) + require.NoError(t, err) + + err = client.SetConfig(okta.WithAuthorizationMode("JWT"), okta.WithScopes(([]string{"okta.users.manage"})), okta.WithClientAssertion(clientAssertion)) + require.NoError(t, err) + + request, err := client.CloneRequestExecutor().NewRequest("GET", "https://example.com/", buff) + require.NoError(t, err) + + assert.Contains(t, request.Header.Get("Authorization"), "Bearer", "does not contain a bearer token for the request") +} + func Test_private_key_request_can_create_a_user(t *testing.T) { ctx, client, err := tests.NewClient(context.TODO(), okta.WithAuthorizationMode("PrivateKey"), okta.WithScopes(([]string{"okta.users.manage"}))) require.NoError(t, err)