From 8c898ba5b46b117d6ca9187251785d6537349c29 Mon Sep 17 00:00:00 2001 From: Rebecca Mahany-Horton Date: Wed, 8 Jan 2025 18:09:00 -0500 Subject: [PATCH] Improvements for osqueryinstance's errgroup (#2017) --- ee/errgroup/errgroup.go | 182 +++++++++++++++++++++ ee/errgroup/errgroup_test.go | 188 ++++++++++++++++++++++ pkg/osquery/runtime/osqueryinstance.go | 160 ++++++------------ pkg/osquery/runtime/runtime_posix_test.go | 2 + pkg/osquery/runtime/runtime_test.go | 8 + 5 files changed, 425 insertions(+), 115 deletions(-) create mode 100644 ee/errgroup/errgroup.go create mode 100644 ee/errgroup/errgroup_test.go diff --git a/ee/errgroup/errgroup.go b/ee/errgroup/errgroup.go new file mode 100644 index 000000000..2606c3fba --- /dev/null +++ b/ee/errgroup/errgroup.go @@ -0,0 +1,182 @@ +package errgroup + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" +) + +type LoggedErrgroup struct { + errgroup *errgroup.Group + cancel context.CancelFunc + doneCtx context.Context // nolint:containedctx + slogger *slog.Logger +} + +const ( + maxShutdownGoroutineDuration = 3 * time.Second +) + +func NewLoggedErrgroup(ctx context.Context, slogger *slog.Logger) *LoggedErrgroup { + ctx, cancel := context.WithCancel(ctx) + e, doneCtx := errgroup.WithContext(ctx) + + return &LoggedErrgroup{ + errgroup: e, + cancel: cancel, + doneCtx: doneCtx, + slogger: slogger, + } +} + +// StartGoroutine starts the given goroutine in the errgroup, ensuring that we log its start and exit. +func (l *LoggedErrgroup) StartGoroutine(ctx context.Context, goroutineName string, goroutine func() error) { + l.errgroup.Go(func() (err error) { + slogger := l.slogger.With("goroutine_name", goroutineName) + + // Catch any panicking goroutines and log them. We also want to make sure + // we return an error from this goroutine overall if it panics. + defer func() { + if r := recover(); r != nil { + slogger.Log(ctx, slog.LevelError, + "panic occurred in goroutine", + "err", r, + ) + if recoveredErr, ok := r.(error); ok { + slogger.Log(ctx, slog.LevelError, + "panic stack trace", + "stack_trace", fmt.Sprintf("%+v", errors.WithStack(recoveredErr)), + ) + err = recoveredErr + } + } + }() + + slogger.Log(ctx, slog.LevelInfo, + "starting goroutine in errgroup", + ) + + err = goroutine() + + slogger.Log(ctx, slog.LevelInfo, + "exiting goroutine in errgroup", + "goroutine_err", err, + ) + + return err + }) +} + +// StartRepeatedGoroutine starts the given goroutine in the errgroup, ensuring that we log its start and exit. +// If the delay is non-zero, the goroutine will not start until after the delay interval has elapsed. The goroutine +// will run on the given interval, and will continue to run until it returns an error or the errgroup shuts down. +func (l *LoggedErrgroup) StartRepeatedGoroutine(ctx context.Context, goroutineName string, interval time.Duration, delay time.Duration, goroutine func() error) { + l.StartGoroutine(ctx, goroutineName, func() error { + slogger := l.slogger.With("goroutine_name", goroutineName) + + if delay != 0*time.Second { + select { + case <-time.After(delay): + slogger.Log(ctx, slog.LevelDebug, + "exiting delay before starting repeated goroutine", + ) + case <-l.doneCtx.Done(): + return nil + } + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + // Run goroutine immediately + if err := goroutine(); err != nil { + slogger.Log(ctx, slog.LevelInfo, + "exiting repeated goroutine in errgroup due to error", + "goroutine_err", err, + ) + return err + } + + // Wait for next interval or for errgroup shutdown + select { + case <-l.doneCtx.Done(): + slogger.Log(ctx, slog.LevelInfo, + "exiting repeated goroutine in errgroup due to shutdown", + ) + return nil + case <-ticker.C: + continue + } + } + }) +} + +// AddShutdownGoroutine adds the given goroutine to the errgroup, ensuring that we log its start and exit. +// The goroutine will not execute until the errgroup has received a signal to exit. +func (l *LoggedErrgroup) AddShutdownGoroutine(ctx context.Context, goroutineName string, goroutine func() error) { + l.errgroup.Go(func() error { + slogger := l.slogger.With("goroutine_name", goroutineName) + + // Catch any panicking goroutines and log them. We do not want to return + // the error from this routine, as we do for StartGoroutine and StartRepeatedGoroutine -- + // shutdown goroutines should not return an error besides the errgroup's initial error. + defer func() { + if r := recover(); r != nil { + slogger.Log(ctx, slog.LevelError, + "panic occurred in shutdown goroutine", + "err", r, + ) + if err, ok := r.(error); ok { + slogger.Log(ctx, slog.LevelError, + "panic stack trace", + "stack_trace", fmt.Sprintf("%+v", errors.WithStack(err)), + ) + } + } + }() + + // Wait for errgroup to exit + <-l.doneCtx.Done() + + slogger.Log(ctx, slog.LevelInfo, + "starting shutdown goroutine in errgroup", + ) + + goroutineStart := time.Now() + err := goroutine() + elapsedTime := time.Since(goroutineStart) + + logLevel := slog.LevelInfo + if elapsedTime > maxShutdownGoroutineDuration || err != nil { + logLevel = slog.LevelWarn + } + slogger.Log(ctx, logLevel, + "exiting shutdown goroutine in errgroup", + "goroutine_name", goroutineName, + "goroutine_run_time", elapsedTime.String(), + "goroutine_err", err, + ) + + // We don't want to actually return the error here, to avoid causing an otherwise successful call + // to `Shutdown` => `Wait` to return an error. Shutdown routine errors don't matter for the success + // of the errgroup overall. + return l.doneCtx.Err() + }) +} + +func (l *LoggedErrgroup) Shutdown() { + l.cancel() +} + +func (l *LoggedErrgroup) Wait() error { + return l.errgroup.Wait() +} + +func (l *LoggedErrgroup) Exited() <-chan struct{} { + return l.doneCtx.Done() +} diff --git a/ee/errgroup/errgroup_test.go b/ee/errgroup/errgroup_test.go new file mode 100644 index 000000000..105d60487 --- /dev/null +++ b/ee/errgroup/errgroup_test.go @@ -0,0 +1,188 @@ +package errgroup + +import ( + "context" + "errors" + "fmt" + "strconv" + "testing" + "time" + + "github.com/kolide/launcher/pkg/log/multislogger" + "github.com/stretchr/testify/require" +) + +func TestWait(t *testing.T) { + t.Parallel() + + err1 := errors.New("errgroup_test: 1") + err2 := errors.New("errgroup_test: 2") + + for _, tt := range []struct { + testCaseName string + errs []error + expectedErr error + }{ + { + testCaseName: "no error on exit", + errs: []error{nil}, + expectedErr: nil, + }, + { + testCaseName: "only first routine has error on exit", + errs: []error{err1, nil}, + expectedErr: err1, + }, + { + testCaseName: "only second routine has error on exit", + errs: []error{nil, err2}, + expectedErr: err2, + }, + { + testCaseName: "multiple routines have error on exit", + errs: []error{err1, nil, err2}, + expectedErr: err1, + }, + } { + tt := tt + t.Run(tt.testCaseName, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger()) + + for i, err := range tt.errs { + err := err + eg.StartGoroutine(ctx, strconv.Itoa(i), func() error { return err }) + time.Sleep(500 * time.Millisecond) // try to enforce ordering of goroutines + } + + // We should get the expected error when we wait for the routines to exit + require.Equal(t, tt.expectedErr, eg.Wait(), "incorrect error returned by errgroup") + + // We expect that the errgroup shuts down + canceled := false + select { + case <-eg.Exited(): + canceled = true + default: + } + + require.True(t, canceled, "errgroup did not exit") + }) + } +} + +func TestShutdown(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger()) + + eg.StartGoroutine(ctx, "test_goroutine", func() error { + return nil + }) + + // We should get the expected error when we wait for the routines to exit + eg.Shutdown() + require.Nil(t, eg.Wait(), "should not have returned error on shutdown") + + // We expect that the errgroup shuts down + canceled := false + select { + case <-eg.Exited(): + canceled = true + default: + } + + require.True(t, canceled, "errgroup did not exit") +} + +func TestStartGoroutine_HandlesPanic(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger()) + + eg.StartGoroutine(ctx, "test_goroutine", func() error { + testArr := make([]int, 0) + fmt.Println(testArr[2]) // cause out-of-bounds panic + return nil + }) + + // We expect that the errgroup shuts itself down -- the test should not panic + require.Error(t, eg.Wait(), "should have returned error from panicking goroutine") + canceled := false + select { + case <-eg.Exited(): + canceled = true + default: + } + + require.True(t, canceled, "errgroup did not exit") +} + +func TestStartRepeatedGoroutine_HandlesPanic(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger()) + + eg.StartRepeatedGoroutine(ctx, "test_goroutine", 100*time.Millisecond, 50*time.Millisecond, func() error { + testArr := make([]int, 0) + fmt.Println(testArr[2]) // cause out-of-bounds panic + return nil + }) + + // Wait for long enough that the repeated goroutine executes at least once + time.Sleep(500 * time.Millisecond) + + // We expect that the errgroup shuts itself down -- the test should not panic + require.Error(t, eg.Wait(), "should have returned error from panicking goroutine") + canceled := false + select { + case <-eg.Exited(): + canceled = true + default: + } + + require.True(t, canceled, "errgroup did not exit") +} + +func TestAddShutdownGoroutine_HandlesPanic(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger()) + + eg.AddShutdownGoroutine(ctx, "test_goroutine", func() error { + testArr := make([]int, 0) + fmt.Println(testArr[2]) // cause out-of-bounds panic + return nil + }) + + // Call shutdown so the shutdown goroutine runs and the errgroup returns. + eg.Shutdown() + + // We expect that the errgroup shuts itself down -- the test should not panic. + // Since we called `Shutdown`, `Wait` should not return an error. + require.Nil(t, eg.Wait(), "should not returned error after call to Shutdown") + canceled := false + select { + case <-eg.Exited(): + canceled = true + default: + } + + require.True(t, canceled, "errgroup did not exit") +} diff --git a/pkg/osquery/runtime/osqueryinstance.go b/pkg/osquery/runtime/osqueryinstance.go index 7680bd098..8b1a7e416 100644 --- a/pkg/osquery/runtime/osqueryinstance.go +++ b/pkg/osquery/runtime/osqueryinstance.go @@ -17,6 +17,7 @@ import ( "github.com/kolide/kit/ulid" "github.com/kolide/launcher/ee/agent/types" + "github.com/kolide/launcher/ee/errgroup" "github.com/kolide/launcher/ee/gowrapper" kolidelog "github.com/kolide/launcher/ee/log/osquerylogs" "github.com/kolide/launcher/pkg/backoff" @@ -31,8 +32,6 @@ import ( osquerylogger "github.com/osquery/osquery-go/plugin/logger" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - - "golang.org/x/sync/errgroup" ) const ( @@ -91,10 +90,8 @@ type OsqueryInstance struct { // the following are instance artifacts that are created and held as a result // of launching an osqueryd process runId string // string identifier for this instance - errgroup *errgroup.Group + errgroup *errgroup.LoggedErrgroup saasExtension *launcherosq.Extension - doneCtx context.Context // nolint:containedctx - cancel context.CancelFunc cmd *exec.Cmd emsLock sync.RWMutex // Lock for extensionManagerServers extensionManagerServers []*osquery.ExtensionManagerServer @@ -184,9 +181,7 @@ func newInstance(registrationId string, knapsack types.Knapsack, serviceClient s opt(i) } - ctx, cancel := context.WithCancel(context.Background()) - i.cancel = cancel - i.errgroup, i.doneCtx = errgroup.WithContext(ctx) + i.errgroup = errgroup.NewLoggedErrgroup(context.Background(), i.slogger) i.startFunc = func(cmd *exec.Cmd) error { return cmd.Start() @@ -200,7 +195,7 @@ func (i *OsqueryInstance) BeginShutdown() { i.slogger.Log(context.TODO(), slog.LevelInfo, "instance shutdown requested", ) - i.cancel() + i.errgroup.Shutdown() } // WaitShutdown waits for the instance's errgroup routines to exit, then returns the @@ -226,7 +221,7 @@ func (i *OsqueryInstance) WaitShutdown() error { // Exited returns a channel to monitor for signal that instance has shut itself down func (i *OsqueryInstance) Exited() <-chan struct{} { - return i.doneCtx.Done() + return i.errgroup.Exited() } // Launch starts the osquery instance and its components. It will run until one of its @@ -355,7 +350,7 @@ func (i *OsqueryInstance) Launch() error { // This loop runs in the background when the process was // successfully started. ("successful" is independent of exit // code. eg: this runs if we could exec. Failure to exec is above.) - i.addGoroutineToErrgroup(ctx, "monitor_osquery_process", func() error { + i.errgroup.StartGoroutine(ctx, "monitor_osquery_process", func() error { err := i.cmd.Wait() switch { case err == nil, isExitOk(err): @@ -378,23 +373,24 @@ func (i *OsqueryInstance) Launch() error { }) // Kill osquery process on shutdown - i.addShutdownGoroutineToErrgroup(ctx, "kill_osquery_process", func() error { - if i.cmd.Process != nil { - // kill osqueryd and children - if err := killProcessGroup(i.cmd); err != nil { - if strings.Contains(err.Error(), "process already finished") || strings.Contains(err.Error(), "no such process") { - i.slogger.Log(ctx, slog.LevelDebug, - "tried to stop osquery, but process already gone", - ) - } else { - i.slogger.Log(ctx, slog.LevelWarn, - "error killing osquery process", - "err", err, - ) - } + i.errgroup.AddShutdownGoroutine(ctx, "kill_osquery_process", func() error { + if i.cmd.Process == nil { + return nil + } + + // kill osqueryd and children + if err := killProcessGroup(i.cmd); err != nil { + if strings.Contains(err.Error(), "process already finished") || strings.Contains(err.Error(), "no such process") { + i.slogger.Log(ctx, slog.LevelDebug, + "tried to stop osquery, but process already gone", + ) + return nil } + + return fmt.Errorf("killing osquery process: %w", err) } - return i.doneCtx.Err() + + return nil }) // Start an extension manager for the extensions that osquery @@ -434,43 +430,22 @@ func (i *OsqueryInstance) Launch() error { } // Health check on interval - i.addGoroutineToErrgroup(ctx, "healthcheck", func() error { - if i.knapsack.OsqueryHealthcheckStartupDelay() != 0*time.Second { - i.slogger.Log(ctx, slog.LevelDebug, - "entering delay before starting osquery healthchecks", - ) - select { - case <-time.After(i.knapsack.OsqueryHealthcheckStartupDelay()): - i.slogger.Log(ctx, slog.LevelDebug, - "exiting delay before starting osquery healthchecks", - ) - case <-i.doneCtx.Done(): - return i.doneCtx.Err() - } + i.errgroup.StartRepeatedGoroutine(ctx, "healthcheck", healthCheckInterval, i.knapsack.OsqueryHealthcheckStartupDelay(), func() error { + // If device is sleeping, we do not want to perform unnecessary healthchecks that + // may force an unnecessary restart. + if i.knapsack != nil && i.knapsack.InModernStandby() { + return nil } - ticker := time.NewTicker(healthCheckInterval) - defer ticker.Stop() - for { - select { - case <-i.doneCtx.Done(): - return i.doneCtx.Err() - case <-ticker.C: - // If device is sleeping, we do not want to perform unnecessary healthchecks that - // may force an unnecessary restart. - if i.knapsack != nil && i.knapsack.InModernStandby() { - break - } - - if err := i.healthcheckWithRetries(ctx, 5, 1*time.Second); err != nil { - return fmt.Errorf("health check failed: %w", err) - } - } + if err := i.healthcheckWithRetries(ctx, 5, 1*time.Second); err != nil { + return fmt.Errorf("health check failed: %w", err) } + + return nil }) // Clean up PID file on shutdown - i.addShutdownGoroutineToErrgroup(ctx, "remove_pid_file", func() error { + i.errgroup.AddShutdownGoroutine(ctx, "remove_pid_file", func() error { // We do a couple retries -- on Windows, the PID file may still be in use // and therefore unable to be removed. if err := backoff.WaitFor(func() error { @@ -479,17 +454,13 @@ func (i *OsqueryInstance) Launch() error { } return nil }, 5*time.Second, 500*time.Millisecond); err != nil { - i.slogger.Log(ctx, slog.LevelInfo, - "could not remove PID file, despite retries", - "pid_file", paths.pidfilePath, - "err", err, - ) + return fmt.Errorf("removing PID file %s failed with retries: %w", paths.pidfilePath, err) } - return i.doneCtx.Err() + return nil }) // Clean up socket file on shutdown - i.addShutdownGoroutineToErrgroup(ctx, "remove_socket_file", func() error { + i.errgroup.AddShutdownGoroutine(ctx, "remove_socket_file", func() error { // We do a couple retries -- on Windows, the socket file may still be in use // and therefore unable to be removed. if err := backoff.WaitFor(func() error { @@ -498,13 +469,9 @@ func (i *OsqueryInstance) Launch() error { } return nil }, 5*time.Second, 500*time.Millisecond); err != nil { - i.slogger.Log(ctx, slog.LevelInfo, - "could not remove socket file, despite retries", - "socket_file", paths.extensionSocketPath, - "err", err, - ) + return fmt.Errorf("removing socket file %s failed with retries: %w", paths.extensionSocketPath, err) } - return i.doneCtx.Err() + return nil }) return nil @@ -599,7 +566,7 @@ func (i *OsqueryInstance) startKolideSaasExtension(ctx context.Context) error { }) // Run extension - i.addGoroutineToErrgroup(ctx, "saas_extension_execute", func() error { + i.errgroup.StartGoroutine(ctx, "saas_extension_execute", func() error { if err := i.saasExtension.Execute(); err != nil { return fmt.Errorf("kolide_grpc extension returned error: %w", err) } @@ -607,52 +574,14 @@ func (i *OsqueryInstance) startKolideSaasExtension(ctx context.Context) error { }) // Register shutdown group for extension - i.addShutdownGoroutineToErrgroup(ctx, "saas_extension_cleanup", func() error { - i.saasExtension.Shutdown(i.doneCtx.Err()) - return i.doneCtx.Err() + i.errgroup.AddShutdownGoroutine(ctx, "saas_extension_cleanup", func() error { + i.saasExtension.Shutdown(nil) + return nil }) return nil } -// addGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. -func (i *OsqueryInstance) addGoroutineToErrgroup(ctx context.Context, goroutineName string, goroutine func() error) { - i.errgroup.Go(func() error { - defer i.slogger.Log(ctx, slog.LevelInfo, - "exiting goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - i.slogger.Log(ctx, slog.LevelInfo, - "starting goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - return goroutine() - }) -} - -// addShutdownGoroutineToErrgroup adds the given goroutine to the errgroup, ensuring that we log its start and exit. -// The goroutine will not execute until the instance has received a signal to exit. -func (i *OsqueryInstance) addShutdownGoroutineToErrgroup(ctx context.Context, goroutineName string, goroutine func() error) { - i.errgroup.Go(func() error { - defer i.slogger.Log(ctx, slog.LevelInfo, - "exiting shutdown goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - // Wait for errgroup to exit - <-i.doneCtx.Done() - - i.slogger.Log(ctx, slog.LevelInfo, - "starting shutdown goroutine in errgroup", - "goroutine_name", goroutineName, - ) - - return goroutine() - }) -} - // osqueryFilePaths is a struct which contains the relevant file paths needed to // launch an osqueryd instance. type osqueryFilePaths struct { @@ -857,7 +786,7 @@ func (i *OsqueryInstance) StartOsqueryExtensionManagerServer(name string, socket i.extensionManagerServers = append(i.extensionManagerServers, extensionManagerServer) // Start! - i.addGoroutineToErrgroup(context.TODO(), name, func() error { + i.errgroup.StartGoroutine(context.TODO(), name, func() error { if err := extensionManagerServer.Start(); err != nil { i.slogger.Log(context.TODO(), slog.LevelInfo, "extension manager server startup got error", @@ -871,15 +800,16 @@ func (i *OsqueryInstance) StartOsqueryExtensionManagerServer(name string, socket }) // register a shutdown routine - i.addShutdownGoroutineToErrgroup(context.TODO(), fmt.Sprintf("%s_cleanup", name), func() error { + i.errgroup.AddShutdownGoroutine(context.TODO(), fmt.Sprintf("%s_cleanup", name), func() error { if err := extensionManagerServer.Shutdown(context.TODO()); err != nil { + // Log error, but no need to bubble it up further i.slogger.Log(context.TODO(), slog.LevelInfo, "got error while shutting down extension server", "err", err, "extension_name", name, ) } - return i.doneCtx.Err() + return nil }) return nil diff --git a/pkg/osquery/runtime/runtime_posix_test.go b/pkg/osquery/runtime/runtime_posix_test.go index 1f64f50cb..24a3e3031 100644 --- a/pkg/osquery/runtime/runtime_posix_test.go +++ b/pkg/osquery/runtime/runtime_posix_test.go @@ -55,6 +55,7 @@ func TestOsquerySlowStart(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) runner := New(k, mockServiceClient(), WithStartFunc(func(cmd *exec.Cmd) error { @@ -102,6 +103,7 @@ func TestExtensionSocketPath(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) extensionSocketPath := filepath.Join(rootDirectory, "sock") diff --git a/pkg/osquery/runtime/runtime_test.go b/pkg/osquery/runtime/runtime_test.go index 3ed3555fd..b3fd2418c 100644 --- a/pkg/osquery/runtime/runtime_test.go +++ b/pkg/osquery/runtime/runtime_test.go @@ -178,6 +178,7 @@ func TestWithOsqueryFlags(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) runner := New(k, mockServiceClient()) @@ -212,6 +213,7 @@ func TestFlagsChanged(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) // Start the runner @@ -343,6 +345,7 @@ func TestSimplePath(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) runner := New(k, mockServiceClient()) @@ -379,6 +382,7 @@ func TestMultipleInstances(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) serviceClient := mockServiceClient() @@ -438,6 +442,7 @@ func TestRunnerHandlesImmediateShutdownWithMultipleInstances(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) serviceClient := mockServiceClient() @@ -489,6 +494,7 @@ func TestMultipleShutdowns(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) runner := New(k, mockServiceClient()) @@ -521,6 +527,7 @@ func TestOsqueryDies(t *testing.T) { k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) runner := New(k, mockServiceClient()) @@ -656,6 +663,7 @@ func setupOsqueryInstanceForTests(t *testing.T) (runner *Runner, logBytes *threa k.On("LogMaxBytesPerBatch").Return(0).Maybe() k.On("Transport").Return("jsonrpc").Maybe() k.On("ReadEnrollSecret").Return("", nil).Maybe() + k.On("InModernStandby").Return(false).Maybe() setUpMockStores(t, k) runner = New(k, mockServiceClient())