diff --git a/go.mod b/go.mod index 08dc81f..424a258 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/sadlil/workgroup go 1.23.1 + +require github.com/avast/retry-go v3.0.0+incompatible // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..af7b582 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= +github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= diff --git a/workgroup.go b/workgroup.go index 5345db3..af2adb5 100644 --- a/workgroup.go +++ b/workgroup.go @@ -25,6 +25,8 @@ import ( "context" "errors" "sync" + + "github.com/avast/retry-go" ) // FailureMode defines how the workgroup handles errors encountered @@ -51,6 +53,14 @@ func WithLimit(n int) Option { } } +// WithRetry sets the retry policy for individual goroutines +// within the workgroup. +func WithRetry(opts ...retry.Option) Option { + return func(g *Group) { + g.retryOptions = append(g.retryOptions, opts...) + } +} + // A Group is a collection of goroutines working on subtasks that are part of // the same overall task. // @@ -59,25 +69,35 @@ func WithLimit(n int) Option { // - Does not cancel on error (uses `Collect` failure mode). // - Does not retry on error. type Group struct { - err error - wg sync.WaitGroup - cancel func() - sem chan struct{} - failureMode FailureMode - errOnce sync.Once - errLock sync.Mutex + cancel func() + + err error + errOnce sync.Once + errLock sync.Mutex + + wg sync.WaitGroup + sem chan struct{} + + failureMode FailureMode + retryOptions []retry.Option } // New creates a new workgroup with the specified failure mode and options. // It returns a context that is derived from `ctx`. // The derived context is canceled when the workgroup finishes // or is canceled explicitly. +// If no Retry is specified, the default behavior is no retries. func New(ctx context.Context, mode FailureMode, opts ...Option) (context.Context, *Group) { ctx, cancel := context.WithCancel(ctx) g := &Group{ cancel: cancel, failureMode: mode, + retryOptions: []retry.Option{ + retry.Attempts(1), + retry.LastErrorOnly(true), + retry.Context(ctx), + }, } for _, opt := range opts { opt(g) @@ -95,7 +115,7 @@ func (g *Group) Go(ctx context.Context, fn func() error) { go func() { defer g.done() - err := fn() + err := retry.Do(fn, g.retryOptions...) if err != nil { g.errLock.Lock() defer g.errLock.Unlock() diff --git a/workgroup_test.go b/workgroup_test.go index 8084007..8185f7a 100644 --- a/workgroup_test.go +++ b/workgroup_test.go @@ -7,6 +7,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/avast/retry-go" ) var ( @@ -194,6 +196,61 @@ func TestGroup_Cancel(t *testing.T) { } } +func TestGroup_WithRetry(t *testing.T) { + tests := []struct { + name string + retryPolicy []retry.Option + fn func() error + wantCount int32 + wantErr error + }{ + { + name: "retry_happy_path", + retryPolicy: []retry.Option{ + retry.Attempts(3), + }, + wantCount: 3, + fn: func() error { return fmt.Errorf("retry_happy_path: %w", errInternal) }, + wantErr: errInternal, + }, + { + name: "retry_happy_success_no_retry", + retryPolicy: []retry.Option{ + retry.Attempts(100), + }, + wantCount: 1, + fn: func() error { return nil }, + }, + { + name: "no_retry_policy", + retryPolicy: nil, // No retry policy specified + wantCount: 1, + fn: func() error { return fmt.Errorf("no_retry_policy: %w", errInternal) }, + wantErr: errInternal, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx, g := New(context.Background(), Collect, WithRetry(tc.retryPolicy...)) + + var retryCount int32 + g.Go(ctx, func() error { + atomic.AddInt32(&retryCount, 1) + return tc.fn() + }) + + err := g.Wait() + if !errors.Is(err, tc.wantErr) { + t.Fatalf("group.Wait() = %v, want tempError", err) + } + if retryCount != tc.wantCount { + t.Errorf("expected %v retries, but got %d", tc.wantCount, retryCount) + } + }) + } +} + func BenchmarkGo(b *testing.B) { ctx, g := New(context.Background(), Collect) b.ResetTimer()