From ee6fcdfbd81bd8f8acfa508b761c35f97da5bb33 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Mon, 23 Dec 2024 23:42:42 +0100 Subject: [PATCH] internal: fix missing trailing slash in outpost websocket (#12470) Signed-off-by: Jens Langhammer --- internal/outpost/ak/api_ws.go | 14 ++++++++++---- internal/outpost/ak/api_ws_test.go | 18 ++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_ws.go index dfab5255ac92..d92941f760c6 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_ws.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "maps" "net/http" "net/url" "strconv" @@ -16,13 +17,16 @@ import ( "goauthentik.io/internal/constants" ) -func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string) *url.URL { +func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, query url.Values) *url.URL { wsUrl := &url.URL{} wsUrl.Scheme = strings.ReplaceAll(akURL.Scheme, "http", "ws") wsUrl.Host = akURL.Host - _p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID) + _p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID, "/") wsUrl.Path = _p - wsUrl.RawQuery = akURL.Query().Encode() + v := url.Values{} + maps.Insert(v, maps.All(akURL.Query())) + maps.Insert(v, maps.All(query)) + wsUrl.RawQuery = v.Encode() return wsUrl } @@ -45,7 +49,9 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { }, } - ws, _, err := dialer.Dial(ac.getWebsocketURL(akURL, outpostUUID).String(), header) + wsu := ac.getWebsocketURL(akURL, outpostUUID, query).String() + ac.logger.WithField("url", wsu).Debug("connecting to websocket") + ws, _, err := dialer.Dial(wsu, header) if err != nil { ac.logger.WithError(err).Warning("failed to connect websocket") return err diff --git a/internal/outpost/ak/api_ws_test.go b/internal/outpost/ak/api_ws_test.go index 33284795c04f..f52c35375841 100644 --- a/internal/outpost/ak/api_ws_test.go +++ b/internal/outpost/ak/api_ws_test.go @@ -19,14 +19,24 @@ func TestWebsocketURL(t *testing.T) { u := URLMustParse("http://localhost:9000?foo=bar") uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" ac := &APIController{} - nu := ac.getWebsocketURL(*u, uuid) - assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77?foo=bar", nu.String()) + nu := ac.getWebsocketURL(*u, uuid, url.Values{}) + assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String()) +} + +func TestWebsocketURL_Query(t *testing.T) { + u := URLMustParse("http://localhost:9000?foo=bar") + uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" + ac := &APIController{} + v := url.Values{} + v.Set("bar", "baz") + nu := ac.getWebsocketURL(*u, uuid, v) + assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String()) } func TestWebsocketURL_Subpath(t *testing.T) { u := URLMustParse("http://localhost:9000/foo/bar/") uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" ac := &APIController{} - nu := ac.getWebsocketURL(*u, uuid) - assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77", nu.String()) + nu := ac.getWebsocketURL(*u, uuid, url.Values{}) + assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/", nu.String()) }