Skip to content

Commit

Permalink
Merge pull request #29 from 3v0k4/x-forwarded
Browse files Browse the repository at this point in the history
Set X-Forwarded-* headers
  • Loading branch information
kevinmcconnell authored May 8, 2024
2 parents c19ae78 + 3a9cdc7 commit ec8c5f6
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Unreleased

* Set in the outbound request `X-Forwarded-For` (to the client IP address), `X-Forwarded-Host` (to the host name requested by the client), and `X-Forwarded-Proto` (to "http" or "https" depending on whether the inbound request was made on a TLS-enabled connection) (#29)

## v0.1.4 / 2024-04-26

* [BREAKING] Rename the `SSL_DOMAIN` env to `TLS_DOMAIN` (#13)
Expand Down
44 changes: 44 additions & 0 deletions internal/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,50 @@ func TestHandlerMaxRequestBody(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
}

func TestHandlerPreserveInboundHostHeaderWhenProxying(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "example.org", r.Host)
}))
defer upstream.Close()

h := NewHandler(handlerOptions(upstream.URL))

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.org", nil)
h.ServeHTTP(w, r)
}

func TestHandlerAppendInboundXFFHeaderWhenProxying(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "0.0.0.0, 0.0.0.1", r.Header.Get("X-Forwarded-For"))
}))
defer upstream.Close()

h := NewHandler(handlerOptions(upstream.URL))

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.org", nil)
r.RemoteAddr = "0.0.0.1:1234"
r.Header.Set("X-Forwarded-For", "0.0.0.0")
h.ServeHTTP(w, r)
}

func TestHandlerXForwardedHeadersWhenProxying(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For"))
assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host"))
assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto"))
}))
defer upstream.Close()

h := NewHandler(handlerOptions(upstream.URL))

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://example.org", nil)
r.RemoteAddr = "1.2.3.4:1234"
h.ServeHTTP(w, r)
}

// Helpers

func handlerOptions(targetUrl string) HandlerOptions {
Expand Down
15 changes: 10 additions & 5 deletions internal/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ import (
)

func NewProxyHandler(targetUrl *url.URL, badGatewayPage string) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
proxy.ErrorHandler = ProxyErrorHandler(badGatewayPage)
proxy.Transport = createProxyTransport()

return proxy
return &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(targetUrl)
r.Out.Host = r.In.Host
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
r.SetXForwarded()
},
ErrorHandler: ProxyErrorHandler(badGatewayPage),
Transport: createProxyTransport(),
}
}

func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *http.Request, err error) {
Expand Down

0 comments on commit ec8c5f6

Please sign in to comment.