Skip to content

Commit

Permalink
Sarthak | Refactors WorkerGroup, loadGenerationResponseChannel is not…
Browse files Browse the repository at this point in the history
… closed after the workers are done, it will now be closed from an external client. The main goroutine in the Run method still waits for all the workers to finish using a done channel
  • Loading branch information
SarthakMakhija committed Aug 18, 2023
1 parent faa79db commit e8733c3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
75 changes: 40 additions & 35 deletions tests/worker_group_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ func TestSendsRequestsWithSingleConnection(t *testing.T) {
defer server.stop()

concurrency, totalRequests := uint(10), uint(20)
loadGenerationResponseChannel := workers.
NewWorkerGroup(
workers.NewGroupOptions(
concurrency,
totalRequests,
[]byte("HelloWorld"),
"localhost:8080",
),
).Run()

workerGroup := workers.NewWorkerGroup(workers.NewGroupOptions(
concurrency,
totalRequests,
[]byte("HelloWorld"),
"localhost:8080",
))
loadGenerationResponseChannel := workerGroup.Run()

workerGroup.WaitTillDone()
close(loadGenerationResponseChannel)

totalRequestsSent := 0
for response := range loadGenerationResponseChannel {
Expand All @@ -46,16 +48,17 @@ func TestSendsRequestsWithMultipleConnections(t *testing.T) {
defer server.stop()

concurrency, connections, totalRequests := uint(20), uint(10), uint(40)
loadGenerationResponseChannel := workers.
NewWorkerGroup(
workers.NewGroupOptionsWithConnections(
concurrency,
connections,
totalRequests,
[]byte("HelloWorld"),
"localhost:8081",
),
).Run()
workerGroup := workers.NewWorkerGroup(workers.NewGroupOptionsWithConnections(
concurrency,
connections,
totalRequests,
[]byte("HelloWorld"),
"localhost:8081",
))
loadGenerationResponseChannel := workerGroup.Run()

workerGroup.WaitTillDone()
close(loadGenerationResponseChannel)

for response := range loadGenerationResponseChannel {
assert.Nil(t, response.Err)
Expand All @@ -78,18 +81,18 @@ func TestSendsARequestAndReadsResponseWithSingleConnection(t *testing.T) {
close(responseChannel)
}()

loadGenerationResponseChannel := workers.NewWorkerGroupWithResponseReader(
workerGroup := workers.NewWorkerGroupWithResponseReader(
workers.NewGroupOptions(
concurrency,
totalRequests,
[]byte("HelloWorld"),
"localhost:8082",
),
report.NewResponseReader(
responseSizeBytes,
responseChannel,
),
).Run()
), report.NewResponseReader(responseSizeBytes, responseChannel),
)
loadGenerationResponseChannel := workerGroup.Run()

workerGroup.WaitTillDone()
close(loadGenerationResponseChannel)

for response := range loadGenerationResponseChannel {
assert.Nil(t, response.Err)
Expand All @@ -112,15 +115,17 @@ func TestSendsAdditionalRequestsThanConfiguredWithSingleConnection(t *testing.T)
defer server.stop()

concurrency, totalRequests := uint(6), uint(20)
loadGenerationResponseChannel := workers.
NewWorkerGroup(
workers.NewGroupOptions(
concurrency,
totalRequests,
[]byte("HelloWorld"),
"localhost:8083",
),
).Run()

workerGroup := workers.NewWorkerGroup(workers.NewGroupOptions(
concurrency,
totalRequests,
[]byte("HelloWorld"),
"localhost:8083",
))
loadGenerationResponseChannel := workerGroup.Run()

workerGroup.WaitTillDone()
close(loadGenerationResponseChannel)

totalRequestsSent := 0
for response := range loadGenerationResponseChannel {
Expand Down
1 change: 1 addition & 0 deletions workers/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (worker Worker) sendRequests() {
if worker.options.requestsPerSecond > 0 {
<-throttle
}

worker.sendRequest()
}
}
Expand Down
16 changes: 10 additions & 6 deletions workers/worker_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type WorkerGroup struct {
options GroupOptions
stopChannel chan struct{}
doneChannel chan struct{}
responseReader *report.ResponseReader
}

Expand All @@ -24,19 +25,24 @@ func NewWorkerGroupWithResponseReader(
return &WorkerGroup{
options: options,
stopChannel: make(chan struct{}, options.concurrency),
doneChannel: make(chan struct{}, 1),
responseReader: responseReader,
}
}

func (group *WorkerGroup) Run() chan report.LoadGenerationResponse {
if group.options.totalRequests%group.options.concurrency != 0 {
group.options.totalRequests = ((group.options.totalRequests / group.options.concurrency) + 1) * group.options.concurrency
}
loadGenerationResponseChannel := make(
chan report.LoadGenerationResponse,
group.options.totalRequests,
)

go func() {
group.runWorkers(loadGenerationResponseChannel)
group.finish(loadGenerationResponseChannel)
group.WaitTillDone()
return
}()
return loadGenerationResponseChannel
}
Expand Down Expand Up @@ -72,10 +78,11 @@ func (group *WorkerGroup) runWorkers(
group.runWorker(connection, &wg, loadGenerationResponseChannel)
}
wg.Wait()
group.doneChannel <- struct{}{}
}

func (group *WorkerGroup) finish(loadGenerationResponseChannel chan report.LoadGenerationResponse) {
close(loadGenerationResponseChannel)
func (group *WorkerGroup) WaitTillDone() {
<-group.doneChannel
}

func (group *WorkerGroup) newConnection() (net.Conn, error) {
Expand All @@ -92,9 +99,6 @@ func (group *WorkerGroup) runWorker(
loadGenerationResponseChannel chan report.LoadGenerationResponse,
) {
totalRequests := group.options.totalRequests
if group.options.totalRequests%group.options.concurrency != 0 {
totalRequests = ((group.options.totalRequests / group.options.concurrency) + 1) * group.options.concurrency
}
Worker{
connection: connection,
options: WorkerOptions{
Expand Down

0 comments on commit e8733c3

Please sign in to comment.