Skip to content

Commit

Permalink
feat(proxy): flag to keep forwarded headers (#706)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr authored Jul 18, 2023
1 parent 3b15b7c commit e77f73f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 27 deletions.
2 changes: 1 addition & 1 deletion jsonx/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strconv"
"strings"

"github.com/evanphx/json-patch/v5"
jsonpatch "github.com/evanphx/json-patch/v5"

"github.com/ory/x/pointerx"
)
Expand Down
46 changes: 31 additions & 15 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (

type (
RespMiddleware func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error)
ReqMiddleware func(req *http.Request, config *HostConfig, body []byte) ([]byte, error)
HostMapper func(ctx context.Context, r *http.Request) (*HostConfig, error)
ReqMiddleware func(req *httputil.ProxyRequest, config *HostConfig, body []byte) ([]byte, error)
HostMapper func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error)
options struct {
hostMapper HostMapper
onResError func(*http.Response, error) error
Expand Down Expand Up @@ -52,6 +52,9 @@ type (
// PathPrefix is a prefix that is prepended on the original host,
// but removed before forwarding.
PathPrefix string
// TrustForwardedHosts is a flag that indicates whether the proxy should trust the
// X-Forwarded-* headers or not.
TrustForwardedHeaders bool
// originalHost the original hostname the request is coming from.
// This value will be maintained internally by the proxy.
originalHost string
Expand Down Expand Up @@ -96,16 +99,28 @@ func rewriter(o *options) func(*httputil.ProxyRequest) {
ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "x.proxy")
defer span.End()

c, err := o.getHostConfig(r.Out)
ctx, c, err := o.getHostConfig(ctx, r.In)
if err != nil {
o.onReqError(r.Out, err)
return
}

if c.TrustForwardedHeaders {
headers := []string{
"X-Forwarded-Host",
"X-Forwarded-Proto",
"X-Forwarded-For",
}
for _, h := range headers {
if v := r.In.Header.Get(h); v != "" {
r.Out.Header.Set(h, v)
}
}
}

c.setScheme(r)
c.setHost(r)

*r.Out = *r.Out.WithContext(context.WithValue(ctx, hostConfigKey, c))
headerRequestRewrite(r.Out, c)

var body []byte
Expand All @@ -120,7 +135,7 @@ func rewriter(o *options) func(*httputil.ProxyRequest) {
}

for _, m := range o.reqMiddlewares {
if body, err = m(r.Out, c, body); err != nil {
if body, err = m(r, c, body); err != nil {
o.onReqError(r.Out, err)
return
}
Expand All @@ -141,7 +156,7 @@ func rewriter(o *options) func(*httputil.ProxyRequest) {
// modifyResponse is a custom internal function for altering a http.Response
func modifyResponse(o *options) func(*http.Response) error {
return func(r *http.Response) error {
c, err := o.getHostConfig(r.Request)
_, c, err := o.getHostConfig(r.Request.Context(), r.Request)
if err != nil {
return err
}
Expand Down Expand Up @@ -209,24 +224,24 @@ func WithErrorHandler(eh func(w http.ResponseWriter, r *http.Request, err error)
}
}

func (o *options) getHostConfig(r *http.Request) (*HostConfig, error) {
if cached, ok := r.Context().Value(hostConfigKey).(*HostConfig); ok && cached != nil {
return cached, nil
func (o *options) getHostConfig(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
if cached, ok := ctx.Value(hostConfigKey).(*HostConfig); ok && cached != nil {
return ctx, cached, nil
}
c, err := o.hostMapper(r.Context(), r)
ctx, c, err := o.hostMapper(ctx, r)
if err != nil {
return nil, err
return nil, nil, err
}
// cache the host config in the request context
// this will be passed on to the request and response proxy functions
*r = *r.WithContext(context.WithValue(r.Context(), hostConfigKey, c))
return c, nil
ctx = context.WithValue(ctx, hostConfigKey, c)
return ctx, c, nil
}

func (o *options) beforeProxyMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
// get the hostmapper configurations before the request is proxied
c, err := o.getHostConfig(request)
ctx, c, err := o.getHostConfig(request.Context(), request)
if err != nil {
o.onReqError(request, err)
return
Expand All @@ -237,7 +252,8 @@ func (o *options) beforeProxyMiddleware(h http.Handler) http.Handler {
if c.CorsEnabled && c.CorsOptions != nil {
cors.New(*c.CorsOptions).HandlerFunc(writer, request)
}
h.ServeHTTP(writer, request)

h.ServeHTTP(writer, request.WithContext(ctx))
})
}

Expand Down
23 changes: 12 additions & 11 deletions proxy/proxy_full_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ func TestFullIntegration(t *testing.T) {
onErrorResp := make(chan CustomErrorResp)

proxy := httptest.NewTLSServer(New(
func(_ context.Context, r *http.Request) (*HostConfig, error) {
return (<-hostMapper)(r)
func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
c, err := (<-hostMapper)(r)
return ctx, c, err
},
WithTransport(upstreamServer.Client().Transport),
WithReqMiddleware(func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) {
WithReqMiddleware(func(req *httputil.ProxyRequest, config *HostConfig, body []byte) ([]byte, error) {
f := <-reqMiddleware
if f == nil {
return body, nil
Expand Down Expand Up @@ -268,8 +269,8 @@ func TestFullIntegration(t *testing.T) {
assert.Equal(t, "OK", string(body))
assert.Equal(t, "1234", r.Header.Get("Some-Header"))
},
reqMiddleware: func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) {
req.Host = "noauth.example.com"
reqMiddleware: func(req *httputil.ProxyRequest, config *HostConfig, body []byte) ([]byte, error) {
req.Out.Host = "noauth.example.com"
return []byte("this is a new body"), nil
},
respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
Expand Down Expand Up @@ -538,8 +539,8 @@ func TestBetweenReverseProxies(t *testing.T) {
revProxyHandler := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(target.URL))
revProxy := httptest.NewServer(revProxyHandler)

thisProxy := httptest.NewServer(New(func(ctx context.Context, _ *http.Request) (*HostConfig, error) {
return &HostConfig{
thisProxy := httptest.NewServer(New(func(ctx context.Context, _ *http.Request) (context.Context, *HostConfig, error) {
return ctx, &HostConfig{
CookieDomain: "sh",
UpstreamHost: urlx.ParseOrPanic(revProxy.URL).Host,
UpstreamScheme: urlx.ParseOrPanic(revProxy.URL).Scheme,
Expand Down Expand Up @@ -635,8 +636,8 @@ func TestProxyProtoMix(t *testing.T) {
upstream.Transport = targetServer.Client().Transport
upstreamServer := upstreamServerFunc(upstream)

proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) {
return &HostConfig{
proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
return ctx, &HostConfig{
CookieDomain: exposedHost,
UpstreamHost: urlx.ParseOrPanic(upstreamServer.URL).Host,
UpstreamScheme: urlx.ParseOrPanic(upstreamServer.URL).Scheme,
Expand Down Expand Up @@ -760,8 +761,8 @@ func TestProxyWebsocketRequests(t *testing.T) {
}

setupProxy := func(targetServer *httptest.Server) *httptest.Server {
proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) {
return &HostConfig{
proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
return ctx, &HostConfig{
UpstreamHost: urlx.ParseOrPanic(targetServer.URL).Host,
UpstreamScheme: urlx.ParseOrPanic(targetServer.URL).Scheme,
TargetHost: urlx.ParseOrPanic(targetServer.URL).Host,
Expand Down

0 comments on commit e77f73f

Please sign in to comment.