Skip to content

Commit

Permalink
Remove global state for audience marshalling
Browse files Browse the repository at this point in the history
  • Loading branch information
hatstand committed Jul 30, 2024
1 parent 62e504c commit d7e0970
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 44 deletions.
2 changes: 1 addition & 1 deletion claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ type Claims interface {
GetNotBefore() (*NumericDate, error)
GetIssuer() (string, error)
GetSubject() (string, error)
GetAudience() (ClaimStrings, error)
GetAudience() (*ClaimStrings, error)
}
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func ExampleNewWithClaims_customClaimsType() {
Issuer: "test",
Subject: "somebody",
ID: "1",
Audience: []string{"somebody_else"},
Audience: jwt.NewClaimStrings([]string{"somebody_else"}),
},
}

Expand Down
6 changes: 3 additions & 3 deletions map_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
}

// GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() (ClaimStrings, error) {
func (m MapClaims) GetAudience() (*ClaimStrings, error) {
return m.parseClaimsString("aud")
}

Expand Down Expand Up @@ -66,7 +66,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {

// parseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
func (m MapClaims) parseClaimsString(key string) (*ClaimStrings, error) {
var cs []string
switch v := m[key].(type) {
case string:
Expand All @@ -83,7 +83,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
}
}

return cs, nil
return NewClaimStrings(cs), nil
}

// parseString tries to parse a key in the map claims type as a [string] type.
Expand Down
63 changes: 55 additions & 8 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"reflect"
"slices"

Check failure on line 10 in parser_test.go

View workflow job for this annotation

GitHub Actions / build (1.20)

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.14/x64/src/slices)
"testing"
"time"

Expand Down Expand Up @@ -360,7 +361,7 @@ var jwtTestData = []struct {
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test"},
Audience: jwt.NewClaimStrings([]string{"test"}),
},
true,
nil,
Expand All @@ -372,7 +373,7 @@ var jwtTestData = []struct {
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test", "test"},
Audience: jwt.NewClaimStrings([]string{"test", "test"}),
},
true,
nil,
Expand All @@ -384,7 +385,7 @@ var jwtTestData = []struct {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
},
false,
[]error{jwt.ErrTokenMalformed},
Expand All @@ -396,7 +397,7 @@ var jwtTestData = []struct {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
},
false,
[]error{jwt.ErrTokenMalformed},
Expand Down Expand Up @@ -449,6 +450,50 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string {
return test.MakeSampleToken(claims, signingMethod, privateKey)
}

func claimsEqual(a, b jwt.Claims) error {
aExp, aErr := a.GetExpirationTime()
bExp, bErr := b.GetExpirationTime()
if !reflect.DeepEqual(aExp, bExp) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `exp`: expected %v vs. %v", aExp, bExp)
}

aIat, aErr := a.GetIssuedAt()
bIat, bErr := b.GetIssuedAt()
if !reflect.DeepEqual(aIat, bIat) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iat`: expected %v vs. %v", aIat, bIat)
}

aNbf, aErr := a.GetNotBefore()
bNbf, bErr := b.GetNotBefore()
if !reflect.DeepEqual(aNbf, bNbf) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `nbf`: expected %v vs. %v", aNbf, bNbf)
}

aIss, aErr := a.GetIssuer()
bIss, bErr := b.GetIssuer()
if aIss != bIss || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iss`: expected %v vs. %v", aIss, bIss)
}

aSub, aErr := a.GetSubject()
bSub, bErr := b.GetSubject()
if aSub != bSub || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `sub`: expected %v vs. %v", aSub, bSub)
}

aAud, aErr := a.GetAudience()
bAud, bErr := b.GetAudience()
if aAud != bAud {
if aAud == nil || bAud == nil {
return fmt.Errorf("mismatched `aud`: expected %v vs. %v", aAud, bAud)
}
if !slices.Equal(aAud.Claims(), bAud.Claims()) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `aud`: expected %v vs %v", aAud, bAud)
}
}
return nil
}

func TestParser_Parse(t *testing.T) {
// Iterate over test data set and run tests
for _, data := range jwtTestData {
Expand Down Expand Up @@ -476,8 +521,10 @@ func TestParser_Parse(t *testing.T) {
}

// Verify result matches expectation
if data.claims != nil && !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
if data.claims != nil {
if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v: %v", data.name, data.claims, token.Claims, err)
}
}

if data.valid && err != nil {
Expand Down Expand Up @@ -557,8 +604,8 @@ func TestParser_ParseUnverified(t *testing.T) {
}

// Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v %v", data.name, data.claims, token.Claims, err)
}

if data.valid && err != nil {
Expand Down
6 changes: 3 additions & 3 deletions registered_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package jwt
//
// This type can be used on its own, but then additional private and
// public claims embedded in the JWT will not be parsed. The typical use-case
// therefore is to embedded this in a user-defined claim type.
// therefore is to embed this in a user-defined claim type.
//
// See examples for how to use this with your own claim types.
type RegisteredClaims struct {
Expand All @@ -17,7 +17,7 @@ type RegisteredClaims struct {
Subject string `json:"sub,omitempty"`

// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience ClaimStrings `json:"aud,omitempty"`
Audience *ClaimStrings `json:"aud,omitempty"`

// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt *NumericDate `json:"exp,omitempty"`
Expand Down Expand Up @@ -48,7 +48,7 @@ func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
}

// GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
func (c RegisteredClaims) GetAudience() (*ClaimStrings, error) {
return c.Audience, nil
}

Expand Down
68 changes: 50 additions & 18 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ import (
// no fractional timestamps are generated.
var TimePrecision = time.Second

// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
var MarshalSingleStringAsArray = true

// NumericDate represents a JSON numeric date value, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-2.
type NumericDate struct {
Expand Down Expand Up @@ -100,10 +90,52 @@ func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
// ClaimStrings is basically just a slice of strings, but it can be either
// serialized from a string array or just a string. This type is necessary,
// since the "aud" claim can either be a single string or an array.
type ClaimStrings []string
type ClaimStrings struct {
claims []string
marshalSingleStringAsArray bool
}

type ClaimStringOption func(*ClaimStrings)

func NewClaimStrings(claims []string, opts ...ClaimStringOption) *ClaimStrings {
ret := ClaimStrings{
claims: claims,
marshalSingleStringAsArray: true,
}
for _, opt := range opts {
opt(&ret)
}
return &ret
}

// WithMarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
func WithMarshalSingleStringAsArray(marshalSingleStringAsArray bool) func(claims *ClaimStrings) {
return func(claims *ClaimStrings) {
claims.marshalSingleStringAsArray = marshalSingleStringAsArray
}
}

func (s *ClaimStrings) Len() int {
return len(s.claims)
}

func (s *ClaimStrings) Claims() []string {
return s.claims
}

func (s *ClaimStrings) String() string {
return fmt.Sprintf("%v", s.claims)
}

func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
var value interface{}
var value any

if err = json.Unmarshal(data, &value); err != nil {
return err
Expand All @@ -115,7 +147,7 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
case string:
aud = append(aud, v)
case []string:
aud = ClaimStrings(v)
aud = v
case []interface{}:
for _, vv := range v {
vs, ok := vv.(string)
Expand All @@ -130,20 +162,20 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
return ErrInvalidType
}

*s = aud
s.claims = aud

return
}

func (s ClaimStrings) MarshalJSON() (b []byte, err error) {
func (s *ClaimStrings) MarshalJSON() (b []byte, err error) {
// This handles a special case in the JWT RFC. If the string array, e.g.
// used by the "aud" field, only contains one element, it MAY be serialized
// as a single string. This may or may not be desired based on the ecosystem
// of other JWT library used, so we make it configurable by the variable
// MarshalSingleStringAsArray.
if len(s) == 1 && !MarshalSingleStringAsArray {
return json.Marshal(s[0])
if len(s.claims) == 1 && !s.marshalSingleStringAsArray {
return json.Marshal(s.claims[0])
}

return json.Marshal([]string(s))
return json.Marshal(s.claims)
}
40 changes: 32 additions & 8 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt_test

import (
"encoding/json"
"errors"
"math"
"testing"
"time"
Expand Down Expand Up @@ -34,27 +35,50 @@ func TestNumericDate(t *testing.T) {
jwt.TimePrecision = oldPrecision
}

func TestSingleArrayMarshal(t *testing.T) {
jwt.MarshalSingleStringAsArray = false

s := jwt.ClaimStrings{"test"}
expected := `"test"`
func TestClaimStrings(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"})
expected := `["test"]`

b, err := json.Marshal(s)

if err != nil {
t.Errorf("Unexpected error: %s", err)
}

if expected != string(b) {
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b))
}
}

jwt.MarshalSingleStringAsArray = true
func TestClaimStringsInvalidType(t *testing.T) {
j := `1`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}

expected = `["test"]`
func TestClaimStringsMismatchedTypes(t *testing.T) {
j := `["test", 1]`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}

b, err = json.Marshal(s)
func TestSingleArrayMarshal(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"}, jwt.WithMarshalSingleStringAsArray(false))
expected := `"test"`

b, err := json.Marshal(s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
Expand Down
4 changes: 2 additions & 2 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,15 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
return err
}

if len(aud) == 0 {
if aud.Len() == 0 {
return errorIfRequired(required, "aud")
}

// use a var here to keep constant time compare when looping over a number of claims
result := false

var stringClaims string
for _, a := range aud {
for _, a := range aud.Claims() {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
result = true
}
Expand Down

0 comments on commit d7e0970

Please sign in to comment.