diff --git a/interceptor/request_forwarder.go b/interceptor/request_forwarder.go index 1ceef2f4..86dfbc86 100644 --- a/interceptor/request_forwarder.go +++ b/interceptor/request_forwarder.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "strings" ) func forwardRequest( @@ -20,8 +21,9 @@ func forwardRequest( req.Host = fwdSvcURL.Host req.URL.Path = r.URL.Path req.URL.RawQuery = r.URL.RawQuery - // delete the incoming X-Forwarded-For header so the proxy - // puts its own in. This is also important to prevent IP spoofing + // delete the incoming X-Forwarded-For header so the + // proxy puts its own in. This is also important to + // prevent IP spoofing req.Header.Del("X-Forwarded-For ") } proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { @@ -33,5 +35,39 @@ func forwardRequest( w.Write([]byte(errMsg)) } + proxy.ModifyResponse = func(resp *http.Response) error { + locHdr := resp.Header.Get("Location") + if locHdr == "" { + // if there's no location header, no action + // required + return nil + } + if !strings.HasPrefix(locHdr, "http://") { + // ensure that the location header + // starts with the scheme so that + // url.Parse parses it properly + locHdr = "http://" + locHdr + } + // if there is a location header, and it has + // no host in it, we need to rewrite that + // a host in it, we need to rewrite that + // host to the originally requested one + locURL, err := url.Parse(locHdr) + if err != nil { + return err + } + + // if the location header has no host, + // add the incoming host to it + if locURL.Hostname() == "" { + locURL.Host = r.Host + locURL.Scheme = "" + loc := strings.TrimPrefix(locURL.String(), "//") + resp.Header.Set("Location", loc) + } + + return nil + } + proxy.ServeHTTP(w, r) } diff --git a/interceptor/request_forwarder_redirect_test.go b/interceptor/request_forwarder_redirect_test.go new file mode 100644 index 00000000..51db8c72 --- /dev/null +++ b/interceptor/request_forwarder_redirect_test.go @@ -0,0 +1,129 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + kedanet "github.com/kedacore/http-add-on/pkg/net" + "github.com/stretchr/testify/require" +) + +func TestForwardRequestRedirectAndHeaders(t *testing.T) { + r := require.New(t) + hdrs := map[string]string{ + "Content-Type": "text/html; charset=utf-8", + "X-Custom-Header": "somethingcustom", + "Location": "abc123.com/def", + } + hdl := standardHandler(301, "hello", hdrs) + // issue the request with a host. + // expect that host to not be set on the + // response header Location header. the host on + // that header should instead be what's listed + // on hdrs["Location"] + const host = "myhost.com" + const path = "/abc" + res, _, err := issueRequest(hdl, host, path) + r.NoError(err) + r.Equal(301, res.Code) + r.Equal("hello", res.Body.String()) + for k, v := range hdrs { + // ensure that the all response headers show up exactly + // as they were returned from the origin (hdl) + r.Equal(v, res.Header().Get(k)) + } +} + +func TestForwardRequestPathOnlyRedirectAndHeaders(t *testing.T) { + r := require.New(t) + hdrs := map[string]string{ + "Content-Type": "text/html; charset=utf-8", + "X-Custom-Header": "somethingcustom", + "Location": "/xyz", + } + hdl := standardHandler(301, "hello", hdrs) + // issue the request with a host. + // expect that host to not be set on the + // response header Location header. the host on + // that header should instead be what's listed + // on hdrs["Location"] + const host = "myhost.com" + const path = "/abc" + res, _, err := issueRequest(hdl, host, path) + r.NoError(err) + r.Equal(301, res.Code) + r.Equal("hello", res.Body.String()) + for k, v := range hdrs { + if k == "Location" { + // the "Location" header should be the + // path returned by the origin (hdl) + // but it should also have the host + // that was passed to issueRequest + expected := fmt.Sprintf( + "%s%s", + host, + v, + ) + r.Equal(expected, res.Header().Get(k)) + } else { + // ensure that all response headers except for + // the "Location" header to show up exactly + // as they were returned from the origin (hdl) + r.Equal(v, res.Header().Get(k)) + } + } +} + +func standardHandler( + code int, + body string, + headers map[string]string, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + for k, v := range headers { + w.Header().Set(k, v) + } + w.WriteHeader(code) + w.Write([]byte(body)) + } +} + +func issueRequest( + hdl http.HandlerFunc, + inHost, + inPath string, +) (*httptest.ResponseRecorder, *http.Request, error) { + srv, srvURL, err := kedanet.StartTestServer( + kedanet.NewTestHTTPHandlerWrapper( + http.HandlerFunc(hdl), + ), + ) + if err != nil { + return nil, nil, err + } + defer srv.Close() + + timeouts := defaultTimeouts() + timeouts.Connect = 10 * time.Millisecond + timeouts.ResponseHeader = 10 * time.Millisecond + backoff := timeouts.Backoff(2, 2, 1) + dialCtxFunc := retryDialContextFunc(timeouts, backoff) + + res, req, err := reqAndRes(inPath) + if err != nil { + return nil, nil, err + } + req.Host = inHost + + forwardRequest( + res, + req, + newRoundTripper(dialCtxFunc, timeouts.ResponseHeader), + srvURL, + ) + return res, req, nil + +} diff --git a/interceptor/request_forwarder_test.go b/interceptor/request_forwarder_test.go index fe51c55d..b6109074 100644 --- a/interceptor/request_forwarder_test.go +++ b/interceptor/request_forwarder_test.go @@ -237,40 +237,3 @@ func TestForwarderConnectionRetryAndTimeout(t *testing.T) { ) r.Contains(res.Body.String(), "error on backend") } - -func TestForwardRequestRedirectAndHeaders(t *testing.T) { - r := require.New(t) - - srv, srvURL, err := kedanet.StartTestServer( - kedanet.NewTestHTTPHandlerWrapper( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Header().Set("X-Custom-Header", "somethingcustom") - w.Header().Set("Location", "abc123.com") - w.WriteHeader(301) - w.Write([]byte("Hello from srv")) - }), - ), - ) - r.NoError(err) - defer srv.Close() - - timeouts := defaultTimeouts() - timeouts.Connect = 10 * time.Millisecond - timeouts.ResponseHeader = 10 * time.Millisecond - backoff := timeouts.Backoff(2, 2, 1) - dialCtxFunc := retryDialContextFunc(timeouts, backoff) - res, req, err := reqAndRes("/testfwd") - r.NoError(err) - forwardRequest( - res, - req, - newRoundTripper(dialCtxFunc, timeouts.ResponseHeader), - srvURL, - ) - r.Equal(301, res.Code) - r.Equal("abc123.com", res.Header().Get("Location")) - r.Equal("text/html; charset=utf-8", res.Header().Get("Content-Type")) - r.Equal("somethingcustom", res.Header().Get("X-Custom-Header")) - r.Equal("Hello from srv", res.Body.String()) -}