Skip to content

Commit

Permalink
Improvements for osqueryinstance's errgroup (kolide#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany authored Jan 8, 2025
1 parent 2605c9c commit 8c898ba
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 115 deletions.
182 changes: 182 additions & 0 deletions ee/errgroup/errgroup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package errgroup

import (
"context"
"fmt"
"log/slog"
"time"

"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
)

type LoggedErrgroup struct {
errgroup *errgroup.Group
cancel context.CancelFunc
doneCtx context.Context // nolint:containedctx
slogger *slog.Logger
}

const (
maxShutdownGoroutineDuration = 3 * time.Second
)

func NewLoggedErrgroup(ctx context.Context, slogger *slog.Logger) *LoggedErrgroup {
ctx, cancel := context.WithCancel(ctx)
e, doneCtx := errgroup.WithContext(ctx)

return &LoggedErrgroup{
errgroup: e,
cancel: cancel,
doneCtx: doneCtx,
slogger: slogger,
}
}

// StartGoroutine starts the given goroutine in the errgroup, ensuring that we log its start and exit.
func (l *LoggedErrgroup) StartGoroutine(ctx context.Context, goroutineName string, goroutine func() error) {
l.errgroup.Go(func() (err error) {
slogger := l.slogger.With("goroutine_name", goroutineName)

// Catch any panicking goroutines and log them. We also want to make sure
// we return an error from this goroutine overall if it panics.
defer func() {
if r := recover(); r != nil {
slogger.Log(ctx, slog.LevelError,
"panic occurred in goroutine",
"err", r,
)
if recoveredErr, ok := r.(error); ok {
slogger.Log(ctx, slog.LevelError,
"panic stack trace",
"stack_trace", fmt.Sprintf("%+v", errors.WithStack(recoveredErr)),
)
err = recoveredErr
}
}
}()

slogger.Log(ctx, slog.LevelInfo,
"starting goroutine in errgroup",
)

err = goroutine()

slogger.Log(ctx, slog.LevelInfo,
"exiting goroutine in errgroup",
"goroutine_err", err,
)

return err
})
}

// StartRepeatedGoroutine starts the given goroutine in the errgroup, ensuring that we log its start and exit.
// If the delay is non-zero, the goroutine will not start until after the delay interval has elapsed. The goroutine
// will run on the given interval, and will continue to run until it returns an error or the errgroup shuts down.
func (l *LoggedErrgroup) StartRepeatedGoroutine(ctx context.Context, goroutineName string, interval time.Duration, delay time.Duration, goroutine func() error) {
l.StartGoroutine(ctx, goroutineName, func() error {
slogger := l.slogger.With("goroutine_name", goroutineName)

if delay != 0*time.Second {
select {
case <-time.After(delay):
slogger.Log(ctx, slog.LevelDebug,
"exiting delay before starting repeated goroutine",
)
case <-l.doneCtx.Done():
return nil
}
}

ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
// Run goroutine immediately
if err := goroutine(); err != nil {
slogger.Log(ctx, slog.LevelInfo,
"exiting repeated goroutine in errgroup due to error",
"goroutine_err", err,
)
return err
}

// Wait for next interval or for errgroup shutdown
select {
case <-l.doneCtx.Done():
slogger.Log(ctx, slog.LevelInfo,
"exiting repeated goroutine in errgroup due to shutdown",
)
return nil
case <-ticker.C:
continue
}
}
})
}

// AddShutdownGoroutine adds the given goroutine to the errgroup, ensuring that we log its start and exit.
// The goroutine will not execute until the errgroup has received a signal to exit.
func (l *LoggedErrgroup) AddShutdownGoroutine(ctx context.Context, goroutineName string, goroutine func() error) {
l.errgroup.Go(func() error {
slogger := l.slogger.With("goroutine_name", goroutineName)

// Catch any panicking goroutines and log them. We do not want to return
// the error from this routine, as we do for StartGoroutine and StartRepeatedGoroutine --
// shutdown goroutines should not return an error besides the errgroup's initial error.
defer func() {
if r := recover(); r != nil {
slogger.Log(ctx, slog.LevelError,
"panic occurred in shutdown goroutine",
"err", r,
)
if err, ok := r.(error); ok {
slogger.Log(ctx, slog.LevelError,
"panic stack trace",
"stack_trace", fmt.Sprintf("%+v", errors.WithStack(err)),
)
}
}
}()

// Wait for errgroup to exit
<-l.doneCtx.Done()

slogger.Log(ctx, slog.LevelInfo,
"starting shutdown goroutine in errgroup",
)

goroutineStart := time.Now()
err := goroutine()
elapsedTime := time.Since(goroutineStart)

logLevel := slog.LevelInfo
if elapsedTime > maxShutdownGoroutineDuration || err != nil {
logLevel = slog.LevelWarn
}
slogger.Log(ctx, logLevel,
"exiting shutdown goroutine in errgroup",
"goroutine_name", goroutineName,
"goroutine_run_time", elapsedTime.String(),
"goroutine_err", err,
)

// We don't want to actually return the error here, to avoid causing an otherwise successful call
// to `Shutdown` => `Wait` to return an error. Shutdown routine errors don't matter for the success
// of the errgroup overall.
return l.doneCtx.Err()
})
}

func (l *LoggedErrgroup) Shutdown() {
l.cancel()
}

func (l *LoggedErrgroup) Wait() error {
return l.errgroup.Wait()
}

func (l *LoggedErrgroup) Exited() <-chan struct{} {
return l.doneCtx.Done()
}
188 changes: 188 additions & 0 deletions ee/errgroup/errgroup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package errgroup

import (
"context"
"errors"
"fmt"
"strconv"
"testing"
"time"

"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/stretchr/testify/require"
)

func TestWait(t *testing.T) {
t.Parallel()

err1 := errors.New("errgroup_test: 1")
err2 := errors.New("errgroup_test: 2")

for _, tt := range []struct {
testCaseName string
errs []error
expectedErr error
}{
{
testCaseName: "no error on exit",
errs: []error{nil},
expectedErr: nil,
},
{
testCaseName: "only first routine has error on exit",
errs: []error{err1, nil},
expectedErr: err1,
},
{
testCaseName: "only second routine has error on exit",
errs: []error{nil, err2},
expectedErr: err2,
},
{
testCaseName: "multiple routines have error on exit",
errs: []error{err1, nil, err2},
expectedErr: err1,
},
} {
tt := tt
t.Run(tt.testCaseName, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

for i, err := range tt.errs {
err := err
eg.StartGoroutine(ctx, strconv.Itoa(i), func() error { return err })
time.Sleep(500 * time.Millisecond) // try to enforce ordering of goroutines
}

// We should get the expected error when we wait for the routines to exit
require.Equal(t, tt.expectedErr, eg.Wait(), "incorrect error returned by errgroup")

// We expect that the errgroup shuts down
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
})
}
}

func TestShutdown(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

eg.StartGoroutine(ctx, "test_goroutine", func() error {
return nil
})

// We should get the expected error when we wait for the routines to exit
eg.Shutdown()
require.Nil(t, eg.Wait(), "should not have returned error on shutdown")

// We expect that the errgroup shuts down
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
}

func TestStartGoroutine_HandlesPanic(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

eg.StartGoroutine(ctx, "test_goroutine", func() error {
testArr := make([]int, 0)
fmt.Println(testArr[2]) // cause out-of-bounds panic
return nil
})

// We expect that the errgroup shuts itself down -- the test should not panic
require.Error(t, eg.Wait(), "should have returned error from panicking goroutine")
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
}

func TestStartRepeatedGoroutine_HandlesPanic(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

eg.StartRepeatedGoroutine(ctx, "test_goroutine", 100*time.Millisecond, 50*time.Millisecond, func() error {
testArr := make([]int, 0)
fmt.Println(testArr[2]) // cause out-of-bounds panic
return nil
})

// Wait for long enough that the repeated goroutine executes at least once
time.Sleep(500 * time.Millisecond)

// We expect that the errgroup shuts itself down -- the test should not panic
require.Error(t, eg.Wait(), "should have returned error from panicking goroutine")
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
}

func TestAddShutdownGoroutine_HandlesPanic(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

eg.AddShutdownGoroutine(ctx, "test_goroutine", func() error {
testArr := make([]int, 0)
fmt.Println(testArr[2]) // cause out-of-bounds panic
return nil
})

// Call shutdown so the shutdown goroutine runs and the errgroup returns.
eg.Shutdown()

// We expect that the errgroup shuts itself down -- the test should not panic.
// Since we called `Shutdown`, `Wait` should not return an error.
require.Nil(t, eg.Wait(), "should not returned error after call to Shutdown")
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
}
Loading

0 comments on commit 8c898ba

Please sign in to comment.