Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering a constructor for RegisterCustomField #1006

Merged
merged 20 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/jwt_get_claims_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"time"

"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
)

Expand All @@ -15,6 +16,7 @@ func ExampleJWT_GetClaims() {
Subject(`example`).
Claim(`claim1`, `value1`).
Claim(`claim2`, `2022-05-16T07:35:56+00:00`).
Claim(`claim3`, `{"kty": "oct", "alg":"A128KW", "k":"GawgguFyGrWKav7AX4VKUg"}`).
Build()
if err != nil {
fmt.Printf("failed to build token: %s\n", err)
Expand Down Expand Up @@ -43,6 +45,7 @@ func ExampleJWT_GetClaims() {
var dummy interface{}
_ = tok.Get(`claim1`, &dummy)
_ = tok.Get(`claim2`, &dummy)
_ = tok.Get(`claim3`, &dummy)

// However, it is possible to globally specify that a private
// claim should be parsed into a custom type.
Expand All @@ -62,5 +65,24 @@ func ExampleJWT_GetClaims() {
return
}

// It's also possible to specify a custom decoder for a private claim.
// For example, in the case of `claim3`, it needs to call `jwk.ParseKey`
// which returns an interface that can't be instantiated like the
// `time.Time` value for `claim2`.
jwt.RegisterCustomField(`claim3`, jwt.CustomDecodeFunc(func(data []byte) (interface{}, error) {
return jwk.ParseKey(data)
}))

tok = jwt.New()
if err := json.Unmarshal([]byte(`{"claim3": {"kty": "oct", "alg":"A128KW", "k":"GawgguFyGrWKav7AX4VKUg"}}`), tok); err != nil {
fmt.Printf(`failed to parse token: %s`, err)
return
}
var claim3 jwk.Key
if err := tok.Get(`claim3`, &claim3); err != nil {
fmt.Printf("failed to get private claim \"claim3\": %s\n", err)
return
}

// OUTPUT:
}
56 changes: 47 additions & 9 deletions internal/json/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,80 @@ import (
"sync"
)

// CustomDecoder is the interface we expect from RegisterCustomField in jws, jwe, jwk, and jwt packages.
type CustomDecoder interface {
// Decode takes a JSON encoded byte slice and returns the desired
// decoded value,which will be used as the value for that field
// registered through RegisterCustomField
Decode([]byte) (interface{}, error)
}

// CustomDecodeFunc is a stateless, function-based implementation of CustomDecoder
type CustomDecodeFunc func([]byte) (interface{}, error)

func (fn CustomDecodeFunc) Decode(data []byte) (interface{}, error) {
return fn(data)
}

type objectTypeDecoder struct {
typ reflect.Type
name string
}

func (dec *objectTypeDecoder) Decode(data []byte) (interface{}, error) {
ptr := reflect.New(dec.typ).Interface()
if err := Unmarshal(data, ptr); err != nil {
return nil, fmt.Errorf(`failed to decode field %s: %w`, dec.name, err)
}
return reflect.ValueOf(ptr).Elem().Interface(), nil
}

type Registry struct {
mu *sync.RWMutex
data map[string]reflect.Type
ctrs map[string]CustomDecoder
}

func NewRegistry() *Registry {
return &Registry{
mu: &sync.RWMutex{},
data: make(map[string]reflect.Type),
ctrs: make(map[string]CustomDecoder),
}
}

func (r *Registry) Register(name string, object interface{}) {
if object == nil {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.data, name)
delete(r.ctrs, name)
return
}

typ := reflect.TypeOf(object)
r.mu.Lock()
defer r.mu.Unlock()
r.data[name] = typ
if ctr, ok := object.(CustomDecoder); ok {
r.ctrs[name] = ctr
} else {
r.ctrs[name] = &objectTypeDecoder{
typ: reflect.TypeOf(object),
name: name,
}
}
}

func (r *Registry) Decode(dec *Decoder, name string) (interface{}, error) {
r.mu.RLock()
defer r.mu.RUnlock()

if typ, ok := r.data[name]; ok {
ptr := reflect.New(typ).Interface()
if err := dec.Decode(ptr); err != nil {
if ctr, ok := r.ctrs[name]; ok {
var raw RawMessage
if err := dec.Decode(&raw); err != nil {
return nil, fmt.Errorf(`failed to decode field %s: %w`, name, err)
}
v, err := ctr.Decode([]byte(raw))
if err != nil {
return nil, fmt.Errorf(`failed to decode field %s: %w`, name, err)
}
return reflect.ValueOf(ptr).Elem().Interface(), nil
return v, nil
}

var decoded interface{}
Expand Down
21 changes: 21 additions & 0 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,9 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
return m, nil
}

type CustomDecoder = json.CustomDecoder
type CustomDecodeFunc = json.CustomDecodeFunc

// RegisterCustomField allows users to specify that a private field
// be decoded as an instance of the specified type. This option has
// a global effect.
Expand All @@ -803,6 +806,24 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
//
// var bday time.Time
// _ = hdr.Get(`x-birthday`, &bday)
//
// If you need a more fine-tuned control over the decoding process,
// you can register a `CustomDecoder`. For example, below shows
// how to register a decoder that can parse RFC1123 format string:
//
// jwe.RegisterCustomField(`x-birthday`, jwe.CustomDecodeFunc(func(data []byte) (interface{}, error) {
// return time.Parse(time.RFC1123, string(data))
// }))
//
// Please note that use of custom fields can be problematic if you
// are using a library that does not implement MarshalJSON/UnmarshalJSON
// and you try to roundtrip from an object to JSON, and then back to an object.
// For example, in the above example, you can _parse_ time values formatted
// in the format specified in RFC822, but when you convert an object into
// JSON, it will be formatted in RFC3339, because that's what `time.Time`
// likes to do. To avoid this, it's always better to use a custom type
// that wraps your desired type (in this case `time.Time`) and implement
// MarshalJSON and UnmashalJSON.
func RegisterCustomField(name string, object interface{}) {
registry.Register(name, object)
}
77 changes: 43 additions & 34 deletions jwe/jwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,57 +689,66 @@ func TestReadFile(t *testing.T) {

func TestCustomField(t *testing.T) {
// XXX has global effect!!!
jwe.RegisterCustomField(`x-birthday`, time.Time{})
defer jwe.RegisterCustomField(`x-birthday`, nil)
const rfc3339Key = `x-test-rfc3339`
const rfc1123Key = `x-test-rfc1123`
jwe.RegisterCustomField(rfc3339Key, time.Time{})
jwe.RegisterCustomField(rfc1123Key, jwe.CustomDecodeFunc(func(data []byte) (interface{}, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return time.Parse(time.RFC1123, s)
}))

defer jwe.RegisterCustomField(rfc3339Key, nil)
defer jwe.RegisterCustomField(rfc1123Key, nil)

expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
bdaybytes, _ := expected.MarshalText() // RFC3339
rfc3339bytes, _ := expected.MarshalText() // RFC3339
rfc1123bytes := expected.Format(time.RFC1123)

plaintext := []byte("Hello, World!")
rsakey, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) {
return
}
pubkey, err := jwk.PublicKeyOf(rsakey)
if !assert.NoError(t, err, `jwk.PublicKeyOf() should succeed`) {
return
}
require.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`)

protected := jwe.NewHeaders()
protected.Set(`x-birthday`, string(bdaybytes))
pubkey, err := jwk.PublicKeyOf(rsakey)
require.NoError(t, err, `jwk.PublicKeyOf() should succeed`)

encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected))
if !assert.NoError(t, err, `jwe.Encrypt should succeed`) {
return
}
t.Run("jwe.Parse", func(t *testing.T) {
protected := jwe.NewHeaders()
protected.Set(rfc3339Key, string(rfc3339bytes))
protected.Set(rfc1123Key, rfc1123bytes)

t.Run("jwe.Parse + json.Unmarshal", func(t *testing.T) {
encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected))
require.NoError(t, err, `jwe.Encrypt should succeed`)
msg, err := jwe.Parse(encrypted)
if !assert.NoError(t, err, `jwe.Parse should succeed`) {
t.Logf("%q", encrypted)
return
}

var v time.Time
require.NoError(t, msg.ProtectedHeaders().Get(`x-birthday`, &v), `msg.ProtectedHeaders().Get("x-birthday") should succeed`)
if !assert.Equal(t, expected, v, `values should match`) {
return
}

// Create JSON from jwe.Message
buf, err := json.Marshal(msg)
if !assert.NoError(t, err, `json.Marshal should succeed`) {
return
for _, key := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, msg.ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key)
require.Equal(t, expected, v, `values should match`)
}

var msg2 jwe.Message
if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) {
})
t.Run("json.Unmarshal", func(t *testing.T) {
protected := jwe.NewHeaders()
protected.Set(rfc3339Key, string(rfc3339bytes))
protected.Set(rfc1123Key, rfc1123bytes)

encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected), jwe.WithJSON())
require.NoError(t, err, `jwe.Encrypt should succeed`)
msg := jwe.NewMessage()
if !assert.NoError(t, json.Unmarshal(encrypted, msg), `json.Unmarshal should succeed`) {
return
}

v = time.Time{} // reset
require.NoError(t, msg2.ProtectedHeaders().Get(`x-birthday`, &v), `msg2.ProtectedHeaders().Get("x-birthday") should succeed`)
if !assert.Equal(t, expected, v, `values should match`) {
return
for _, key := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, msg.ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key)
require.Equal(t, expected, v, `values should match`)
}
})
}
Expand Down
21 changes: 21 additions & 0 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,9 @@ func asnEncode(key Key) (string, []byte, error) {
}
}

type CustomDecoder = json.CustomDecoder
type CustomDecodeFunc = json.CustomDecodeFunc

// RegisterCustomField allows users to specify that a private field
// be decoded as an instance of the specified type. This option has
// a global effect.
Expand All @@ -752,6 +755,24 @@ func asnEncode(key Key) (string, []byte, error) {
//
// var bday time.Time
// _ = key.Get(`x-birthday`, &bday)
//
// If you need a more fine-tuned control over the decoding process,
// you can register a `CustomDecoder`. For example, below shows
// how to register a decoder that can parse RFC1123 format string:
//
// jwk.RegisterCustomField(`x-birthday`, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) {
// return time.Parse(time.RFC1123, string(data))
// }))
//
// Please note that use of custom fields can be problematic if you
// are using a library that does not implement MarshalJSON/UnmarshalJSON
// and you try to roundtrip from an object to JSON, and then back to an object.
// For example, in the above example, you can _parse_ time values formatted
// in the format specified in RFC822, but when you convert an object into
// JSON, it will be formatted in RFC3339, because that's what `time.Time`
// likes to do. To avoid this, it's always better to use a custom type
// that wraps your desired type (in this case `time.Time`) and implement
// MarshalJSON and UnmashalJSON.
func RegisterCustomField(name string, object interface{}) {
registry.Register(name, object)
}
Expand Down
37 changes: 27 additions & 10 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1393,16 +1393,34 @@ func TestOKP(t *testing.T) {
}

func TestCustomField(t *testing.T) {
const rfc3339Key = `x-rfc3339-key`
const rfc1123Key = `x-rfc1123-key`

// XXX has global effect!!!
jwk.RegisterCustomField(`x-birthday`, time.Time{})
defer jwk.RegisterCustomField(`x-birthday`, nil)
jwk.RegisterCustomField(rfc3339Key, time.Time{})
jwk.RegisterCustomField(rfc1123Key, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return time.Parse(time.RFC1123, s)
}))
defer jwk.RegisterCustomField(rfc3339Key, nil)
defer jwk.RegisterCustomField(rfc1123Key, nil)

expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
bdaybytes, _ := expected.MarshalText() // RFC3339
rfc3339bytes, _ := expected.MarshalText() // RFC3339
rfc1123bytes := expected.Format(time.RFC1123)

var b strings.Builder
b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","x-birthday":"`)
b.Write(bdaybytes)
b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","`)
b.WriteString(rfc3339Key)
b.WriteString(`":"`)
b.Write(rfc3339bytes)
b.WriteString(`","`)
b.WriteString(rfc1123Key)
b.WriteString(`":"`)
b.WriteString(rfc1123bytes)
b.WriteString(`"}`)
src := b.String()

Expand All @@ -1412,11 +1430,10 @@ func TestCustomField(t *testing.T) {
return
}

var v interface{}
require.NoError(t, key.Get(`x-birthday`, &v), `key.Get("x-birthday") should succeed`)

if !assert.Equal(t, expected, v, `values should match`) {
return
for _, name := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, key.Get(name, &v), `key.Get(%q) should succeed`, name)
require.Equal(t, expected, v, `values should match`)
}
})
}
Expand Down
Loading
Loading