From 763c917ff999967ef7ad31c131ad6efc5a0fedc9 Mon Sep 17 00:00:00 2001 From: Rebecca Mahany-Horton Date: Thu, 6 Jun 2024 16:55:56 -0400 Subject: [PATCH 1/2] Accept referer if origin is not available --- ee/localserver/krypto-ec-middleware.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ee/localserver/krypto-ec-middleware.go b/ee/localserver/krypto-ec-middleware.go index ca9ed15ee..c7f28a0d6 100644 --- a/ee/localserver/krypto-ec-middleware.go +++ b/ee/localserver/krypto-ec-middleware.go @@ -220,22 +220,27 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { } // Check if the origin is in the allowed list. See https://github.com/kolide/k2/issues/9634 + origin := r.Header.Get("Origin") + // When loading images, the origin may not be set, but the referer will. We can accept that instead. + if origin == "" { + origin = strings.TrimSuffix(r.Header.Get("Referer"), "/") + } if len(cmdReq.AllowedOrigins) > 0 { allowed := false for _, ao := range cmdReq.AllowedOrigins { - if strings.EqualFold(ao, r.Header.Get("Origin")) { + if strings.EqualFold(ao, origin) { allowed = true break } } if !allowed { - span.SetAttributes(attribute.String("origin", r.Header.Get("Origin"))) - traces.SetError(span, fmt.Errorf("origin %s is not allowed", r.Header.Get("Origin"))) + span.SetAttributes(attribute.String("origin", origin)) + traces.SetError(span, fmt.Errorf("origin %s is not allowed", origin)) e.slogger.Log(r.Context(), slog.LevelError, "origin is not allowed", "allowlist", cmdReq.AllowedOrigins, - "origin", r.Header.Get("Origin"), + "origin", origin, ) w.WriteHeader(http.StatusUnauthorized) @@ -245,12 +250,12 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { e.slogger.Log(r.Context(), slog.LevelDebug, "origin matches allowlist", - "origin", r.Header.Get("Origin"), + "origin", origin, ) } else { e.slogger.Log(r.Context(), slog.LevelDebug, "origin is allowed by default, no allowlist", - "origin", r.Header.Get("Origin"), + "origin", origin, ) } @@ -281,6 +286,7 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { } newReq.Header.Set("Origin", r.Header.Get("Origin")) + newReq.Header.Set("Referer", r.Header.Get("Referer")) // setting the newReq context to the current request context // allows the trace to continue to the inner request, From cb01108d792e5899597eba9fa1de6d46e467a1a2 Mon Sep 17 00:00:00 2001 From: Rebecca Mahany-Horton Date: Thu, 6 Jun 2024 17:01:40 -0400 Subject: [PATCH 2/2] Add test case --- ee/localserver/krypto-ec-middleware_test.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ee/localserver/krypto-ec-middleware_test.go b/ee/localserver/krypto-ec-middleware_test.go index e8b85a4d8..fa135795a 100644 --- a/ee/localserver/krypto-ec-middleware_test.go +++ b/ee/localserver/krypto-ec-middleware_test.go @@ -219,11 +219,12 @@ func Test_AllowedOrigin(t *testing.T) { challengeData := []byte(ulid.New()) var tests = []struct { - name string - requestOrigin string - allowedOrigins []string - logStr string - expectedStatus int + name string + requestOrigin string + requestReferrer string + allowedOrigins []string + logStr string + expectedStatus int }{ { name: "no allowed specified", @@ -277,6 +278,13 @@ func Test_AllowedOrigin(t *testing.T) { expectedStatus: http.StatusOK, logStr: "origin matches allowlist", }, + { + name: "no allowed specified origin, but acceptable referer is present", + allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"}, + requestReferrer: "https://auth.example.com", + expectedStatus: http.StatusOK, + logStr: "origin matches allowlist", + }, } for _, tt := range tests { @@ -320,6 +328,7 @@ func Test_AllowedOrigin(t *testing.T) { req := makeGetRequest(t, encodedChallenge) req.Header.Set("origin", tt.requestOrigin) + req.Header.Set("referer", tt.requestReferrer) rr := httptest.NewRecorder() h.ServeHTTP(rr, req)