Skip to content

Commit

Permalink
Allow logging custom request & response headers
Browse files Browse the repository at this point in the history
You can now specify headers to log from the request and/or the response,
using a combination of `--log-request-header` and
`--log-response-header`.

To add multiple headers, you can specify the flags multiple times, or
use a comma-delimited list. For example:

    --log-response-header=Cache-Control,X-Revision

Header names are converted to snake case in the output, and prefixed
with `req` or `resp`. So the headers from the above example will be
logged as `resp_cache_control` and `resp_x_revision`.
  • Loading branch information
kevinmcconnell committed Jul 29, 2024
1 parent 3838748 commit a16c8b1
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 53 deletions.
3 changes: 3 additions & 0 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ func newDeployCommand() *deployCommand {
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxRequestBodySize, "max-request-body", server.DefaultMaxRequestBodySize, "Max size of request body when buffering (default of 0 means unlimited)")
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.TargetOptions.MaxResponseBodySize, "max-response-body", server.DefaultMaxResponseBodySize, "Max size of response body when buffering (default of 0 means unlimited)")

deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.TargetOptions.LogRequestHeaders, "log-request-header", nil, "Additional request header to log (may be specified multiple times)")
deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.TargetOptions.LogResponseHeaders, "log-response-header", nil, "Additional response header to log (may be specified multiple times)")

deployCommand.cmd.MarkFlagRequired("target")

return deployCommand
Expand Down
77 changes: 51 additions & 26 deletions internal/server/logging_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@ import (
"log/slog"
"net"
"net/http"
"strings"
"time"
)

type contextKey string

var (
contextKeyService = contextKey("service")
contextKeyTarget = contextKey("target")
contextKeyRequestContext = contextKey("request-context")
)

type loggingRequestContext struct {
Service string
Target string
RequestHeaders []string
ResponseHeaders []string
}

type LoggingMiddleware struct {
logger *slog.Logger
next http.Handler
Expand All @@ -29,44 +36,62 @@ func WithLoggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler
}
}

func LoggingRequestContext(r *http.Request) *loggingRequestContext {
lrc, ok := r.Context().Value(contextKeyRequestContext).(*loggingRequestContext)
if !ok {
return &loggingRequestContext{}
}
return lrc
}

func (h *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writer := newLoggerResponseWriter(w)

var service string
var target string
ctx := context.WithValue(r.Context(), contextKeyService, &service)
ctx = context.WithValue(ctx, contextKeyTarget, &target)
var loggingRequestContext loggingRequestContext
ctx := context.WithValue(r.Context(), contextKeyRequestContext, &loggingRequestContext)
r = r.WithContext(ctx)

started := time.Now()
h.next.ServeHTTP(writer, r)
elapsed := time.Since(started)

userAgent := r.Header.Get("User-Agent")
reqContent := r.Header.Get("Content-Type")
respContent := writer.Header().Get("Content-Type")
remoteAddr := r.Header.Get("X-Forwarded-For")
requestID := r.Header.Get("X-Request-ID")
if remoteAddr == "" {
remoteAddr = r.RemoteAddr
}

h.logger.Info("Request",
"host", r.Host,
"path", r.URL.Path,
"request_id", requestID,
"status", writer.statusCode,
"service", service,
"target", target,
"duration", elapsed.Nanoseconds(),
"method", r.Method,
"req_content_length", r.ContentLength,
"req_content_type", reqContent,
"resp_content_length", writer.bytesWritten,
"resp_content_type", respContent,
"remote_addr", remoteAddr,
"user_agent", userAgent,
"query", r.URL.RawQuery)
attrs := []slog.Attr{
slog.String("host", r.Host),
slog.String("path", r.URL.Path),
slog.String("request_id", r.Header.Get("X-Request-ID")),
slog.Int("status", writer.statusCode),
slog.String("service", loggingRequestContext.Service),
slog.String("target", loggingRequestContext.Target),
slog.Int64("duration", elapsed.Nanoseconds()),
slog.String("method", r.Method),
slog.Int64("req_content_length", r.ContentLength),
slog.String("req_content_type", r.Header.Get("Content-Type")),
slog.Int64("resp_content_length", writer.bytesWritten),
slog.String("resp_content_type", writer.Header().Get("Content-Type")),
slog.String("remote_addr", remoteAddr),
slog.String("user_agent", r.Header.Get("User-Agent")),
slog.String("query", r.URL.RawQuery),
}

attrs = append(attrs, h.retrieveCustomHeaders(loggingRequestContext.RequestHeaders, r.Header, "req")...)
attrs = append(attrs, h.retrieveCustomHeaders(loggingRequestContext.ResponseHeaders, writer.Header(), "resp")...)

h.logger.LogAttrs(nil, slog.LevelInfo, "Request", attrs...)
}

func (h *LoggingMiddleware) retrieveCustomHeaders(headerNames []string, header http.Header, prefix string) []slog.Attr {
attrs := []slog.Attr{}
for _, headerName := range headerNames {
name := prefix + "_" + strings.ReplaceAll(strings.ToLower(headerName), "-", "_")
value := strings.Join(header[headerName], ",")
attrs = append(attrs, slog.String(name, value))
}
return attrs
}

type loggerResponseWriter struct {
Expand Down
24 changes: 15 additions & 9 deletions internal/server/logging_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) {
out := &strings.Builder{}
logger := slog.New(slog.NewJSONHandler(out, nil))
middleware := WithLoggingMiddleware(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Record a value for the `service` and `target` context keys
service, ok := r.Context().Value(contextKeyService).(*string)
if ok {
*service = "myapp"
}
target, ok := r.Context().Value(contextKeyTarget).(*string)
if ok {
*target = "upstream:3000"
}
LoggingRequestContext(r).Service = "myapp"
LoggingRequestContext(r).Target = "upstream:3000"
LoggingRequestContext(r).RequestHeaders = []string{"X-Custom"}
LoggingRequestContext(r).ResponseHeaders = []string{"Cache-Control", "X-Custom"}

w.Header().Set("Content-Type", "text/html")
w.Header().Set("Cache-Control", "public, max-age=3600")
w.Header().Set("X-Custom", "goodbye")
w.WriteHeader(http.StatusCreated)
fmt.Fprintln(w, "goodbye")
}))
Expand All @@ -39,6 +36,9 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) {
req.Header.Set("User-Agent", "Robot/1")
req.Header.Set("Content-Type", "application/json")

// Ensure non-canonicalised headers are logged too
req.Header.Set("x-custom", "hello")

middleware.ServeHTTP(httptest.NewRecorder(), req)

logline := struct {
Expand All @@ -58,6 +58,9 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) {
Query string `json:"query"`
Service string `json:"service"`
Target string `json:"target"`
ReqXCustom string `json:"req_x_custom"`
RespCacheControl string `json:"resp_cache_control"`
RespXCustom string `json:"resp_x_custom"`
}{}

err := json.NewDecoder(strings.NewReader(out.String())).Decode(&logline)
Expand All @@ -79,4 +82,7 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) {
assert.Equal(t, int64(8), logline.RespContentLength)
assert.Equal(t, "upstream:3000", logline.Target)
assert.Equal(t, "myapp", logline.Service)
assert.Equal(t, "hello", logline.ReqXCustom)
assert.Equal(t, "public, max-age=3600", logline.RespCacheControl)
assert.Equal(t, "goodbye", logline.RespXCustom)
}
9 changes: 1 addition & 8 deletions internal/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (s *Service) SetActiveTarget(target *Target, drainTimeout time.Duration) {
}

func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.recordServiceNameForRequest(r)
LoggingRequestContext(r).Service = s.name

if s.options.RequireTLS() && r.TLS == nil {
s.redirectToHTTPS(w, r)
Expand Down Expand Up @@ -258,13 +258,6 @@ func (s *Service) initialize() {
s.certManager = s.createCertManager()
}

func (s *Service) recordServiceNameForRequest(req *http.Request) {
serviceIdentifer, ok := req.Context().Value(contextKeyService).(*string)
if ok {
*serviceIdentifer = s.name
}
}

func (s *Service) createCertManager() *autocert.Manager {
if s.options.TLSHostname == "" {
return nil
Expand Down
28 changes: 18 additions & 10 deletions internal/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ type TargetOptions struct {
MaxMemoryBufferSize int64 `json:"max_memory_buffer_size"`
MaxRequestBodySize int64 `json:"max_request_body_size"`
MaxResponseBodySize int64 `json:"max_response_body_size"`
LogRequestHeaders []string `json:"log_request_headers"`
LogResponseHeaders []string `json:"log_response_headers"`
}

func (to *TargetOptions) CanonicalizeLogHeaders() {
for i, header := range to.LogRequestHeaders {
to.LogRequestHeaders[i] = http.CanonicalHeaderKey(header)
}
for i, header := range to.LogResponseHeaders {
to.LogResponseHeaders[i] = http.CanonicalHeaderKey(header)
}
}

type Target struct {
Expand All @@ -78,6 +89,8 @@ func NewTarget(targetURL string, options TargetOptions) (*Target, error) {
return nil, err
}

options.CanonicalizeLogHeaders()

target := &Target{
targetURL: uri,
options: options,
Expand Down Expand Up @@ -120,12 +133,14 @@ func (t *Target) StartRequest(req *http.Request) (*http.Request, error) {
}

func (t *Target) SendRequest(w http.ResponseWriter, req *http.Request) {
defer t.endInflightRequest(req)
t.recordTargetNameForRequest(req)
LoggingRequestContext(req).Target = t.Target()
LoggingRequestContext(req).RequestHeaders = t.options.LogRequestHeaders
LoggingRequestContext(req).ResponseHeaders = t.options.LogResponseHeaders

inflightRequest := t.getInflightRequest(req)
tw := newTargetResponseWriter(w, inflightRequest)
defer t.endInflightRequest(req)

tw := newTargetResponseWriter(w, inflightRequest)
t.proxyHandler.ServeHTTP(tw, req)
}

Expand Down Expand Up @@ -243,13 +258,6 @@ func (t *Target) HealthCheckCompleted(success bool) {

// Private

func (t *Target) recordTargetNameForRequest(req *http.Request) {
targetIdentifer, ok := req.Context().Value(contextKeyTarget).(*string)
if ok {
*targetIdentifer = t.Target()
}
}

func (t *Target) createProxyHandler() http.Handler {
bufferPool := NewBufferPool(ProxyBufferSize)

Expand Down

0 comments on commit a16c8b1

Please sign in to comment.