Skip to content

Commit

Permalink
Merge pull request #6961 from TheThingsNetwork/fix/redirect-loop
Browse files Browse the repository at this point in the history
Fix redirect loop when explicit ports are used
  • Loading branch information
adriansmares authored Feb 28, 2024
2 parents a0e6f48 + 2cb3ae5 commit 8449d7c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pkg/webmiddleware/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (c RedirectConfiguration) build(url *url.URL) (*url.URL, bool) {
if c.HostName != nil {
hostname = c.HostName(hostname)
}
originalPortStr := portStr
if c.Port != nil {
port, _ := strconv.ParseUint(portStr, 10, 0)
port = uint64(c.Port(uint(port)))
Expand All @@ -62,12 +63,15 @@ func (c RedirectConfiguration) build(url *url.URL) (*url.URL, bool) {
host := hostname
if portStr != "" {
switch {
case portStr == "0":
case portStr == "0" && originalPortStr == "":
// Just use the hostname.
case portStr == "0" && originalPortStr != "":
// Maintain the original port, in order to avoid loops.
host = net.JoinHostPort(host, originalPortStr)
case target.Scheme == "http" && portStr == "80":
// This is the default. Just use the hostame.
// This is the default. Just use the hostname.
case target.Scheme == "https" && portStr == "443":
// This is the default. Just use the hostame.
// This is the default. Just use the hostname.
default:
host = net.JoinHostPort(host, portStr)
}
Expand Down
20 changes: 20 additions & 0 deletions pkg/webmiddleware/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
)

func TestRedirect(t *testing.T) {
t.Parallel()

m := Redirect(RedirectConfiguration{
Scheme: func(s string) string { return SchemeHTTPS },
HostName: func(h string) string {
Expand All @@ -47,6 +49,7 @@ func TestRedirect(t *testing.T) {
})

t.Run("None", func(t *testing.T) {
t.Parallel()
a := assertions.New(t)
r := httptest.NewRequest(http.MethodGet, "https://dev.example.com/path?query=true", nil)
rec := httptest.NewRecorder()
Expand Down Expand Up @@ -86,7 +89,9 @@ func TestRedirect(t *testing.T) {
Redirect: "https://dev.example.com/path",
},
} {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
a := assertions.New(t)
r := httptest.NewRequest(http.MethodGet, tc.URL, nil)
rec := httptest.NewRecorder()
Expand All @@ -100,4 +105,19 @@ func TestRedirect(t *testing.T) {
a.So(res.Header.Get("Location"), should.Equal, tc.Redirect)
})
}

t.Run("ExplicitPort", func(t *testing.T) {
t.Parallel()
a := assertions.New(t)
r := httptest.NewRequest(http.MethodGet, "https://dev.example.com:8885/", nil)
rec := httptest.NewRecorder()
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})).ServeHTTP(rec, r)
res := rec.Result()
a.So(res.StatusCode, should.Equal, http.StatusOK)
body, _ := io.ReadAll(res.Body)
a.So(string(body), should.Equal, "OK")
})
}

0 comments on commit 8449d7c

Please sign in to comment.