diff --git a/tests/worker_group_integration_test.go b/tests/worker_group_integration_test.go index 0801e51..fa2e2b4 100644 --- a/tests/worker_group_integration_test.go +++ b/tests/worker_group_integration_test.go @@ -18,15 +18,17 @@ func TestSendsRequestsWithSingleConnection(t *testing.T) { defer server.stop() concurrency, totalRequests := uint(10), uint(20) - loadGenerationResponseChannel := workers. - NewWorkerGroup( - workers.NewGroupOptions( - concurrency, - totalRequests, - []byte("HelloWorld"), - "localhost:8080", - ), - ).Run() + + workerGroup := workers.NewWorkerGroup(workers.NewGroupOptions( + concurrency, + totalRequests, + []byte("HelloWorld"), + "localhost:8080", + )) + loadGenerationResponseChannel := workerGroup.Run() + + workerGroup.WaitTillDone() + close(loadGenerationResponseChannel) totalRequestsSent := 0 for response := range loadGenerationResponseChannel { @@ -46,16 +48,17 @@ func TestSendsRequestsWithMultipleConnections(t *testing.T) { defer server.stop() concurrency, connections, totalRequests := uint(20), uint(10), uint(40) - loadGenerationResponseChannel := workers. - NewWorkerGroup( - workers.NewGroupOptionsWithConnections( - concurrency, - connections, - totalRequests, - []byte("HelloWorld"), - "localhost:8081", - ), - ).Run() + workerGroup := workers.NewWorkerGroup(workers.NewGroupOptionsWithConnections( + concurrency, + connections, + totalRequests, + []byte("HelloWorld"), + "localhost:8081", + )) + loadGenerationResponseChannel := workerGroup.Run() + + workerGroup.WaitTillDone() + close(loadGenerationResponseChannel) for response := range loadGenerationResponseChannel { assert.Nil(t, response.Err) @@ -78,18 +81,18 @@ func TestSendsARequestAndReadsResponseWithSingleConnection(t *testing.T) { close(responseChannel) }() - loadGenerationResponseChannel := workers.NewWorkerGroupWithResponseReader( + workerGroup := workers.NewWorkerGroupWithResponseReader( workers.NewGroupOptions( concurrency, totalRequests, []byte("HelloWorld"), "localhost:8082", - ), - report.NewResponseReader( - responseSizeBytes, - responseChannel, - ), - ).Run() + ), report.NewResponseReader(responseSizeBytes, responseChannel), + ) + loadGenerationResponseChannel := workerGroup.Run() + + workerGroup.WaitTillDone() + close(loadGenerationResponseChannel) for response := range loadGenerationResponseChannel { assert.Nil(t, response.Err) @@ -112,15 +115,17 @@ func TestSendsAdditionalRequestsThanConfiguredWithSingleConnection(t *testing.T) defer server.stop() concurrency, totalRequests := uint(6), uint(20) - loadGenerationResponseChannel := workers. - NewWorkerGroup( - workers.NewGroupOptions( - concurrency, - totalRequests, - []byte("HelloWorld"), - "localhost:8083", - ), - ).Run() + + workerGroup := workers.NewWorkerGroup(workers.NewGroupOptions( + concurrency, + totalRequests, + []byte("HelloWorld"), + "localhost:8083", + )) + loadGenerationResponseChannel := workerGroup.Run() + + workerGroup.WaitTillDone() + close(loadGenerationResponseChannel) totalRequestsSent := 0 for response := range loadGenerationResponseChannel { diff --git a/workers/worker.go b/workers/worker.go index 37e42b7..bf1871d 100644 --- a/workers/worker.go +++ b/workers/worker.go @@ -36,6 +36,7 @@ func (worker Worker) sendRequests() { if worker.options.requestsPerSecond > 0 { <-throttle } + worker.sendRequest() } } diff --git a/workers/worker_group.go b/workers/worker_group.go index ea8340c..c958bbc 100644 --- a/workers/worker_group.go +++ b/workers/worker_group.go @@ -10,6 +10,7 @@ import ( type WorkerGroup struct { options GroupOptions stopChannel chan struct{} + doneChannel chan struct{} responseReader *report.ResponseReader } @@ -24,11 +25,15 @@ func NewWorkerGroupWithResponseReader( return &WorkerGroup{ options: options, stopChannel: make(chan struct{}, options.concurrency), + doneChannel: make(chan struct{}, 1), responseReader: responseReader, } } func (group *WorkerGroup) Run() chan report.LoadGenerationResponse { + if group.options.totalRequests%group.options.concurrency != 0 { + group.options.totalRequests = ((group.options.totalRequests / group.options.concurrency) + 1) * group.options.concurrency + } loadGenerationResponseChannel := make( chan report.LoadGenerationResponse, group.options.totalRequests, @@ -36,7 +41,8 @@ func (group *WorkerGroup) Run() chan report.LoadGenerationResponse { go func() { group.runWorkers(loadGenerationResponseChannel) - group.finish(loadGenerationResponseChannel) + group.WaitTillDone() + return }() return loadGenerationResponseChannel } @@ -72,10 +78,11 @@ func (group *WorkerGroup) runWorkers( group.runWorker(connection, &wg, loadGenerationResponseChannel) } wg.Wait() + group.doneChannel <- struct{}{} } -func (group *WorkerGroup) finish(loadGenerationResponseChannel chan report.LoadGenerationResponse) { - close(loadGenerationResponseChannel) +func (group *WorkerGroup) WaitTillDone() { + <-group.doneChannel } func (group *WorkerGroup) newConnection() (net.Conn, error) { @@ -92,9 +99,6 @@ func (group *WorkerGroup) runWorker( loadGenerationResponseChannel chan report.LoadGenerationResponse, ) { totalRequests := group.options.totalRequests - if group.options.totalRequests%group.options.concurrency != 0 { - totalRequests = ((group.options.totalRequests / group.options.concurrency) + 1) * group.options.concurrency - } Worker{ connection: connection, options: WorkerOptions{