diff --git a/ee/localserver/krypto-ec-middleware.go b/ee/localserver/krypto-ec-middleware.go index ca9ed15ee..c7f28a0d6 100644 --- a/ee/localserver/krypto-ec-middleware.go +++ b/ee/localserver/krypto-ec-middleware.go @@ -220,22 +220,27 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { } // Check if the origin is in the allowed list. See https://github.com/kolide/k2/issues/9634 + origin := r.Header.Get("Origin") + // When loading images, the origin may not be set, but the referer will. We can accept that instead. + if origin == "" { + origin = strings.TrimSuffix(r.Header.Get("Referer"), "/") + } if len(cmdReq.AllowedOrigins) > 0 { allowed := false for _, ao := range cmdReq.AllowedOrigins { - if strings.EqualFold(ao, r.Header.Get("Origin")) { + if strings.EqualFold(ao, origin) { allowed = true break } } if !allowed { - span.SetAttributes(attribute.String("origin", r.Header.Get("Origin"))) - traces.SetError(span, fmt.Errorf("origin %s is not allowed", r.Header.Get("Origin"))) + span.SetAttributes(attribute.String("origin", origin)) + traces.SetError(span, fmt.Errorf("origin %s is not allowed", origin)) e.slogger.Log(r.Context(), slog.LevelError, "origin is not allowed", "allowlist", cmdReq.AllowedOrigins, - "origin", r.Header.Get("Origin"), + "origin", origin, ) w.WriteHeader(http.StatusUnauthorized) @@ -245,12 +250,12 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { e.slogger.Log(r.Context(), slog.LevelDebug, "origin matches allowlist", - "origin", r.Header.Get("Origin"), + "origin", origin, ) } else { e.slogger.Log(r.Context(), slog.LevelDebug, "origin is allowed by default, no allowlist", - "origin", r.Header.Get("Origin"), + "origin", origin, ) } @@ -281,6 +286,7 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { } newReq.Header.Set("Origin", r.Header.Get("Origin")) + newReq.Header.Set("Referer", r.Header.Get("Referer")) // setting the newReq context to the current request context // allows the trace to continue to the inner request, diff --git a/ee/localserver/krypto-ec-middleware_test.go b/ee/localserver/krypto-ec-middleware_test.go index e8b85a4d8..fa135795a 100644 --- a/ee/localserver/krypto-ec-middleware_test.go +++ b/ee/localserver/krypto-ec-middleware_test.go @@ -219,11 +219,12 @@ func Test_AllowedOrigin(t *testing.T) { challengeData := []byte(ulid.New()) var tests = []struct { - name string - requestOrigin string - allowedOrigins []string - logStr string - expectedStatus int + name string + requestOrigin string + requestReferrer string + allowedOrigins []string + logStr string + expectedStatus int }{ { name: "no allowed specified", @@ -277,6 +278,13 @@ func Test_AllowedOrigin(t *testing.T) { expectedStatus: http.StatusOK, logStr: "origin matches allowlist", }, + { + name: "no allowed specified origin, but acceptable referer is present", + allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"}, + requestReferrer: "https://auth.example.com", + expectedStatus: http.StatusOK, + logStr: "origin matches allowlist", + }, } for _, tt := range tests { @@ -320,6 +328,7 @@ func Test_AllowedOrigin(t *testing.T) { req := makeGetRequest(t, encodedChallenge) req.Header.Set("origin", tt.requestOrigin) + req.Header.Set("referer", tt.requestReferrer) rr := httptest.NewRecorder() h.ServeHTTP(rr, req)