diff --git a/CHANGELOG.md b/CHANGELOG.md index 05b60aa..5b38570 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## Unreleased + +* Set in the outbound request `X-Forwarded-For` (to the client IP address), `X-Forwarded-Host` (to the host name requested by the client), and `X-Forwarded-Proto` (to "http" or "https" depending on whether the inbound request was made on a TLS-enabled connection) (#29) + ## v0.1.4 / 2024-04-26 * [BREAKING] Rename the `SSL_DOMAIN` env to `TLS_DOMAIN` (#13) diff --git a/internal/handler_test.go b/internal/handler_test.go index ef2d947..7519433 100644 --- a/internal/handler_test.go +++ b/internal/handler_test.go @@ -155,6 +155,35 @@ func TestHandlerMaxRequestBody(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } +func TestHandlerPreserveInboundHostHeaderWhenProxying(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "example.org", r.Host) + })) + defer upstream.Close() + + h := NewHandler(handlerOptions(upstream.URL)) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "http://example.org", nil) + h.ServeHTTP(w, r) +} + +func TestHandlerXForwardedHeadersWhenProxying(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) + assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host")) + assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto")) + })) + defer upstream.Close() + + h := NewHandler(handlerOptions(upstream.URL)) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://example.org", nil) + r.RemoteAddr = "1.2.3.4:1234" + h.ServeHTTP(w, r) +} + // Helpers func handlerOptions(targetUrl string) HandlerOptions { diff --git a/internal/proxy_handler.go b/internal/proxy_handler.go index a550848..6c7434a 100644 --- a/internal/proxy_handler.go +++ b/internal/proxy_handler.go @@ -10,11 +10,15 @@ import ( ) func NewProxyHandler(targetUrl *url.URL, badGatewayPage string) http.Handler { - proxy := httputil.NewSingleHostReverseProxy(targetUrl) - proxy.ErrorHandler = ProxyErrorHandler(badGatewayPage) - proxy.Transport = createProxyTransport() - - return proxy + return &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetUrl) + r.Out.Host = r.In.Host + r.SetXForwarded() + }, + ErrorHandler: ProxyErrorHandler(badGatewayPage), + Transport: createProxyTransport(), + } } func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *http.Request, err error) {