Skip to content

Commit

Permalink
Refactor code + tests for cleanup, fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rharpavat committed Aug 25, 2023
1 parent 6d679d2 commit 91f1098
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 128 deletions.
53 changes: 53 additions & 0 deletions pkg/pop/authnscheme.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package pop

import (
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
"time"

"github.com/google/uuid"
)

type PopAuthenticationScheme struct {
// host is the u claim we will add on the pop token
Host string
PoPKey PoPKey
}

func (as *PopAuthenticationScheme) TokenRequestParams() map[string]string {
return map[string]string{
"token_type": popTokenType,
"req_cnf": as.PoPKey.ReqCnf(),
}
}

func (as *PopAuthenticationScheme) KeyID() string {
return as.PoPKey.KeyID()
}

func (as *PopAuthenticationScheme) FormatAccessToken(accessToken string) (string, error) {
ts := time.Now().Unix()
nonce := uuid.New().String()
nonce = strings.ReplaceAll(nonce, "-", "")
header := fmt.Sprintf(`{"typ":"%s","alg":"%s","kid":"%s"}`, popTokenType, as.PoPKey.Alg(), as.PoPKey.KeyID())
headerB64 := base64.RawURLEncoding.EncodeToString([]byte(header))
payload := fmt.Sprintf(`{"at":"%s","ts":%d,"u":"%s","cnf":{"jwk":%s},"nonce":"%s"}`, accessToken, ts, as.Host, as.PoPKey.JWK(), nonce)
payloadB64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
h256 := sha256.Sum256([]byte(headerB64 + "." + payloadB64))
signature, err := as.PoPKey.Sign(h256[:])
if err != nil {
return "", err
}
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)

return headerB64 + "." + payloadB64 + "." + signatureB64, nil
}

func (as *PopAuthenticationScheme) AccessTokenType() string {
return popTokenType
}
8 changes: 5 additions & 3 deletions pkg/pop/poptokenutils.go → pkg/pop/msal.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
)

func AcquirePoPTokenInteractive(
context context.Context,
popClaims map[string]string,
scopes []string,
authority,
Expand All @@ -22,7 +23,7 @@ func AcquirePoPTokenInteractive(
return "", -1, fmt.Errorf("unable to create public client: %w", err)
}
result, err := client.AcquireTokenInteractive(
context.Background(),
context,
scopes,
public.WithAuthenticationScheme(
&PopAuthenticationScheme{
Expand All @@ -39,6 +40,7 @@ func AcquirePoPTokenInteractive(
}

func AcquirePoPTokenConfidential(
context context.Context,
popClaims map[string]string,
scopes []string,
cred confidential.Credential,
Expand All @@ -59,14 +61,14 @@ func AcquirePoPTokenConfidential(
return "", -1, fmt.Errorf("unable to create confidential client: %w", err)
}
result, err := client.AcquireTokenSilent(
context.Background(),
context,
scopes,
confidential.WithAuthenticationScheme(authnScheme),
confidential.WithTenantID(tenantID),
)
if err != nil {
result, err = client.AcquireTokenByCredential(
context.Background(),
context,
scopes,
confidential.WithAuthenticationScheme(authnScheme),
confidential.WithTenantID(tenantID),
Expand Down
42 changes: 0 additions & 42 deletions pkg/pop/poptoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,12 @@ import (
"fmt"
"log"
"math/big"
"strings"
"sync"
"time"

"github.com/google/uuid"
)

const popTokenType = "pop"

type PopAuthenticationScheme struct {
// host is the u claim we will add on the pop token
Host string
PoPKey PoPKey
}

func (as *PopAuthenticationScheme) TokenRequestParams() map[string]string {
return map[string]string{
"token_type": popTokenType,
"req_cnf": as.PoPKey.ReqCnf(),
}
}

func (as *PopAuthenticationScheme) KeyID() string {
return as.PoPKey.KeyID()
}

func (as *PopAuthenticationScheme) FormatAccessToken(accessToken string) (string, error) {
ts := time.Now().Unix()
nonce := uuid.New().String()
nonce = strings.ReplaceAll(nonce, "-", "")
header := fmt.Sprintf(`{"typ":"%s","alg":"%s","kid":"%s"}`, popTokenType, as.PoPKey.Alg(), as.PoPKey.KeyID())
headerB64 := base64.RawURLEncoding.EncodeToString([]byte(header))
payload := fmt.Sprintf(`{"at":"%s","ts":%d,"u":"%s","cnf":{"jwk":%s},"nonce":"%s"}`, accessToken, ts, as.Host, as.PoPKey.JWK(), nonce)
payloadB64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
h256 := sha256.Sum256([]byte(headerB64 + "." + payloadB64))
signature, err := as.PoPKey.Sign(h256[:])
if err != nil {
return "", err
}
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)

return headerB64 + "." + payloadB64 + "." + signatureB64, nil
}

func (as *PopAuthenticationScheme) AccessTokenType() string {
return popTokenType
}

// PoPKey - generic interface for PoP key properties and methods
type PoPKey interface {
// encryption/signature algo
Expand Down
1 change: 1 addition & 0 deletions pkg/token/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (p *InteractiveToken) Token() (adal.Token, error) {
// If PoP token support is enabled and the correct u-claim is provided, use the MSAL
// token provider to acquire a new token
token, expirationTimeUnix, err = pop.AcquirePoPTokenInteractive(
context.Background(),
p.popClaims,
scopes,
p.oAuthConfig.AuthorityEndpoint.Host,
Expand Down
4 changes: 2 additions & 2 deletions pkg/token/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (o *Options) AddFlags(fs *pflag.FlagSet) {
fs.BoolVar(&o.UseAzureRMTerraformEnv, "use-azurerm-env-vars", o.UseAzureRMTerraformEnv,
"Use environment variable names of Terraform Azure Provider (ARM_CLIENT_ID, ARM_CLIENT_SECRET, ARM_CLIENT_CERTIFICATE_PATH, ARM_CLIENT_CERTIFICATE_PASSWORD, ARM_TENANT_ID)")
fs.BoolVar(&o.IsPoPTokenEnabled, "pop-enabled", o.IsPoPTokenEnabled, "set to true to use a PoP token for authentication or false to use a regular bearer token")
fs.StringVar(&o.PoPTokenClaims, "pop-claims", o.PoPTokenClaims, "contains a comma-separated list of claims to attach to the pop token in the format `key=val,key2=val2`. At minimum, specify the ARM ID of the connected cluster as `u=ARM_ID`")
fs.StringVar(&o.PoPTokenClaims, "pop-claims", o.PoPTokenClaims, "contains a comma-separated list of claims to attach to the pop token in the format `key=val,key2=val2`. At minimum, specify the ARM ID of the cluster as `u=ARM_ID`")
}

func (o *Options) Validate() error {
Expand Down Expand Up @@ -272,7 +272,7 @@ func parsePopClaims(popClaims string) (map[string]string, error) {
claimsMap[key] = val
}
if claimsMap["u"] == "" {
return nil, fmt.Errorf("required u-claim not provided for PoP token flow. Please provide the ARM ID of the connected cluster in the format `u=<ARM_ID>`")
return nil, fmt.Errorf("required u-claim not provided for PoP token flow. Please provide the ARM ID of the cluster in the format `u=<ARM_ID>`")
}
return claimsMap, nil
}
111 changes: 71 additions & 40 deletions pkg/token/options_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package token

import (
"fmt"
"path/filepath"
"strings"
"testing"
Expand Down Expand Up @@ -174,45 +175,75 @@ func TestOptionsWithEnvVars(t *testing.T) {
}

func TestParsePoPClaims(t *testing.T) {
t.Run("pop-claim parsing should fail on empty string", func(t *testing.T) {
popClaims := ""
if _, err := parsePopClaims(popClaims); err == nil || !strings.Contains(err.Error(), "no claims provided") {
t.Fatalf("parsing pop claims should return error if claims is an empty string. got: %s", err)
}
})

t.Run("pop-claim parsing should fail on whitespace-only string", func(t *testing.T) {
popClaims := " "
if _, err := parsePopClaims(popClaims); err == nil || !strings.Contains(err.Error(), "no claims provided") {
t.Fatalf("parsing pop claims should return error if claims is whitespace-only. got: %s", err)
}
})

t.Run("pop-claim parsing should fail if claims are not provided in key=value format", func(t *testing.T) {
popClaims := "claim1=val1,claim2"
if _, err := parsePopClaims(popClaims); err == nil || !strings.Contains(err.Error(), "Ensure the claims are formatted as `key=value`") {
t.Fatalf("parsing pop claims should return error if claims are not provided in key=value format. got: %s", err)
}
})

t.Run("pop-claim parsing should fail if claims are malformed", func(t *testing.T) {
popClaims := "claim1= "
if _, err := parsePopClaims(popClaims); err == nil || !strings.Contains(err.Error(), "Ensure the claims are formatted as `key=value`") {
t.Fatalf("parsing pop claims should return error if claims are malformed. got: %s", err)
}
})

t.Run("pop-claim parsing should fail if u-claim is not provided", func(t *testing.T) {
popClaims := "claim1=val1, claim2=val2"
if _, err := parsePopClaims(popClaims); err == nil || !strings.Contains(err.Error(), "required u-claim not provided") {
t.Fatalf("parsing pop claims should return error if u-claim is not provided. got: %s", err)
}
})
testCases := []struct {
name string
popClaims string
expectedError error
expectedClaims map[string]string
}{
{
name: "pop-claim parsing should fail on empty string",
popClaims: "",
expectedError: fmt.Errorf("error parsing PoP token claims: no claims provided"),
expectedClaims: nil,
},
{
name: "pop-claim parsing should fail on whitespace-only string",
popClaims: " ",
expectedError: fmt.Errorf("error parsing PoP token claims: no claims provided"),
expectedClaims: nil,
},
{
name: "pop-claim parsing should fail if claims are not provided in key=value format",
popClaims: "claim1=val1,claim2",
expectedError: fmt.Errorf("error parsing PoP token claims. Ensure the claims are formatted as `key=value` with no extra whitespace"),
expectedClaims: nil,
},
{
name: "pop-claim parsing should fail if claims are malformed",
popClaims: "claim1= ",
expectedError: fmt.Errorf("error parsing PoP token claims. Ensure the claims are formatted as `key=value` with no extra whitespace"),
expectedClaims: nil,
},
{
name: "pop-claim parsing should fail if claims are malformed/commas only",
popClaims: ",,,,,,,,",
expectedError: fmt.Errorf("error parsing PoP token claims. Ensure the claims are formatted as `key=value` with no extra whitespace"),
expectedClaims: nil,
},
{
name: "pop-claim parsing should fail if u-claim is not provided",
popClaims: "1=2,3=4",
expectedError: fmt.Errorf("required u-claim not provided for PoP token flow. Please provide the ARM ID of the cluster in the format `u=<ARM_ID>`"),
expectedClaims: nil,
},
{
name: "pop-claim parsing should succeed with u-claim and additional claims",
popClaims: "u=val1, claim2=val2, claim3=val3",
expectedError: nil,
expectedClaims: map[string]string{
"u": "val1",
"claim2": "val2",
"claim3": "val3",
},
},
}

t.Run("pop-claim parsing should succeed with u-claim and additional claims", func(t *testing.T) {
popClaims := "u=val1, claim2=val2, claim3=val3"
if _, err := parsePopClaims(popClaims); err != nil {
t.Fatalf("parsing pop claims should return successfully on valid claims. got: %s", err)
}
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
claimsMap, err := parsePopClaims(tc.popClaims)
if err != nil {
if !ErrorContains(err, tc.expectedError.Error()) {
t.Fatalf("expected error: %+v, got error: %+v", tc.expectedError, err)
}
} else {
if err != tc.expectedError {
t.Fatalf("expected error: %+v, got error: %+v", tc.expectedError, err)
}
}
if !cmp.Equal(claimsMap, tc.expectedClaims) {
t.Fatalf("expected claims map to be %s, got map: %s", tc.expectedClaims, claimsMap)
}
})
}
}
21 changes: 12 additions & 9 deletions pkg/token/serviceprincipaltoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ func (p *servicePrincipalToken) TokenWithOptions(options *azcore.ClientOptions)

// Request a new Azure token provider for service principal
if p.clientSecret != "" {
accessToken, expirationTimeUnix, err = p.getTokenWithClientSecret(options, scopes)
accessToken, expirationTimeUnix, err = p.getTokenWithClientSecret(options, context.Background(), scopes)
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token using secret: %w", err)
}
} else if p.clientCert != "" {
accessToken, expirationTimeUnix, err = p.getTokenWithClientCert(options, scopes)
accessToken, expirationTimeUnix, err = p.getTokenWithClientCert(options, context.Background(), scopes)
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token using certificate: %w", err)
}
Expand All @@ -110,10 +110,10 @@ func (p *servicePrincipalToken) TokenWithOptions(options *azcore.ClientOptions)
}, nil
}

func (p *servicePrincipalToken) getTokenWithClientSecret(options *azcore.ClientOptions, scopes []string) (string, int64, error) {
func (p *servicePrincipalToken) getTokenWithClientSecret(options *azcore.ClientOptions, context context.Context, scopes []string) (string, int64, error) {
if p.popClaims != nil && len(p.popClaims) > 0 {
// if PoP token support is enabled, use the PoP token flow to request the token
return p.getPoPTokenWithClientSecret(scopes)
return p.getPoPTokenWithClientSecret(context, scopes)
}

clientOptions := &azidentity.ClientSecretCredentialOptions{
Expand All @@ -135,21 +135,22 @@ func (p *servicePrincipalToken) getTokenWithClientSecret(options *azcore.ClientO
}

// Use the token provider to get a new token
spnAccessToken, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: scopes})
spnAccessToken, err := cred.GetToken(context, policy.TokenRequestOptions{Scopes: scopes})
if err != nil {
return "", -1, fmt.Errorf("failed to create service principal bearer token using secret: %w", err)
}

return spnAccessToken.Token, spnAccessToken.ExpiresOn.Unix(), nil
}

func (p *servicePrincipalToken) getPoPTokenWithClientSecret(scopes []string) (string, int64, error) {
func (p *servicePrincipalToken) getPoPTokenWithClientSecret(context context.Context, scopes []string) (string, int64, error) {
cred, err := confidential.NewCredFromSecret(p.clientSecret)
if err != nil {
return "", -1, fmt.Errorf("unable to create credential. Received: %w", err)
}

accessToken, expiresOn, err := pop.AcquirePoPTokenConfidential(
context,
p.popClaims,
scopes,
cred,
Expand All @@ -164,7 +165,7 @@ func (p *servicePrincipalToken) getPoPTokenWithClientSecret(scopes []string) (st
return accessToken, expiresOn, nil
}

func (p *servicePrincipalToken) getTokenWithClientCert(options *azcore.ClientOptions, scopes []string) (string, int64, error) {
func (p *servicePrincipalToken) getTokenWithClientCert(options *azcore.ClientOptions, context context.Context, scopes []string) (string, int64, error) {
clientOptions := &azidentity.ClientCertificateCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: p.cloud,
Expand All @@ -188,7 +189,7 @@ func (p *servicePrincipalToken) getTokenWithClientCert(options *azcore.ClientOpt
certArray := []*x509.Certificate{cert}
if p.popClaims != nil && len(p.popClaims) > 0 {
// if PoP token support is enabled, use the PoP token flow to request the token
return p.getPoPTokenWithClientCert(scopes, certArray, rsaPrivateKey)
return p.getPoPTokenWithClientCert(context, scopes, certArray, rsaPrivateKey)
}

cred, err := azidentity.NewClientCertificateCredential(
Expand All @@ -201,7 +202,7 @@ func (p *servicePrincipalToken) getTokenWithClientCert(options *azcore.ClientOpt
if err != nil {
return "", -1, fmt.Errorf("unable to create credential. Received: %v", err)
}
spnAccessToken, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{p.resourceID + "/.default"}})
spnAccessToken, err := cred.GetToken(context, policy.TokenRequestOptions{Scopes: scopes})
if err != nil {
return "", -1, fmt.Errorf("failed to create service principal token using cert: %s", err)
}
Expand All @@ -210,6 +211,7 @@ func (p *servicePrincipalToken) getTokenWithClientCert(options *azcore.ClientOpt
}

func (p *servicePrincipalToken) getPoPTokenWithClientCert(
context context.Context,
scopes []string,
certArray []*x509.Certificate,
rsaPrivateKey *rsa.PrivateKey,
Expand All @@ -220,6 +222,7 @@ func (p *servicePrincipalToken) getPoPTokenWithClientCert(
}

accessToken, expiresOn, err := pop.AcquirePoPTokenConfidential(
context,
p.popClaims,
scopes,
cred,
Expand Down
Loading

0 comments on commit 91f1098

Please sign in to comment.