-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from lucasrod16/rate-limiting-and-tests
Add rate limiting and unit tests
- Loading branch information
Showing
12 changed files
with
299 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
name: Run Unit Tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
unit-tests: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 | ||
|
||
- name: Run Unit Tests | ||
run: make test-unit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
.PHONY: build-dev | ||
build-dev: | ||
npm --prefix ui run build && CGO_ENABLED=0 go build -o ./bin/api | ||
|
||
.PHONY: test-unit | ||
test-unit: | ||
go test -race -v -count=1 -failfast ./internal/... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package cache | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestCache(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
setData []byte | ||
expected []byte | ||
}{ | ||
{ | ||
name: "Set and Get data", | ||
setData: []byte("test data"), | ||
expected: []byte("test data"), | ||
}, | ||
{ | ||
name: "Set and Get empty data", | ||
setData: []byte(""), | ||
expected: []byte(""), | ||
}, | ||
{ | ||
name: "Set and Get nil data", | ||
setData: nil, | ||
expected: nil, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
cache := New() | ||
|
||
cache.Set(tt.setData) | ||
data, timestamp := cache.Get() | ||
|
||
require.Equal(t, tt.expected, data) | ||
require.False(t, timestamp.IsZero(), "timestamp should not be empty") | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,27 @@ | ||
package osscontribute | ||
package http | ||
|
||
import ( | ||
"net/http" | ||
|
||
"github.com/lucasrod16/oss-contribute/internal/cache" | ||
) | ||
|
||
func GetRepos(c *Cache) http.HandlerFunc { | ||
func GetRepos(c *cache.Cache) http.HandlerFunc { | ||
return func(w http.ResponseWriter, r *http.Request) { | ||
if r.Method != http.MethodGet { | ||
http.Error(w, "405 Method Not Allowed\n", http.StatusMethodNotAllowed) | ||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) | ||
return | ||
} | ||
|
||
data, timestamp := c.Get() | ||
if data == nil { | ||
http.Error(w, "No data found in cache\n", http.StatusInternalServerError) | ||
http.Error(w, "No data found in cache", http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
w.Header().Set("Content-Type", "application/json") | ||
w.Header().Set("Last-Modified", timestamp.Format(http.TimeFormat)) | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(data) | ||
w.Write([]byte("\n")) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package http | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/lucasrod16/oss-contribute/internal/cache" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestGetRepos(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
method string | ||
expectedStatus int | ||
expectedBody string | ||
cacheData []byte | ||
}{ | ||
{ | ||
name: "valid request", | ||
method: http.MethodGet, | ||
expectedStatus: http.StatusOK, | ||
expectedBody: `{"data": "some data"}`, | ||
cacheData: []byte(`{"data": "some data"}`), | ||
}, | ||
{ | ||
name: "method not allowed", | ||
method: http.MethodPost, | ||
expectedStatus: http.StatusMethodNotAllowed, | ||
expectedBody: http.StatusText(http.StatusMethodNotAllowed) + "\n", | ||
}, | ||
{ | ||
name: "no data found in cache", | ||
method: http.MethodGet, | ||
expectedStatus: http.StatusInternalServerError, | ||
expectedBody: "No data found in cache\n", | ||
cacheData: nil, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
c := cache.New() | ||
if tt.cacheData != nil { | ||
c.Set(tt.cacheData) | ||
} | ||
|
||
req := httptest.NewRequest(tt.method, "/repos", nil) | ||
rr := httptest.NewRecorder() | ||
|
||
GetRepos(c).ServeHTTP(rr, req) | ||
|
||
require.Equal(t, tt.expectedStatus, rr.Code, "GetRepos handler returned wrong status code") | ||
require.Equal(t, tt.expectedBody, rr.Body.String(), "GetRepos handler returned unexpected body") | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package http | ||
|
||
import ( | ||
"log" | ||
"net" | ||
"net/http" | ||
"strings" | ||
"sync" | ||
"time" | ||
|
||
"golang.org/x/time/rate" | ||
) | ||
|
||
type client struct { | ||
limiter *rate.Limiter | ||
lastSeen time.Time | ||
mu sync.Mutex | ||
} | ||
|
||
// RateLimiter holds rate limiters per client IP address. | ||
type RateLimiter struct { | ||
mu sync.Mutex | ||
clients map[string]*client | ||
} | ||
|
||
func NewRateLimiter() *RateLimiter { | ||
rl := &RateLimiter{ | ||
clients: make(map[string]*client), | ||
} | ||
go rl.cleanupStaleClients(10) | ||
return rl | ||
} | ||
|
||
// Limit applies rate limiting to the given HTTP handler. | ||
func (rl *RateLimiter) Limit(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
cl := rl.getClientLimiter(getClientIP(r)) | ||
|
||
if !cl.limiter.Allow() { | ||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) | ||
return | ||
} | ||
|
||
cl.mu.Lock() | ||
cl.lastSeen = time.Now() | ||
cl.mu.Unlock() | ||
|
||
next.ServeHTTP(w, r) | ||
}) | ||
} | ||
|
||
func (rl *RateLimiter) getClientLimiter(ip string) *client { | ||
rl.mu.Lock() | ||
defer rl.mu.Unlock() | ||
|
||
cl, exists := rl.clients[ip] | ||
if !exists { | ||
cl = &client{ | ||
limiter: rate.NewLimiter(5, 10), | ||
lastSeen: time.Now(), | ||
} | ||
rl.clients[ip] = cl | ||
} | ||
return cl | ||
} | ||
|
||
// cleanupStaleClients removes clients that haven't requested in the last specified duration to conserve memory. | ||
func (rl *RateLimiter) cleanupStaleClients(minutes time.Duration) { | ||
ticker := time.NewTicker(time.Minute) | ||
defer ticker.Stop() | ||
|
||
for range ticker.C { | ||
rl.mu.Lock() | ||
for ip, cl := range rl.clients { | ||
if time.Since(cl.lastSeen) > minutes*time.Minute { | ||
delete(rl.clients, ip) | ||
} | ||
} | ||
rl.mu.Unlock() | ||
} | ||
} | ||
|
||
// getClientIP extracts the client's IP address from the request. | ||
func getClientIP(r *http.Request) string { | ||
// handle "X-Forwarded-For" header used by proxies and load balancers. | ||
xff := r.Header.Get("X-Forwarded-For") | ||
if xff != "" { | ||
ips := strings.Split(xff, ",") | ||
return strings.TrimSpace(ips[0]) | ||
} | ||
|
||
ip, _, err := net.SplitHostPort(r.RemoteAddr) | ||
if err != nil { | ||
log.Printf("could not determine client IP: %v\n", err) | ||
return "" | ||
} | ||
return ip | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package http | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"sync" | ||
"testing" | ||
|
||
"github.com/lucasrod16/oss-contribute/internal/cache" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestRateLimiter(t *testing.T) { | ||
rl := NewRateLimiter() | ||
c := cache.New() | ||
c.Set([]byte(`{"data": "some data"}`)) | ||
|
||
req := httptest.NewRequest(http.MethodGet, "/repos", nil) | ||
req.Header.Set("X-Forwarded-For", "192.168.1.1") | ||
|
||
var wg sync.WaitGroup | ||
var mu sync.Mutex | ||
successCount := 0 | ||
failCount := 0 | ||
|
||
// send 20 requests concurrently to trigger the rate limit | ||
for i := 0; i < 20; i++ { | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
rr := httptest.NewRecorder() | ||
rl.Limit(GetRepos(c)).ServeHTTP(rr, req) | ||
|
||
if rr.Code == http.StatusOK { | ||
mu.Lock() | ||
successCount++ | ||
mu.Unlock() | ||
} else if rr.Code == http.StatusTooManyRequests { | ||
mu.Lock() | ||
failCount++ | ||
mu.Unlock() | ||
} | ||
}() | ||
} | ||
wg.Wait() | ||
|
||
require.Equal(t, 10, successCount, "Expected 10 successful requests") | ||
require.Equal(t, 10, failCount, "Expected 10 failed requests") | ||
} |
Oops, something went wrong.