From 5b9a7f4e3295944d46a1b37986cd62d9ecca896b Mon Sep 17 00:00:00 2001 From: James Pickett Date: Wed, 15 Nov 2023 07:38:21 -0800 Subject: [PATCH] remove old logger from local server, use slog and span http, add span_id and trace_sampled attrs to log (#1457) --- cmd/launcher/launcher.go | 1 - ee/localserver/krypto-ec-middleware.go | 3 +- ee/localserver/logging-handler.go | 6 +- ee/localserver/request-controlservice.go | 2 +- ee/localserver/request-controlservice_test.go | 2 +- ee/localserver/request-id.go | 19 +-- ee/localserver/request-id_test.go | 14 ++- ee/localserver/request-query.go | 10 +- ee/localserver/request-query_test.go | 6 +- ee/localserver/server.go | 116 ++++++++++++------ pkg/log/multislogger/multislogger.go | 31 +++-- pkg/traces/traces.go | 27 +++- 12 files changed, 151 insertions(+), 86 deletions(-) diff --git a/cmd/launcher/launcher.go b/cmd/launcher/launcher.go index 1494b337f..6b9e9907c 100644 --- a/cmd/launcher/launcher.go +++ b/cmd/launcher/launcher.go @@ -381,7 +381,6 @@ func runLauncher(ctx context.Context, cancel func(), slogger, systemSlogger *mul if runLocalServer { ls, err := localserver.New( k, - localserver.WithLogger(logger), ) if err != nil { diff --git a/ee/localserver/krypto-ec-middleware.go b/ee/localserver/krypto-ec-middleware.go index 0ba1ed9f9..145ae33c5 100644 --- a/ee/localserver/krypto-ec-middleware.go +++ b/ee/localserver/krypto-ec-middleware.go @@ -138,8 +138,7 @@ func (e *kryptoEcMiddleware) sendCallback(req *http.Request, data *callbackDataS func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - spanCtx, span := traces.StartSpan(r.Context()) - r = r.WithContext(spanCtx) + r, span := traces.StartHttpRequestSpan(r) defer span.End() diff --git a/ee/localserver/logging-handler.go b/ee/localserver/logging-handler.go index f66fae9e8..935d5a608 100644 --- a/ee/localserver/logging-handler.go +++ b/ee/localserver/logging-handler.go @@ -1,10 +1,9 @@ package localserver import ( + "log/slog" "net/http" "time" - - "github.com/go-kit/kit/log/level" ) type statusRecorder struct { @@ -22,7 +21,8 @@ func (ls *localServer) requestLoggingHandler(next http.Handler) http.Handler { recorder := &statusRecorder{ResponseWriter: w, Status: 200} defer func(begin time.Time) { - level.Debug(ls.logger).Log( + ls.slogger.Log(r.Context(), slog.LevelInfo, + "request log", "path", r.URL.Path, "method", r.Method, "status", recorder.Status, diff --git a/ee/localserver/request-controlservice.go b/ee/localserver/request-controlservice.go index 7abc1f12a..a454b0031 100644 --- a/ee/localserver/request-controlservice.go +++ b/ee/localserver/request-controlservice.go @@ -15,7 +15,7 @@ func (ls *localServer) requestAccelerateControlHandler() http.Handler { } func (ls *localServer) requestAccelerateControlFunc(w http.ResponseWriter, r *http.Request) { - _, span := traces.StartSpan(r.Context(), "path", r.URL.Path) + r, span := traces.StartHttpRequestSpan(r, "path", r.URL.Path) defer span.End() if r.Body == nil { diff --git a/ee/localserver/request-controlservice_test.go b/ee/localserver/request-controlservice_test.go index e5a237e96..75de109ea 100644 --- a/ee/localserver/request-controlservice_test.go +++ b/ee/localserver/request-controlservice_test.go @@ -130,7 +130,7 @@ func Test_localServer_requestAccelerateControlFunc(t *testing.T) { } var logBytes bytes.Buffer - server := testServer(t, k, &logBytes) + server := testServer(t, k) req, err := http.NewRequest("", "", nil) if tt.body != nil { diff --git a/ee/localserver/request-id.go b/ee/localserver/request-id.go index c3223b6df..c5ce8b292 100644 --- a/ee/localserver/request-id.go +++ b/ee/localserver/request-id.go @@ -5,12 +5,12 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "net/http" "os/user" "runtime" "time" - "github.com/go-kit/kit/log/level" "github.com/kolide/kit/ulid" "github.com/kolide/launcher/ee/consoleuser" "github.com/kolide/launcher/pkg/backoff" @@ -68,8 +68,8 @@ func (ls *localServer) requestIdHandler() http.Handler { return http.HandlerFunc(ls.requestIdHandlerFunc) } -func (ls *localServer) requestIdHandlerFunc(res http.ResponseWriter, req *http.Request) { - _, span := traces.StartSpan(req.Context(), "path", req.URL.Path) +func (ls *localServer) requestIdHandlerFunc(w http.ResponseWriter, r *http.Request) { + r, span := traces.StartHttpRequestSpan(r, "path", r.URL.Path) defer span.End() response := requestIdsResponse{ @@ -81,10 +81,11 @@ func (ls *localServer) requestIdHandlerFunc(res http.ResponseWriter, req *http.R consoleUsers, err := consoleUsers() if err != nil { traces.SetError(span, err) - level.Error(ls.logger).Log( - "msg", "getting console users", + ls.slogger.Log(r.Context(), slog.LevelError, + "getting console users", "err", err, ) + response.ConsoleUsers = []*user.User{} } else { response.ConsoleUsers = consoleUsers @@ -93,11 +94,15 @@ func (ls *localServer) requestIdHandlerFunc(res http.ResponseWriter, req *http.R jsonBytes, err := json.Marshal(response) if err != nil { traces.SetError(span, err) - level.Info(ls.logger).Log("msg", "unable to marshal json", "err", err) + ls.slogger.Log(r.Context(), slog.LevelError, + "marshaling json", + "err", err, + ) + jsonBytes = []byte(fmt.Sprintf("unable to marshal json: %v", err)) } - res.Write(jsonBytes) + w.Write(jsonBytes) } func consoleUsers() ([]*user.User, error) { diff --git a/ee/localserver/request-id_test.go b/ee/localserver/request-id_test.go index 0e99d974a..c554352a7 100644 --- a/ee/localserver/request-id_test.go +++ b/ee/localserver/request-id_test.go @@ -3,6 +3,7 @@ package localserver import ( "bytes" "encoding/json" + "log/slog" "net/http" "net/http/httptest" "os" @@ -15,7 +16,6 @@ import ( storageci "github.com/kolide/launcher/pkg/agent/storage/ci" "github.com/kolide/launcher/pkg/agent/types" typesMocks "github.com/kolide/launcher/pkg/agent/types/mocks" - "github.com/kolide/launcher/pkg/log/multislogger" "github.com/kolide/launcher/pkg/osquery" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -27,10 +27,14 @@ func Test_localServer_requestIdHandler(t *testing.T) { mockKnapsack := typesMocks.NewKnapsack(t) mockKnapsack.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String())) mockKnapsack.On("KolideServerURL").Return("localhost") - mockKnapsack.On("Slogger").Return(multislogger.New().Logger) var logBytes bytes.Buffer - server := testServer(t, mockKnapsack, &logBytes) + slogger := slog.New(slog.NewJSONHandler(&logBytes, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + mockKnapsack.On("Slogger").Return(slogger) + + server := testServer(t, mockKnapsack) req, err := http.NewRequest("", "", nil) require.NoError(t, err) @@ -61,10 +65,10 @@ func Test_localServer_requestIdHandler(t *testing.T) { assert.GreaterOrEqual(t, len(response.ConsoleUsers), 1, "should have at least one console user") } -func testServer(t *testing.T, k types.Knapsack, logBytes *bytes.Buffer) *localServer { +func testServer(t *testing.T, k types.Knapsack) *localServer { require.NoError(t, osquery.SetupLauncherKeys(k.ConfigStore())) - server, err := New(k, WithLogger(log.NewLogfmtLogger(logBytes))) + server, err := New(k) require.NoError(t, err) return server } diff --git a/ee/localserver/request-query.go b/ee/localserver/request-query.go index 15770a253..29932bf01 100644 --- a/ee/localserver/request-query.go +++ b/ee/localserver/request-query.go @@ -4,10 +4,10 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "net/http" "time" - "github.com/go-kit/kit/log/level" "github.com/kolide/launcher/pkg/backoff" "github.com/kolide/launcher/pkg/traces" "github.com/osquery/osquery-go/plugin/distributed" @@ -19,7 +19,7 @@ func (ls *localServer) requestQueryHandler() http.Handler { } func (ls *localServer) requestQueryHanlderFunc(w http.ResponseWriter, r *http.Request) { - _, span := traces.StartSpan(r.Context(), "path", r.URL.Path) + r, span := traces.StartHttpRequestSpan(r, "path", r.URL.Path) defer span.End() if r.Body == nil { @@ -61,7 +61,7 @@ func (ls *localServer) requestScheduledQueryHandler() http.Handler { // requestScheduledQueryHandlerFunc uses the name field in the request body to look up // an existing osquery scheduled query execute it, returning the results. func (ls *localServer) requestScheduledQueryHandlerFunc(w http.ResponseWriter, r *http.Request) { - _, span := traces.StartSpan(r.Context(), "path", r.URL.Path) + r, span := traces.StartHttpRequestSpan(r, "path", r.URL.Path) defer span.End() // The driver behind this is that the JS bridge has to use GET requests passing the query (in a nacl box) as a URL parameter. @@ -107,8 +107,8 @@ func (ls *localServer) requestScheduledQueryHandlerFunc(w http.ResponseWriter, r scheduledQueryResult, err := queryWithRetries(ls.querier, scheduledQuery["query"]) if err != nil { - level.Error(ls.logger).Log( - "msg", "running scheduled query on demand", + ls.slogger.Log(r.Context(), slog.LevelError, + "running scheduled query on demand", "err", err, "query", scheduledQuery["query"], "query_name", scheduledQuery["name"], diff --git a/ee/localserver/request-query_test.go b/ee/localserver/request-query_test.go index 0666f18fd..cc92084c1 100644 --- a/ee/localserver/request-query_test.go +++ b/ee/localserver/request-query_test.go @@ -68,8 +68,7 @@ func Test_localServer_requestQueryHandler(t *testing.T) { mockQuerier.On("Query", tt.query).Return(tt.mockQueryResult, nil).Once() } - var logBytes bytes.Buffer - server := testServer(t, mockKnapsack, &logBytes) + server := testServer(t, mockKnapsack) server.querier = mockQuerier jsonBytes, err := json.Marshal(map[string]string{ @@ -239,8 +238,7 @@ func Test_localServer_requestRunScheduledQueryHandler(t *testing.T) { } // set up test server - var logBytes bytes.Buffer - server := testServer(t, mockKnapsack, &logBytes) + server := testServer(t, mockKnapsack) server.querier = mockQuerier // make request body diff --git a/ee/localserver/server.go b/ee/localserver/server.go index f987ac76b..f51b12c07 100644 --- a/ee/localserver/server.go +++ b/ee/localserver/server.go @@ -9,13 +9,12 @@ import ( _ "embed" "errors" "fmt" + "log/slog" "net" "net/http" "strings" "time" - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" "github.com/kolide/krypto" "github.com/kolide/krypto/pkg/echelper" "github.com/kolide/launcher/pkg/agent" @@ -40,7 +39,7 @@ type Querier interface { } type localServer struct { - logger log.Logger + slogger *slog.Logger knapsack types.Knapsack srv *http.Server identifiers identifiers @@ -62,17 +61,9 @@ const ( defaultRateBurst = 10 ) -type LocalServerOption func(*localServer) - -func WithLogger(logger log.Logger) LocalServerOption { - return func(s *localServer) { - s.logger = log.With(logger, "component", "localserver") - } -} - -func New(k types.Knapsack, opts ...LocalServerOption) (*localServer, error) { +func New(k types.Knapsack) (*localServer, error) { ls := &localServer{ - logger: log.NewNopLogger(), + slogger: k.Slogger().With("component", "localserver"), knapsack: k, limiter: rate.NewLimiter(defaultRateLimit, defaultRateBurst), kolideServer: k.KolideServerURL(), @@ -80,10 +71,6 @@ func New(k types.Knapsack, opts ...LocalServerOption) (*localServer, error) { myLocalHardwareSigner: agent.HardwareKeys(), } - for _, o := range opts { - o(ls) - } - // TODO: As there may be things that adjust the keys during runtime, we need to persist that across // restarts. We should load-old-state here. This is still pretty TBD, so don't angst too much. if err := ls.LoadDefaultKeyIfNotSet(); err != nil { @@ -155,17 +142,29 @@ func (ls *localServer) LoadDefaultKeyIfNotSet() error { serverRsaCertPem := k2RsaServerCert serverEccCertPem := k2EccServerCert + + ctx := context.TODO() + slogLevel := slog.LevelDebug + switch { case strings.HasPrefix(ls.kolideServer, "localhost"), strings.HasPrefix(ls.kolideServer, "127.0.0.1"), strings.Contains(ls.kolideServer, ".ngrok."): - level.Debug(ls.logger).Log("msg", "using developer certificates") + ls.slogger.Log(ctx, slogLevel, + "using developer certificates", + ) + serverRsaCertPem = localhostRsaServerCert serverEccCertPem = localhostEccServerCert case strings.HasSuffix(ls.kolideServer, ".herokuapp.com"): - level.Debug(ls.logger).Log("msg", "using review app certificates") + ls.slogger.Log(ctx, slogLevel, + "using review app certificates", + ) + serverRsaCertPem = reviewRsaServerCert serverEccCertPem = reviewEccServerCert default: - level.Debug(ls.logger).Log("msg", "using default/production certificates") + ls.slogger.Log(ctx, slogLevel, + "using default/production certificates", + ) } serverKeyRaw, err := krypto.KeyFromPem([]byte(serverRsaCertPem)) @@ -191,18 +190,21 @@ func (ls *localServer) LoadDefaultKeyIfNotSet() error { func (ls *localServer) runAsyncdWorkers() time.Time { success := true - level.Debug(ls.logger).Log("msg", "Starting an async worker run") + ctx := context.TODO() + ls.slogger.Log(ctx, slog.LevelDebug, + "starting async worker run", + ) if err := ls.updateIdFields(); err != nil { success = false - level.Info(ls.logger).Log( - "msg", "Got error updating id fields", + ls.slogger.Log(ctx, slog.LevelError, + "updating id fields", "err", err, ) } - level.Debug(ls.logger).Log( - "msg", "Completed async worker run", + ls.slogger.Log(ctx, slog.LevelDebug, + "completed async worker run", "success", success, ) @@ -226,12 +228,16 @@ func (ls *localServer) Start() error { go func() { var lastRun time.Time + ctx := context.TODO() + // note that this will trigger the check for the first time after pollInterval (not immediately) for range time.Tick(pollInterval) { if time.Since(lastRun) > recalculateInterval { lastRun = ls.runAsyncdWorkers() if lastRun.IsZero() { - level.Debug(ls.logger).Log("message", "runAsyncdWorkers unsuccessful, will retry in the future.") + ls.slogger.Log(ctx, slog.LevelDebug, + "runAsyncdWorkers unsuccessful, will retry in the future", + ) } } } @@ -242,26 +248,39 @@ func (ls *localServer) Start() error { return fmt.Errorf("starting listener: %w", err) } + ctx := context.TODO() + if ls.tlsCerts != nil && len(ls.tlsCerts) > 0 { - level.Debug(ls.logger).Log("message", "Using TLS") + ls.slogger.Log(ctx, slog.LevelDebug, + "using TLS", + ) tlsConfig := &tls.Config{Certificates: ls.tlsCerts} l = tls.NewListener(l, tlsConfig) } else { - level.Debug(ls.logger).Log("message", "No TLS") + ls.slogger.Log(ctx, slog.LevelDebug, + "not using TLS", + ) } return ls.srv.Serve(l) } func (ls *localServer) Stop() error { - level.Debug(ls.logger).Log("msg", "Stopping") - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx := context.TODO() + ls.slogger.Log(ctx, slog.LevelDebug, + "stopping", + ) + + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() if err := ls.srv.Shutdown(ctx); err != nil { - level.Info(ls.logger).Log("message", "got error shutting down", "error", err) + ls.slogger.Log(ctx, slog.LevelError, + "shutting down", + "err", err, + ) } // Consider calling srv.Stop as a more forceful shutdown? @@ -270,27 +289,44 @@ func (ls *localServer) Stop() error { } func (ls *localServer) Interrupt(_ error) { - level.Debug(ls.logger).Log("message", "Stopping due to interrupt") + ctx := context.TODO() + + ls.slogger.Log(ctx, slog.LevelDebug, + "stopping due to interrupt", + ) + if err := ls.Stop(); err != nil { - level.Info(ls.logger).Log("message", "got error interrupting", "error", err) + ls.slogger.Log(ctx, slog.LevelError, + "stopping", + "err", err, + ) } } func (ls *localServer) startListener() (net.Listener, error) { + ctx := context.TODO() + for _, p := range portList { - level.Debug(ls.logger).Log("msg", "Trying port", "port", p) + ls.slogger.Log(ctx, slog.LevelDebug, + "trying port", + "port", p, + ) l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", p)) if err != nil { - level.Debug(ls.logger).Log( - "message", "Unable to bind to port. Moving on", + ls.slogger.Log(ctx, slog.LevelDebug, + "unable to bind to port, moving on", "port", p, "err", err, ) + continue } - level.Info(ls.logger).Log("msg", "Got port", "port", p) + ls.slogger.Log(ctx, slog.LevelInfo, + "got port", + "port", p, + ) return l, nil } @@ -326,7 +362,11 @@ func (ls *localServer) rateLimitHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ls.limiter.Allow() == false { http.Error(w, http.StatusText(429), http.StatusTooManyRequests) - level.Error(ls.logger).Log("msg", "Over rate limit") + + ls.slogger.Log(r.Context(), slog.LevelError, + "over rate limit", + ) + return } diff --git a/pkg/log/multislogger/multislogger.go b/pkg/log/multislogger/multislogger.go index 776ef983d..28f17e5d3 100644 --- a/pkg/log/multislogger/multislogger.go +++ b/pkg/log/multislogger/multislogger.go @@ -8,25 +8,29 @@ import ( slogmulti "github.com/samber/slog-multi" ) -type contextKey int +type contextKey string func (c contextKey) String() string { - switch c { - case KolideSessionIdKey: - return "kolide_session_id" - case SpanIdKey: - return "span_id" - default: - return "unknown" - } + return string(c) } const ( // KolideSessionIdKey this the also the saml session id - KolideSessionIdKey contextKey = iota - SpanIdKey + KolideSessionIdKey contextKey = "kolide_session_id" + SpanIdKey contextKey = "span_id" + TraceIdKey contextKey = "trace_id" + TraceSampledKey contextKey = "trace_sampled" ) +// ctxValueKeysToAdd is a list of context keys that will be +// added as log attributes +var ctxValueKeysToAdd = []contextKey{ + SpanIdKey, + TraceIdKey, + KolideSessionIdKey, + TraceSampledKey, +} + type MultiSlogger struct { *slog.Logger handlers []slog.Handler @@ -70,11 +74,6 @@ func utcTimeMiddleware(ctx context.Context, record slog.Record, next func(contex return next(ctx, record) } -var ctxValueKeysToAdd = []contextKey{ - SpanIdKey, - KolideSessionIdKey, -} - func ctxValuesMiddleWare(ctx context.Context, record slog.Record, next func(context.Context, slog.Record) error) error { for _, key := range ctxValueKeysToAdd { if v := ctx.Value(key); v != nil { diff --git a/pkg/traces/traces.go b/pkg/traces/traces.go index ab2776d21..2fe9e6688 100644 --- a/pkg/traces/traces.go +++ b/pkg/traces/traces.go @@ -3,6 +3,7 @@ package traces import ( "context" "fmt" + "net/http" "path/filepath" "runtime" @@ -20,20 +21,37 @@ const ( defaultAttributeNamespace = "unknown" ) +// StartHttpRequestSpan returns a copy of the request with a new context to include span info and span, +// including information about the calling function as appropriate. +// Standardizes the tracer name. The caller is always responsible for +// ending the span. `keyVals` should be a list of pairs, where the first in the pair is a +// string representing the attribute key and the second in the pair is the attribute value. +func StartHttpRequestSpan(r *http.Request, keyVals ...interface{}) (*http.Request, trace.Span) { + ctx, span := startSpan(r.Context(), keyVals...) + return r.WithContext(ctx), span +} + // StartSpan returns a new context and span, including information about the calling function // as appropriate. Standardizes the tracer name. The caller is always responsible for // ending the span. `keyVals` should be a list of pairs, where the first in the pair is a // string representing the attribute key and the second in the pair is the attribute value. func StartSpan(ctx context.Context, keyVals ...interface{}) (context.Context, trace.Span) { + return startSpan(ctx, keyVals...) +} + +// startSpan is the internal implementation of StartSpan and StartHttpRequestSpan with runtime.Caller(2) +// so that the caller of the wrapper function is used. +func startSpan(ctx context.Context, keyVals ...interface{}) (context.Context, trace.Span) { spanName := defaultSpanName opts := make([]trace.SpanStartOption, 0) // Extract information about the caller to set some standard attributes (code.filepath, // code.lineno, code.function) and to set more specific span and attribute names. - // runtime.Caller(0) would return information about `StartSpan` -- calling - // runtime.Caller(1) will return information about the function calling `StartSpan`. - programCounter, callerFile, callerLine, ok := runtime.Caller(1) + // runtime.Caller(0) would return information about `startSpan` -- calling + // runtime.Caller(1) will return information about the wrapper function calling `startSpan`. + // runtime.Caller(2) will return information about the function calling the wrapper function + programCounter, callerFile, callerLine, ok := runtime.Caller(2) if ok { opts = append(opts, trace.WithAttributes( semconv.CodeFilepath(callerFile), @@ -51,6 +69,9 @@ func StartSpan(ctx context.Context, keyVals ...interface{}) (context.Context, tr spanCtx, span := otel.Tracer(instrumentationPkg).Start(ctx, spanName, opts...) spanCtx = context.WithValue(spanCtx, multislogger.SpanIdKey, span.SpanContext().SpanID().String()) + spanCtx = context.WithValue(spanCtx, multislogger.TraceIdKey, span.SpanContext().TraceID().String()) + spanCtx = context.WithValue(spanCtx, multislogger.TraceSampledKey, span.SpanContext().IsSampled()) + return spanCtx, span }