Skip to content

Commit

Permalink
Add it to jws
Browse files Browse the repository at this point in the history
  • Loading branch information
lestrrat committed Oct 27, 2023
1 parent bf5110b commit f92b4dd
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 30 deletions.
21 changes: 21 additions & 0 deletions jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,9 @@ func parse(protected, payload, signature []byte) (*Message, error) {
return &msg, nil
}

type CustomDecoder = json.CustomDecoder
type CustomDecodeFunc = json.CustomDecodeFunc

// RegisterCustomField allows users to specify that a private field
// be decoded as an instance of the specified type. This option has
// a global effect.
Expand All @@ -744,6 +747,24 @@ func parse(protected, payload, signature []byte) (*Message, error) {
//
// var bday time.Time
// _ = hdr.Get(`x-birthday`, &bday)
//
// If you need a more fine-tuned control over the decoding process,
// you can register a `CustomDecoder`. For example, below shows
// how to register a decoder that can parse RFC1123 format string:
//
// jws.RegisterCustomField(`x-birthday`, jws.CustomDecodeFunc(func(data []byte) (interface{}, error) {
// return time.Parse(time.RFC1123, string(data))
// }))
//
// Please note that use of custom fields can be problematic if you
// are using a library that does not implement MarshalJSON/UnmarshalJSON
// and you try to roundtrip from an object to JSON, and then back to an object.
// For example, in the above example, you can _parse_ time values formatted
// in the format specified in RFC822, but when you convert an object into
// JSON, it will be formatted in RFC3339, because that's what `time.Time`
// likes to do. To avoid this, it's always better to use a custom type
// that wraps your desired type (in this case `time.Time`) and implement
// MarshalJSON and UnmashalJSON.
func RegisterCustomField(name string, object interface{}) {
registry.Register(name, object)
}
Expand Down
123 changes: 93 additions & 30 deletions jws/jws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1259,57 +1259,120 @@ func TestVerifySet(t *testing.T) {

func TestCustomField(t *testing.T) {
// XXX has global effect!!!
jws.RegisterCustomField(`x-birthday`, time.Time{})
defer jws.RegisterCustomField(`x-birthday`, nil)
const rfc3339Key = `x-test-rfc3339`
const rfc1123Key = `x-test-rfc1123`
jws.RegisterCustomField(rfc3339Key, time.Time{})
jws.RegisterCustomField(rfc1123Key, jws.CustomDecodeFunc(func(data []byte) (interface{}, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return time.Parse(time.RFC1123, s)
}))

defer jws.RegisterCustomField(rfc3339Key, nil)
defer jws.RegisterCustomField(rfc1123Key, nil)

expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
bdaybytes, _ := expected.MarshalText() // RFC3339
rfc3339bytes, _ := expected.MarshalText() // RFC3339
rfc1123bytes := expected.Format(time.RFC1123)

payload := "Hello, World!"
privkey, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) {
return
}
plaintext := []byte("Hello, World!")
rsakey, err := jwxtest.GenerateRsaJwk()
require.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`)

hdrs := jws.NewHeaders()
hdrs.Set(`x-birthday`, string(bdaybytes))
t.Run("jws.Parse", func(t *testing.T) {
protected := jws.NewHeaders()
protected.Set(rfc3339Key, string(rfc3339bytes))
protected.Set(rfc1123Key, rfc1123bytes)

signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.RS256, privkey, jws.WithProtectedHeaders(hdrs)))
if !assert.NoError(t, err, `jws.Sign should succeed`) {
return
}

t.Run("jws.Parse + json.Unmarshal", func(t *testing.T) {
msg, err := jws.Parse(signed)
encrypted, err := jws.Sign(plaintext, jws.WithKey(jwa.RS256, rsakey, jws.WithProtectedHeaders(protected)))
require.NoError(t, err, `jws.Sign should succeed`)
msg, err := jws.Parse(encrypted)
if !assert.NoError(t, err, `jws.Parse should succeed`) {
t.Logf("%q", encrypted)
return
}

var v interface{}
require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`)
for _, key := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key)
require.Equal(t, expected, v, `values should match`)
}
})
t.Run("json.Unmarshal", func(t *testing.T) {
protected := jws.NewHeaders()
protected.Set(rfc3339Key, string(rfc3339bytes))
protected.Set(rfc1123Key, rfc1123bytes)

if !assert.Equal(t, expected, v, `values should match`) {
encrypted, err := jws.Sign(plaintext, jws.WithKey(jwa.RS256, rsakey, jws.WithProtectedHeaders(protected)), jws.WithJSON())
require.NoError(t, err, `jws.Sign should succeed`)
msg := jws.NewMessage()
if !assert.NoError(t, json.Unmarshal(encrypted, msg), `json.Unmarshal should succeed`) {
return
}

// Create JSON from jws.Message
buf, err := json.Marshal(msg)
if !assert.NoError(t, err, `json.Marshal should succeed`) {
return
for _, key := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key)
require.Equal(t, expected, v, `values should match`)
}
})

/*
// XXX has global effect!!!
jws.RegisterCustomField(`x-birthday`, time.Time{})
defer jws.RegisterCustomField(`x-birthday`, nil)
var msg2 jws.Message
if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) {
expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
bdaybytes, _ := expected.MarshalText() // RFC3339
payload := "Hello, World!"
privkey, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) {
return
}
v = nil
require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`)
hdrs := jws.NewHeaders()
hdrs.Set(`x-birthday`, string(bdaybytes))
if !assert.Equal(t, expected, v, `values should match`) {
signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.RS256, privkey, jws.WithProtectedHeaders(hdrs)))
if !assert.NoError(t, err, `jws.Sign should succeed`) {
return
}
})
t.Run("jws.Parse + json.Unmarshal", func(t *testing.T) {
msg, err := jws.Parse(signed)
if !assert.NoError(t, err, `jws.Parse should succeed`) {
return
}
var v interface{}
require.NoError(t, msg.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`)
if !assert.Equal(t, expected, v, `values should match`) {
return
}
// Create JSON from jws.Message
buf, err := json.Marshal(msg)
if !assert.NoError(t, err, `json.Marshal should succeed`) {
return
}
var msg2 jws.Message
if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) {
return
}
v = nil
require.NoError(t, msg2.Signatures()[0].ProtectedHeaders().Get(`x-birthday`, &v), `msg2.Signatures()[0].ProtectedHeaders().Get("x-birthday") should succeed`)
if !assert.Equal(t, expected, v, `values should match`) {
return
}
})
*/
}

func TestWithMessage(t *testing.T) {
Expand Down

0 comments on commit f92b4dd

Please sign in to comment.