diff --git a/report/reporter.go b/report/reporter.go index 4b583c8..f8bd4a6 100644 --- a/report/reporter.go +++ b/report/reporter.go @@ -42,8 +42,8 @@ type Reporter struct { func NewLoadGenerationMetricsCollectingReporter( loadGenerationChannel chan LoadGenerationResponse, -) Reporter { - return Reporter{ +) *Reporter { + return &Reporter{ report: &Report{ Load: LoadMetrics{ ErrorCountByType: make(map[string]uint), @@ -53,14 +53,15 @@ func NewLoadGenerationMetricsCollectingReporter( }, }, loadGenerationChannel: loadGenerationChannel, + responseChannel: nil, } } func NewResponseMetricsCollectingReporter( loadGenerationChannel chan LoadGenerationResponse, responseChannel chan SubjectServerResponse, -) Reporter { - return Reporter{ +) *Reporter { + return &Reporter{ report: &Report{ Load: LoadMetrics{ ErrorCountByType: make(map[string]uint), diff --git a/report/response_reader.go b/report/response_reader.go index 5a5d89b..b742c43 100644 --- a/report/response_reader.go +++ b/report/response_reader.go @@ -4,6 +4,7 @@ import ( "errors" "io" "net" + "sync/atomic" "time" ) @@ -21,19 +22,23 @@ type SubjectServerResponse struct { } type ResponseReader struct { - responseSizeBytes int64 - stopChannel chan struct{} - responseChannel chan SubjectServerResponse + responseSizeBytes int64 + totalResponsesToRead uint32 + readResponses atomic.Uint32 + stopChannel chan struct{} + responseChannel chan SubjectServerResponse } func NewResponseReader( responseSizeBytes int64, + totalResponsesToRead uint, responseChannel chan SubjectServerResponse, ) *ResponseReader { return &ResponseReader{ - responseSizeBytes: responseSizeBytes, - stopChannel: make(chan struct{}), - responseChannel: responseChannel, + responseSizeBytes: responseSizeBytes, + totalResponsesToRead: uint32(totalResponsesToRead), + stopChannel: make(chan struct{}), + responseChannel: responseChannel, } } @@ -69,6 +74,7 @@ func (responseReader *ResponseReader) StartReading(connection net.Conn) { PayloadLengthBytes: int64(len(buffer)), } } + responseReader.readResponses.Add(1) } } }(connection) @@ -77,3 +83,7 @@ func (responseReader *ResponseReader) StartReading(connection net.Conn) { func (responseReader *ResponseReader) close() { close(responseReader.stopChannel) } + +func (responseReader *ResponseReader) TotalResponsesRead() uint32 { + return responseReader.readResponses.Load() +} diff --git a/tests/response_reader_integration_test.go b/tests/response_reader_integration_test.go index 8eb33fb..1b3bb0f 100644 --- a/tests/response_reader_integration_test.go +++ b/tests/response_reader_integration_test.go @@ -32,6 +32,7 @@ func TestReadsResponseFromASingleConnection(t *testing.T) { responseReader := report.NewResponseReader( payloadSizeBytes, + 1, responseChannel, ) responseReader.StartReading(connection) @@ -65,6 +66,7 @@ func TestReadsResponseFromTwoConnections(t *testing.T) { responseReader := report.NewResponseReader( payloadSizeBytes, + 2, responseChannel, ) responseReader.StartReading(connection) @@ -75,6 +77,35 @@ func TestReadsResponseFromTwoConnections(t *testing.T) { assert.Equal(t, []byte("HelloWorld"), responses[1]) } +func TestTracksTheNumberOfResponsesRead(t *testing.T) { + payloadSizeBytes := int64(10) + server, err := NewEchoServer("tcp", "localhost:9092", payloadSizeBytes) + assert.Nil(t, err) + + server.accept(t) + + connection := connectTo(t, "localhost:9092") + writeTo(t, connection, []byte("HelloWorld")) + + responseChannel := make(chan report.SubjectServerResponse) + + defer func() { + server.stop() + close(responseChannel) + _ = connection.Close() + }() + + responseReader := report.NewResponseReader( + payloadSizeBytes, + 1, + responseChannel, + ) + responseReader.StartReading(connection) + + _ = <-responseChannel + assert.Equal(t, uint32(1), responseReader.TotalResponsesRead()) +} + func connectTo(t *testing.T, address string) net.Conn { connection, err := net.Dial("tcp", address) assert.Nil(t, err) diff --git a/tests/worker_group_integration_test.go b/tests/worker_group_integration_test.go index 3bd2ed8..f938711 100644 --- a/tests/worker_group_integration_test.go +++ b/tests/worker_group_integration_test.go @@ -84,6 +84,7 @@ func TestSendsARequestAndReadsResponseWithSingleConnection(t *testing.T) { ), report.NewResponseReader( responseSizeBytes, + totalRequests, responseChannel, ), ).Run()