diff --git a/cmd/query/app/http_handler.go b/cmd/query/app/http_handler.go index e00472babb5..3eec204720e 100644 --- a/cmd/query/app/http_handler.go +++ b/cmd/query/app/http_handler.go @@ -18,7 +18,6 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gorilla/mux" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" @@ -132,7 +131,6 @@ func (aH *APIHandler) handleFunc( ) *mux.Route { route := aH.formatRoute(routeFmt, args...) var handler http.Handler = http.HandlerFunc(f) - handler = traceResponseHandler(handler) handler = otelhttp.WithRouteTag(route, handler) handler = spanNameHandler(route, handler) return router.HandleFunc(route, handler.ServeHTTP) @@ -517,21 +515,6 @@ func (aH *APIHandler) writeJSON(w http.ResponseWriter, r *http.Request, response } } -// Returns a handler that generates a traceresponse header. -// https://github.com/w3c/trace-context/blob/main/spec/21-http_response_header_format.md -func traceResponseHandler(handler http.Handler) http.Handler { - // We use the standard TraceContext propagator, since the formats are identical. - // But the propagator uses "traceparent" header name, so we inject it into a map - // `carrier` and then use the result to set the "tracereponse" header. - var prop propagation.TraceContext - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - carrier := make(map[string]string) - prop.Inject(r.Context(), propagation.MapCarrier(carrier)) - w.Header().Add("traceresponse", carrier["traceparent"]) - handler.ServeHTTP(w, r) - }) -} - func spanNameHandler(spanName string, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) diff --git a/cmd/query/app/server.go b/cmd/query/app/server.go index 543ede956ee..043fe978d3c 100644 --- a/cmd/query/app/server.go +++ b/cmd/query/app/server.go @@ -201,6 +201,7 @@ func initRouter( if tenancyMgr.Enabled { handler = tenancy.ExtractTenantHTTPHandler(tenancyMgr, handler) } + handler = traceResponseHandler(handler) return handler, staticHandlerCloser } diff --git a/cmd/query/app/trace_response_handler.go b/cmd/query/app/trace_response_handler.go new file mode 100644 index 00000000000..f01dc4a7315 --- /dev/null +++ b/cmd/query/app/trace_response_handler.go @@ -0,0 +1,25 @@ +// Copyright (c) 2024 The Jaeger Authors. +// SPDX-License-Identifier: Apache-2.0 + +package app + +import ( + "net/http" + + "go.opentelemetry.io/otel/propagation" +) + +// Returns a handler that generates a traceresponse header. +// https://github.com/w3c/trace-context/blob/main/spec/21-http_response_header_format.md +func traceResponseHandler(handler http.Handler) http.Handler { + // We use the standard TraceContext propagator, since the formats are identical. + // But the propagator uses "traceparent" header name, so we inject it into a map + // `carrier` and then use the result to set the "tracereponse" header. + var prop propagation.TraceContext + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + carrier := make(map[string]string) + prop.Inject(r.Context(), propagation.MapCarrier(carrier)) + w.Header().Add("traceresponse", carrier["traceparent"]) + handler.ServeHTTP(w, r) + }) +} diff --git a/cmd/query/app/trace_response_handler_test.go b/cmd/query/app/trace_response_handler_test.go new file mode 100644 index 00000000000..59f988f2a37 --- /dev/null +++ b/cmd/query/app/trace_response_handler_test.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024 The Jaeger Authors. +// SPDX-License-Identifier: Apache-2.0 + +package app + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +func TestTraceResponseHandler(t *testing.T) { + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte{}) + }) + handler := traceResponseHandler(emptyHandler) + + exporter := tracetest.NewInMemoryExporter() + tracerProvider := trace.NewTracerProvider( + trace.WithSyncer(exporter), + trace.WithSampler(trace.AlwaysSample()), + ) + tracer := tracerProvider.Tracer("test-tracer") + ctx, span := tracer.Start(context.Background(), "test-span") + defer span.End() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + traceResponse := w.Header().Get("traceresponse") + parts := strings.Split(traceResponse, "-") + require.Len(t, parts, 4) +}