-
Notifications
You must be signed in to change notification settings - Fork 0
/
http_context.go
149 lines (127 loc) · 4.12 KB
/
http_context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package main
import (
"fmt"
"strings"
"github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm"
"github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types"
)
const (
StatusOK = 200
StatusBadRequest = 400
StatusTooManyRequests = 429
StatusServerError = 500
)
type HeaderFilter interface {
FilterHeaders(headers map[string][]string) (int, error)
}
type BodyFilter interface {
FilterBody(body []byte) (int, error)
}
type HTTPContext struct {
types.DefaultHttpContext
hFilters []HeaderFilter
bFilters []BodyFilter
}
func NewHTTPContext(config *Config) *HTTPContext {
httpCtx := &HTTPContext{}
rlFilter := NewRateLimitFilter(config.RateLimitRequests, config.RateLimitInterval)
sqlFilter := NewSQLFilter(config.SQLKeywords)
return httpCtx.WithHeaderFilters(rlFilter, sqlFilter).WithBodyFilters(sqlFilter)
}
func (h *HTTPContext) WithBodyFilters(f ...BodyFilter) *HTTPContext {
h.bFilters = append(h.bFilters, f...)
return h
}
func (h *HTTPContext) WithHeaderFilters(f ...HeaderFilter) *HTTPContext {
h.hFilters = append(h.hFilters, f...)
return h
}
// OnHttpRequestHeaders is called when HTTP headers are received.
// The HTTPContext calls each registered HeaderFilter with the headers.
// If all filters return successfully the request continues. If any filter
// returns with an error a response is immediately returned to the
// downstream client and the request is terminated.
func (h *HTTPContext) OnHttpRequestHeaders(_ int, _ bool) types.Action {
headers, err := GetHTTPRequestHeaders()
if err != nil {
proxywasm.LogErrorf("failed to get request headers: %v", err)
return types.ActionContinue
}
for _, f := range h.hFilters {
status, err := f.FilterHeaders(headers)
if err != nil {
proxywasm.LogInfo(err.Error())
err = proxywasm.SendHttpResponse(uint32(status), nil, []byte(err.Error()), -1)
if err != nil {
proxywasm.LogErrorf("failed to send HTTP response: %v", err)
}
return types.ActionPause
}
}
return types.ActionContinue
}
// OnHttpRequestBody is called when an HTTP request body is received.
// The HTTPContext calls each registered BodyFilter with the body content.
// If all filters return successfully the request continues. If any filter
// returns with an error a response is immediately returned to the
// downstream client and the request is terminated.
func (h *HTTPContext) OnHttpRequestBody(size int, eos bool) types.Action {
proxywasm.LogDebugf("OnHttpRequestBody called: size = %d, eos = %t", size, eos)
if !eos {
// If we haven't reached the end of the stream then return ActionPause until
// we get all the data. We buffer it so that we don't stream the partial
// data to the upstream until we have inspected the entire contents of the
// body.
return types.ActionPause
}
body, err := proxywasm.GetHttpRequestBody(0, size)
if err != nil {
proxywasm.LogErrorf("failed to get request body: %v", err)
err = proxywasm.SendHttpResponse(uint32(StatusServerError), nil, []byte(err.Error()), -1)
if err != nil {
proxywasm.LogErrorf("failed to read request body: %v", err)
}
return types.ActionPause
}
for _, f := range h.bFilters {
status, err := f.FilterBody(body)
if err != nil {
proxywasm.LogInfo(err.Error())
err = proxywasm.SendHttpResponse(uint32(status), nil, []byte(err.Error()), -1)
if err != nil {
proxywasm.LogErrorf("failed to send HTTP response: %v", err)
}
return types.ActionPause
}
}
return types.ActionContinue
}
func GetHTTPRequestHeaders() (map[string][]string, error) {
headers := make(map[string][]string)
rawHeaders, err := proxywasm.GetHttpRequestHeaders()
if err != nil {
return headers, fmt.Errorf("failed to get HTTP headers: %w", err)
}
for _, header := range rawHeaders {
rawValues := strings.Split(header[1], ",")
var values []string
if v, exists := headers[header[0]]; exists {
values = v
} else {
values = make([]string, 0, len(rawValues))
}
headers[header[0]] = append(values, rawValues...)
}
return headers, nil
}
func Retry(attempts int, f func() error) error {
var err error
var n int
for attempts > 0 && n < attempts {
err = f()
if err == nil {
return nil
}
}
return err
}