diff --git a/once_test.go b/once_test.go index 39d518d..7b99713 100644 --- a/once_test.go +++ b/once_test.go @@ -23,7 +23,7 @@ func TestOnce(t *testing.T) { func TestOnceConcurrent(t *testing.T) { var once Once[int32] - var count int32 + var count atomic.Int32 var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -31,14 +31,14 @@ func TestOnceConcurrent(t *testing.T) { go func() { defer wg.Done() result, _ := once.Do(func() (int32, error) { - newCount := atomic.AddInt32(&count, 1) + newCount := count.Add(1) return newCount, nil }) - atomic.StoreInt32(&count, result) + count.Store(result) }() } wg.Wait() - assert.Equal(t, count, 1) + assert.Equal(t, int(count.Load()), 1) } func TestOncePanic(t *testing.T) { diff --git a/wait_group_context.go b/wait_group_context.go index 8880b3c..39a3743 100644 --- a/wait_group_context.go +++ b/wait_group_context.go @@ -13,8 +13,8 @@ import ( type WaitGroupContext struct { ctx context.Context done chan struct{} - counter int32 - state int32 + counter atomic.Int32 + state atomic.Int32 } // NewWaitGroupContext returns a new WaitGroupContext with Context ctx. @@ -29,10 +29,10 @@ func NewWaitGroupContext(ctx context.Context) *WaitGroupContext { // If the counter becomes zero, all goroutines blocked on Wait are released. // If the counter goes negative, Add panics. func (wgc *WaitGroupContext) Add(delta int) { - counter := atomic.AddInt32(&wgc.counter, int32(delta)) - if counter == 0 && atomic.CompareAndSwapInt32(&wgc.state, 0, 1) { + counter := wgc.counter.Add(int32(delta)) + if counter == 0 && wgc.state.CompareAndSwap(0, 1) { wgc.release() - } else if counter < 0 && atomic.LoadInt32(&wgc.state) == 0 { + } else if counter < 0 && wgc.state.Load() == 0 { panic("async: negative WaitGroupContext counter") } } diff --git a/wait_group_context_test.go b/wait_group_context_test.go index 2b3b69e..f8de1eb 100644 --- a/wait_group_context_test.go +++ b/wait_group_context_test.go @@ -2,6 +2,7 @@ package async import ( "context" + "sync/atomic" "testing" "time" @@ -9,37 +10,37 @@ import ( ) func TestWaitGroupContext(t *testing.T) { - result := 0 + var result atomic.Int32 wgc := NewWaitGroupContext(context.Background()) wgc.Add(2) go func() { defer wgc.Done() time.Sleep(time.Millisecond * 10) - result++ + result.Add(1) }() go func() { defer wgc.Done() time.Sleep(time.Millisecond * 20) - result += 2 + result.Add(2) }() go func() { wgc.Wait() - result += 3 + result.Add(3) }() wgc.Wait() time.Sleep(time.Millisecond * 10) - assert.Equal(t, result, 6) + assert.Equal(t, int(result.Load()), 6) } func TestWaitGroupContextCanceled(t *testing.T) { - result := 0 + var result atomic.Int32 ctx, cancelFunc := context.WithCancel(context.Background()) go func() { time.Sleep(time.Millisecond * 100) - result += 10 + result.Add(10) cancelFunc() }() wgc := NewWaitGroupContext(ctx) @@ -48,22 +49,22 @@ func TestWaitGroupContextCanceled(t *testing.T) { go func() { defer wgc.Done() time.Sleep(time.Millisecond * 10) - result++ + result.Add(1) }() go func() { defer wgc.Done() time.Sleep(time.Millisecond * 300) - result += 2 + result.Add(2) }() go func() { wgc.Wait() - result += 100 + result.Add(100) }() wgc.Wait() time.Sleep(time.Millisecond * 10) - assert.Equal(t, result, 111) + assert.Equal(t, int(result.Load()), 111) } func TestWaitGroupContextPanic(t *testing.T) {