From a16c8b1563cdb4f1e9ae3deaec471720ab7b2544 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Mon, 29 Jul 2024 14:20:24 +0100 Subject: [PATCH] Allow logging custom request & response headers You can now specify headers to log from the request and/or the response, using a combination of `--log-request-header` and `--log-response-header`. To add multiple headers, you can specify the flags multiple times, or use a comma-delimited list. For example: --log-response-header=Cache-Control,X-Revision Header names are converted to snake case in the output, and prefixed with `req` or `resp`. So the headers from the above example will be logged as `resp_cache_control` and `resp_x_revision`. --- internal/cmd/deploy.go | 3 + internal/server/logging_middleware.go | 77 ++++++++++++++-------- internal/server/logging_middleware_test.go | 24 ++++--- internal/server/service.go | 9 +-- internal/server/target.go | 28 +++++--- 5 files changed, 88 insertions(+), 53 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 7e846df..3ec9d72 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -48,6 +48,9 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxRequestBodySize, "max-request-body", server.DefaultMaxRequestBodySize, "Max size of request body when buffering (default of 0 means unlimited)") deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxResponseBodySize, "max-response-body", server.DefaultMaxResponseBodySize, "Max size of response body when buffering (default of 0 means unlimited)") + deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.TargetOptions.LogRequestHeaders, "log-request-header", nil, "Additional request header to log (may be specified multiple times)") + deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.TargetOptions.LogResponseHeaders, "log-response-header", nil, "Additional response header to log (may be specified multiple times)") + deployCommand.cmd.MarkFlagRequired("target") return deployCommand diff --git a/internal/server/logging_middleware.go b/internal/server/logging_middleware.go index f3ea2b0..3cf42ea 100644 --- a/internal/server/logging_middleware.go +++ b/internal/server/logging_middleware.go @@ -7,16 +7,23 @@ import ( "log/slog" "net" "net/http" + "strings" "time" ) type contextKey string var ( - contextKeyService = contextKey("service") - contextKeyTarget = contextKey("target") + contextKeyRequestContext = contextKey("request-context") ) +type loggingRequestContext struct { + Service string + Target string + RequestHeaders []string + ResponseHeaders []string +} + type LoggingMiddleware struct { logger *slog.Logger next http.Handler @@ -29,44 +36,62 @@ func WithLoggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler } } +func LoggingRequestContext(r *http.Request) *loggingRequestContext { + lrc, ok := r.Context().Value(contextKeyRequestContext).(*loggingRequestContext) + if !ok { + return &loggingRequestContext{} + } + return lrc +} + func (h *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { writer := newLoggerResponseWriter(w) - var service string - var target string - ctx := context.WithValue(r.Context(), contextKeyService, &service) - ctx = context.WithValue(ctx, contextKeyTarget, &target) + var loggingRequestContext loggingRequestContext + ctx := context.WithValue(r.Context(), contextKeyRequestContext, &loggingRequestContext) r = r.WithContext(ctx) started := time.Now() h.next.ServeHTTP(writer, r) elapsed := time.Since(started) - userAgent := r.Header.Get("User-Agent") - reqContent := r.Header.Get("Content-Type") - respContent := writer.Header().Get("Content-Type") remoteAddr := r.Header.Get("X-Forwarded-For") - requestID := r.Header.Get("X-Request-ID") if remoteAddr == "" { remoteAddr = r.RemoteAddr } - h.logger.Info("Request", - "host", r.Host, - "path", r.URL.Path, - "request_id", requestID, - "status", writer.statusCode, - "service", service, - "target", target, - "duration", elapsed.Nanoseconds(), - "method", r.Method, - "req_content_length", r.ContentLength, - "req_content_type", reqContent, - "resp_content_length", writer.bytesWritten, - "resp_content_type", respContent, - "remote_addr", remoteAddr, - "user_agent", userAgent, - "query", r.URL.RawQuery) + attrs := []slog.Attr{ + slog.String("host", r.Host), + slog.String("path", r.URL.Path), + slog.String("request_id", r.Header.Get("X-Request-ID")), + slog.Int("status", writer.statusCode), + slog.String("service", loggingRequestContext.Service), + slog.String("target", loggingRequestContext.Target), + slog.Int64("duration", elapsed.Nanoseconds()), + slog.String("method", r.Method), + slog.Int64("req_content_length", r.ContentLength), + slog.String("req_content_type", r.Header.Get("Content-Type")), + slog.Int64("resp_content_length", writer.bytesWritten), + slog.String("resp_content_type", writer.Header().Get("Content-Type")), + slog.String("remote_addr", remoteAddr), + slog.String("user_agent", r.Header.Get("User-Agent")), + slog.String("query", r.URL.RawQuery), + } + + attrs = append(attrs, h.retrieveCustomHeaders(loggingRequestContext.RequestHeaders, r.Header, "req")...) + attrs = append(attrs, h.retrieveCustomHeaders(loggingRequestContext.ResponseHeaders, writer.Header(), "resp")...) + + h.logger.LogAttrs(nil, slog.LevelInfo, "Request", attrs...) +} + +func (h *LoggingMiddleware) retrieveCustomHeaders(headerNames []string, header http.Header, prefix string) []slog.Attr { + attrs := []slog.Attr{} + for _, headerName := range headerNames { + name := prefix + "_" + strings.ReplaceAll(strings.ToLower(headerName), "-", "_") + value := strings.Join(header[headerName], ",") + attrs = append(attrs, slog.String(name, value)) + } + return attrs } type loggerResponseWriter struct { diff --git a/internal/server/logging_middleware_test.go b/internal/server/logging_middleware_test.go index bb780c5..5b72ce3 100644 --- a/internal/server/logging_middleware_test.go +++ b/internal/server/logging_middleware_test.go @@ -18,17 +18,14 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) { out := &strings.Builder{} logger := slog.New(slog.NewJSONHandler(out, nil)) middleware := WithLoggingMiddleware(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Record a value for the `service` and `target` context keys - service, ok := r.Context().Value(contextKeyService).(*string) - if ok { - *service = "myapp" - } - target, ok := r.Context().Value(contextKeyTarget).(*string) - if ok { - *target = "upstream:3000" - } + LoggingRequestContext(r).Service = "myapp" + LoggingRequestContext(r).Target = "upstream:3000" + LoggingRequestContext(r).RequestHeaders = []string{"X-Custom"} + LoggingRequestContext(r).ResponseHeaders = []string{"Cache-Control", "X-Custom"} w.Header().Set("Content-Type", "text/html") + w.Header().Set("Cache-Control", "public, max-age=3600") + w.Header().Set("X-Custom", "goodbye") w.WriteHeader(http.StatusCreated) fmt.Fprintln(w, "goodbye") })) @@ -39,6 +36,9 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) { req.Header.Set("User-Agent", "Robot/1") req.Header.Set("Content-Type", "application/json") + // Ensure non-canonicalised headers are logged too + req.Header.Set("x-custom", "hello") + middleware.ServeHTTP(httptest.NewRecorder(), req) logline := struct { @@ -58,6 +58,9 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) { Query string `json:"query"` Service string `json:"service"` Target string `json:"target"` + ReqXCustom string `json:"req_x_custom"` + RespCacheControl string `json:"resp_cache_control"` + RespXCustom string `json:"resp_x_custom"` }{} err := json.NewDecoder(strings.NewReader(out.String())).Decode(&logline) @@ -79,4 +82,7 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) { assert.Equal(t, int64(8), logline.RespContentLength) assert.Equal(t, "upstream:3000", logline.Target) assert.Equal(t, "myapp", logline.Service) + assert.Equal(t, "hello", logline.ReqXCustom) + assert.Equal(t, "public, max-age=3600", logline.RespCacheControl) + assert.Equal(t, "goodbye", logline.RespXCustom) } diff --git a/internal/server/service.go b/internal/server/service.go index 47da342..f21e1aa 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -126,7 +126,7 @@ func (s *Service) SetActiveTarget(target *Target, drainTimeout time.Duration) { } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.recordServiceNameForRequest(r) + LoggingRequestContext(r).Service = s.name if s.options.RequireTLS() && r.TLS == nil { s.redirectToHTTPS(w, r) @@ -258,13 +258,6 @@ func (s *Service) initialize() { s.certManager = s.createCertManager() } -func (s *Service) recordServiceNameForRequest(req *http.Request) { - serviceIdentifer, ok := req.Context().Value(contextKeyService).(*string) - if ok { - *serviceIdentifer = s.name - } -} - func (s *Service) createCertManager() *autocert.Manager { if s.options.TLSHostname == "" { return nil diff --git a/internal/server/target.go b/internal/server/target.go index 5bc0e41..1e7ff41 100644 --- a/internal/server/target.go +++ b/internal/server/target.go @@ -57,6 +57,17 @@ type TargetOptions struct { MaxMemoryBufferSize int64 `json:"max_memory_buffer_size"` MaxRequestBodySize int64 `json:"max_request_body_size"` MaxResponseBodySize int64 `json:"max_response_body_size"` + LogRequestHeaders []string `json:"log_request_headers"` + LogResponseHeaders []string `json:"log_response_headers"` +} + +func (to *TargetOptions) CanonicalizeLogHeaders() { + for i, header := range to.LogRequestHeaders { + to.LogRequestHeaders[i] = http.CanonicalHeaderKey(header) + } + for i, header := range to.LogResponseHeaders { + to.LogResponseHeaders[i] = http.CanonicalHeaderKey(header) + } } type Target struct { @@ -78,6 +89,8 @@ func NewTarget(targetURL string, options TargetOptions) (*Target, error) { return nil, err } + options.CanonicalizeLogHeaders() + target := &Target{ targetURL: uri, options: options, @@ -120,12 +133,14 @@ func (t *Target) StartRequest(req *http.Request) (*http.Request, error) { } func (t *Target) SendRequest(w http.ResponseWriter, req *http.Request) { - defer t.endInflightRequest(req) - t.recordTargetNameForRequest(req) + LoggingRequestContext(req).Target = t.Target() + LoggingRequestContext(req).RequestHeaders = t.options.LogRequestHeaders + LoggingRequestContext(req).ResponseHeaders = t.options.LogResponseHeaders inflightRequest := t.getInflightRequest(req) - tw := newTargetResponseWriter(w, inflightRequest) + defer t.endInflightRequest(req) + tw := newTargetResponseWriter(w, inflightRequest) t.proxyHandler.ServeHTTP(tw, req) } @@ -243,13 +258,6 @@ func (t *Target) HealthCheckCompleted(success bool) { // Private -func (t *Target) recordTargetNameForRequest(req *http.Request) { - targetIdentifer, ok := req.Context().Value(contextKeyTarget).(*string) - if ok { - *targetIdentifer = t.Target() - } -} - func (t *Target) createProxyHandler() http.Handler { bufferPool := NewBufferPool(ProxyBufferSize)