Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensuring that the proper host is set on all Location headers #404

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions interceptor/request_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"strings"
)

func forwardRequest(
Expand All @@ -20,8 +21,9 @@ func forwardRequest(
req.Host = fwdSvcURL.Host
req.URL.Path = r.URL.Path
req.URL.RawQuery = r.URL.RawQuery
// delete the incoming X-Forwarded-For header so the proxy
// puts its own in. This is also important to prevent IP spoofing
// delete the incoming X-Forwarded-For header so the
// proxy puts its own in. This is also important to
// prevent IP spoofing
req.Header.Del("X-Forwarded-For ")
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
Expand All @@ -33,5 +35,39 @@ func forwardRequest(
w.Write([]byte(errMsg))
}

proxy.ModifyResponse = func(resp *http.Response) error {
locHdr := resp.Header.Get("Location")
if locHdr == "" {
// if there's no location header, no action
// required
return nil
}
if !strings.HasPrefix(locHdr, "http://") {
// ensure that the location header
// starts with the scheme so that
// url.Parse parses it properly
locHdr = "http://" + locHdr
}
// if there is a location header, and it has
// no host in it, we need to rewrite that
// a host in it, we need to rewrite that
// host to the originally requested one
locURL, err := url.Parse(locHdr)
if err != nil {
return err
}

// if the location header has no host,
// add the incoming host to it
if locURL.Hostname() == "" {
locURL.Host = r.Host
locURL.Scheme = ""
loc := strings.TrimPrefix(locURL.String(), "//")
resp.Header.Set("Location", loc)
}

return nil
}

proxy.ServeHTTP(w, r)
}
129 changes: 129 additions & 0 deletions interceptor/request_forwarder_redirect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

kedanet "github.com/kedacore/http-add-on/pkg/net"
"github.com/stretchr/testify/require"
)

func TestForwardRequestRedirectAndHeaders(t *testing.T) {
r := require.New(t)
hdrs := map[string]string{
"Content-Type": "text/html; charset=utf-8",
"X-Custom-Header": "somethingcustom",
"Location": "abc123.com/def",
}
hdl := standardHandler(301, "hello", hdrs)
// issue the request with a host.
// expect that host to not be set on the
// response header Location header. the host on
// that header should instead be what's listed
// on hdrs["Location"]
const host = "myhost.com"
const path = "/abc"
res, _, err := issueRequest(hdl, host, path)
r.NoError(err)
r.Equal(301, res.Code)
r.Equal("hello", res.Body.String())
for k, v := range hdrs {
// ensure that the all response headers show up exactly
// as they were returned from the origin (hdl)
r.Equal(v, res.Header().Get(k))
}
}

func TestForwardRequestPathOnlyRedirectAndHeaders(t *testing.T) {
r := require.New(t)
hdrs := map[string]string{
"Content-Type": "text/html; charset=utf-8",
"X-Custom-Header": "somethingcustom",
"Location": "/xyz",
}
hdl := standardHandler(301, "hello", hdrs)
// issue the request with a host.
// expect that host to not be set on the
// response header Location header. the host on
// that header should instead be what's listed
// on hdrs["Location"]
const host = "myhost.com"
const path = "/abc"
res, _, err := issueRequest(hdl, host, path)
r.NoError(err)
r.Equal(301, res.Code)
r.Equal("hello", res.Body.String())
for k, v := range hdrs {
if k == "Location" {
// the "Location" header should be the
// path returned by the origin (hdl)
// but it should also have the host
// that was passed to issueRequest
expected := fmt.Sprintf(
"%s%s",
host,
v,
)
r.Equal(expected, res.Header().Get(k))
} else {
// ensure that all response headers except for
// the "Location" header to show up exactly
// as they were returned from the origin (hdl)
r.Equal(v, res.Header().Get(k))
}
}
}

func standardHandler(
code int,
body string,
headers map[string]string,
) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
for k, v := range headers {
w.Header().Set(k, v)
}
w.WriteHeader(code)
w.Write([]byte(body))
}
}

func issueRequest(
hdl http.HandlerFunc,
inHost,
inPath string,
) (*httptest.ResponseRecorder, *http.Request, error) {
srv, srvURL, err := kedanet.StartTestServer(
kedanet.NewTestHTTPHandlerWrapper(
http.HandlerFunc(hdl),
),
)
if err != nil {
return nil, nil, err
}
defer srv.Close()

timeouts := defaultTimeouts()
timeouts.Connect = 10 * time.Millisecond
timeouts.ResponseHeader = 10 * time.Millisecond
backoff := timeouts.Backoff(2, 2, 1)
dialCtxFunc := retryDialContextFunc(timeouts, backoff)

res, req, err := reqAndRes(inPath)
if err != nil {
return nil, nil, err
}
req.Host = inHost

forwardRequest(
res,
req,
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
srvURL,
)
return res, req, nil

}
37 changes: 0 additions & 37 deletions interceptor/request_forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,40 +237,3 @@ func TestForwarderConnectionRetryAndTimeout(t *testing.T) {
)
r.Contains(res.Body.String(), "error on backend")
}

func TestForwardRequestRedirectAndHeaders(t *testing.T) {
r := require.New(t)

srv, srvURL, err := kedanet.StartTestServer(
kedanet.NewTestHTTPHandlerWrapper(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("X-Custom-Header", "somethingcustom")
w.Header().Set("Location", "abc123.com")
w.WriteHeader(301)
w.Write([]byte("Hello from srv"))
}),
),
)
r.NoError(err)
defer srv.Close()

timeouts := defaultTimeouts()
timeouts.Connect = 10 * time.Millisecond
timeouts.ResponseHeader = 10 * time.Millisecond
backoff := timeouts.Backoff(2, 2, 1)
dialCtxFunc := retryDialContextFunc(timeouts, backoff)
res, req, err := reqAndRes("/testfwd")
r.NoError(err)
forwardRequest(
res,
req,
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
srvURL,
)
r.Equal(301, res.Code)
r.Equal("abc123.com", res.Header().Get("Location"))
r.Equal("text/html; charset=utf-8", res.Header().Get("Content-Type"))
r.Equal("somethingcustom", res.Header().Get("X-Custom-Header"))
r.Equal("Hello from srv", res.Body.String())
}