Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(proxy): flag to keep forwarded headers #706

Merged
merged 6 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jsonx/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strconv"
"strings"

"github.com/evanphx/json-patch/v5"
jsonpatch "github.com/evanphx/json-patch/v5"

"github.com/ory/x/pointerx"
)
Expand Down
46 changes: 31 additions & 15 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (

type (
RespMiddleware func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error)
ReqMiddleware func(req *http.Request, config *HostConfig, body []byte) ([]byte, error)
HostMapper func(ctx context.Context, r *http.Request) (*HostConfig, error)
ReqMiddleware func(req *httputil.ProxyRequest, config *HostConfig, body []byte) ([]byte, error)
HostMapper func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error)
Comment on lines +20 to +21
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no longer possible to use the context properly because we don't know whether we get the context from req.In or req.Out now. So context needs to be explicit now.

options struct {
hostMapper HostMapper
onResError func(*http.Response, error) error
Expand Down Expand Up @@ -52,6 +52,9 @@ type (
// PathPrefix is a prefix that is prepended on the original host,
// but removed before forwarding.
PathPrefix string
// TrustForwardedHosts is a flag that indicates whether the proxy should trust the
// X-Forwarded-* headers or not.
TrustForwardedHeaders bool
// originalHost the original hostname the request is coming from.
// This value will be maintained internally by the proxy.
originalHost string
Expand Down Expand Up @@ -96,16 +99,28 @@ func rewriter(o *options) func(*httputil.ProxyRequest) {
ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "x.proxy")
defer span.End()

c, err := o.getHostConfig(r.Out)
ctx, c, err := o.getHostConfig(ctx, r.In)
if err != nil {
o.onReqError(r.Out, err)
return
}

if c.TrustForwardedHeaders {
headers := []string{
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes the regression

"X-Forwarded-Host",
"X-Forwarded-Proto",
"X-Forwarded-For",
}
for _, h := range headers {
if v := r.In.Header.Get(h); v != "" {
r.Out.Header.Set(h, v)
}
}
}

c.setScheme(r)
c.setHost(r)

*r.Out = *r.Out.WithContext(context.WithValue(ctx, hostConfigKey, c))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already done by getHostConfig

headerRequestRewrite(r.Out, c)

var body []byte
Expand All @@ -120,7 +135,7 @@ func rewriter(o *options) func(*httputil.ProxyRequest) {
}

for _, m := range o.reqMiddlewares {
if body, err = m(r.Out, c, body); err != nil {
if body, err = m(r, c, body); err != nil {
o.onReqError(r.Out, err)
return
}
Expand All @@ -141,7 +156,7 @@ func rewriter(o *options) func(*httputil.ProxyRequest) {
// modifyResponse is a custom internal function for altering a http.Response
func modifyResponse(o *options) func(*http.Response) error {
return func(r *http.Response) error {
c, err := o.getHostConfig(r.Request)
_, c, err := o.getHostConfig(r.Request.Context(), r.Request)
if err != nil {
return err
}
Expand Down Expand Up @@ -209,24 +224,24 @@ func WithErrorHandler(eh func(w http.ResponseWriter, r *http.Request, err error)
}
}

func (o *options) getHostConfig(r *http.Request) (*HostConfig, error) {
if cached, ok := r.Context().Value(hostConfigKey).(*HostConfig); ok && cached != nil {
return cached, nil
func (o *options) getHostConfig(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
if cached, ok := ctx.Value(hostConfigKey).(*HostConfig); ok && cached != nil {
return ctx, cached, nil
}
c, err := o.hostMapper(r.Context(), r)
ctx, c, err := o.hostMapper(ctx, r)
if err != nil {
return nil, err
return nil, nil, err
}
// cache the host config in the request context
// this will be passed on to the request and response proxy functions
*r = *r.WithContext(context.WithValue(r.Context(), hostConfigKey, c))
return c, nil
ctx = context.WithValue(ctx, hostConfigKey, c)
return ctx, c, nil
}

func (o *options) beforeProxyMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
// get the hostmapper configurations before the request is proxied
c, err := o.getHostConfig(request)
ctx, c, err := o.getHostConfig(request.Context(), request)
if err != nil {
o.onReqError(request, err)
return
Expand All @@ -237,7 +252,8 @@ func (o *options) beforeProxyMiddleware(h http.Handler) http.Handler {
if c.CorsEnabled && c.CorsOptions != nil {
cors.New(*c.CorsOptions).HandlerFunc(writer, request)
}
h.ServeHTTP(writer, request)

h.ServeHTTP(writer, request.WithContext(ctx))
})
}

Expand Down
23 changes: 12 additions & 11 deletions proxy/proxy_full_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ func TestFullIntegration(t *testing.T) {
onErrorResp := make(chan CustomErrorResp)

proxy := httptest.NewTLSServer(New(
func(_ context.Context, r *http.Request) (*HostConfig, error) {
return (<-hostMapper)(r)
func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
c, err := (<-hostMapper)(r)
return ctx, c, err
},
WithTransport(upstreamServer.Client().Transport),
WithReqMiddleware(func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) {
WithReqMiddleware(func(req *httputil.ProxyRequest, config *HostConfig, body []byte) ([]byte, error) {
f := <-reqMiddleware
if f == nil {
return body, nil
Expand Down Expand Up @@ -268,8 +269,8 @@ func TestFullIntegration(t *testing.T) {
assert.Equal(t, "OK", string(body))
assert.Equal(t, "1234", r.Header.Get("Some-Header"))
},
reqMiddleware: func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) {
req.Host = "noauth.example.com"
reqMiddleware: func(req *httputil.ProxyRequest, config *HostConfig, body []byte) ([]byte, error) {
req.Out.Host = "noauth.example.com"
return []byte("this is a new body"), nil
},
respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
Expand Down Expand Up @@ -538,8 +539,8 @@ func TestBetweenReverseProxies(t *testing.T) {
revProxyHandler := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(target.URL))
revProxy := httptest.NewServer(revProxyHandler)

thisProxy := httptest.NewServer(New(func(ctx context.Context, _ *http.Request) (*HostConfig, error) {
return &HostConfig{
thisProxy := httptest.NewServer(New(func(ctx context.Context, _ *http.Request) (context.Context, *HostConfig, error) {
return ctx, &HostConfig{
CookieDomain: "sh",
UpstreamHost: urlx.ParseOrPanic(revProxy.URL).Host,
UpstreamScheme: urlx.ParseOrPanic(revProxy.URL).Scheme,
Expand Down Expand Up @@ -635,8 +636,8 @@ func TestProxyProtoMix(t *testing.T) {
upstream.Transport = targetServer.Client().Transport
upstreamServer := upstreamServerFunc(upstream)

proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) {
return &HostConfig{
proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
return ctx, &HostConfig{
CookieDomain: exposedHost,
UpstreamHost: urlx.ParseOrPanic(upstreamServer.URL).Host,
UpstreamScheme: urlx.ParseOrPanic(upstreamServer.URL).Scheme,
Expand Down Expand Up @@ -760,8 +761,8 @@ func TestProxyWebsocketRequests(t *testing.T) {
}

setupProxy := func(targetServer *httptest.Server) *httptest.Server {
proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) {
return &HostConfig{
proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
return ctx, &HostConfig{
UpstreamHost: urlx.ParseOrPanic(targetServer.URL).Host,
UpstreamScheme: urlx.ParseOrPanic(targetServer.URL).Scheme,
TargetHost: urlx.ParseOrPanic(targetServer.URL).Host,
Expand Down
Loading