Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove global state for audience marshalling #404

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ type Claims interface {
GetNotBefore() (*NumericDate, error)
GetIssuer() (string, error)
GetSubject() (string, error)
GetAudience() (ClaimStrings, error)
GetAudience() (*ClaimStrings, error)
}
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func ExampleNewWithClaims_customClaimsType() {
Issuer: "test",
Subject: "somebody",
ID: "1",
Audience: []string{"somebody_else"},
Audience: jwt.NewClaimStrings([]string{"somebody_else"}),
},
}

Expand Down
6 changes: 3 additions & 3 deletions map_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
}

// GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() (ClaimStrings, error) {
func (m MapClaims) GetAudience() (*ClaimStrings, error) {
return m.parseClaimsString("aud")
}

Expand Down Expand Up @@ -66,7 +66,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {

// parseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
func (m MapClaims) parseClaimsString(key string) (*ClaimStrings, error) {
var cs []string
switch v := m[key].(type) {
case string:
Expand All @@ -83,7 +83,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
}
}

return cs, nil
return NewClaimStrings(cs), nil
}

// parseString tries to parse a key in the map claims type as a [string] type.
Expand Down
63 changes: 55 additions & 8 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"errors"
"fmt"
"reflect"
"slices"

Check failure on line 10 in parser_test.go

View workflow job for this annotation

GitHub Actions / build (1.20)

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.14/x64/src/slices)
"testing"
"time"

Expand Down Expand Up @@ -360,7 +361,7 @@
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test"},
Audience: jwt.NewClaimStrings([]string{"test"}),
},
true,
nil,
Expand All @@ -372,7 +373,7 @@
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test", "test"},
Audience: jwt.NewClaimStrings([]string{"test", "test"}),
},
true,
nil,
Expand All @@ -384,7 +385,7 @@
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
},
false,
[]error{jwt.ErrTokenMalformed},
Expand All @@ -396,7 +397,7 @@
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
},
false,
[]error{jwt.ErrTokenMalformed},
Expand Down Expand Up @@ -449,6 +450,50 @@
return test.MakeSampleToken(claims, signingMethod, privateKey)
}

func claimsEqual(a, b jwt.Claims) error {
aExp, aErr := a.GetExpirationTime()
bExp, bErr := b.GetExpirationTime()
if !reflect.DeepEqual(aExp, bExp) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `exp`: expected %v vs. %v", aExp, bExp)
}

aIat, aErr := a.GetIssuedAt()
bIat, bErr := b.GetIssuedAt()
if !reflect.DeepEqual(aIat, bIat) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iat`: expected %v vs. %v", aIat, bIat)
}

aNbf, aErr := a.GetNotBefore()
bNbf, bErr := b.GetNotBefore()
if !reflect.DeepEqual(aNbf, bNbf) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `nbf`: expected %v vs. %v", aNbf, bNbf)
}

aIss, aErr := a.GetIssuer()
bIss, bErr := b.GetIssuer()
if aIss != bIss || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iss`: expected %v vs. %v", aIss, bIss)
}

aSub, aErr := a.GetSubject()
bSub, bErr := b.GetSubject()
if aSub != bSub || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `sub`: expected %v vs. %v", aSub, bSub)
}

aAud, aErr := a.GetAudience()
bAud, bErr := b.GetAudience()
if aAud != bAud {
if aAud == nil || bAud == nil {
return fmt.Errorf("mismatched `aud`: expected %v vs. %v", aAud, bAud)
}
if !slices.Equal(aAud.Claims(), bAud.Claims()) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `aud`: expected %v vs %v", aAud, bAud)
}
}
return nil
}

func TestParser_Parse(t *testing.T) {
// Iterate over test data set and run tests
for _, data := range jwtTestData {
Expand Down Expand Up @@ -476,8 +521,10 @@
}

// Verify result matches expectation
if data.claims != nil && !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
if data.claims != nil {
if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v: %v", data.name, data.claims, token.Claims, err)
}
}

if data.valid && err != nil {
Expand Down Expand Up @@ -557,8 +604,8 @@
}

// Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v %v", data.name, data.claims, token.Claims, err)
}

if data.valid && err != nil {
Expand Down
6 changes: 3 additions & 3 deletions registered_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package jwt
//
// This type can be used on its own, but then additional private and
// public claims embedded in the JWT will not be parsed. The typical use-case
// therefore is to embedded this in a user-defined claim type.
// therefore is to embed this in a user-defined claim type.
//
// See examples for how to use this with your own claim types.
type RegisteredClaims struct {
Expand All @@ -17,7 +17,7 @@ type RegisteredClaims struct {
Subject string `json:"sub,omitempty"`

// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience ClaimStrings `json:"aud,omitempty"`
Audience *ClaimStrings `json:"aud,omitempty"`

// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt *NumericDate `json:"exp,omitempty"`
Expand Down Expand Up @@ -48,7 +48,7 @@ func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
}

// GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
func (c RegisteredClaims) GetAudience() (*ClaimStrings, error) {
return c.Audience, nil
}

Expand Down
68 changes: 50 additions & 18 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ import (
// no fractional timestamps are generated.
var TimePrecision = time.Second

// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
var MarshalSingleStringAsArray = true

// NumericDate represents a JSON numeric date value, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-2.
type NumericDate struct {
Expand Down Expand Up @@ -100,10 +90,52 @@ func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
// ClaimStrings is basically just a slice of strings, but it can be either
// serialized from a string array or just a string. This type is necessary,
// since the "aud" claim can either be a single string or an array.
type ClaimStrings []string
type ClaimStrings struct {
claims []string
marshalSingleStringAsArray bool
}

type ClaimStringOption func(*ClaimStrings)

func NewClaimStrings(claims []string, opts ...ClaimStringOption) *ClaimStrings {
ret := ClaimStrings{
claims: claims,
marshalSingleStringAsArray: true,
}
for _, opt := range opts {
opt(&ret)
}
return &ret
}

// WithMarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
func WithMarshalSingleStringAsArray(marshalSingleStringAsArray bool) func(claims *ClaimStrings) {
return func(claims *ClaimStrings) {
claims.marshalSingleStringAsArray = marshalSingleStringAsArray
}
}

func (s *ClaimStrings) Len() int {
return len(s.claims)
}

func (s *ClaimStrings) Claims() []string {
return s.claims
}

func (s *ClaimStrings) String() string {
return fmt.Sprintf("%v", s.claims)
}

func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
var value interface{}
var value any

if err = json.Unmarshal(data, &value); err != nil {
return err
Expand All @@ -115,7 +147,7 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
case string:
aud = append(aud, v)
case []string:
aud = ClaimStrings(v)
aud = v
case []interface{}:
for _, vv := range v {
vs, ok := vv.(string)
Expand All @@ -130,20 +162,20 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
return ErrInvalidType
}

*s = aud
s.claims = aud

return
}

func (s ClaimStrings) MarshalJSON() (b []byte, err error) {
func (s *ClaimStrings) MarshalJSON() (b []byte, err error) {
// This handles a special case in the JWT RFC. If the string array, e.g.
// used by the "aud" field, only contains one element, it MAY be serialized
// as a single string. This may or may not be desired based on the ecosystem
// of other JWT library used, so we make it configurable by the variable
// MarshalSingleStringAsArray.
if len(s) == 1 && !MarshalSingleStringAsArray {
return json.Marshal(s[0])
if len(s.claims) == 1 && !s.marshalSingleStringAsArray {
return json.Marshal(s.claims[0])
}

return json.Marshal([]string(s))
return json.Marshal(s.claims)
}
40 changes: 32 additions & 8 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt_test

import (
"encoding/json"
"errors"
"math"
"testing"
"time"
Expand Down Expand Up @@ -34,27 +35,50 @@ func TestNumericDate(t *testing.T) {
jwt.TimePrecision = oldPrecision
}

func TestSingleArrayMarshal(t *testing.T) {
jwt.MarshalSingleStringAsArray = false

s := jwt.ClaimStrings{"test"}
expected := `"test"`
func TestClaimStrings(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"})
expected := `["test"]`

b, err := json.Marshal(s)

if err != nil {
t.Errorf("Unexpected error: %s", err)
}

if expected != string(b) {
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b))
}
}

jwt.MarshalSingleStringAsArray = true
func TestClaimStringsInvalidType(t *testing.T) {
j := `1`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}

expected = `["test"]`
func TestClaimStringsMismatchedTypes(t *testing.T) {
j := `["test", 1]`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}

b, err = json.Marshal(s)
func TestSingleArrayMarshal(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"}, jwt.WithMarshalSingleStringAsArray(false))
expected := `"test"`

b, err := json.Marshal(s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
Expand Down
4 changes: 2 additions & 2 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,15 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
return err
}

if len(aud) == 0 {
if aud.Len() == 0 {
return errorIfRequired(required, "aud")
}

// use a var here to keep constant time compare when looping over a number of claims
result := false

var stringClaims string
for _, a := range aud {
for _, a := range aud.Claims() {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
result = true
}
Expand Down
Loading