From c4be9ab692a8d7dca67a06fe62fc6f9735cff8c2 Mon Sep 17 00:00:00 2001 From: Lucas Rodriguez Date: Sun, 29 Sep 2024 13:22:32 -0500 Subject: [PATCH] Use switch statement and errgroup in rate limit test Signed-off-by: Lucas Rodriguez --- go.mod | 2 +- internal/http/ratelimit_test.go | 30 ++++++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index 1efa745..6d8333d 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( cloud.google.com/go/storage v1.43.0 github.com/google/go-github/v64 v64.0.0 github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.7.0 golang.org/x/time v0.6.0 ) @@ -36,7 +37,6 @@ require ( golang.org/x/crypto v0.24.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/api v0.187.0 // indirect diff --git a/internal/http/ratelimit_test.go b/internal/http/ratelimit_test.go index e620c0b..5630928 100644 --- a/internal/http/ratelimit_test.go +++ b/internal/http/ratelimit_test.go @@ -1,6 +1,7 @@ package http import ( + "fmt" "net/http" "net/http/httptest" "sync" @@ -8,6 +9,7 @@ import ( "github.com/lucasrod16/oss-contribute/internal/cache" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) func TestRateLimiter(t *testing.T) { @@ -18,31 +20,35 @@ func TestRateLimiter(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/repos", nil) req.Header.Set("X-Forwarded-For", "192.168.1.1") - var wg sync.WaitGroup + var g errgroup.Group var mu sync.Mutex + successCount := 0 failCount := 0 // send 20 requests concurrently to trigger the rate limit for i := 0; i < 20; i++ { - wg.Add(1) - go func() { - defer wg.Done() + g.Go(func() error { rr := httptest.NewRecorder() rl.Limit(GetRepos(c)).ServeHTTP(rr, req) - if rr.Code == http.StatusOK { - mu.Lock() + mu.Lock() + defer mu.Unlock() + + switch rr.Code { + case http.StatusOK: successCount++ - mu.Unlock() - } else if rr.Code == http.StatusTooManyRequests { - mu.Lock() + case http.StatusTooManyRequests: failCount++ - mu.Unlock() + default: + return fmt.Errorf("unexpected status code: %d", rr.Code) } - }() + return nil + }) + } + if err := g.Wait(); err != nil { + t.Fatal(err) } - wg.Wait() require.Equal(t, 10, successCount, "Expected 10 successful requests") require.Equal(t, 10, failCount, "Expected 10 failed requests")