forked from auth0/go-jwt-middleware
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjwtmiddleware_test.go
221 lines (198 loc) · 7.62 KB
/
jwtmiddleware_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
package jwtmiddleware
import (
"encoding/json"
"fmt"
"github.com/codegangsta/negroni"
"github.com/henningda/jwt-go"
"github.com/gorilla/context"
"github.com/gorilla/mux"
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// defaultAuthorizationHeaderName is the default header name where the Auth
// token should be written
const defaultAuthorizationHeaderName = "Authorization"
// envVarClientSecretName the environment variable to read the JWT environment
// variable
const envVarClientSecretName = "CLIENT_SECRET_VAR_SHHH"
// userPropertyName is the property name that will be set in the request context
const userPropertyName = "custom-user-property"
// the bytes read from the keys/sample-key file
// private key generated with http://kjur.github.io/jsjws/tool_jwt.html
var privateKey []byte = nil
// TestUnauthenticatedRequest will perform requests with no Authorization header
func TestUnauthenticatedRequest(t *testing.T) {
Convey("Simple unauthenticated request", t, func() {
Convey("Unauthenticated GET to / path should return a 200 reponse", func() {
w := makeUnauthenticatedRequest("GET", "/")
So(w.Code, ShouldEqual, http.StatusOK)
})
Convey("Unauthenticated GET to /protected path should return a 401 reponse", func() {
w := makeUnauthenticatedRequest("GET", "/protected")
So(w.Code, ShouldEqual, http.StatusUnauthorized)
})
})
}
// TestUnauthenticatedRequest will perform requests with no Authorization header
func TestAuthenticatedRequest(t *testing.T) {
var e error
privateKey, e = readPrivateKey()
if e != nil {
panic(e)
}
Convey("Simple unauthenticated request", t, func() {
Convey("Authenticated GET to / path should return a 200 reponse", func() {
w := makeAuthenticatedRequest("GET", "/", map[string]interface{}{"foo": "bar"}, nil)
So(w.Code, ShouldEqual, http.StatusOK)
})
Convey("Authenticated GET to /protected path should return a 200 reponse if expected algorithm is not specified", func() {
var expectedAlgorithm jwt.SigningMethod
expectedAlgorithm = nil
w := makeAuthenticatedRequest("GET", "/protected", map[string]interface{}{"foo": "bar"}, expectedAlgorithm)
So(w.Code, ShouldEqual, http.StatusOK)
responseBytes, err := ioutil.ReadAll(w.Body)
if err != nil {
panic(err)
}
responseString := string(responseBytes)
// check that the encoded data in the jwt was properly returned as json
So(responseString, ShouldEqual, `{"text":"bar"}`)
})
Convey("Authenticated GET to /protected path should return a 200 reponse if expected algorithm is correct", func() {
expectedAlgorithm := jwt.SigningMethodHS256
w := makeAuthenticatedRequest("GET", "/protected", map[string]interface{}{"foo": "bar"}, expectedAlgorithm)
So(w.Code, ShouldEqual, http.StatusOK)
responseBytes, err := ioutil.ReadAll(w.Body)
if err != nil {
panic(err)
}
responseString := string(responseBytes)
// check that the encoded data in the jwt was properly returned as json
So(responseString, ShouldEqual, `{"text":"bar"}`)
})
Convey("Authenticated GET to /protected path should return a 401 reponse if algorithm is not expected one", func() {
expectedAlgorithm := jwt.SigningMethodRS256
w := makeAuthenticatedRequest("GET", "/protected", map[string]interface{}{"foo": "bar"}, expectedAlgorithm)
So(w.Code, ShouldEqual, http.StatusUnauthorized)
responseBytes, err := ioutil.ReadAll(w.Body)
if err != nil {
panic(err)
}
responseString := string(responseBytes)
// check that the encoded data in the jwt was properly returned as json
So(strings.TrimSpace(responseString), ShouldEqual, "Expected RS256 signing method but token specified HS256")
})
})
}
func makeUnauthenticatedRequest(method string, url string) *httptest.ResponseRecorder {
return makeAuthenticatedRequest(method, url, nil, nil)
}
func makeAuthenticatedRequest(method string, url string, c map[string]interface{}, expectedSignatureAlgorithm jwt.SigningMethod) *httptest.ResponseRecorder {
r, _ := http.NewRequest(method, url, nil)
if c != nil {
token := jwt.New(jwt.SigningMethodHS256)
token.Claims = c
// private key generated with http://kjur.github.io/jsjws/tool_jwt.html
s, e := token.SignedString(privateKey)
if e != nil {
panic(e)
}
r.Header.Set(defaultAuthorizationHeaderName, fmt.Sprintf("bearer %v", s))
}
w := httptest.NewRecorder()
n := createNegroniMiddleware(expectedSignatureAlgorithm)
n.ServeHTTP(w, r)
return w
}
func createNegroniMiddleware(expectedSignatureAlgorithm jwt.SigningMethod) *negroni.Negroni {
// create a gorilla mux router for public requests
publicRouter := mux.NewRouter().StrictSlash(true)
publicRouter.Methods("GET").
Path("/").
Name("Index").
Handler(http.HandlerFunc(indexHandler))
// create a gorilla mux route for protected requests
// the routes will be tested for jwt tokens in the default auth header
protectedRouter := mux.NewRouter().StrictSlash(true)
protectedRouter.Methods("GET").
Path("/protected").
Name("Protected").
Handler(http.HandlerFunc(protectedHandler))
// create a negroni handler for public routes
negPublic := negroni.New()
negPublic.UseHandler(publicRouter)
// negroni handler for api request
negProtected := negroni.New()
//add the JWT negroni handler
negProtected.Use(negroni.HandlerFunc(JWT(expectedSignatureAlgorithm).HandlerWithNext))
negProtected.UseHandler(protectedRouter)
//Create the main router
mainRouter := mux.NewRouter().StrictSlash(true)
mainRouter.Handle("/", negPublic)
mainRouter.Handle("/protected", negProtected)
//if routes match the handle prefix then I need to add this dummy matcher {_dummy:.*}
mainRouter.Handle("/protected/{_dummy:.*}", negProtected)
n := negroni.Classic()
// This are the "GLOBAL" middlewares that will be applied to every request
// examples are listed below:
//n.Use(gzip.Gzip(gzip.DefaultCompression))
//n.Use(negroni.HandlerFunc(SecurityMiddleware().HandlerFuncWithNext))
n.UseHandler(mainRouter)
return n
}
// JWT creates the middleware that parses a JWT encoded token
func JWT(expectedSignatureAlgorithm jwt.SigningMethod) *JWTMiddleware {
return New(Options{
Debug: false,
CredentialsOptional: false,
UserProperty: userPropertyName,
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
if privateKey == nil {
var err error
privateKey, err = readPrivateKey()
if err != nil {
panic(err)
}
}
return privateKey, nil
},
SigningMethod: expectedSignatureAlgorithm,
})
}
// readPrivateKey will load the keys/sample-key file into the
// global privateKey variable
func readPrivateKey() ([]byte, error) {
privateKey, e := ioutil.ReadFile("keys/sample-key")
return privateKey, e
}
// indexHandler will return an empty 200 OK response
func indexHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
// protectedHandler will return the content of the "foo" encoded data
// in the token as json -> {"text":"bar"}
func protectedHandler(w http.ResponseWriter, r *http.Request) {
// retrieve the token from the context (Gorilla context lib)
u := context.Get(r, userPropertyName)
user := u.(*jwt.Token)
respondJson(user.Claims["foo"].(string), w)
}
// Response quick n' dirty Response struct to be encoded as json
type Response struct {
Text string `json:"text"`
}
// respondJson will take an string to write through the writer as json
func respondJson(text string, w http.ResponseWriter) {
response := Response{text}
jsonResponse, err := json.Marshal(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonResponse)
}