diff --git a/examples/jwt_get_claims_example_test.go b/examples/jwt_get_claims_example_test.go index 32cd10895..d53d6cb6a 100644 --- a/examples/jwt_get_claims_example_test.go +++ b/examples/jwt_get_claims_example_test.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" ) @@ -15,6 +16,7 @@ func ExampleJWT_GetClaims() { Subject(`example`). Claim(`claim1`, `value1`). Claim(`claim2`, `2022-05-16T07:35:56+00:00`). + Claim(`claim3`, `{"kty": "oct", "alg":"A128KW", "k":"GawgguFyGrWKav7AX4VKUg"}`). Build() if err != nil { fmt.Printf("failed to build token: %s\n", err) @@ -43,6 +45,7 @@ func ExampleJWT_GetClaims() { var dummy interface{} _ = tok.Get(`claim1`, &dummy) _ = tok.Get(`claim2`, &dummy) + _ = tok.Get(`claim3`, &dummy) // However, it is possible to globally specify that a private // claim should be parsed into a custom type. @@ -62,5 +65,24 @@ func ExampleJWT_GetClaims() { return } + // It's also possible to specify a custom decoder for a private claim. + // For example, in the case of `claim3`, it needs to call `jwk.ParseKey` + // which returns an interface that can't be instantiated like the + // `time.Time` value for `claim2`. + jwt.RegisterCustomField(`claim3`, jwt.CustomDecodeFunc(func(data []byte) (interface{}, error) { + return jwk.ParseKey(data) + })) + + tok = jwt.New() + if err := json.Unmarshal([]byte(`{"claim3": {"kty": "oct", "alg":"A128KW", "k":"GawgguFyGrWKav7AX4VKUg"}}`), tok); err != nil { + fmt.Printf(`failed to parse token: %s`, err) + return + } + var claim3 jwk.Key + if err := tok.Get(`claim3`, &claim3); err != nil { + fmt.Printf("failed to get private claim \"claim3\": %s\n", err) + return + } + // OUTPUT: } diff --git a/internal/json/registry.go b/internal/json/registry.go index 4830e86de..7e337271f 100644 --- a/internal/json/registry.go +++ b/internal/json/registry.go @@ -6,15 +6,43 @@ import ( "sync" ) +// CustomDecoder is the interface we expect from RegisterCustomField in jws, jwe, jwk, and jwt packages. +type CustomDecoder interface { + // Decode takes a JSON encoded byte slice and returns the desired + // decoded value,which will be used as the value for that field + // registered through RegisterCustomField + Decode([]byte) (interface{}, error) +} + +// CustomDecodeFunc is a stateless, function-based implementation of CustomDecoder +type CustomDecodeFunc func([]byte) (interface{}, error) + +func (fn CustomDecodeFunc) Decode(data []byte) (interface{}, error) { + return fn(data) +} + +type objectTypeDecoder struct { + typ reflect.Type + name string +} + +func (dec *objectTypeDecoder) Decode(data []byte) (interface{}, error) { + ptr := reflect.New(dec.typ).Interface() + if err := Unmarshal(data, ptr); err != nil { + return nil, fmt.Errorf(`failed to decode field %s: %w`, dec.name, err) + } + return reflect.ValueOf(ptr).Elem().Interface(), nil +} + type Registry struct { mu *sync.RWMutex - data map[string]reflect.Type + ctrs map[string]CustomDecoder } func NewRegistry() *Registry { return &Registry{ mu: &sync.RWMutex{}, - data: make(map[string]reflect.Type), + ctrs: make(map[string]CustomDecoder), } } @@ -22,26 +50,36 @@ func (r *Registry) Register(name string, object interface{}) { if object == nil { r.mu.Lock() defer r.mu.Unlock() - delete(r.data, name) + delete(r.ctrs, name) return } - typ := reflect.TypeOf(object) r.mu.Lock() defer r.mu.Unlock() - r.data[name] = typ + if ctr, ok := object.(CustomDecoder); ok { + r.ctrs[name] = ctr + } else { + r.ctrs[name] = &objectTypeDecoder{ + typ: reflect.TypeOf(object), + name: name, + } + } } func (r *Registry) Decode(dec *Decoder, name string) (interface{}, error) { r.mu.RLock() defer r.mu.RUnlock() - if typ, ok := r.data[name]; ok { - ptr := reflect.New(typ).Interface() - if err := dec.Decode(ptr); err != nil { + if ctr, ok := r.ctrs[name]; ok { + var raw RawMessage + if err := dec.Decode(&raw); err != nil { + return nil, fmt.Errorf(`failed to decode field %s: %w`, name, err) + } + v, err := ctr.Decode([]byte(raw)) + if err != nil { return nil, fmt.Errorf(`failed to decode field %s: %w`, name, err) } - return reflect.ValueOf(ptr).Elem().Interface(), nil + return v, nil } var decoded interface{} diff --git a/jwe/jwe.go b/jwe/jwe.go index f3b5edf83..760c733af 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -840,6 +840,9 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) { return m, nil } +type CustomDecoder = json.CustomDecoder +type CustomDecodeFunc = json.CustomDecodeFunc + // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -858,6 +861,24 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) { // // var bday time.Time // _ = hdr.Get(`x-birthday`, &bday) +// +// If you need a more fine-tuned control over the decoding process, +// you can register a `CustomDecoder`. For example, below shows +// how to register a decoder that can parse RFC1123 format string: +// +// jwe.RegisterCustomField(`x-birthday`, jwe.CustomDecodeFunc(func(data []byte) (interface{}, error) { +// return time.Parse(time.RFC1123, string(data)) +// })) +// +// Please note that use of custom fields can be problematic if you +// are using a library that does not implement MarshalJSON/UnmarshalJSON +// and you try to roundtrip from an object to JSON, and then back to an object. +// For example, in the above example, you can _parse_ time values formatted +// in the format specified in RFC822, but when you convert an object into +// JSON, it will be formatted in RFC3339, because that's what `time.Time` +// likes to do. To avoid this, it's always better to use a custom type +// that wraps your desired type (in this case `time.Time`) and implement +// MarshalJSON and UnmashalJSON. func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwe/jwe_test.go b/jwe/jwe_test.go index 33c6ebdcd..814667df8 100644 --- a/jwe/jwe_test.go +++ b/jwe/jwe_test.go @@ -689,57 +689,66 @@ func TestReadFile(t *testing.T) { func TestCustomField(t *testing.T) { // XXX has global effect!!! - jwe.RegisterCustomField(`x-birthday`, time.Time{}) - defer jwe.RegisterCustomField(`x-birthday`, nil) + const rfc3339Key = `x-test-rfc3339` + const rfc1123Key = `x-test-rfc1123` + jwe.RegisterCustomField(rfc3339Key, time.Time{}) + jwe.RegisterCustomField(rfc1123Key, jwe.CustomDecodeFunc(func(data []byte) (interface{}, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return time.Parse(time.RFC1123, s) + })) + + defer jwe.RegisterCustomField(rfc3339Key, nil) + defer jwe.RegisterCustomField(rfc1123Key, nil) expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) - bdaybytes, _ := expected.MarshalText() // RFC3339 + rfc3339bytes, _ := expected.MarshalText() // RFC3339 + rfc1123bytes := expected.Format(time.RFC1123) plaintext := []byte("Hello, World!") rsakey, err := jwxtest.GenerateRsaJwk() - if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) { - return - } - pubkey, err := jwk.PublicKeyOf(rsakey) - if !assert.NoError(t, err, `jwk.PublicKeyOf() should succeed`) { - return - } + require.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) - protected := jwe.NewHeaders() - protected.Set(`x-birthday`, string(bdaybytes)) + pubkey, err := jwk.PublicKeyOf(rsakey) + require.NoError(t, err, `jwk.PublicKeyOf() should succeed`) - encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected)) - if !assert.NoError(t, err, `jwe.Encrypt should succeed`) { - return - } + t.Run("jwe.Parse", func(t *testing.T) { + protected := jwe.NewHeaders() + protected.Set(rfc3339Key, string(rfc3339bytes)) + protected.Set(rfc1123Key, rfc1123bytes) - t.Run("jwe.Parse + json.Unmarshal", func(t *testing.T) { + encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected)) + require.NoError(t, err, `jwe.Encrypt should succeed`) msg, err := jwe.Parse(encrypted) if !assert.NoError(t, err, `jwe.Parse should succeed`) { + t.Logf("%q", encrypted) return } - var v time.Time - require.NoError(t, msg.ProtectedHeaders().Get(`x-birthday`, &v), `msg.ProtectedHeaders().Get("x-birthday") should succeed`) - if !assert.Equal(t, expected, v, `values should match`) { - return - } - - // Create JSON from jwe.Message - buf, err := json.Marshal(msg) - if !assert.NoError(t, err, `json.Marshal should succeed`) { - return + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, msg.ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) } - - var msg2 jwe.Message - if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) { + }) + t.Run("json.Unmarshal", func(t *testing.T) { + protected := jwe.NewHeaders() + protected.Set(rfc3339Key, string(rfc3339bytes)) + protected.Set(rfc1123Key, rfc1123bytes) + + encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected), jwe.WithJSON()) + require.NoError(t, err, `jwe.Encrypt should succeed`) + msg := jwe.NewMessage() + if !assert.NoError(t, json.Unmarshal(encrypted, msg), `json.Unmarshal should succeed`) { return } - v = time.Time{} // reset - require.NoError(t, msg2.ProtectedHeaders().Get(`x-birthday`, &v), `msg2.ProtectedHeaders().Get("x-birthday") should succeed`) - if !assert.Equal(t, expected, v, `values should match`) { - return + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, msg.ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) } }) } diff --git a/jwk/jwk.go b/jwk/jwk.go index 4cdf589a2..2c97d7d7f 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -734,6 +734,9 @@ func asnEncode(key Key) (string, []byte, error) { } } +type CustomDecoder = json.CustomDecoder +type CustomDecodeFunc = json.CustomDecodeFunc + // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -752,6 +755,24 @@ func asnEncode(key Key) (string, []byte, error) { // // var bday time.Time // _ = key.Get(`x-birthday`, &bday) +// +// If you need a more fine-tuned control over the decoding process, +// you can register a `CustomDecoder`. For example, below shows +// how to register a decoder that can parse RFC1123 format string: +// +// jwk.RegisterCustomField(`x-birthday`, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) { +// return time.Parse(time.RFC1123, string(data)) +// })) +// +// Please note that use of custom fields can be problematic if you +// are using a library that does not implement MarshalJSON/UnmarshalJSON +// and you try to roundtrip from an object to JSON, and then back to an object. +// For example, in the above example, you can _parse_ time values formatted +// in the format specified in RFC822, but when you convert an object into +// JSON, it will be formatted in RFC3339, because that's what `time.Time` +// likes to do. To avoid this, it's always better to use a custom type +// that wraps your desired type (in this case `time.Time`) and implement +// MarshalJSON and UnmashalJSON. func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index d711d5649..1c35b9a17 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -1393,16 +1393,34 @@ func TestOKP(t *testing.T) { } func TestCustomField(t *testing.T) { + const rfc3339Key = `x-rfc3339-key` + const rfc1123Key = `x-rfc1123-key` + // XXX has global effect!!! - jwk.RegisterCustomField(`x-birthday`, time.Time{}) - defer jwk.RegisterCustomField(`x-birthday`, nil) + jwk.RegisterCustomField(rfc3339Key, time.Time{}) + jwk.RegisterCustomField(rfc1123Key, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return time.Parse(time.RFC1123, s) + })) + defer jwk.RegisterCustomField(rfc3339Key, nil) + defer jwk.RegisterCustomField(rfc1123Key, nil) expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) - bdaybytes, _ := expected.MarshalText() // RFC3339 + rfc3339bytes, _ := expected.MarshalText() // RFC3339 + rfc1123bytes := expected.Format(time.RFC1123) var b strings.Builder - b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","x-birthday":"`) - b.Write(bdaybytes) + b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","`) + b.WriteString(rfc3339Key) + b.WriteString(`":"`) + b.Write(rfc3339bytes) + b.WriteString(`","`) + b.WriteString(rfc1123Key) + b.WriteString(`":"`) + b.WriteString(rfc1123bytes) b.WriteString(`"}`) src := b.String() @@ -1412,11 +1430,10 @@ func TestCustomField(t *testing.T) { return } - var v interface{} - require.NoError(t, key.Get(`x-birthday`, &v), `key.Get("x-birthday") should succeed`) - - if !assert.Equal(t, expected, v, `values should match`) { - return + for _, name := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, key.Get(name, &v), `key.Get(%q) should succeed`, name) + require.Equal(t, expected, v, `values should match`) } }) } diff --git a/jws/jws.go b/jws/jws.go index 8ff76fa50..2c8745e81 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -772,6 +772,9 @@ func parse(protected, payload, signature []byte) (*Message, error) { return &msg, nil } +type CustomDecoder = json.CustomDecoder +type CustomDecodeFunc = json.CustomDecodeFunc + // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -790,6 +793,24 @@ func parse(protected, payload, signature []byte) (*Message, error) { // // var bday time.Time // _ = hdr.Get(`x-birthday`, &bday) +// +// If you need a more fine-tuned control over the decoding process, +// you can register a `CustomDecoder`. For example, below shows +// how to register a decoder that can parse RFC1123 format string: +// +// jws.RegisterCustomField(`x-birthday`, jws.CustomDecodeFunc(func(data []byte) (interface{}, error) { +// return time.Parse(time.RFC1123, string(data)) +// })) +// +// Please note that use of custom fields can be problematic if you +// are using a library that does not implement MarshalJSON/UnmarshalJSON +// and you try to roundtrip from an object to JSON, and then back to an object. +// For example, in the above example, you can _parse_ time values formatted +// in the format specified in RFC822, but when you convert an object into +// JSON, it will be formatted in RFC3339, because that's what `time.Time` +// likes to do. To avoid this, it's always better to use a custom type +// that wraps your desired type (in this case `time.Time`) and implement +// MarshalJSON and UnmashalJSON. func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jws/jws_test.go b/jws/jws_test.go index ae8fbc7ce..0a28be15e 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -1032,45 +1032,120 @@ func TestVerifySet(t *testing.T) { func TestCustomField(t *testing.T) { // XXX has global effect!!! - jws.RegisterCustomField(`x-birthday`, time.Time{}) - defer jws.RegisterCustomField(`x-birthday`, nil) + const rfc3339Key = `x-test-rfc3339` + const rfc1123Key = `x-test-rfc1123` + jws.RegisterCustomField(rfc3339Key, time.Time{}) + jws.RegisterCustomField(rfc1123Key, jws.CustomDecodeFunc(func(data []byte) (interface{}, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return time.Parse(time.RFC1123, s) + })) + + defer jws.RegisterCustomField(rfc3339Key, nil) + defer jws.RegisterCustomField(rfc1123Key, nil) expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) - bdaybytes, _ := expected.MarshalText() // RFC3339 + rfc3339bytes, _ := expected.MarshalText() // RFC3339 + rfc1123bytes := expected.Format(time.RFC1123) - payload := "Hello, World!" - privkey, err := jwxtest.GenerateRsaJwk() + plaintext := []byte("Hello, World!") + rsakey, err := jwxtest.GenerateRsaJwk() require.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) - hdrs := jws.NewHeaders() - hdrs.Set(`x-birthday`, string(bdaybytes)) + t.Run("jws.Parse", func(t *testing.T) { + protected := jws.NewHeaders() + protected.Set(rfc3339Key, string(rfc3339bytes)) + protected.Set(rfc1123Key, rfc1123bytes) - signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.RS256, privkey, jws.WithProtectedHeaders(hdrs))) - require.NoError(t, err, `jws.Sign should succeed`) + encrypted, err := jws.Sign(plaintext, jws.WithKey(jwa.RS256, rsakey, jws.WithProtectedHeaders(protected))) + require.NoError(t, err, `jws.Sign should succeed`) + msg, err := jws.Parse(encrypted) + if !assert.NoError(t, err, `jws.Parse should succeed`) { + t.Logf("%q", encrypted) + return + } + + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) + } + }) + t.Run("json.Unmarshal", func(t *testing.T) { + protected := jws.NewHeaders() + protected.Set(rfc3339Key, string(rfc3339bytes)) + protected.Set(rfc1123Key, rfc1123bytes) + + encrypted, err := jws.Sign(plaintext, jws.WithKey(jwa.RS256, rsakey, jws.WithProtectedHeaders(protected)), jws.WithJSON()) + require.NoError(t, err, `jws.Sign should succeed`) + msg := jws.NewMessage() + if !assert.NoError(t, json.Unmarshal(encrypted, msg), `json.Unmarshal should succeed`) { + return + } - t.Run("jws.Parse + json.Unmarshal", func(t *testing.T) { - msg, err := jws.Parse(signed) - require.NoError(t, err, `jws.Parse should succeed`) + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) + } + }) - var v interface{} - require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) - require.Equal(t, expected, v, `values should match`) + /* + // XXX has global effect!!! + jws.RegisterCustomField(`x-birthday`, time.Time{}) + defer jws.RegisterCustomField(`x-birthday`, nil) - // Create JSON from jws.Message - buf, err := json.Marshal(msg) - require.NoError(t, err, `json.Marshal should succeed`) + expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) + bdaybytes, _ := expected.MarshalText() // RFC3339 - var msg2 jws.Message - require.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) + payload := "Hello, World!" + privkey, err := jwxtest.GenerateRsaJwk() + if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) { + return + } - v = nil - require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) - require.Equal(t, expected, v, `values should match`) + hdrs := jws.NewHeaders() + hdrs.Set(`x-birthday`, string(bdaybytes)) - if !assert.Equal(t, expected, v, `values should match`) { + signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.RS256, privkey, jws.WithProtectedHeaders(hdrs))) + if !assert.NoError(t, err, `jws.Sign should succeed`) { return } - }) + + t.Run("jws.Parse + json.Unmarshal", func(t *testing.T) { + msg, err := jws.Parse(signed) + if !assert.NoError(t, err, `jws.Parse should succeed`) { + return + } + + var v interface{} + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) + + if !assert.Equal(t, expected, v, `values should match`) { + return + } + + // Create JSON from jws.Message + buf, err := json.Marshal(msg) + if !assert.NoError(t, err, `json.Marshal should succeed`) { + return + } + + var msg2 jws.Message + if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) { + return + } + + v = nil + require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) + + if !assert.Equal(t, expected, v, `values should match`) { + return + } + }) + */ } func TestWithMessage(t *testing.T) { diff --git a/jwt/jwt.go b/jwt/jwt.go index d472700c0..f21b35228 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -477,6 +477,9 @@ func (t *stdToken) Clone() (Token, error) { return dst, nil } +type CustomDecoder = json.CustomDecoder +type CustomDecodeFunc = json.CustomDecodeFunc + // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -487,7 +490,7 @@ func (t *stdToken) Clone() (Token, error) { // // In such case you would register a custom field as follows // -// jwt.RegisterCustomField(`x-birthday`, time.Time) +// jwt.RegisterCustomField(`x-birthday`, time.Time{}) // // Then you can use a `time.Time` variable to extract the value // of `x-birthday` field, instead of having to use `interface{}` @@ -495,6 +498,24 @@ func (t *stdToken) Clone() (Token, error) { // // var bday time.Time // _ = token.Get(`x-birthday`, &bday) +// +// If you need a more fine-tuned control over the decoding process, +// you can register a `CustomDecoder`. For example, below shows +// how to register a decoder that can parse RFC822 format string: +// +// jwt.RegisterCustomField(`x-birthday`, jwt.CustomDecodeFunc(func(data []byte) (interface{}, error) { +// return time.Parse(time.RFC822, string(data)) +// })) +// +// Please note that use of custom fields can be problematic if you +// are using a library that does not implement MarshalJSON/UnmarshalJSON +// and you try to roundtrip from an object to JSON, and then back to an object. +// For example, in the above example, you can _parse_ time values formatted +// in the format specified in RFC822, but when you convert an object into +// JSON, it will be formatted in RFC3339, because that's what `time.Time` +// likes to do. To avoid this, it's always better to use a custom type +// that wraps your desired type (in this case `time.Time`) and implement +// MarshalJSON and UnmashalJSON. func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 6ccad6de5..07cbda4c2 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -743,15 +743,33 @@ func TestReadFile(t *testing.T) { func TestCustomField(t *testing.T) { // XXX has global effect!!! - jwt.RegisterCustomField(`x-birthday`, time.Time{}) - defer jwt.RegisterCustomField(`x-birthday`, nil) + const rfc3339Key = `x-test-rfc3339` + const rfc1123Key = `x-test-rfc1123` + jwt.RegisterCustomField(rfc3339Key, time.Time{}) + jwt.RegisterCustomField(rfc1123Key, jwt.CustomDecodeFunc(func(data []byte) (interface{}, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return time.Parse(time.RFC1123, s) + })) + + defer jwt.RegisterCustomField(rfc3339Key, nil) + defer jwt.RegisterCustomField(rfc1123Key, nil) expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) - bdaybytes, _ := expected.MarshalText() // RFC3339 + rfc3339bytes, _ := expected.MarshalText() // RFC3339 + rfc1123bytes := expected.Format(time.RFC1123) var b strings.Builder - b.WriteString(`{"iss": "github.com/lesstrrat-go/jwx", "x-birthday": "`) - b.Write(bdaybytes) + b.WriteString(`{"iss": "github.com/lesstrrat-go/jwx", "`) + b.WriteString(rfc3339Key) + b.WriteString(`": "`) + b.Write(rfc3339bytes) + b.WriteString(`", "`) + b.WriteString(rfc1123Key) + b.WriteString(`": "`) + b.WriteString(rfc1123bytes) b.WriteString(`"}`) src := b.String() @@ -762,11 +780,10 @@ func TestCustomField(t *testing.T) { return } - var v time.Time - require.NoError(t, token.Get(`x-birthday`, &v), `token.Get("x-birthday") should succeed`) - - if !assert.Equal(t, expected, v, `values should match`) { - return + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, token.Get(key, &v), `token.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) } }) t.Run("json.Unmarshal", func(t *testing.T) { @@ -775,11 +792,10 @@ func TestCustomField(t *testing.T) { return } - var v time.Time - require.NoError(t, token.Get(`x-birthday`, &v), `token.Get("x-birthday") should succeed`) - - if !assert.Equal(t, expected, v, `values should match`) { - return + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, token.Get(key, &v), `token.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) } }) }