diff --git a/blast/blast.go b/blast/blast.go new file mode 100644 index 0000000..2fbd5c4 --- /dev/null +++ b/blast/blast.go @@ -0,0 +1,97 @@ +package blast + +import ( + "io" + "os" + "time" + + "blast/report" + "blast/workers" +) + +var OutputStream io.Writer = os.Stdout + +type ResponseOptions struct { + responsePayloadSizeBytes int64 + totalResponsesToRead uint + totalSuccessfulResponsesToRead uint +} + +type Blast struct { + reporter *report.Reporter + groupOptions workers.GroupOptions + workerGroup *workers.WorkerGroup + loadGenerationResponseChannel chan report.LoadGenerationResponse + doneChannel chan struct{} + maxRunDuration time.Duration +} + +func NewBlastWithoutResponseReading( + workerGroupOptions workers.GroupOptions, + maxRunDuration time.Duration, +) { + startLoad := func() (*workers.WorkerGroup, chan report.LoadGenerationResponse) { + workerGroup := workers.NewWorkerGroup(workerGroupOptions) + return workerGroup, workerGroup.Run() + } + + startReporter := func(loadGenerationResponseChannel chan report.LoadGenerationResponse) *report.Reporter { + reporter := report. + NewLoadGenerationMetricsCollectingReporter(loadGenerationResponseChannel) + + reporter.Run() + return reporter + } + + setUpBlast := func() Blast { + workerGroup, loadGenerationResponseChannel := startLoad() + reporter := startReporter(loadGenerationResponseChannel) + + return Blast{ + reporter: reporter, + groupOptions: workerGroupOptions, + workerGroup: workerGroup, + loadGenerationResponseChannel: loadGenerationResponseChannel, + doneChannel: make(chan struct{}), + maxRunDuration: maxRunDuration, + } + } + + blast := setUpBlast() + blast.start() + + <-blast.doneChannel + blast.reporter.PrintReport(OutputStream) +} + +func (blast Blast) start() { + loadReportedInspectionTimer := time.NewTicker(5 * time.Millisecond) + maxRunTimer := time.NewTimer(blast.maxRunDuration) + + go func() { + stopAll := func() { + blast.workerGroup.Close() + loadReportedInspectionTimer.Stop() + maxRunTimer.Stop() + close(blast.loadGenerationResponseChannel) + close(blast.doneChannel) + } + + for { + select { + case <-blast.workerGroup.DoneChannel(): + println("load completed") + case <-loadReportedInspectionTimer.C: + if blast.reporter.TotalLoadReportedTillNow() >= uint64( + blast.groupOptions.TotalRequests(), + ) { + stopAll() + return + } + case <-maxRunTimer.C: + stopAll() + return + } + } + }() +} diff --git a/report/reporter.go b/report/reporter.go index f5a801d..0efd27d 100644 --- a/report/reporter.go +++ b/report/reporter.go @@ -124,7 +124,7 @@ func (reporter *Reporter) collectLoadMetrics() { } } startTime := reporter.report.Load.EarliestLoadSendTime - timeToCompleteLoad := time.Now().Sub(startTime) + timeToCompleteLoad := reporter.report.Load.LatestLoadSendTime.Sub(startTime) reporter.report.Load.TotalTime = timeToCompleteLoad reporter.report.Load.TotalRequests = totalGeneratedLoad @@ -162,11 +162,11 @@ func (reporter *Reporter) collectResponseMetrics() { ) { reporter.report.Response.LatestResponseReceivedTime = response.ResponseTime } - - timeToCompleteResponses := time.Now(). - Sub(reporter.report.Response.EarliestResponseReceivedTime) - reporter.report.Response.TotalTime = timeToCompleteResponses } reporter.report.Response.TotalResponses = uint(totalResponses) + + timeToCompleteResponses := reporter.report.Response.LatestResponseReceivedTime. + Sub(reporter.report.Response.EarliestResponseReceivedTime) + reporter.report.Response.TotalTime = timeToCompleteResponses }() } diff --git a/tests/blast_integration_test.go b/tests/blast_integration_test.go new file mode 100644 index 0000000..52a2f49 --- /dev/null +++ b/tests/blast_integration_test.go @@ -0,0 +1,78 @@ +package tests + +import ( + "bytes" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "blast/blast" + "blast/workers" +) + +func TestBlastWithLoadGeneration(t *testing.T) { + payloadSizeBytes := int64(10) + server, err := NewEchoServer("tcp", "localhost:10001", payloadSizeBytes) + assert.Nil(t, err) + + server.accept(t) + defer server.stop() + + concurrency, totalRequests := uint(10), uint(20) + + groupOptions := workers.NewGroupOptions( + concurrency, + totalRequests, + []byte("HelloWorld"), + "localhost:10001", + ) + buffer := &bytes.Buffer{} + blast.OutputStream = buffer + blast.NewBlastWithoutResponseReading(groupOptions, 5*time.Minute) + + output := string(buffer.Bytes()) + assert.True(t, strings.Contains(output, "TotalRequests: 20")) + assert.True(t, strings.Contains(output, "SuccessCount: 20")) + assert.True(t, strings.Contains(output, "ErrorCount: 0")) + assert.True(t, strings.Contains(output, "TotalPayloadSize: 200 B")) + assert.True(t, strings.Contains(output, "AveragePayloadSize: 10 B")) +} + +func TestBlastWithLoadGenerationForMaximumDuration(t *testing.T) { + payloadSizeBytes := int64(10) + server, err := NewEchoServer("tcp", "localhost:10002", payloadSizeBytes) + assert.Nil(t, err) + + server.accept(t) + defer server.stop() + + concurrency, totalRequests := uint(1000), uint(2_00_000) + + groupOptions := workers.NewGroupOptionsWithConnections( + concurrency, + 10, + totalRequests, + []byte("HelloWorld"), + "localhost:10002", + ) + buffer := &bytes.Buffer{} + blast.OutputStream = buffer + blast.NewBlastWithoutResponseReading(groupOptions, 10*time.Millisecond) + + output := string(buffer.Bytes()) + assert.True(t, strings.Contains(output, "TotalRequests")) + assert.True(t, strings.Contains(output, "ErrorCount: 0")) + + regexp := regexp.MustCompile("TotalRequests.*") + totalRequestsString := regexp.Find(buffer.Bytes()) + totalRequestsMade, _ := strconv.Atoi(strings.Trim( + strings.ReplaceAll(string(totalRequestsString), "TotalRequests:", ""), + " ", + )) + + assert.True(t, totalRequestsMade < 2_00_000) +} diff --git a/workers/options.go b/workers/options.go index 9c48cd2..0cf35dd 100644 --- a/workers/options.go +++ b/workers/options.go @@ -72,3 +72,7 @@ func NewGroupOptionsFullyLoaded( requestsPerSecond: requestsPerSecond, } } + +func (groupOptions GroupOptions) TotalRequests() uint { + return groupOptions.totalRequests +} diff --git a/workers/worker.go b/workers/worker.go index d2e254d..9d88e90 100644 --- a/workers/worker.go +++ b/workers/worker.go @@ -46,6 +46,9 @@ func (worker Worker) sendRequests() { } func (worker Worker) sendRequest() { + defer func() { + _ = recover() + }() if worker.connection != nil { _, err := worker.connection.Write(worker.options.payload) worker.options.loadGenerationResponse <- report.LoadGenerationResponse{ diff --git a/workers/worker_group.go b/workers/worker_group.go index 8b4dc4c..b843804 100644 --- a/workers/worker_group.go +++ b/workers/worker_group.go @@ -81,6 +81,10 @@ func (group *WorkerGroup) WaitTillDone() { <-group.doneChannel } +func (group *WorkerGroup) DoneChannel() chan struct{} { + return group.doneChannel +} + func (group *WorkerGroup) newConnection() (net.Conn, error) { connection, err := net.Dial("tcp", group.options.targetAddress) if err != nil {