diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 6a9cfdce56d..5c02be4ae2c 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -220,10 +220,11 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { TxResultQueryMode: backend.IndexQueryModeExecutionNodesOnly.String(), // default to ENs only for now }, RestConfig: rest.Config{ - ListenAddress: "", - WriteTimeout: rest.DefaultWriteTimeout, - ReadTimeout: rest.DefaultReadTimeout, - IdleTimeout: rest.DefaultIdleTimeout, + ListenAddress: "", + WriteTimeout: rest.DefaultWriteTimeout, + ReadTimeout: rest.DefaultReadTimeout, + IdleTimeout: rest.DefaultIdleTimeout, + MaxRequestSize: routes.DefaultMaxRequestSize, }, MaxMsgSize: grpcutils.DefaultMaxMsgSize, CompressorName: grpcutils.NoCompressor, @@ -1190,6 +1191,10 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { defaultConfig.rpcConf.RestConfig.ReadTimeout, "timeout to use when reading REST request headers") flags.DurationVar(&builder.rpcConf.RestConfig.IdleTimeout, "rest-idle-timeout", defaultConfig.rpcConf.RestConfig.IdleTimeout, "idle timeout for REST connections") + flags.Int64Var(&builder.rpcConf.RestConfig.MaxRequestSize, + "rest-max-request-size", + defaultConfig.rpcConf.RestConfig.MaxRequestSize, + "the maximum request size in bytes for payload sent over REST server") flags.StringVarP(&builder.rpcConf.CollectionAddr, "static-collection-ingress-addr", "", @@ -1508,6 +1513,10 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { return errors.New("execution-data-indexing-enabled must be set if check-payer-balance is enabled") } + if builder.rpcConf.RestConfig.MaxRequestSize <= 0 { + return errors.New("rest-max-request-size must be greater than 0") + } + return nil }) } diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 98fc1fc701a..9bf08313a48 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -191,10 +191,11 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig { TxResultQueryMode: backend.IndexQueryModeExecutionNodesOnly.String(), // default to ENs only for now }, RestConfig: rest.Config{ - ListenAddress: "", - WriteTimeout: rest.DefaultWriteTimeout, - ReadTimeout: rest.DefaultReadTimeout, - IdleTimeout: rest.DefaultIdleTimeout, + ListenAddress: "", + WriteTimeout: rest.DefaultWriteTimeout, + ReadTimeout: rest.DefaultReadTimeout, + IdleTimeout: rest.DefaultIdleTimeout, + MaxRequestSize: routes.DefaultMaxRequestSize, }, MaxMsgSize: grpcutils.DefaultMaxMsgSize, CompressorName: grpcutils.NoCompressor, @@ -621,6 +622,10 @@ func (builder *ObserverServiceBuilder) extraFlags() { defaultConfig.rpcConf.RestConfig.ReadTimeout, "timeout to use when reading REST request headers") flags.DurationVar(&builder.rpcConf.RestConfig.IdleTimeout, "rest-idle-timeout", defaultConfig.rpcConf.RestConfig.IdleTimeout, "idle timeout for REST connections") + flags.Int64Var(&builder.rpcConf.RestConfig.MaxRequestSize, + "rest-max-request-size", + defaultConfig.rpcConf.RestConfig.MaxRequestSize, + "the maximum request size in bytes for payload sent over REST server") flags.UintVar(&builder.rpcConf.MaxMsgSize, "rpc-max-message-size", defaultConfig.rpcConf.MaxMsgSize, @@ -851,6 +856,10 @@ func (builder *ObserverServiceBuilder) extraFlags() { } } + if builder.rpcConf.RestConfig.MaxRequestSize <= 0 { + return errors.New("rest-max-request-size must be greater than 0") + } + return nil }) } diff --git a/engine/access/rest/routes/handler.go b/engine/access/rest/routes/handler.go index 2779fe32699..6c05266a8f5 100644 --- a/engine/access/rest/routes/handler.go +++ b/engine/access/rest/routes/handler.go @@ -36,12 +36,13 @@ func NewHandler( handlerFunc ApiHandlerFunc, generator models.LinkGenerator, chain flow.Chain, + maxRequestSize int64, ) *Handler { handler := &Handler{ backend: backend, apiHandlerFunc: handlerFunc, linkGenerator: generator, - HttpHandler: NewHttpHandler(logger, chain), + HttpHandler: NewHttpHandler(logger, chain, maxRequestSize), } return handler diff --git a/engine/access/rest/routes/http_handler.go b/engine/access/rest/routes/http_handler.go index f6a190ba0ad..35dbffd52fd 100644 --- a/engine/access/rest/routes/http_handler.go +++ b/engine/access/rest/routes/http_handler.go @@ -16,7 +16,7 @@ import ( "github.com/onflow/flow-go/model/flow" ) -const MaxRequestSize = 2 << 20 // 2MB +const DefaultMaxRequestSize = 2 << 20 // 2MB // HttpHandler is custom http handler implementing custom handler function. // HttpHandler function allows easier handling of errors and responses as it @@ -24,15 +24,19 @@ const MaxRequestSize = 2 << 20 // 2MB type HttpHandler struct { Logger zerolog.Logger Chain flow.Chain + + MaxRequestSize int64 } func NewHttpHandler( logger zerolog.Logger, chain flow.Chain, + maxRequestSize int64, ) *HttpHandler { return &HttpHandler{ - Logger: logger, - Chain: chain, + Logger: logger, + Chain: chain, + MaxRequestSize: maxRequestSize, } } @@ -43,7 +47,7 @@ func (h *HttpHandler) VerifyRequest(w http.ResponseWriter, r *http.Request) erro errLog := h.Logger.With().Str("request_url", r.URL.String()).Logger() // limit requested body size - r.Body = http.MaxBytesReader(w, r.Body, MaxRequestSize) + r.Body = http.MaxBytesReader(w, r.Body, h.MaxRequestSize) err := r.ParseForm() if err != nil { h.errorHandler(w, err, errLog) diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index 57e505d7497..90092c3c4c7 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -46,10 +46,14 @@ func NewRouterBuilder( } // AddRestRoutes adds rest routes to the router. -func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *RouterBuilder { +func (b *RouterBuilder) AddRestRoutes( + backend access.API, + chain flow.Chain, + maxRequestSize int64, +) *RouterBuilder { linkGenerator := models.NewLinkGeneratorImpl(b.v1SubRouter) for _, r := range Routes { - h := NewHandler(b.logger, backend, r.Handler, linkGenerator, chain) + h := NewHandler(b.logger, backend, r.Handler, linkGenerator, chain, maxRequestSize) b.v1SubRouter. Methods(r.Method). Path(r.Pattern). @@ -64,10 +68,11 @@ func (b *RouterBuilder) AddWsRoutes( stateStreamApi state_stream.API, chain flow.Chain, stateStreamConfig backend.Config, + maxRequestSize int64, ) *RouterBuilder { for _, r := range WSRoutes { - h := NewWSHandler(b.logger, stateStreamApi, r.Handler, chain, stateStreamConfig) + h := NewWSHandler(b.logger, stateStreamApi, r.Handler, chain, stateStreamConfig, maxRequestSize) b.v1SubRouter. Methods(r.Method). Path(r.Pattern). diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index feae66f5bf9..8053bf9d356 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -126,6 +126,7 @@ func executeRequest(req *http.Request, backend access.API) *httptest.ResponseRec ).AddRestRoutes( backend, flow.Testnet.Chain(), + DefaultMaxRequestSize, ).Build() rr := httptest.NewRecorder() @@ -144,7 +145,7 @@ func executeWsRequest(req *http.Request, stateStreamApi state_stream.API, respon router := NewRouterBuilder(unittest.Logger(), restCollector).AddWsRoutes( stateStreamApi, - chain, config).Build() + chain, config, DefaultMaxRequestSize).Build() router.ServeHTTP(responseRecorder, req) } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index f2261baa76f..5e680e6e3ed 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -253,6 +253,7 @@ func NewWSHandler( subscribeFunc SubscribeHandlerFunc, chain flow.Chain, stateStreamConfig backend.Config, + maxRequestSize int64, ) *WSHandler { handler := &WSHandler{ subscribeFunc: subscribeFunc, @@ -261,7 +262,7 @@ func NewWSHandler( maxStreams: int32(stateStreamConfig.MaxGlobalStreams), defaultHeartbeatInterval: stateStreamConfig.HeartbeatInterval, activeStreamCount: atomic.NewInt32(0), - HttpHandler: NewHttpHandler(logger, chain), + HttpHandler: NewHttpHandler(logger, chain, maxRequestSize), } return handler diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 0d05fcd67cf..a33e2e24e58 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -27,10 +27,11 @@ const ( ) type Config struct { - ListenAddress string - WriteTimeout time.Duration - ReadTimeout time.Duration - IdleTimeout time.Duration + ListenAddress string + WriteTimeout time.Duration + ReadTimeout time.Duration + IdleTimeout time.Duration + MaxRequestSize int64 } // NewServer returns an HTTP server initialized with the REST API handler @@ -42,9 +43,9 @@ func NewServer(serverAPI access.API, stateStreamApi state_stream.API, stateStreamConfig backend.Config, ) (*http.Server, error) { - builder := routes.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain) + builder := routes.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain, config.MaxRequestSize) if stateStreamApi != nil { - builder.AddWsRoutes(stateStreamApi, chain, stateStreamConfig) + builder.AddWsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } c := cors.New(cors.Options{ diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 96c6aadf150..2f28c8bda48 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -424,7 +424,7 @@ func (suite *RestAPITestSuite) TestRequestSizeRestriction() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() // make a request of size larger than the max permitted size - requestBytes := make([]byte, routes.MaxRequestSize+1) + requestBytes := make([]byte, routes.DefaultMaxRequestSize+1) script := restclient.ScriptsBody{ Script: string(requestBytes), }