From 9ffbb1f468c80fcb15bbcdd5ac3cf9dfe5d7e805 Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Fri, 6 Dec 2024 15:13:46 -0800 Subject: [PATCH] add ConcurrencyLimit to worker to enable dynamic tuning of concurrencies (#1410) What changed? [High Risk] replaced buffered channel with resizable semaphore to control task concurrency [Low Risk] added worker package for modularity added ConcurrencyLimit entity to worker removed unused methods of autoscaler interface Why? needed as first step to enable dynamic tuning of poller and task concurrencies How did you test it? Unit Test --- internal/common/autoscaler/autoscaler.go | 6 - internal/internal_poller_autoscaler.go | 26 +- internal/internal_poller_autoscaler_test.go | 12 +- internal/internal_worker_base.go | 40 +-- internal/worker/concurrency.go | 41 +++ internal/worker/resizable_permit.go | 120 +++++++++ internal/worker/resizable_permit_test.go | 283 ++++++++++++++++++++ 7 files changed, 482 insertions(+), 46 deletions(-) create mode 100644 internal/worker/concurrency.go create mode 100644 internal/worker/resizable_permit.go create mode 100644 internal/worker/resizable_permit_test.go diff --git a/internal/common/autoscaler/autoscaler.go b/internal/common/autoscaler/autoscaler.go index ecac9641c..c080917ce 100644 --- a/internal/common/autoscaler/autoscaler.go +++ b/internal/common/autoscaler/autoscaler.go @@ -24,12 +24,6 @@ package autoscaler type ( AutoScaler interface { Estimator - // Acquire X ResourceUnit of resource - Acquire(ResourceUnit) error - // Release X ResourceUnit of resource - Release(ResourceUnit) - // GetCurrent ResourceUnit of resource - GetCurrent() ResourceUnit // Start starts the autoscaler go routine that scales the ResourceUnit according to Estimator Start() // Stop stops the autoscaler if started or do nothing if not yet started diff --git a/internal/internal_poller_autoscaler.go b/internal/internal_poller_autoscaler.go index d88dd1f10..2dc81e7ba 100644 --- a/internal/internal_poller_autoscaler.go +++ b/internal/internal_poller_autoscaler.go @@ -26,11 +26,11 @@ import ( "sync" "time" - "github.com/marusama/semaphore/v2" "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/cadence/internal/common/autoscaler" + "go.uber.org/cadence/internal/worker" ) // defaultPollerScalerCooldownInSeconds @@ -53,7 +53,7 @@ type ( isDryRun bool cooldownTime time.Duration logger *zap.Logger - sem semaphore.Semaphore // resizable semaphore to control number of concurrent pollers + permit worker.Permit ctx context.Context cancel context.CancelFunc wg *sync.WaitGroup // graceful stop @@ -82,6 +82,7 @@ type ( func newPollerScaler( options pollerAutoScalerOptions, logger *zap.Logger, + permit worker.Permit, hooks ...func()) *pollerAutoScaler { if !options.Enabled { return nil @@ -91,7 +92,7 @@ func newPollerScaler( isDryRun: options.DryRun, cooldownTime: options.Cooldown, logger: logger, - sem: semaphore.New(options.InitCount), + permit: permit, wg: &sync.WaitGroup{}, ctx: ctx, cancel: cancel, @@ -107,21 +108,6 @@ func newPollerScaler( } } -// Acquire concurrent poll quota -func (p *pollerAutoScaler) Acquire(resource autoscaler.ResourceUnit) error { - return p.sem.Acquire(p.ctx, int(resource)) -} - -// Release concurrent poll quota -func (p *pollerAutoScaler) Release(resource autoscaler.ResourceUnit) { - p.sem.Release(int(resource)) -} - -// GetCurrent poll quota -func (p *pollerAutoScaler) GetCurrent() autoscaler.ResourceUnit { - return autoscaler.ResourceUnit(p.sem.GetLimit()) -} - // Start an auto-scaler go routine and returns a done to stop it func (p *pollerAutoScaler) Start() { logger := p.logger.Sugar() @@ -133,7 +119,7 @@ func (p *pollerAutoScaler) Start() { case <-p.ctx.Done(): return case <-time.After(p.cooldownTime): - currentResource := autoscaler.ResourceUnit(p.sem.GetLimit()) + currentResource := autoscaler.ResourceUnit(p.permit.Quota()) currentUsages, err := p.pollerUsageEstimator.Estimate() if err != nil { logger.Warnw("poller autoscaler skip due to estimator error", "error", err) @@ -146,7 +132,7 @@ func (p *pollerAutoScaler) Start() { "recommend", uint64(proposedResource), "isDryRun", p.isDryRun) if !p.isDryRun { - p.sem.SetLimit(int(proposedResource)) + p.permit.SetQuota(int(proposedResource)) } p.pollerUsageEstimator.Reset() diff --git a/internal/internal_poller_autoscaler_test.go b/internal/internal_poller_autoscaler_test.go index 68514602f..4a441b642 100644 --- a/internal/internal_poller_autoscaler_test.go +++ b/internal/internal_poller_autoscaler_test.go @@ -21,12 +21,14 @@ package internal import ( + "context" "math/rand" "sync" "testing" "time" "go.uber.org/cadence/internal/common/testlogger" + "go.uber.org/cadence/internal/worker" "github.com/stretchr/testify/assert" "go.uber.org/atomic" @@ -171,6 +173,7 @@ func Test_pollerAutoscaler(t *testing.T) { TargetUtilization: float64(tt.args.targetMilliUsage) / 1000, }, testlogger.NewZap(t), + worker.NewResizablePermit(tt.args.initialPollerCount), // hook function that collects number of iterations func() { autoscalerEpoch.Add(1) @@ -190,18 +193,19 @@ func Test_pollerAutoscaler(t *testing.T) { go func() { defer wg.Done() for pollResult := range pollChan { - pollerScaler.Acquire(1) + err := pollerScaler.permit.Acquire(context.Background()) + assert.NoError(t, err) pollerScaler.CollectUsage(pollResult) - pollerScaler.Release(1) + pollerScaler.permit.Release() } }() } assert.Eventually(t, func() bool { return autoscalerEpoch.Load() == uint64(tt.args.autoScalerEpoch) - }, tt.args.cooldownTime+20*time.Millisecond, 10*time.Millisecond) + }, tt.args.cooldownTime+100*time.Millisecond, 10*time.Millisecond) pollerScaler.Stop() - res := pollerScaler.GetCurrent() + res := pollerScaler.permit.Quota() - pollerScaler.permit.Count() assert.Equal(t, tt.want, int(res)) }) } diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index ba9da7818..b4bfb0ad6 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -33,6 +33,7 @@ import ( "time" "go.uber.org/cadence/internal/common/debug" + "go.uber.org/cadence/internal/worker" "github.com/uber-go/tally" "go.uber.org/zap" @@ -141,7 +142,7 @@ type ( logger *zap.Logger metricsScope tally.Scope - pollerRequestCh chan struct{} + concurrency *worker.ConcurrencyLimit pollerAutoScaler *pollerAutoScaler taskQueueCh chan interface{} sessionTokenBucket *sessionTokenBucket @@ -167,11 +168,17 @@ func createPollRetryPolicy() backoff.RetryPolicy { func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope tally.Scope, sessionTokenBucket *sessionTokenBucket) *baseWorker { ctx, cancel := context.WithCancel(context.Background()) + concurrency := &worker.ConcurrencyLimit{ + PollerPermit: worker.NewResizablePermit(options.pollerCount), + TaskPermit: worker.NewResizablePermit(options.maxConcurrentTask), + } + var pollerAS *pollerAutoScaler if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled { pollerAS = newPollerScaler( pollerOptions, logger, + concurrency.PollerPermit, ) } @@ -182,7 +189,7 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t retrier: backoff.NewConcurrentRetrier(pollOperationRetryPolicy), logger: logger.With(zapcore.Field{Key: tagWorkerType, Type: zapcore.StringType, String: options.workerType}), metricsScope: tagScope(metricsScope, tagWorkerType, options.workerType), - pollerRequestCh: make(chan struct{}, options.maxConcurrentTask), + concurrency: concurrency, pollerAutoScaler: pollerAS, taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched. limiterContext: ctx, @@ -241,14 +248,19 @@ func (bw *baseWorker) runPoller() { bw.metricsScope.Counter(metrics.PollerStartCounter).Inc(1) for { + permitChannel, channelDone := bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext) select { case <-bw.shutdownCh: + channelDone() return - case <-bw.pollerRequestCh: - bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(cap(bw.pollerRequestCh))) - // This metric is used to monitor how many poll requests have been allocated - // and can be used to approximate number of concurrent task running (not pinpoint accurate) - bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(cap(bw.pollerRequestCh) - len(bw.pollerRequestCh))) + case <-permitChannel: // don't poll unless there is a task permit + channelDone() + // TODO move to a centralized place inside the worker + // emit metrics on concurrent task permit quota and current task permit count + // NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process + // thus the metrics is only an estimated value of how many tasks are running concurrently + bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.concurrency.TaskPermit.Quota())) + bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.concurrency.TaskPermit.Count())) if bw.sessionTokenBucket != nil { bw.sessionTokenBucket.waitForAvailableToken() } @@ -260,10 +272,6 @@ func (bw *baseWorker) runPoller() { func (bw *baseWorker) runTaskDispatcher() { defer bw.shutdownWG.Done() - for i := 0; i < bw.options.maxConcurrentTask; i++ { - bw.pollerRequestCh <- struct{}{} - } - for { // wait for new task or shutdown select { @@ -294,10 +302,10 @@ func (bw *baseWorker) pollTask() { var task interface{} if bw.pollerAutoScaler != nil { - if pErr := bw.pollerAutoScaler.Acquire(1); pErr == nil { - defer bw.pollerAutoScaler.Release(1) + if pErr := bw.concurrency.PollerPermit.Acquire(bw.limiterContext); pErr == nil { + defer bw.concurrency.PollerPermit.Release() } else { - bw.logger.Warn("poller auto scaler acquire error", zap.Error(pErr)) + bw.logger.Warn("poller permit acquire error", zap.Error(pErr)) } } @@ -333,7 +341,7 @@ func (bw *baseWorker) pollTask() { case <-bw.shutdownCh: } } else { - bw.pollerRequestCh <- struct{}{} // poll failed, trigger a new poll + bw.concurrency.TaskPermit.Release() // poll failed, trigger a new poll by returning a task permit } } @@ -368,7 +376,7 @@ func (bw *baseWorker) processTask(task interface{}) { } if isPolledTask { - bw.pollerRequestCh <- struct{}{} + bw.concurrency.TaskPermit.Release() // task processed, trigger a new poll by returning a task permit } }() err := bw.options.taskWorker.ProcessTask(task) diff --git a/internal/worker/concurrency.go b/internal/worker/concurrency.go new file mode 100644 index 000000000..8d0771b91 --- /dev/null +++ b/internal/worker/concurrency.go @@ -0,0 +1,41 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import "context" + +var _ Permit = (*resizablePermit)(nil) + +// ConcurrencyLimit contains synchronization primitives for dynamically controlling the concurrencies in workers +type ConcurrencyLimit struct { + PollerPermit Permit // controls concurrency of pollers + TaskPermit Permit // controls concurrency of task processing +} + +// Permit is an adaptive permit issuer to control concurrency +type Permit interface { + Acquire(context.Context) error + AcquireChan(context.Context) (channel <-chan struct{}, done func()) + Count() int + Quota() int + Release() + SetQuota(int) +} diff --git a/internal/worker/resizable_permit.go b/internal/worker/resizable_permit.go new file mode 100644 index 000000000..de785bfce --- /dev/null +++ b/internal/worker/resizable_permit.go @@ -0,0 +1,120 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "fmt" + "sync" + + "github.com/marusama/semaphore/v2" +) + +type resizablePermit struct { + sem semaphore.Semaphore +} + +// NewResizablePermit creates a dynamic permit that's resizable +func NewResizablePermit(initCount int) Permit { + return &resizablePermit{sem: semaphore.New(initCount)} +} + +// Acquire is blocking until a permit is acquired or returns error after context is done +// Remember to call Release(count) to release the permit after usage +func (p *resizablePermit) Acquire(ctx context.Context) error { + if err := p.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire permit before context is done: %w", err) + } + return nil +} + +// Release release one permit +func (p *resizablePermit) Release() { + p.sem.Release(1) +} + +// Quota returns the maximum number of permits that can be acquired +func (p *resizablePermit) Quota() int { + return p.sem.GetLimit() +} + +// SetQuota sets the maximum number of permits that can be acquired +func (p *resizablePermit) SetQuota(c int) { + p.sem.SetLimit(c) +} + +// Count returns the number of permits available +func (p *resizablePermit) Count() int { + return p.sem.GetCount() +} + +// AcquireChan returns a channel that could be used to wait for the permit and a close function when done +// Notes: +// 1. avoid goroutine leak by calling the done function +// 2. if the read succeeded, release permit by calling permit.Release() +func (p *resizablePermit) AcquireChan(ctx context.Context) (<-chan struct{}, func()) { + ctx, cancel := context.WithCancel(ctx) + pc := &permitChannel{ + p: p, + c: make(chan struct{}), + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + } + pc.Start() + return pc.C(), func() { + pc.Close() + } +} + +// permitChannel is an implementation to acquire a permit through channel +type permitChannel struct { + p Permit + c chan struct{} + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +func (ch *permitChannel) C() <-chan struct{} { + return ch.c +} + +func (ch *permitChannel) Start() { + ch.wg.Add(1) + go func() { + defer ch.wg.Done() + if err := ch.p.Acquire(ch.ctx); err != nil { + return + } + // avoid blocking on sending to the channel + select { + case ch.c <- struct{}{}: + case <-ch.ctx.Done(): // release if acquire is successful but fail to send to the channel + ch.p.Release() + } + }() +} + +func (ch *permitChannel) Close() { + ch.cancel() + ch.wg.Wait() +} diff --git a/internal/worker/resizable_permit_test.go b/internal/worker/resizable_permit_test.go new file mode 100644 index 000000000..141510bde --- /dev/null +++ b/internal/worker/resizable_permit_test.go @@ -0,0 +1,283 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "sync" + "testing" + "time" + + "math/rand" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "go.uber.org/goleak" +) + +func TestPermit_Simulation(t *testing.T) { + tests := []struct { + name string + capacity []int // update every 50ms + goroutines int // each would block on acquiring 1 token for 100-150ms + maxTestDuration time.Duration + expectFailuresRange []int // range of failures, inclusive [min, max] + }{ + { + name: "enough permit, no blocking", + maxTestDuration: 200 * time.Millisecond, // at most need 150 ms, add 50 ms buffer + capacity: []int{10000}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "not enough permit, blocking but all acquire", + maxTestDuration: 800 * time.Millisecond, // at most need 150ms * 1000 / 200 = 750ms to acquire all permit + capacity: []int{200}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "not enough permit for some to acquire, fail some", + maxTestDuration: 250 * time.Millisecond, // at least need 100ms * 1000 / 200 = 500ms to acquire all permit + capacity: []int{200}, + goroutines: 1000, + expectFailuresRange: []int{400, 600}, // should at least pass some acquires + }, + { + name: "not enough permit at beginning but due to capacity change, blocking but all acquire", + maxTestDuration: 250 * time.Millisecond, + capacity: []int{200, 400, 600}, + goroutines: 1000, + expectFailuresRange: []int{0, 0}, + }, + { + name: "enough permit at beginning but due to capacity change, some would fail", + maxTestDuration: 250 * time.Millisecond, + capacity: []int{600, 400, 200}, + goroutines: 1000, + expectFailuresRange: []int{1, 500}, // the worst case with 200 capacity will at least pass 500 acquires + }, + { + name: "not enough permit for any acquire, fail all", + maxTestDuration: 300 * time.Millisecond, + capacity: []int{0}, + goroutines: 1000, + expectFailuresRange: []int{1000, 1000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer goleak.VerifyNone(t) + wg := &sync.WaitGroup{} + permit := NewResizablePermit(tt.capacity[0]) + wg.Add(1) + go func() { // update quota every 50ms + defer wg.Done() + for i := 1; i < len(tt.capacity); i++ { + time.Sleep(50 * time.Millisecond) + permit.SetQuota(tt.capacity[i]) + } + }() + failures := atomic.NewInt32(0) + ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration) + defer cancel() + + aquireChan := tt.goroutines / 2 + for i := 0; i < tt.goroutines-aquireChan; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := permit.Acquire(ctx); err != nil { + failures.Inc() + return + } + time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) + permit.Release() + }() + } + for i := 0; i < aquireChan; i++ { + wg.Add(1) + go func() { + defer wg.Done() + permitChan, done := permit.AcquireChan(ctx) + select { + case <-permitChan: + time.Sleep(time.Duration(100+rand.Intn(50)) * time.Millisecond) + permit.Release() + case <-ctx.Done(): + failures.Inc() + } + done() + }() + } + + wg.Wait() + // sanity check + assert.Equal(t, 0, permit.Count(), "all permit should be released") + assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota()) + + // expect failures in range + expectFailureMin := tt.expectFailuresRange[0] + expectFailureMax := tt.expectFailuresRange[1] + assert.GreaterOrEqual(t, int(failures.Load()), expectFailureMin) + assert.LessOrEqual(t, int(failures.Load()), expectFailureMax) + }) + } +} + +// Test_Permit_Acquire tests the basic acquire functionality +// before each acquire will wait 100ms +func Test_Permit_Acquire(t *testing.T) { + + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + err := permit.Acquire(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, permit.Count()) + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "context deadline exceeded") + assert.Empty(t, permit.Count()) + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := permit.Acquire(ctx) + assert.ErrorContains(t, err, "canceled") + assert.Empty(t, permit.Count()) + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err := permit.Acquire(ctx) + assert.NoError(t, err) + err = permit.Acquire(ctx) + assert.ErrorContains(t, err, "failed to acquire permit") + assert.Equal(t, 1, permit.Count()) + }) +} + +func Test_Permit_Release(t *testing.T) { + for _, tt := range []struct { + name string + quota, acquire, release int + expectPanic bool + }{ + {"release all acquired permits", 10, 5, 5, false}, + {"release partial acquired permit", 10, 5, 1, false}, + {"release non acquired permit", 10, 5, 0, false}, + {"release more than acquired permit", 10, 5, 10, true}, + } { + t.Run(tt.name, func(t *testing.T) { + permit := NewResizablePermit(tt.quota) + for i := 0; i < tt.acquire; i++ { + err := permit.Acquire(context.Background()) + assert.NoError(t, err) + } + releaseOp := func() { + for i := 0; i < tt.release; i++ { + permit.Release() + } + } + + if tt.expectPanic { + assert.Panics(t, releaseOp) + } else { + assert.NotPanics(t, releaseOp) + assert.Equal(t, tt.acquire-tt.release, permit.Count()) + } + }) + } +} + +func Test_Permit_AcquireChan(t *testing.T) { + t.Run("acquire 1 permit", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + assert.Equal(t, 1, permit.Count()) + case <-ctx.Done(): + t.Errorf("permit not acquired") + } + }) + + t.Run("acquire timeout", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + time.Sleep(100 * time.Millisecond) + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("cancel acquire", func(t *testing.T) { + permit := NewResizablePermit(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + channel, done := permit.AcquireChan(ctx) + defer done() + select { + case <-channel: + t.Errorf("permit acquired") + case <-ctx.Done(): + assert.Empty(t, permit.Count()) + } + }) + + t.Run("acquire more than quota", func(t *testing.T) { + permit := NewResizablePermit(4) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + for i := 0; i < 10; i++ { + channel, done := permit.AcquireChan(ctx) + select { + case <-channel: + case <-ctx.Done(): + } + done() + } + + assert.Equal(t, 4, permit.Count()) + }) +}