From d7e0970ea13a427b1fce2ae5891ef205c8e53148 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Tue, 30 Jul 2024 11:54:03 +0100 Subject: [PATCH] Remove global state for audience marshalling --- claims.go | 2 +- example_test.go | 2 +- map_claims.go | 6 ++-- parser_test.go | 63 ++++++++++++++++++++++++++++++++++------ registered_claims.go | 6 ++-- types.go | 68 ++++++++++++++++++++++++++++++++------------ types_test.go | 40 ++++++++++++++++++++------ validator.go | 4 +-- 8 files changed, 147 insertions(+), 44 deletions(-) diff --git a/claims.go b/claims.go index d50ff3da..f48bee42 100644 --- a/claims.go +++ b/claims.go @@ -12,5 +12,5 @@ type Claims interface { GetNotBefore() (*NumericDate, error) GetIssuer() (string, error) GetSubject() (string, error) - GetAudience() (ClaimStrings, error) + GetAudience() (*ClaimStrings, error) } diff --git a/example_test.go b/example_test.go index 651841de..ea30e2c8 100644 --- a/example_test.go +++ b/example_test.go @@ -50,7 +50,7 @@ func ExampleNewWithClaims_customClaimsType() { Issuer: "test", Subject: "somebody", ID: "1", - Audience: []string{"somebody_else"}, + Audience: jwt.NewClaimStrings([]string{"somebody_else"}), }, } diff --git a/map_claims.go b/map_claims.go index b2b51a1f..2e842aa5 100644 --- a/map_claims.go +++ b/map_claims.go @@ -25,7 +25,7 @@ func (m MapClaims) GetIssuedAt() (*NumericDate, error) { } // GetAudience implements the Claims interface. -func (m MapClaims) GetAudience() (ClaimStrings, error) { +func (m MapClaims) GetAudience() (*ClaimStrings, error) { return m.parseClaimsString("aud") } @@ -66,7 +66,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) { // parseClaimsString tries to parse a key in the map claims type as a // [ClaimsStrings] type, which can either be a string or an array of string. -func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) { +func (m MapClaims) parseClaimsString(key string) (*ClaimStrings, error) { var cs []string switch v := m[key].(type) { case string: @@ -83,7 +83,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) { } } - return cs, nil + return NewClaimStrings(cs), nil } // parseString tries to parse a key in the map claims type as a [string] type. diff --git a/parser_test.go b/parser_test.go index c0f81711..88836732 100644 --- a/parser_test.go +++ b/parser_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "reflect" + "slices" "testing" "time" @@ -360,7 +361,7 @@ var jwtTestData = []struct { "", defaultKeyFunc, &jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{"test"}, + Audience: jwt.NewClaimStrings([]string{"test"}), }, true, nil, @@ -372,7 +373,7 @@ var jwtTestData = []struct { "", defaultKeyFunc, &jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{"test", "test"}, + Audience: jwt.NewClaimStrings([]string{"test", "test"}), }, true, nil, @@ -384,7 +385,7 @@ var jwtTestData = []struct { "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 } defaultKeyFunc, &jwt.RegisteredClaims{ - Audience: nil, // because of the unmarshal error, this will be empty + Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty }, false, []error{jwt.ErrTokenMalformed}, @@ -396,7 +397,7 @@ var jwtTestData = []struct { "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] } defaultKeyFunc, &jwt.RegisteredClaims{ - Audience: nil, // because of the unmarshal error, this will be empty + Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty }, false, []error{jwt.ErrTokenMalformed}, @@ -449,6 +450,50 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string { return test.MakeSampleToken(claims, signingMethod, privateKey) } +func claimsEqual(a, b jwt.Claims) error { + aExp, aErr := a.GetExpirationTime() + bExp, bErr := b.GetExpirationTime() + if !reflect.DeepEqual(aExp, bExp) || !reflect.DeepEqual(aErr, bErr) { + return fmt.Errorf("mismatched `exp`: expected %v vs. %v", aExp, bExp) + } + + aIat, aErr := a.GetIssuedAt() + bIat, bErr := b.GetIssuedAt() + if !reflect.DeepEqual(aIat, bIat) || !reflect.DeepEqual(aErr, bErr) { + return fmt.Errorf("mismatched `iat`: expected %v vs. %v", aIat, bIat) + } + + aNbf, aErr := a.GetNotBefore() + bNbf, bErr := b.GetNotBefore() + if !reflect.DeepEqual(aNbf, bNbf) || !reflect.DeepEqual(aErr, bErr) { + return fmt.Errorf("mismatched `nbf`: expected %v vs. %v", aNbf, bNbf) + } + + aIss, aErr := a.GetIssuer() + bIss, bErr := b.GetIssuer() + if aIss != bIss || !reflect.DeepEqual(aErr, bErr) { + return fmt.Errorf("mismatched `iss`: expected %v vs. %v", aIss, bIss) + } + + aSub, aErr := a.GetSubject() + bSub, bErr := b.GetSubject() + if aSub != bSub || !reflect.DeepEqual(aErr, bErr) { + return fmt.Errorf("mismatched `sub`: expected %v vs. %v", aSub, bSub) + } + + aAud, aErr := a.GetAudience() + bAud, bErr := b.GetAudience() + if aAud != bAud { + if aAud == nil || bAud == nil { + return fmt.Errorf("mismatched `aud`: expected %v vs. %v", aAud, bAud) + } + if !slices.Equal(aAud.Claims(), bAud.Claims()) || !reflect.DeepEqual(aErr, bErr) { + return fmt.Errorf("mismatched `aud`: expected %v vs %v", aAud, bAud) + } + } + return nil +} + func TestParser_Parse(t *testing.T) { // Iterate over test data set and run tests for _, data := range jwtTestData { @@ -476,8 +521,10 @@ func TestParser_Parse(t *testing.T) { } // Verify result matches expectation - if data.claims != nil && !reflect.DeepEqual(data.claims, token.Claims) { - t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) + if data.claims != nil { + if err := claimsEqual(data.claims, token.Claims); err != nil { + t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v: %v", data.name, data.claims, token.Claims, err) + } } if data.valid && err != nil { @@ -557,8 +604,8 @@ func TestParser_ParseUnverified(t *testing.T) { } // Verify result matches expectation - if !reflect.DeepEqual(data.claims, token.Claims) { - t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) + if err := claimsEqual(data.claims, token.Claims); err != nil { + t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v %v", data.name, data.claims, token.Claims, err) } if data.valid && err != nil { diff --git a/registered_claims.go b/registered_claims.go index 77951a53..6294785f 100644 --- a/registered_claims.go +++ b/registered_claims.go @@ -6,7 +6,7 @@ package jwt // // This type can be used on its own, but then additional private and // public claims embedded in the JWT will not be parsed. The typical use-case -// therefore is to embedded this in a user-defined claim type. +// therefore is to embed this in a user-defined claim type. // // See examples for how to use this with your own claim types. type RegisteredClaims struct { @@ -17,7 +17,7 @@ type RegisteredClaims struct { Subject string `json:"sub,omitempty"` // the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 - Audience ClaimStrings `json:"aud,omitempty"` + Audience *ClaimStrings `json:"aud,omitempty"` // the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 ExpiresAt *NumericDate `json:"exp,omitempty"` @@ -48,7 +48,7 @@ func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) { } // GetAudience implements the Claims interface. -func (c RegisteredClaims) GetAudience() (ClaimStrings, error) { +func (c RegisteredClaims) GetAudience() (*ClaimStrings, error) { return c.Audience, nil } diff --git a/types.go b/types.go index b2655a9e..ead9462e 100644 --- a/types.go +++ b/types.go @@ -17,16 +17,6 @@ import ( // no fractional timestamps are generated. var TimePrecision = time.Second -// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type, -// especially its MarshalJSON function. -// -// If it is set to true (the default), it will always serialize the type as an -// array of strings, even if it just contains one element, defaulting to the -// behavior of the underlying []string. If it is set to false, it will serialize -// to a single string, if it contains one element. Otherwise, it will serialize -// to an array of strings. -var MarshalSingleStringAsArray = true - // NumericDate represents a JSON numeric date value, as referenced at // https://datatracker.ietf.org/doc/html/rfc7519#section-2. type NumericDate struct { @@ -100,10 +90,52 @@ func (date *NumericDate) UnmarshalJSON(b []byte) (err error) { // ClaimStrings is basically just a slice of strings, but it can be either // serialized from a string array or just a string. This type is necessary, // since the "aud" claim can either be a single string or an array. -type ClaimStrings []string +type ClaimStrings struct { + claims []string + marshalSingleStringAsArray bool +} + +type ClaimStringOption func(*ClaimStrings) + +func NewClaimStrings(claims []string, opts ...ClaimStringOption) *ClaimStrings { + ret := ClaimStrings{ + claims: claims, + marshalSingleStringAsArray: true, + } + for _, opt := range opts { + opt(&ret) + } + return &ret +} + +// WithMarshalSingleStringAsArray modifies the behavior of the ClaimStrings type, +// especially its MarshalJSON function. +// +// If it is set to true (the default), it will always serialize the type as an +// array of strings, even if it just contains one element, defaulting to the +// behavior of the underlying []string. If it is set to false, it will serialize +// to a single string, if it contains one element. Otherwise, it will serialize +// to an array of strings. +func WithMarshalSingleStringAsArray(marshalSingleStringAsArray bool) func(claims *ClaimStrings) { + return func(claims *ClaimStrings) { + claims.marshalSingleStringAsArray = marshalSingleStringAsArray + } +} + +func (s *ClaimStrings) Len() int { + return len(s.claims) +} + +func (s *ClaimStrings) Claims() []string { + return s.claims +} + +func (s *ClaimStrings) String() string { + return fmt.Sprintf("%v", s.claims) +} func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { - var value interface{} + var value any if err = json.Unmarshal(data, &value); err != nil { return err @@ -115,7 +147,7 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { case string: aud = append(aud, v) case []string: - aud = ClaimStrings(v) + aud = v case []interface{}: for _, vv := range v { vs, ok := vv.(string) @@ -130,20 +162,20 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { return ErrInvalidType } - *s = aud + s.claims = aud return } -func (s ClaimStrings) MarshalJSON() (b []byte, err error) { +func (s *ClaimStrings) MarshalJSON() (b []byte, err error) { // This handles a special case in the JWT RFC. If the string array, e.g. // used by the "aud" field, only contains one element, it MAY be serialized // as a single string. This may or may not be desired based on the ecosystem // of other JWT library used, so we make it configurable by the variable // MarshalSingleStringAsArray. - if len(s) == 1 && !MarshalSingleStringAsArray { - return json.Marshal(s[0]) + if len(s.claims) == 1 && !s.marshalSingleStringAsArray { + return json.Marshal(s.claims[0]) } - return json.Marshal([]string(s)) + return json.Marshal(s.claims) } diff --git a/types_test.go b/types_test.go index bd7b139f..d09a8b2a 100644 --- a/types_test.go +++ b/types_test.go @@ -2,6 +2,7 @@ package jwt_test import ( "encoding/json" + "errors" "math" "testing" "time" @@ -34,13 +35,12 @@ func TestNumericDate(t *testing.T) { jwt.TimePrecision = oldPrecision } -func TestSingleArrayMarshal(t *testing.T) { - jwt.MarshalSingleStringAsArray = false - - s := jwt.ClaimStrings{"test"} - expected := `"test"` +func TestClaimStrings(t *testing.T) { + s := jwt.NewClaimStrings([]string{"test"}) + expected := `["test"]` b, err := json.Marshal(s) + if err != nil { t.Errorf("Unexpected error: %s", err) } @@ -48,13 +48,37 @@ func TestSingleArrayMarshal(t *testing.T) { if expected != string(b) { t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b)) } +} - jwt.MarshalSingleStringAsArray = true +func TestClaimStringsInvalidType(t *testing.T) { + j := `1` + var s jwt.ClaimStrings + err := json.Unmarshal([]byte(j), &s) + if !errors.Is(err, jwt.ErrInvalidType) { + t.Errorf("expected `ErrInvalidType` but was: %v", err) + } + if s.Claims() != nil { + t.Errorf("expected claims to be nil but was: %v", err) + } +} - expected = `["test"]` +func TestClaimStringsMismatchedTypes(t *testing.T) { + j := `["test", 1]` + var s jwt.ClaimStrings + err := json.Unmarshal([]byte(j), &s) + if !errors.Is(err, jwt.ErrInvalidType) { + t.Errorf("expected `ErrInvalidType` but was: %v", err) + } + if s.Claims() != nil { + t.Errorf("expected claims to be nil but was: %v", err) + } +} - b, err = json.Marshal(s) +func TestSingleArrayMarshal(t *testing.T) { + s := jwt.NewClaimStrings([]string{"test"}, jwt.WithMarshalSingleStringAsArray(false)) + expected := `"test"` + b, err := json.Marshal(s) if err != nil { t.Errorf("Unexpected error: %s", err) } diff --git a/validator.go b/validator.go index 008ecd87..62e1ee7c 100644 --- a/validator.go +++ b/validator.go @@ -232,7 +232,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err return err } - if len(aud) == 0 { + if aud.Len() == 0 { return errorIfRequired(required, "aud") } @@ -240,7 +240,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err result := false var stringClaims string - for _, a := range aud { + for _, a := range aud.Claims() { if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { result = true }