From f92b4ddd8792e1919ec62308c1e3215aed501bac Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 27 Oct 2023 21:57:06 +0900 Subject: [PATCH] Add it to jws --- jws/jws.go | 21 +++++++++ jws/jws_test.go | 123 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 114 insertions(+), 30 deletions(-) diff --git a/jws/jws.go b/jws/jws.go index 1c7a3e7a3..2b1559850 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -726,6 +726,9 @@ func parse(protected, payload, signature []byte) (*Message, error) { return &msg, nil } +type CustomDecoder = json.CustomDecoder +type CustomDecodeFunc = json.CustomDecodeFunc + // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -744,6 +747,24 @@ func parse(protected, payload, signature []byte) (*Message, error) { // // var bday time.Time // _ = hdr.Get(`x-birthday`, &bday) +// +// If you need a more fine-tuned control over the decoding process, +// you can register a `CustomDecoder`. For example, below shows +// how to register a decoder that can parse RFC1123 format string: +// +// jws.RegisterCustomField(`x-birthday`, jws.CustomDecodeFunc(func(data []byte) (interface{}, error) { +// return time.Parse(time.RFC1123, string(data)) +// })) +// +// Please note that use of custom fields can be problematic if you +// are using a library that does not implement MarshalJSON/UnmarshalJSON +// and you try to roundtrip from an object to JSON, and then back to an object. +// For example, in the above example, you can _parse_ time values formatted +// in the format specified in RFC822, but when you convert an object into +// JSON, it will be formatted in RFC3339, because that's what `time.Time` +// likes to do. To avoid this, it's always better to use a custom type +// that wraps your desired type (in this case `time.Time`) and implement +// MarshalJSON and UnmashalJSON. func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jws/jws_test.go b/jws/jws_test.go index 3d770f8e3..08a4a8eb4 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -1259,57 +1259,120 @@ func TestVerifySet(t *testing.T) { func TestCustomField(t *testing.T) { // XXX has global effect!!! - jws.RegisterCustomField(`x-birthday`, time.Time{}) - defer jws.RegisterCustomField(`x-birthday`, nil) + const rfc3339Key = `x-test-rfc3339` + const rfc1123Key = `x-test-rfc1123` + jws.RegisterCustomField(rfc3339Key, time.Time{}) + jws.RegisterCustomField(rfc1123Key, jws.CustomDecodeFunc(func(data []byte) (interface{}, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return time.Parse(time.RFC1123, s) + })) + + defer jws.RegisterCustomField(rfc3339Key, nil) + defer jws.RegisterCustomField(rfc1123Key, nil) expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) - bdaybytes, _ := expected.MarshalText() // RFC3339 + rfc3339bytes, _ := expected.MarshalText() // RFC3339 + rfc1123bytes := expected.Format(time.RFC1123) - payload := "Hello, World!" - privkey, err := jwxtest.GenerateRsaJwk() - if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) { - return - } + plaintext := []byte("Hello, World!") + rsakey, err := jwxtest.GenerateRsaJwk() + require.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) - hdrs := jws.NewHeaders() - hdrs.Set(`x-birthday`, string(bdaybytes)) + t.Run("jws.Parse", func(t *testing.T) { + protected := jws.NewHeaders() + protected.Set(rfc3339Key, string(rfc3339bytes)) + protected.Set(rfc1123Key, rfc1123bytes) - signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.RS256, privkey, jws.WithProtectedHeaders(hdrs))) - if !assert.NoError(t, err, `jws.Sign should succeed`) { - return - } - - t.Run("jws.Parse + json.Unmarshal", func(t *testing.T) { - msg, err := jws.Parse(signed) + encrypted, err := jws.Sign(plaintext, jws.WithKey(jwa.RS256, rsakey, jws.WithProtectedHeaders(protected))) + require.NoError(t, err, `jws.Sign should succeed`) + msg, err := jws.Parse(encrypted) if !assert.NoError(t, err, `jws.Parse should succeed`) { + t.Logf("%q", encrypted) return } - var v interface{} - require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) + } + }) + t.Run("json.Unmarshal", func(t *testing.T) { + protected := jws.NewHeaders() + protected.Set(rfc3339Key, string(rfc3339bytes)) + protected.Set(rfc1123Key, rfc1123bytes) - if !assert.Equal(t, expected, v, `values should match`) { + encrypted, err := jws.Sign(plaintext, jws.WithKey(jwa.RS256, rsakey, jws.WithProtectedHeaders(protected)), jws.WithJSON()) + require.NoError(t, err, `jws.Sign should succeed`) + msg := jws.NewMessage() + if !assert.NoError(t, json.Unmarshal(encrypted, msg), `json.Unmarshal should succeed`) { return } - // Create JSON from jws.Message - buf, err := json.Marshal(msg) - if !assert.NoError(t, err, `json.Marshal should succeed`) { - return + for _, key := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key) + require.Equal(t, expected, v, `values should match`) } + }) + + /* + // XXX has global effect!!! + jws.RegisterCustomField(`x-birthday`, time.Time{}) + defer jws.RegisterCustomField(`x-birthday`, nil) - var msg2 jws.Message - if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) { + expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) + bdaybytes, _ := expected.MarshalText() // RFC3339 + + payload := "Hello, World!" + privkey, err := jwxtest.GenerateRsaJwk() + if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) { return } - v = nil - require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) + hdrs := jws.NewHeaders() + hdrs.Set(`x-birthday`, string(bdaybytes)) - if !assert.Equal(t, expected, v, `values should match`) { + signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.RS256, privkey, jws.WithProtectedHeaders(hdrs))) + if !assert.NoError(t, err, `jws.Sign should succeed`) { return } - }) + + t.Run("jws.Parse + json.Unmarshal", func(t *testing.T) { + msg, err := jws.Parse(signed) + if !assert.NoError(t, err, `jws.Parse should succeed`) { + return + } + + var v interface{} + require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) + + if !assert.Equal(t, expected, v, `values should match`) { + return + } + + // Create JSON from jws.Message + buf, err := json.Marshal(msg) + if !assert.NoError(t, err, `json.Marshal should succeed`) { + return + } + + var msg2 jws.Message + if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) { + return + } + + v = nil + require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`) + + if !assert.Equal(t, expected, v, `values should match`) { + return + } + }) + */ } func TestWithMessage(t *testing.T) {