Skip to content

Commit

Permalink
test: improve code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
reugn committed Aug 26, 2024
1 parent 56d4c67 commit af60676
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
17 changes: 9 additions & 8 deletions executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ func NewExecutorConfig(workerPoolSize, queueSize int) *ExecutorConfig {
// Executor implements the [ExecutorService] interface.
type Executor[T any] struct {
cancel context.CancelFunc
queue chan job[T]
queue chan executorJob[T]
status atomic.Uint32
}

var _ ExecutorService[any] = (*Executor[any])(nil)

type job[T any] struct {
type executorJob[T any] struct {
promise Promise[T]
task func(context.Context) (T, error)
}
Expand All @@ -70,23 +70,24 @@ func NewExecutor[T any](ctx context.Context, config *ExecutorConfig) *Executor[T
ctx, cancel := context.WithCancel(ctx)
executor := &Executor[T]{
cancel: cancel,
queue: make(chan job[T], config.QueueSize),
queue: make(chan executorJob[T], config.QueueSize),
}
// set the executor status to running explicitly
executor.status.Store(uint32(ExecutorStatusRunning))

// init the workers pool
go executor.startWorkers(ctx, config.WorkerPoolSize)

// set status to terminating when ctx is done
go executor.monitorCtx(ctx)

// set the executor status to running
executor.status.Store(uint32(ExecutorStatusRunning))

return executor
}

func (e *Executor[T]) monitorCtx(ctx context.Context) {
<-ctx.Done()
e.status.Store(uint32(ExecutorStatusTerminating))
_ = e.status.CompareAndSwap(uint32(ExecutorStatusRunning),
uint32(ExecutorStatusTerminating))
}

func (e *Executor[T]) startWorkers(ctx context.Context, poolSize int) {
Expand Down Expand Up @@ -130,7 +131,7 @@ func (e *Executor[T]) Submit(f func(context.Context) (T, error)) (Future[T], err
promise := NewPromise[T]()
if ExecutorStatus(e.status.Load()) == ExecutorStatusRunning {
select {
case e.queue <- job[T]{promise, f}:
case e.queue <- executorJob[T]{promise, f}:
default:
return nil, ErrExecutorQueueFull
}
Expand Down
23 changes: 22 additions & 1 deletion executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package async

import (
"context"
"errors"
"runtime"
"testing"
"time"
Expand Down Expand Up @@ -44,7 +45,7 @@ func TestExecutor(t *testing.T) {
routines := runtime.NumGoroutine()

// shut down the executor
executor.Shutdown()
_ = executor.Shutdown()
time.Sleep(time.Millisecond)

// verify that submit fails after the executor was shut down
Expand All @@ -62,6 +63,26 @@ func TestExecutor(t *testing.T) {
assertFutureError(t, ErrExecutorShutdown, future5, future6)
}

func TestExecutor_context(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
executor := NewExecutor[int](ctx, NewExecutorConfig(2, 2))

job := func(_ context.Context) (int, error) {
return 0, errors.New("error")
}

future, err := executor.Submit(job)
assert.IsNil(t, err)

result, err := future.Join()
assert.Equal(t, result, 0)
assert.ErrorContains(t, err, "error")

cancel()
time.Sleep(5 * time.Millisecond)
assert.Equal(t, executor.Status(), ExecutorStatusShutdown)
}

func submitJob[T any](t *testing.T, executor ExecutorService[T],
f func(context.Context) (T, error)) Future[T] {
future, err := executor.Submit(f)
Expand Down
37 changes: 28 additions & 9 deletions future_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func TestFuture(t *testing.T) {
time.Sleep(100 * time.Millisecond)
p.Success(true)
}()

res, err := p.Future().Join()

assert.Equal(t, true, res)
Expand All @@ -42,6 +43,7 @@ func TestFuture_Utils(t *testing.T) {
time.Sleep(300 * time.Millisecond)
p3.Failure(err3)
}()

arr := []Future[*int]{p1.Future(), p2.Future(), p3.Future()}
res := []any{res1, res2, err3}
futRes, _ := FutureSeq(arr).Join()
Expand All @@ -55,6 +57,7 @@ func TestFuture_FirstCompleted(t *testing.T) {
time.Sleep(100 * time.Millisecond)
p.Success(util.Ptr(true))
}()

timeout := FutureTimer[*bool](10 * time.Millisecond)
futRes, futErr := FutureFirstCompletedOf(p.Future(), timeout).Join()

Expand All @@ -68,6 +71,7 @@ func TestFuture_Transform(t *testing.T) {
time.Sleep(100 * time.Millisecond)
p1.Success(util.Ptr(1))
}()

future := p1.Future().Map(func(v *int) (*int, error) {
inc := *v + 1
return &inc, nil
Expand Down Expand Up @@ -96,6 +100,7 @@ func TestFuture_Recover(t *testing.T) {
time.Sleep(10 * time.Millisecond)
p2.Failure(errors.New("recover Future failure"))
}()

future := p1.Future().Map(func(_ int) (int, error) {
return 0, errors.New("map error")
}).FlatMap(func(_ int) (Future[int], error) {
Expand All @@ -116,17 +121,30 @@ func TestFuture_Recover(t *testing.T) {
}

func TestFuture_Failure(t *testing.T) {
p1 := NewPromise[*int]()
p2 := NewPromise[*int]()
p1 := NewPromise[int]()
p2 := NewPromise[int]()
p3 := NewPromise[int]()
err := errors.New("error")
go func() {
time.Sleep(10 * time.Millisecond)
p1.Failure(errors.New("Future error"))
time.Sleep(20 * time.Millisecond)
p2.Success(util.Ptr(2))
time.Sleep(5 * time.Millisecond)
p1.Failure(err)
time.Sleep(5 * time.Millisecond)
p2.Failure(err)
time.Sleep(5 * time.Millisecond)
p3.Success(2)
}()
res, _ := p1.Future().RecoverWith(p2.Future()).Join()

assert.Equal(t, 2, *res)
res, _ := p1.Future().
Map(func(_ int) (int, error) { return 0, err }).
FlatMap(func(_ int) (Future[int], error) { return p1.Future(), err }).
RecoverWith(p2.Future()).
RecoverWith(p3.Future()).
FlatMap(func(_ int) (Future[int], error) { return p2.Future(), err }).
RecoverWith(p3.Future()).
RecoverWith(p3.Future()).
Join()

assert.Equal(t, 2, res)
}

func TestFuture_Timeout(t *testing.T) {
Expand All @@ -135,6 +153,7 @@ func TestFuture_Timeout(t *testing.T) {
time.Sleep(100 * time.Millisecond)
p.Success(true)
}()

future := p.Future()

ctx, cancel := context.WithTimeout(context.Background(),
Expand All @@ -150,7 +169,6 @@ func TestFuture_Timeout(t *testing.T) {

func TestFuture_GoroutineLeak(t *testing.T) {
var wg sync.WaitGroup

fmt.Println(runtime.NumGoroutine())

numFuture := 100
Expand All @@ -176,6 +194,7 @@ func TestFuture_GoroutineLeak(t *testing.T) {
time.Sleep(10 * time.Millisecond)
numGoroutine := runtime.NumGoroutine()
fmt.Println(numGoroutine)

if numGoroutine > numFuture {
t.Fatalf("numGoroutine is %d", numGoroutine)
}
Expand Down

0 comments on commit af60676

Please sign in to comment.