-
Notifications
You must be signed in to change notification settings - Fork 0
/
token_parser.go
211 lines (187 loc) · 5.86 KB
/
token_parser.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
package keycloak
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"sync"
"time"
"github.com/dgrijalva/jwt-go/v4"
)
// TokenParser represents any type that can handle parsing and persisting a range of certificate types
type TokenParser interface {
// Parse must attempt to validate the provided token was signed using the mechanism expected by the realm's issuer
Parse(context.Context, *APIClient, *jwt.Token) (pk interface{}, err error)
SupportedAlgorithms() []string
}
var (
tokenParsersMu sync.RWMutex
tokenParsers map[string]TokenParser
)
func init() {
tokenParsers = make(map[string]TokenParser)
RegisterTokenParsers(NewX509TokenParser(time.Hour))
}
func RegisterTokenParsers(parsers ...TokenParser) {
tokenParsersMu.Lock()
defer tokenParsersMu.Unlock()
for _, parser := range parsers {
for _, alg := range parser.SupportedAlgorithms() {
tokenParsers[alg] = parser
}
}
}
func GetTokenParser(alg string) (TokenParser, bool) {
tokenParsersMu.RLock()
tokenParsersMu.RUnlock()
parser, ok := tokenParsers[alg]
return parser, ok
}
type X509TokenParser struct {
mu sync.RWMutex
dttl time.Duration
}
// NewX509TokenParser will return to you a token parser capable of handling most RSA & ECDSA signed tokens and keys
func NewX509TokenParser(cacheTTL time.Duration) *X509TokenParser {
xtp := new(X509TokenParser)
xtp.dttl = cacheTTL
return xtp
}
func (tp *X509TokenParser) Parse(ctx context.Context, client *APIClient, token *jwt.Token) (interface{}, error) {
var (
kid string
cacheKey string
realmName string
pub interface{}
rpk *rsa.PublicKey
epk *ecdsa.PublicKey
expires *time.Time
ok bool
err error
authServerURL = client.AuthServerURL()
)
if token == nil {
return nil, errors.New("token is nil")
}
if !tp.supports(token.Method.Alg()) {
return nil, fmt.Errorf("cannot validate token with alg %q against public key of type %T", token.Method.Alg(), pub)
}
if v, ok := token.Header["kid"]; !ok {
return nil, errors.New("unable to locate \"kid\" field in token header")
} else if kid, ok = v.(string); !ok {
return nil, fmt.Errorf("token header key \"kid\" has non-string value: %v (%[1]T)", v)
}
if _, realmName, err = TokenSource(token.Raw); err != nil {
return nil, err
}
cacheKey = buildPKCacheKey(authServerURL, realmName, kid)
tp.mu.RLock()
if pub, ok = client.CacheBackend().Load(cacheKey); !ok {
tp.mu.RUnlock()
tp.mu.Lock()
defer tp.mu.Unlock()
if pub, ok = client.CacheBackend().Load(cacheKey); !ok {
if pub, expires, err = tp.fetchPK(ctx, client, realmName, kid); err != nil {
return nil, fmt.Errorf("error loading public key: %w", err)
}
client.CacheBackend().StoreUntil(cacheKey, pub, *expires)
}
} else {
defer tp.mu.RUnlock()
}
// perform some basic type assertions
if rpk, ok = pub.(*rsa.PublicKey); ok {
if _, ok = token.Method.(*jwt.SigningMethodRSA); ok {
return rpk, nil
} else if _, ok = token.Method.(*jwt.SigningMethodRSAPSS); ok {
return rpk, nil
}
} else if epk, ok = pub.(*ecdsa.PublicKey); ok {
if _, ok = token.Method.(*jwt.SigningMethodECDSA); ok {
return epk, nil
}
}
// todo: should not be possible?
return nil, fmt.Errorf("cannot validate token with alg %q against public key of type %T", token.Method.Alg(), pub)
}
func (*X509TokenParser) SupportedAlgorithms() []string {
return []string{
"RS256",
"RS384",
"RS512",
"PS256",
"PS384",
"PS512",
"ES256",
"ES384",
"ES512",
}
}
func (tp *X509TokenParser) supports(alg string) bool {
for _, a := range tp.SupportedAlgorithms() {
if a == alg {
return true
}
}
return false
}
func (tp *X509TokenParser) fetchPKLegacy(ctx context.Context, client *APIClient, realmName string) (interface{}, *time.Time, error) {
var (
conf *RealmIssuerConfiguration
b []byte
pub interface{}
err error
)
if conf, err = client.RealmIssuerConfiguration(ctx, realmName); err != nil {
return nil, nil, fmt.Errorf("error attempting to fetch public key from legacy realm info endpoint: %w", err)
}
if b, err = base64.StdEncoding.DecodeString(conf.PublicKey); err != nil {
return nil, nil, fmt.Errorf("error decoding public key from legacy realm info endpoint: %w", err)
}
if pub, err = x509.ParsePKIXPublicKey(b); err != nil {
return nil, nil, fmt.Errorf("error parsing public key from legacy realm info endpoint: %w", err)
}
exp := time.Now().Add(tp.dttl)
return pub, &exp, nil
}
func (tp *X509TokenParser) fetchPKByID(ctx context.Context, client *APIClient, realmName, kid string) (interface{}, *time.Time, error) {
var (
b []byte
jwks *JSONWebKeySet
jwk *JSONWebKey
cert *x509.Certificate
err error
)
if jwks, err = client.JSONWebKeys(ctx, realmName); err != nil {
return nil, nil, fmt.Errorf("error fetching json web keys: %w", err)
}
if jwk = jwks.KeychainByID(kid); jwk == nil {
return nil, nil, fmt.Errorf("issuer %q realm %q has no key with id %q", client.AuthServerURL(), realmName, kid)
}
// todo: use full chain
if len(jwk.X509CertificateChain) == 0 {
return nil, nil, errors.New("no certificates returned from json web keys endpoint")
}
if b, err = base64.StdEncoding.DecodeString(jwk.X509CertificateChain[0]); err != nil {
return nil, nil, fmt.Errorf("error decoding certificate %q: %w", kid, err)
}
if cert, err = x509.ParseCertificate(b); err != nil {
return nil, nil, fmt.Errorf("error parsing certificate %q: %w", kid, err)
}
return cert.PublicKey, &cert.NotAfter, nil
}
func (tp *X509TokenParser) fetchPK(ctx context.Context, client *APIClient, realmName, keyID string) (interface{}, *time.Time, error) {
env, err := client.RealmEnvironment(ctx, realmName)
if err != nil {
return nil, nil, err
}
if env.SupportsUMA2() {
if pk, deadline, err := tp.fetchPKByID(ctx, client, realmName, keyID); err == nil {
return pk, deadline, nil
}
}
return tp.fetchPKLegacy(ctx, client, realmName)
}