diff --git a/go.mod b/go.mod index 00d86328..8e87a050 100644 --- a/go.mod +++ b/go.mod @@ -67,6 +67,7 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect + github.com/go-chi/httprate v0.14.1 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.5.0 // indirect github.com/go-logr/logr v1.4.2 // indirect diff --git a/go.sum b/go.sum index c3b83096..b4c97452 100644 --- a/go.sum +++ b/go.sum @@ -217,6 +217,8 @@ github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= +github.com/go-chi/httprate v0.14.1 h1:EKZHYEZ58Cg6hWcYzoZILsv7ppb46Wt4uQ738IRtpZs= +github.com/go-chi/httprate v0.14.1/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5cAt5cnTt0= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= github.com/go-git/go-billy/v5 v5.5.0 h1:yEY4yhzCDuMGSv83oGxiBotRzhwhNr8VZyphhiu+mTU= diff --git a/pkg/http/types.go b/pkg/http/types.go index f2de3561..c7c93f37 100644 --- a/pkg/http/types.go +++ b/pkg/http/types.go @@ -1,9 +1,15 @@ package http type ServerOptions struct { - URL string - Host string - Port int + URL string + Host string + Port int + RateLimiter RateLimiterOptions +} + +type RateLimiterOptions struct { + RequestLimit int + WindowLength int } type ClientOptions struct { diff --git a/pkg/options/server.go b/pkg/options/server.go index 752de7f8..8658dfb0 100644 --- a/pkg/options/server.go +++ b/pkg/options/server.go @@ -9,9 +9,17 @@ import ( func GetDefaultServerOptions() http.ServerOptions { return http.ServerOptions{ - URL: GetDefaultServeOptionString("SERVER_URL", ""), - Host: GetDefaultServeOptionString("SERVER_HOST", "0.0.0.0"), - Port: GetDefaultServeOptionInt("SERVER_PORT", 8080), //nolint:gomnd + URL: GetDefaultServeOptionString("SERVER_URL", ""), + Host: GetDefaultServeOptionString("SERVER_HOST", "0.0.0.0"), + Port: GetDefaultServeOptionInt("SERVER_PORT", 8080), //nolint:gomnd + RateLimiter: GetDefaultRateLimiterOptions(), + } +} + +func GetDefaultRateLimiterOptions() http.RateLimiterOptions { + return http.RateLimiterOptions{ + RequestLimit: GetDefaultServeOptionInt("SERVER_RATE_REQUEST_LIMIT", 5), + WindowLength: GetDefaultServeOptionInt("SERVER_RATE_WINDOW_LENGTH", 10), } } @@ -28,6 +36,14 @@ func AddServerCliFlags(cmd *cobra.Command, serverOptions *http.ServerOptions) { &serverOptions.Port, "server-port", serverOptions.Port, `The port to bind the api server to (SERVER_PORT).`, ) + cmd.PersistentFlags().IntVar( + &serverOptions.RateLimiter.RequestLimit, "server-rate-request-limit", serverOptions.RateLimiter.RequestLimit, + `The max requests over the rate window length (SERVER_RATE_REQUEST_LIMIT).`, + ) + cmd.PersistentFlags().IntVar( + &serverOptions.RateLimiter.WindowLength, "server-rate-window-length", serverOptions.RateLimiter.WindowLength, + `The time window over which to limit in seconds (SERVER_RATE_WINDOW_LENGTH).`, + ) } func CheckServerOptions(options http.ServerOptions) error { diff --git a/pkg/solver/server.go b/pkg/solver/server.go index d30b86ab..eb856c5c 100644 --- a/pkg/solver/server.go +++ b/pkg/solver/server.go @@ -11,6 +11,7 @@ import ( "path/filepath" "time" + "github.com/go-chi/httprate" "github.com/gorilla/mux" "github.com/lilypad-tech/lilypad/pkg/data" "github.com/lilypad-tech/lilypad/pkg/http" @@ -65,6 +66,11 @@ func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system subrouter.Use(http.CorsMiddleware) subrouter.Use(otelmux.Middleware("solver", otelmux.WithTracerProvider(tracerProvider))) + subrouter.Use(httprate.Limit( + solverServer.options.RateLimiter.RequestLimit, + time.Duration(solverServer.options.RateLimiter.WindowLength)*time.Second, + httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint), + )) subrouter.HandleFunc("/job_offers", http.GetHandler(solverServer.getJobOffers)).Methods("GET") subrouter.HandleFunc("/job_offers", http.PostHandler(solverServer.addJobOffer)).Methods("POST") diff --git a/stack b/stack index d4374e92..0515e499 100755 --- a/stack +++ b/stack @@ -217,7 +217,7 @@ function solver() { load-local-env export WEB3_PRIVATE_KEY=${SOLVER_PRIVATE_KEY} export LOG_LEVEL=debug - go run . solver --network dev + go run . solver --network dev "$@" } function solver-docker-build() { diff --git a/test/ratelimit_test.go b/test/ratelimit_test.go new file mode 100644 index 00000000..d307464b --- /dev/null +++ b/test/ratelimit_test.go @@ -0,0 +1,86 @@ +package main + +import ( + "fmt" + "net/http" + "os" + "sync" + "testing" + "time" +) + +type rateResult struct { + path string + okCount int + limitedCount int +} + +// This test suite sends 100 requests over approximately half a second. +// We assume the solver uses the default rate limiting settings with +// a request limit of 5 and window length of 10 seconds. +func TestRateLimiter(t *testing.T) { + paths := []string{ + "/api/v1/resource_offers", + "/api/v1/job_offers", + "/api/v1/deals", + } + + var wg sync.WaitGroup + ch := make(chan rateResult, len(paths)) + + // Send off callers to run concurrently + for _, path := range paths { + wg.Add(1) + + go func() { + defer wg.Done() + makeCalls(t, path, ch) + }() + } + + wg.Wait() + close(ch) + + expectedOkCount := 5 + for result := range ch { + if result.okCount > expectedOkCount { + t.Errorf( + "%s allowed %d requests and limited %d requests, but expected limiting after %d requests\n", + result.path, result.okCount, result.limitedCount, expectedOkCount, + ) + } + } +} + +func makeCalls(t *testing.T, path string, ch chan rateResult) { + var okCount int + var limitedCount int + + // Make 100 requests + for range 100 { + requestURL := fmt.Sprintf("http://localhost:%d%s", 8080, path) + res, err := http.Get(requestURL) + + if err != nil { + t.Errorf("Get request failed on %s: %s\n", path, err) + os.Exit(1) + } + + if res.StatusCode == 200 { + okCount++ + } else if res.StatusCode == 429 { + limitedCount++ + } else { + t.Errorf("Expected a 200 or 429 status code, but received a %d\n", res.StatusCode) + } + + // Wait before making next call + time.Sleep(5 * time.Millisecond) + } + + ch <- rateResult{ + path: path, + okCount: okCount, + limitedCount: limitedCount, + } +}