-
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
cors.go
320 lines (278 loc) · 10.9 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
package cors
import (
"errors"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/kataras/iris/v12/context"
)
func init() {
context.SetHandlerName("iris/middleware/cors.*", "iris.cors")
}
var (
// ErrOriginNotAllowed is given to the error handler
// when the error is caused because an origin was not allowed to pass through.
ErrOriginNotAllowed = errors.New("origin not allowed")
// AllowAnyOrigin allows all origins to pass.
AllowAnyOrigin = func(_ *context.Context, _ string) bool {
return true
}
// DefaultErrorHandler is the default error handler which
// fires forbidden status (403) on disallowed origins.
DefaultErrorHandler = func(ctx *context.Context, _ error) {
ctx.StopWithStatus(http.StatusForbidden)
}
// DefaultOriginExtractor is the default method which
// an origin is extracted. It returns the value of the request's "Origin" header
// and always true, means that it allows empty origin headers as well.
DefaultOriginExtractor = func(ctx *context.Context) (string, bool) {
header := ctx.GetHeader(originRequestHeader)
return header, true
}
// StrictOriginExtractor is an ExtractOriginFunc type
// which is a bit more strictly than the DefaultOriginExtractor.
// It allows only non-empty "Origin" header values to be passed.
// If the header is missing, the middleware will not allow the execution
// of the next handler(s).
StrictOriginExtractor = func(ctx *context.Context) (string, bool) {
header := ctx.GetHeader(originRequestHeader)
return header, header != ""
}
)
type (
// ExtractOriginFunc describes the function which should return the request's origin or false.
ExtractOriginFunc = func(ctx *context.Context) (string, bool)
// AllowOriginFunc describes the function which is called when the
// middleware decides if the request's origin should be allowed or not.
AllowOriginFunc = func(ctx *context.Context, origin string) bool
// HandleErrorFunc describes the function which is fired
// when a request by a specific (or empty) origin was not allowed to pass through.
HandleErrorFunc = func(ctx *context.Context, err error)
// CORS holds the customizations developers can
// do on the cors middleware.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS.
CORS struct {
extractOriginFunc ExtractOriginFunc
allowOriginFunc AllowOriginFunc
errorHandler HandleErrorFunc
allowCredentialsValue string
exposeHeadersValue string
allowHeadersValue string
allowMethodsValue string
maxAgeSecondsValue string
referrerPolicyValue string
}
)
// New returns the default CORS middleware.
// For a more advanced type of protection middleware with more options
// please refer to: https://github.com/iris-contrib/middleware repository instead.
//
// Example Code:
//
// import "github.com/kataras/iris/v12/middleware/cors"
// import "github.com/kataras/iris/v12/x/errors"
//
// app.UseRouter(cors.New().
// HandleErrorFunc(func(ctx iris.Context, err error) {
// errors.FailedPrecondition.Err(ctx, err)
// }).
// ExtractOriginFunc(cors.StrictOriginExtractor).
// ReferrerPolicy(cors.NoReferrerWhenDowngrade).
// AllowOrigin("domain1.com,domain2.com,domain3.com").
// Handler())
func New() *CORS {
return &CORS{
extractOriginFunc: DefaultOriginExtractor,
allowOriginFunc: AllowAnyOrigin,
errorHandler: DefaultErrorHandler,
allowCredentialsValue: "true",
exposeHeadersValue: "*, Authorization, X-Authorization",
allowHeadersValue: "*",
// This field cannot be modified by the end-developer,
// as we have another type of controlling the HTTP verbs per handler.
allowMethodsValue: "*",
maxAgeSecondsValue: "86400",
referrerPolicyValue: NoReferrerWhenDowngrade.String(),
}
}
// ExtractOriginFunc sets the function which should return the request's origin.
func (c *CORS) ExtractOriginFunc(fn ExtractOriginFunc) *CORS {
c.extractOriginFunc = fn
return c
}
// AllowOriginFunc sets the function which decides if an origin(domain) is allowed
// to continue or not.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-origin.
func (c *CORS) AllowOriginFunc(fn AllowOriginFunc) *CORS {
c.allowOriginFunc = fn
return c
}
// AllowOrigin calls the "AllowOriginFunc" method
// and registers a function which accepts any incoming
// request with origin of the given "originLine".
// The originLine can contain one or more domains separated by comma.
// See "AllowOrigins" to set a list of strings instead.
func (c *CORS) AllowOrigin(originLine string) *CORS {
return c.AllowOrigins(strings.Split(originLine, ",")...)
}
// AllowOriginMatcherFunc sets the allow origin func without iris.Context
// as its first parameter, i.e. a regular expression.
func (c *CORS) AllowOriginMatcherFunc(fn func(origin string) bool) *CORS {
return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
return fn(origin)
})
}
// AllowOriginRegex calls the "AllowOriginFunc" method
// and registers a function which accepts any incoming
// request with origin that matches at least one of the given "regexpLines".
func (c *CORS) AllowOriginRegex(regexpLines ...string) *CORS {
matchers := make([]func(string) bool, 0, len(regexpLines))
for _, line := range regexpLines {
matcher := regexp.MustCompile(line).MatchString
matchers = append(matchers, matcher)
}
return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
for _, m := range matchers {
if m(origin) {
return true
}
}
return false
})
}
// AllowOrigins calls the "AllowOriginFunc" method
// and registers a function which accepts any incoming
// request with origin of one of the given "origins".
func (c *CORS) AllowOrigins(origins ...string) *CORS {
allowOrigins := make(map[string]struct{}, len(origins)) // read-only at serve time.
for _, origin := range origins {
if origin == "*" {
// If AllowOrigins called with asterix, it is a missuse of this
// middleware (set AllowAnyOrigin instead).
allowOrigins = nil
return c.AllowOriginFunc(AllowAnyOrigin)
// panic("wildcard is not allowed, use AllowOriginFunc(AllowAnyOrigin) instead")
// No ^ let's register a function which allows all and continue.
}
origin = strings.TrimSpace(origin)
allowOrigins[origin] = struct{}{}
}
return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
_, allow := allowOrigins[origin]
return allow
})
}
// HandleErrorFunc sets the function which is called
// when an error of origin not allowed is fired.
func (c *CORS) HandleErrorFunc(fn HandleErrorFunc) *CORS {
c.errorHandler = fn
return c
}
// DisallowCredentials sets the "Access-Control-Allow-Credentials" header to false.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-credentials.
func (c *CORS) DisallowCredentials() *CORS {
c.allowCredentialsValue = "false"
return c
}
// ExposeHeaders sets the "Access-Control-Expose-Headers" header value.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-expose-headers.
func (c *CORS) ExposeHeaders(headers ...string) *CORS {
c.exposeHeadersValue = strings.Join(headers, ", ")
return c
}
// AllowHeaders sets the "Access-Control-Allow-Headers" header value.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-headers.
func (c *CORS) AllowHeaders(headers ...string) *CORS {
c.allowHeadersValue = strings.Join(headers, ", ")
return c
}
// ReferrerPolicy type for referrer-policy header value.
type ReferrerPolicy string
// All available referrer policies.
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy.
const (
NoReferrer ReferrerPolicy = "no-referrer"
NoReferrerWhenDowngrade ReferrerPolicy = "no-referrer-when-downgrade"
Origin ReferrerPolicy = "origin"
OriginWhenCrossOrigin ReferrerPolicy = "origin-when-cross-origin"
SameOrigin ReferrerPolicy = "same-origin"
StrictOrigin ReferrerPolicy = "strict-origin"
StrictOriginWhenCrossOrigin ReferrerPolicy = "strict-origin-when-cross-origin"
UnsafeURL ReferrerPolicy = "unsafe-url"
)
// String returns the text representation of the "r" ReferrerPolicy.
func (r ReferrerPolicy) String() string {
return string(r)
}
// ReferrerPolicy sets the "Referrer-Policy" header value.
// Defaults to "no-referrer-when-downgrade".
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
// and https://developer.mozilla.org/en-US/docs/Web/Security/Referer_header:_privacy_and_security_concerns.
func (c *CORS) ReferrerPolicy(referrerPolicy ReferrerPolicy) *CORS {
c.referrerPolicyValue = referrerPolicy.String()
return c
}
// MaxAge sets the "Access-Control-Max-Age" header value.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-max-age.
func (c *CORS) MaxAge(d time.Duration) *CORS {
c.maxAgeSecondsValue = strconv.FormatFloat(d.Seconds(), 'E', -1, 64)
return c
}
const (
originRequestHeader = "Origin"
allowOriginHeader = "Access-Control-Allow-Origin"
allowCredentialsHeader = "Access-Control-Allow-Credentials"
referrerPolicyHeader = "Referrer-Policy"
exposeHeadersHeader = "Access-Control-Expose-Headers"
requestMethodHeader = "Access-Control-Request-Method"
requestHeadersHeader = "Access-Control-Request-Headers"
allowMethodsHeader = "Access-Control-Allow-Methods"
allowAllMethodsValue = "*"
allowHeadersHeader = "Access-Control-Allow-Headers"
maxAgeHeader = "Access-Control-Max-Age"
varyHeader = "Vary"
)
func (c *CORS) addVaryHeaders(ctx *context.Context) {
ctx.Header(varyHeader, originRequestHeader)
if ctx.Method() == http.MethodOptions {
ctx.Header(varyHeader, requestMethodHeader)
ctx.Header(varyHeader, requestHeadersHeader)
}
}
// Handler method returns the Iris CORS Handler with basic features.
// Note that the caller should NOT modify any of the CORS instance fields afterwards.
func (c *CORS) Handler() context.Handler {
return func(ctx *context.Context) {
c.addVaryHeaders(ctx) // add vary headers at any case.
origin, ok := c.extractOriginFunc(ctx)
if !ok || !c.allowOriginFunc(ctx, origin) {
c.errorHandler(ctx, ErrOriginNotAllowed)
return
}
if origin == "" { // if we allow empty origins, set it to wildcard.
origin = "*"
}
ctx.Header(allowOriginHeader, origin)
ctx.Header(allowCredentialsHeader, c.allowCredentialsValue)
// 08 July 2021 Mozzila updated the following document: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
ctx.Header(referrerPolicyHeader, c.referrerPolicyValue)
ctx.Header(exposeHeadersHeader, c.exposeHeadersValue)
if ctx.Method() == http.MethodOptions {
ctx.Header(allowMethodsHeader, allowAllMethodsValue)
ctx.Header(allowHeadersHeader, c.allowHeadersValue)
ctx.Header(maxAgeHeader, c.maxAgeSecondsValue)
ctx.StatusCode(http.StatusNoContent)
return
}
ctx.Next()
}
}