Skip to content

Commit

Permalink
Add rate limiting and unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Rodriguez <[email protected]>
  • Loading branch information
lucasrod16 committed Sep 29, 2024
1 parent 77e5632 commit 1080464
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ jobs:
- name: Checkout
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0

- name: Run Tests
run: make test-unit

- name: Google Cloud Auth
uses: google-github-actions/auth@62cf5bd3e4211a0a0b51f2c6d6a37129d828611d # v2.1.5
with:
Expand Down
20 changes: 20 additions & 0 deletions .github/workflows/test-unit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Run Unit Tests

on:
push:
branches:
- main
pull_request:

permissions:
contents: read

jobs:
unit-tests:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0

- name: Run Unit Tests
run: make test-unit
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
.PHONY: build-dev
build-dev:
npm --prefix ui run build && CGO_ENABLED=0 go build -o ./bin/api

.PHONY: test-unit
test-unit:
go test -race -v -count=1 -failfast ./internal/...
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ go 1.23.1
require (
cloud.google.com/go/storage v1.43.0
github.com/google/go-github/v64 v64.0.0
github.com/stretchr/testify v1.9.0
golang.org/x/time v0.6.0
)

require (
Expand All @@ -13,6 +15,7 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
cloud.google.com/go/iam v1.1.8 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
Expand All @@ -23,6 +26,7 @@ require (
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
Expand All @@ -35,11 +39,11 @@ require (
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/api v0.187.0 // indirect
google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
google.golang.org/grpc v1.64.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
5 changes: 3 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
Expand Down Expand Up @@ -169,6 +169,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
4 changes: 2 additions & 2 deletions pkg/cache.go → internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package osscontribute
package cache

import (
"context"
Expand All @@ -22,7 +22,7 @@ type Cache struct {
mu sync.RWMutex
}

func NewCache() *Cache {
func New() *Cache {
return &Cache{}
}

Expand Down
43 changes: 43 additions & 0 deletions internal/cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package cache

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestCache(t *testing.T) {
tests := []struct {
name string
setData []byte
expected []byte
}{
{
name: "Set and Get data",
setData: []byte("test data"),
expected: []byte("test data"),
},
{
name: "Set and Get empty data",
setData: []byte(""),
expected: []byte(""),
},
{
name: "Set and Get nil data",
setData: nil,
expected: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := New()

cache.Set(tt.setData)
data, timestamp := cache.Get()

require.Equal(t, tt.expected, data)
require.False(t, timestamp.IsZero(), "timestamp should not be empty")
})
}
}
11 changes: 6 additions & 5 deletions pkg/handler.go → internal/http/handler.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
package osscontribute
package http

import (
"net/http"

"github.com/lucasrod16/oss-contribute/internal/cache"
)

func GetRepos(c *Cache) http.HandlerFunc {
func GetRepos(c *cache.Cache) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "405 Method Not Allowed\n", http.StatusMethodNotAllowed)
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}

data, timestamp := c.Get()
if data == nil {
http.Error(w, "No data found in cache\n", http.StatusInternalServerError)
http.Error(w, "No data found in cache", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Last-Modified", timestamp.Format(http.TimeFormat))
w.WriteHeader(http.StatusOK)
w.Write(data)
w.Write([]byte("\n"))
}
}
58 changes: 58 additions & 0 deletions internal/http/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package http

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/lucasrod16/oss-contribute/internal/cache"
"github.com/stretchr/testify/require"
)

func TestGetRepos(t *testing.T) {
tests := []struct {
name string
method string
expectedStatus int
expectedBody string
cacheData []byte
}{
{
name: "valid request",
method: http.MethodGet,
expectedStatus: http.StatusOK,
expectedBody: `{"data": "some data"}`,
cacheData: []byte(`{"data": "some data"}`),
},
{
name: "method not allowed",
method: http.MethodPost,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: http.StatusText(http.StatusMethodNotAllowed) + "\n",
},
{
name: "no data found in cache",
method: http.MethodGet,
expectedStatus: http.StatusInternalServerError,
expectedBody: "No data found in cache\n",
cacheData: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cache.New()
if tt.cacheData != nil {
c.Set(tt.cacheData)
}

req := httptest.NewRequest(tt.method, "/repos", nil)
rr := httptest.NewRecorder()

GetRepos(c).ServeHTTP(rr, req)

require.Equal(t, tt.expectedStatus, rr.Code, "GetRepos handler returned wrong status code")
require.Equal(t, tt.expectedBody, rr.Body.String(), "GetRepos handler returned unexpected body")
})
}
}
98 changes: 98 additions & 0 deletions internal/http/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package http

import (
"log"
"net"
"net/http"
"strings"
"sync"
"time"

"golang.org/x/time/rate"
)

type client struct {
limiter *rate.Limiter
lastSeen time.Time
mu sync.Mutex
}

// RateLimiter holds rate limiters per client IP address.
type RateLimiter struct {
mu sync.Mutex
clients map[string]*client
}

func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
clients: make(map[string]*client),
}
go rl.cleanupStaleClients(10)
return rl
}

// Limit applies rate limiting to the given HTTP handler.
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cl := rl.getClientLimiter(getClientIP(r))

if !cl.limiter.Allow() {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}

cl.mu.Lock()
cl.lastSeen = time.Now()
cl.mu.Unlock()

next.ServeHTTP(w, r)
})
}

func (rl *RateLimiter) getClientLimiter(ip string) *client {
rl.mu.Lock()
defer rl.mu.Unlock()

cl, exists := rl.clients[ip]
if !exists {
cl = &client{
limiter: rate.NewLimiter(5, 10),
lastSeen: time.Now(),
}
rl.clients[ip] = cl
}
return cl
}

// cleanupStaleClients removes clients that haven't requested in the last specified duration to conserve memory.
func (rl *RateLimiter) cleanupStaleClients(minutes time.Duration) {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()

for range ticker.C {
rl.mu.Lock()
for ip, cl := range rl.clients {
if time.Since(cl.lastSeen) > minutes*time.Minute {
delete(rl.clients, ip)
}
}
rl.mu.Unlock()
}
}

// getClientIP extracts the client's IP address from the request.
func getClientIP(r *http.Request) string {
// handle "X-Forwarded-For" header used by proxies and load balancers.
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}

ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
log.Printf("could not determine client IP: %v\n", err)
return ""
}
return ip
}
49 changes: 49 additions & 0 deletions internal/http/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package http

import (
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/lucasrod16/oss-contribute/internal/cache"
"github.com/stretchr/testify/require"
)

func TestRateLimiter(t *testing.T) {
rl := NewRateLimiter()
c := cache.New()
c.Set([]byte(`{"data": "some data"}`))

req := httptest.NewRequest(http.MethodGet, "/repos", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.1")

var wg sync.WaitGroup
var mu sync.Mutex
successCount := 0
failCount := 0

// send 20 requests concurrently to trigger the rate limit
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rr := httptest.NewRecorder()
rl.Limit(GetRepos(c)).ServeHTTP(rr, req)

if rr.Code == http.StatusOK {
mu.Lock()
successCount++
mu.Unlock()
} else if rr.Code == http.StatusTooManyRequests {
mu.Lock()
failCount++
mu.Unlock()
}
}()
}
wg.Wait()

require.Equal(t, 10, successCount, "Expected 10 successful requests")
require.Equal(t, 10, failCount, "Expected 10 failed requests")
}
Loading

0 comments on commit 1080464

Please sign in to comment.