Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for osqueryinstance's errgroup #2017

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions ee/errgroup/errgroup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package errgroup

import (
"context"
"log/slog"
"time"

"golang.org/x/sync/errgroup"
)

type LoggedErrgroup struct {
errgroup *errgroup.Group
cancel context.CancelFunc
doneCtx context.Context // nolint:containedctx
slogger *slog.Logger
}

const (
maxShutdownGoroutineDuration = 3 * time.Second
)

func NewLoggedErrgroup(ctx context.Context, slogger *slog.Logger) *LoggedErrgroup {
ctx, cancel := context.WithCancel(ctx)
e, doneCtx := errgroup.WithContext(ctx)

return &LoggedErrgroup{
errgroup: e,
cancel: cancel,
doneCtx: doneCtx,
slogger: slogger,
}
}

// StartGoroutine starts the given goroutine in the errgroup, ensuring that we log its start and exit.
func (l *LoggedErrgroup) StartGoroutine(ctx context.Context, goroutineName string, goroutine func() error) {
l.errgroup.Go(func() error {
l.slogger.Log(ctx, slog.LevelInfo,
"starting goroutine in errgroup",
"goroutine_name", goroutineName,
)

err := goroutine()

l.slogger.Log(ctx, slog.LevelInfo,
"exiting goroutine in errgroup",
"goroutine_name", goroutineName,
"goroutine_err", err,
)

return err
})
}

// StartRepeatedGoroutine starts the given goroutine in the errgroup, ensuring that we log its start and exit.
// If the delay is non-zero, the goroutine will not start until after the delay interval has elapsed. The goroutine
// will run on the given interval, and will continue to run until it returns an error or the errgroup shuts down.
func (l *LoggedErrgroup) StartRepeatedGoroutine(ctx context.Context, goroutineName string, interval time.Duration, delay time.Duration, goroutine func() error) {
l.errgroup.Go(func() error {
l.slogger.Log(ctx, slog.LevelInfo,
"starting repeated goroutine in errgroup",
"goroutine_name", goroutineName,
"goroutine_interval", interval.String(),
"goroutine_start_delay", delay.String(),
)

if delay != 0*time.Second {
select {
case <-time.After(delay):
l.slogger.Log(ctx, slog.LevelDebug,
"exiting delay before starting repeated goroutine",
"goroutine_name", goroutineName,
)
case <-l.doneCtx.Done():
return nil
}
}

ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
select {
case <-l.doneCtx.Done():
l.slogger.Log(ctx, slog.LevelInfo,
"exiting repeated goroutine in errgroup due to shutdown",
"goroutine_name", goroutineName,
)
return nil
case <-ticker.C:
if err := goroutine(); err != nil {
l.slogger.Log(ctx, slog.LevelInfo,
"exiting repeated goroutine in errgroup due to error",
"goroutine_name", goroutineName,
"goroutine_err", err,
)
return err
}
}
}
})
}

// AddShutdownGoroutine adds the given goroutine to the errgroup, ensuring that we log its start and exit.
// The goroutine will not execute until the errgroup has received a signal to exit.
func (l *LoggedErrgroup) AddShutdownGoroutine(ctx context.Context, goroutineName string, goroutine func() error) {
l.errgroup.Go(func() error {
// Wait for errgroup to exit
<-l.doneCtx.Done()

l.slogger.Log(ctx, slog.LevelInfo,
"starting shutdown goroutine in errgroup",
"goroutine_name", goroutineName,
)

goroutineStart := time.Now()
err := goroutine()
elapsedTime := time.Since(goroutineStart)

logLevel := slog.LevelInfo
if elapsedTime > maxShutdownGoroutineDuration {
logLevel = slog.LevelWarn
}

l.slogger.Log(ctx, logLevel,
"exiting shutdown goroutine in errgroup",
"goroutine_name", goroutineName,
"goroutine_run_time", elapsedTime.String(),
"goroutine_err", err,
)

// We don't want to actually return the error here, to avoid causing an otherwise successful call
// to `Shutdown` => `Wait` to return an error. Shutdown routine errors don't matter for the success
// of the errgroup overall.
return l.doneCtx.Err()
})
}

func (l *LoggedErrgroup) Shutdown() {
l.cancel()
}

func (l *LoggedErrgroup) Wait() error {
return l.errgroup.Wait()
}

func (l *LoggedErrgroup) Exited() <-chan struct{} {
return l.doneCtx.Done()
}
102 changes: 102 additions & 0 deletions ee/errgroup/errgroup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package errgroup

import (
"context"
"errors"
"strconv"
"testing"
"time"

"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/stretchr/testify/require"
)

func TestWait(t *testing.T) {
t.Parallel()

err1 := errors.New("errgroup_test: 1")
err2 := errors.New("errgroup_test: 2")

for _, tt := range []struct {
testCaseName string
errs []error
expectedErr error
}{
{
testCaseName: "no error on exit",
errs: []error{nil},
expectedErr: nil,
},
{
testCaseName: "only first routine has error on exit",
errs: []error{err1, nil},
expectedErr: err1,
},
{
testCaseName: "only second routine has error on exit",
errs: []error{nil, err2},
expectedErr: err2,
},
{
testCaseName: "multiple routines have error on exit",
errs: []error{err1, nil, err2},
expectedErr: err1,
},
} {
tt := tt
t.Run(tt.testCaseName, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

for i, err := range tt.errs {
err := err
eg.StartGoroutine(ctx, strconv.Itoa(i), func() error { return err })
time.Sleep(500 * time.Millisecond) // try to enforce ordering of goroutines
}

// We should get the expected error when we wait for the routines to exit
require.Equal(t, tt.expectedErr, eg.Wait(), "incorrect error returned by errgroup")

// We expect that the errgroup shuts down
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
})
}
}

func TestShutdown(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

eg := NewLoggedErrgroup(ctx, multislogger.NewNopLogger())

eg.StartGoroutine(ctx, "test_goroutine", func() error {
return nil
})

// We should get the expected error when we wait for the routines to exit
eg.Shutdown()
require.Nil(t, eg.Wait(), "should not have returned error on shutdown")

// We expect that the errgroup shuts down
canceled := false
select {
case <-eg.Exited():
canceled = true
default:
}

require.True(t, canceled, "errgroup did not exit")
}
Loading
Loading