diff --git a/executor.go b/executor.go index ffe133b..789fc1d 100644 --- a/executor.go +++ b/executor.go @@ -54,13 +54,13 @@ func NewExecutorConfig(workerPoolSize, queueSize int) *ExecutorConfig { // Executor implements the [ExecutorService] interface. type Executor[T any] struct { cancel context.CancelFunc - queue chan job[T] + queue chan executorJob[T] status atomic.Uint32 } var _ ExecutorService[any] = (*Executor[any])(nil) -type job[T any] struct { +type executorJob[T any] struct { promise Promise[T] task func(context.Context) (T, error) } @@ -70,23 +70,24 @@ func NewExecutor[T any](ctx context.Context, config *ExecutorConfig) *Executor[T ctx, cancel := context.WithCancel(ctx) executor := &Executor[T]{ cancel: cancel, - queue: make(chan job[T], config.QueueSize), + queue: make(chan executorJob[T], config.QueueSize), } + // set the executor status to running explicitly + executor.status.Store(uint32(ExecutorStatusRunning)) + // init the workers pool go executor.startWorkers(ctx, config.WorkerPoolSize) // set status to terminating when ctx is done go executor.monitorCtx(ctx) - // set the executor status to running - executor.status.Store(uint32(ExecutorStatusRunning)) - return executor } func (e *Executor[T]) monitorCtx(ctx context.Context) { <-ctx.Done() - e.status.Store(uint32(ExecutorStatusTerminating)) + _ = e.status.CompareAndSwap(uint32(ExecutorStatusRunning), + uint32(ExecutorStatusTerminating)) } func (e *Executor[T]) startWorkers(ctx context.Context, poolSize int) { @@ -130,7 +131,7 @@ func (e *Executor[T]) Submit(f func(context.Context) (T, error)) (Future[T], err promise := NewPromise[T]() if ExecutorStatus(e.status.Load()) == ExecutorStatusRunning { select { - case e.queue <- job[T]{promise, f}: + case e.queue <- executorJob[T]{promise, f}: default: return nil, ErrExecutorQueueFull } diff --git a/executor_test.go b/executor_test.go index 870cf7f..e27fed5 100644 --- a/executor_test.go +++ b/executor_test.go @@ -2,6 +2,7 @@ package async import ( "context" + "errors" "runtime" "testing" "time" @@ -44,7 +45,7 @@ func TestExecutor(t *testing.T) { routines := runtime.NumGoroutine() // shut down the executor - executor.Shutdown() + _ = executor.Shutdown() time.Sleep(time.Millisecond) // verify that submit fails after the executor was shut down @@ -62,6 +63,26 @@ func TestExecutor(t *testing.T) { assertFutureError(t, ErrExecutorShutdown, future5, future6) } +func TestExecutor_context(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + executor := NewExecutor[int](ctx, NewExecutorConfig(2, 2)) + + job := func(_ context.Context) (int, error) { + return 0, errors.New("error") + } + + future, err := executor.Submit(job) + assert.IsNil(t, err) + + result, err := future.Join() + assert.Equal(t, result, 0) + assert.ErrorContains(t, err, "error") + + cancel() + time.Sleep(5 * time.Millisecond) + assert.Equal(t, executor.Status(), ExecutorStatusShutdown) +} + func submitJob[T any](t *testing.T, executor ExecutorService[T], f func(context.Context) (T, error)) Future[T] { future, err := executor.Submit(f) diff --git a/future_test.go b/future_test.go index c1493e9..ae49866 100644 --- a/future_test.go +++ b/future_test.go @@ -19,6 +19,7 @@ func TestFuture(t *testing.T) { time.Sleep(100 * time.Millisecond) p.Success(true) }() + res, err := p.Future().Join() assert.Equal(t, true, res) @@ -42,6 +43,7 @@ func TestFuture_Utils(t *testing.T) { time.Sleep(300 * time.Millisecond) p3.Failure(err3) }() + arr := []Future[*int]{p1.Future(), p2.Future(), p3.Future()} res := []any{res1, res2, err3} futRes, _ := FutureSeq(arr).Join() @@ -55,6 +57,7 @@ func TestFuture_FirstCompleted(t *testing.T) { time.Sleep(100 * time.Millisecond) p.Success(util.Ptr(true)) }() + timeout := FutureTimer[*bool](10 * time.Millisecond) futRes, futErr := FutureFirstCompletedOf(p.Future(), timeout).Join() @@ -68,6 +71,7 @@ func TestFuture_Transform(t *testing.T) { time.Sleep(100 * time.Millisecond) p1.Success(util.Ptr(1)) }() + future := p1.Future().Map(func(v *int) (*int, error) { inc := *v + 1 return &inc, nil @@ -96,6 +100,7 @@ func TestFuture_Recover(t *testing.T) { time.Sleep(10 * time.Millisecond) p2.Failure(errors.New("recover Future failure")) }() + future := p1.Future().Map(func(_ int) (int, error) { return 0, errors.New("map error") }).FlatMap(func(_ int) (Future[int], error) { @@ -116,17 +121,30 @@ func TestFuture_Recover(t *testing.T) { } func TestFuture_Failure(t *testing.T) { - p1 := NewPromise[*int]() - p2 := NewPromise[*int]() + p1 := NewPromise[int]() + p2 := NewPromise[int]() + p3 := NewPromise[int]() + err := errors.New("error") go func() { - time.Sleep(10 * time.Millisecond) - p1.Failure(errors.New("Future error")) - time.Sleep(20 * time.Millisecond) - p2.Success(util.Ptr(2)) + time.Sleep(5 * time.Millisecond) + p1.Failure(err) + time.Sleep(5 * time.Millisecond) + p2.Failure(err) + time.Sleep(5 * time.Millisecond) + p3.Success(2) }() - res, _ := p1.Future().RecoverWith(p2.Future()).Join() - assert.Equal(t, 2, *res) + res, _ := p1.Future(). + Map(func(_ int) (int, error) { return 0, err }). + FlatMap(func(_ int) (Future[int], error) { return p1.Future(), err }). + RecoverWith(p2.Future()). + RecoverWith(p3.Future()). + FlatMap(func(_ int) (Future[int], error) { return p2.Future(), err }). + RecoverWith(p3.Future()). + RecoverWith(p3.Future()). + Join() + + assert.Equal(t, 2, res) } func TestFuture_Timeout(t *testing.T) { @@ -135,6 +153,7 @@ func TestFuture_Timeout(t *testing.T) { time.Sleep(100 * time.Millisecond) p.Success(true) }() + future := p.Future() ctx, cancel := context.WithTimeout(context.Background(), @@ -150,7 +169,6 @@ func TestFuture_Timeout(t *testing.T) { func TestFuture_GoroutineLeak(t *testing.T) { var wg sync.WaitGroup - fmt.Println(runtime.NumGoroutine()) numFuture := 100 @@ -176,6 +194,7 @@ func TestFuture_GoroutineLeak(t *testing.T) { time.Sleep(10 * time.Millisecond) numGoroutine := runtime.NumGoroutine() fmt.Println(numGoroutine) + if numGoroutine > numFuture { t.Fatalf("numGoroutine is %d", numGoroutine) }