From bf4cc9a27868da214f6af2df3bbea22440c39c87 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 8 Nov 2024 09:28:17 +0200 Subject: [PATCH 01/10] Add new websocket handler and skeleton for its deps * Added websocket controller * Added mock block provider * Added data provider factory * Added websocket handler * Added websocket config * Added a tiny POC test for websocket handler --- cmd/observer/node_builder/observer_builder.go | 3 + cmd/util/cmd/run-script/cmd.go | 2 + .../access/handle_irrecoverable_state_test.go | 2 + .../integration_unsecure_grpc_server_test.go | 2 + engine/access/rest/router/router.go | 26 ++- .../access/rest/router/router_test_helpers.go | 4 +- engine/access/rest/server.go | 6 +- engine/access/rest/websockets/config.go | 19 ++ engine/access/rest/websockets/controller.go | 166 ++++++++++++++++++ .../rest/websockets/data_provider/blocks.go | 61 +++++++ .../rest/websockets/data_provider/factory.go | 33 ++++ .../rest/websockets/data_provider/provider.go | 12 ++ engine/access/rest/websockets/handler.go | 63 +++++++ engine/access/rest/websockets/handler_test.go | 85 +++++++++ .../legacy/routes/subscribe_events_test.go | 10 +- engine/access/rest/websockets/models.go | 59 +++++++ .../access/rest/websockets/threadsafe_map.go | 55 ++++++ engine/access/rest_api_test.go | 2 + engine/access/rpc/engine.go | 25 ++- engine/access/rpc/rate_limit_test.go | 2 + engine/access/secure_grpcr_test.go | 2 + 21 files changed, 620 insertions(+), 19 deletions(-) create mode 100644 engine/access/rest/websockets/config.go create mode 100644 engine/access/rest/websockets/controller.go create mode 100644 engine/access/rest/websockets/data_provider/blocks.go create mode 100644 engine/access/rest/websockets/data_provider/factory.go create mode 100644 engine/access/rest/websockets/data_provider/provider.go create mode 100644 engine/access/rest/websockets/handler.go create mode 100644 engine/access/rest/websockets/handler_test.go create mode 100644 engine/access/rest/websockets/models.go create mode 100644 engine/access/rest/websockets/threadsafe_map.go diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 63721725711..4033e18830b 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -44,6 +44,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest" restapiproxy "github.com/onflow/flow-go/engine/access/rest/apiproxy" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" rpcConnection "github.com/onflow/flow-go/engine/access/rpc/connection" @@ -167,6 +168,7 @@ type ObserverServiceConfig struct { registerCacheSize uint programCacheSize uint registerDBPruneThreshold uint64 + websocketConfig websockets.Config } // DefaultObserverServiceConfig defines all the default values for the ObserverServiceConfig @@ -250,6 +252,7 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig { registerCacheSize: 0, programCacheSize: 0, registerDBPruneThreshold: pruner.DefaultThreshold, + websocketConfig: *websockets.NewDefaultWebsocketConfig(), } } diff --git a/cmd/util/cmd/run-script/cmd.go b/cmd/util/cmd/run-script/cmd.go index 1f24d2599c2..dc4d6e381a0 100644 --- a/cmd/util/cmd/run-script/cmd.go +++ b/cmd/util/cmd/run-script/cmd.go @@ -16,6 +16,7 @@ import ( "github.com/onflow/flow-go/cmd/util/ledger/util" "github.com/onflow/flow-go/cmd/util/ledger/util/registers" "github.com/onflow/flow-go/engine/access/rest" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/engine/execution/computation" @@ -169,6 +170,7 @@ func run(*cobra.Command, []string) { metrics.NewNoopCollector(), nil, backend.Config{}, + *websockets.NewDefaultWebsocketConfig(), ) if err != nil { log.Fatal().Err(err).Msg("failed to create server") diff --git a/engine/access/handle_irrecoverable_state_test.go b/engine/access/handle_irrecoverable_state_test.go index 466e94090aa..e9db308e86c 100644 --- a/engine/access/handle_irrecoverable_state_test.go +++ b/engine/access/handle_irrecoverable_state_test.go @@ -22,6 +22,7 @@ import ( accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rest" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -108,6 +109,7 @@ func (suite *IrrecoverableStateTestSuite) SetupTest() { RestConfig: rest.Config{ ListenAddress: unittest.DefaultAddress, }, + WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/integration_unsecure_grpc_server_test.go b/engine/access/integration_unsecure_grpc_server_test.go index f99805687ba..98de205ad66 100644 --- a/engine/access/integration_unsecure_grpc_server_test.go +++ b/engine/access/integration_unsecure_grpc_server_test.go @@ -21,6 +21,7 @@ import ( "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/access/index" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/access/state_stream" @@ -138,6 +139,7 @@ func (suite *SameGRPCPortTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, + WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), } blockCount := 5 diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index c623669d916..74f34f8ff7f 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -2,6 +2,7 @@ package router import ( "fmt" + "net/http" "regexp" "strings" @@ -10,8 +11,9 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/common/middleware" - "github.com/onflow/flow-go/engine/access/rest/http" + flowhttp "github.com/onflow/flow-go/engine/access/rest/http" "github.com/onflow/flow-go/engine/access/rest/http/models" + "github.com/onflow/flow-go/engine/access/rest/websockets" legacyws "github.com/onflow/flow-go/engine/access/rest/websockets/legacy" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -50,7 +52,7 @@ func NewRouterBuilder( func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *RouterBuilder { linkGenerator := models.NewLinkGeneratorImpl(b.v1SubRouter) for _, r := range Routes { - h := http.NewHandler(b.logger, backend, r.Handler, linkGenerator, chain) + h := flowhttp.NewHandler(b.logger, backend, r.Handler, linkGenerator, chain) b.v1SubRouter. Methods(r.Method). Path(r.Pattern). @@ -60,8 +62,8 @@ func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *Rou return b } -// AddWsRoutes adds WebSocket routes to the router. -func (b *RouterBuilder) AddWsRoutes( +// AddLegacyWebsocketsRoutes adds WebSocket routes to the router. +func (b *RouterBuilder) AddLegacyWebsocketsRoutes( stateStreamApi state_stream.API, chain flow.Chain, stateStreamConfig backend.Config, @@ -79,6 +81,22 @@ func (b *RouterBuilder) AddWsRoutes( return b } +func (b *RouterBuilder) AddWebsocketsRoute( + chain flow.Chain, + config *websockets.Config, + streamApi state_stream.API, + streamConfig backend.Config, +) *RouterBuilder { + handler := websockets.NewWebSocketHandler(b.logger, config, chain, streamApi, streamConfig) + b.v1SubRouter. + Methods(http.MethodGet). + Path("/ws"). + Name("ws"). + Handler(handler) + + return b +} + func (b *RouterBuilder) Build() *mux.Router { return b.router } diff --git a/engine/access/rest/router/router_test_helpers.go b/engine/access/rest/router/router_test_helpers.go index 0256e529457..68c46df34f3 100644 --- a/engine/access/rest/router/router_test_helpers.go +++ b/engine/access/rest/router/router_test_helpers.go @@ -133,7 +133,7 @@ func ExecuteRequest(req *http.Request, backend access.API) *httptest.ResponseRec return rr } -func ExecuteWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *TestHijackResponseRecorder, chain flow.Chain) { +func ExecuteLegacyWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *TestHijackResponseRecorder, chain flow.Chain) { restCollector := metrics.NewNoopCollector() config := backend.Config{ @@ -145,7 +145,7 @@ func ExecuteWsRequest(req *http.Request, stateStreamApi state_stream.API, respon router := NewRouterBuilder( unittest.Logger(), restCollector, - ).AddWsRoutes( + ).AddLegacyWebsocketsRoutes( stateStreamApi, chain, config, ).Build() diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index caed80a27c3..d74c8e361ef 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -9,6 +9,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" @@ -41,12 +42,15 @@ func NewServer(serverAPI access.API, restCollector module.RestMetrics, stateStreamApi state_stream.API, stateStreamConfig backend.Config, + wsConfig websockets.Config, ) (*http.Server, error) { builder := router.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain) if stateStreamApi != nil { - builder.AddWsRoutes(stateStreamApi, chain, stateStreamConfig) + builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig) } + builder.AddWebsocketsRoute(chain, &wsConfig, stateStreamApi, stateStreamConfig) + c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedHeaders: []string{"*"}, diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go new file mode 100644 index 00000000000..8354ca4c11d --- /dev/null +++ b/engine/access/rest/websockets/config.go @@ -0,0 +1,19 @@ +package websockets + +import ( + "time" +) + +type Config struct { + MaxSubscriptionsPerConnection uint64 + MaxResponsesPerSecond uint64 + SendMessageTimeout time.Duration +} + +func NewDefaultWebsocketConfig() *Config { + return &Config{ + MaxSubscriptionsPerConnection: 1000, + MaxResponsesPerSecond: 1000, + SendMessageTimeout: 10 * time.Second, + } +} diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go new file mode 100644 index 00000000000..6d1c6f6416b --- /dev/null +++ b/engine/access/rest/websockets/controller.go @@ -0,0 +1,166 @@ +package websockets + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/rs/zerolog" + + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" +) + +type Controller struct { + ctx context.Context + logger zerolog.Logger + config *Config + conn *websocket.Conn + communicationChannel chan interface{} + dataProviders *ThreadSafeMap[uuid.UUID, dp.DataProvider] + dataProvidersFactory *dp.Factory +} + +func NewWebSocketController( + ctx context.Context, + logger zerolog.Logger, + config *Config, + streamApi state_stream.API, + streamConfig backend.Config, + conn *websocket.Conn, +) *Controller { + return &Controller{ + ctx: ctx, + logger: logger.With().Str("component", "websocket-controller").Logger(), + config: config, + conn: conn, + communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? + dataProviders: NewThreadSafeMap[uuid.UUID, dp.DataProvider](), + dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), + } +} + +// HandleConnection manages the WebSocket connection, adding context and error handling. +func (c *Controller) HandleConnection() { + //TODO: configure the connection with ping-pong and deadlines + + go c.readMessagesFromClient(c.ctx) + go c.writeMessagesToClient(c.ctx) +} + +func (c *Controller) writeMessagesToClient(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg := <-c.communicationChannel: + // TODO: handle 'response per second' limits + c.conn.WriteJSON(msg) + } + } +} + +func (c *Controller) readMessagesFromClient(ctx context.Context) { + defer close(c.communicationChannel) + defer c.conn.Close() + + for { + select { + case <-ctx.Done(): + c.logger.Info().Msg("context canceled, stopping read message loop") + return + default: + msg, err := c.readMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { + return + } + c.logger.Warn().Err(err).Msg("error reading message from client") + return + } + + baseMsg, err := c.parseMessage(msg) + if err != nil { + c.logger.Warn().Err(err).Msg("error parsing base message") + return + } + + if err := c.dispatchAction(baseMsg.Action, msg); err != nil { + c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") + } + } + } +} + +func (c *Controller) readMessage() (json.RawMessage, error) { + var message json.RawMessage + if err := c.conn.ReadJSON(&message); err != nil { + return nil, fmt.Errorf("error reading JSON from client: %w", err) + } + return message, nil +} + +func (c *Controller) parseMessage(message json.RawMessage) (BaseMessageRequest, error) { + var baseMsg BaseMessageRequest + if err := json.Unmarshal(message, &baseMsg); err != nil { + return BaseMessageRequest{}, fmt.Errorf("error unmarshalling base message: %w", err) + } + return baseMsg, nil +} + +// dispatchAction routes the action to the appropriate handler based on the action type. +func (c *Controller) dispatchAction(action string, message json.RawMessage) error { + switch action { + case "subscribe": + var subscribeMsg SubscribeMessageRequest + if err := json.Unmarshal(message, &subscribeMsg); err != nil { + return fmt.Errorf("error unmarshalling subscribe message: %w", err) + } + c.handleSubscribe(subscribeMsg) + + case "unsubscribe": + var unsubscribeMsg UnsubscribeMessageRequest + if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { + return fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + } + c.handleUnsubscribe(unsubscribeMsg) + + case "list_subscriptions": + var listMsg ListSubscriptionsMessageRequest + if err := json.Unmarshal(message, &listMsg); err != nil { + return fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + } + c.handleListSubscriptions(listMsg) + + default: + c.logger.Warn().Str("action", action).Msg("unknown action type") + return fmt.Errorf("unknown action type: %s", action) + } + return nil +} + +func (c *Controller) handleSubscribe(msg SubscribeMessageRequest) { + dp := c.dataProvidersFactory.NewDataProvider(c.ctx, c.communicationChannel, msg.Topic) + c.dataProviders.Insert(dp.ID(), dp) + dp.Run() +} + +func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) { + id, err := uuid.Parse(msg.ID) + if err != nil { + c.logger.Warn().Err(err).Str("topic", msg.Topic).Msg("error parsing message ID") + return + } + + dp, ok := c.dataProviders.Get(id) + if ok { + dp.Close() + c.dataProviders.Remove(id) + } +} + +func (c *Controller) handleListSubscriptions(msg ListSubscriptionsMessageRequest) {} diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go new file mode 100644 index 00000000000..4c23bd4b587 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/blocks.go @@ -0,0 +1,61 @@ +package data_provider + +import ( + "context" + + "github.com/google/uuid" + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/state_stream" +) + +type MockBlockProvider struct { + id uuid.UUID + ch chan<- interface{} + topic string + logger zerolog.Logger + ctx context.Context + stopProviderFunc context.CancelFunc + streamApi state_stream.API +} + +func NewMockBlockProvider( + ctx context.Context, + ch chan<- interface{}, + topic string, + logger zerolog.Logger, + streamApi state_stream.API, +) *MockBlockProvider { + ctx, cancel := context.WithCancel(ctx) + return &MockBlockProvider{ + id: uuid.New(), + ch: ch, + topic: topic, + logger: logger.With().Str("component", "block-provider").Logger(), + ctx: ctx, + stopProviderFunc: cancel, + streamApi: streamApi, + } +} + +func (p *MockBlockProvider) Run() { + select { + case <-p.ctx.Done(): + return + default: + p.ch <- "hello" + p.ch <- "world" + } +} + +func (p *MockBlockProvider) ID() uuid.UUID { + return p.id +} + +func (p *MockBlockProvider) Topic() string { + return p.topic +} + +func (p *MockBlockProvider) Close() { + p.stopProviderFunc() +} diff --git a/engine/access/rest/websockets/data_provider/factory.go b/engine/access/rest/websockets/data_provider/factory.go new file mode 100644 index 00000000000..86d69475377 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/factory.go @@ -0,0 +1,33 @@ +package data_provider + +import ( + "context" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" +) + +type Factory struct { + logger zerolog.Logger + streamApi state_stream.API + streamConfig backend.Config +} + +func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, streamConfig backend.Config) *Factory { + return &Factory{ + logger: logger, + streamApi: streamApi, + streamConfig: streamConfig, + } +} + +func (f *Factory) NewDataProvider(ctx context.Context, ch chan<- interface{}, topic string) DataProvider { + switch topic { + case "blocks": + return NewMockBlockProvider(ctx, ch, topic, f.logger, f.streamApi) + default: + return nil + } +} diff --git a/engine/access/rest/websockets/data_provider/provider.go b/engine/access/rest/websockets/data_provider/provider.go new file mode 100644 index 00000000000..e919af590b6 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/provider.go @@ -0,0 +1,12 @@ +package data_provider + +import ( + "github.com/google/uuid" +) + +type DataProvider interface { + Run() + ID() uuid.UUID + Topic() string + Close() +} diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go new file mode 100644 index 00000000000..09fd537bb02 --- /dev/null +++ b/engine/access/rest/websockets/handler.go @@ -0,0 +1,63 @@ +package websockets + +import ( + "context" + "net/http" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/model/flow" +) + +type Handler struct { + *common.BaseHttpHandler + + logger zerolog.Logger + websocketConfig *Config + streamApi state_stream.API + streamConfig backend.Config +} + +var _ http.Handler = (*Handler)(nil) + +func NewWebSocketHandler(logger zerolog.Logger, config *Config, chain flow.Chain, streamApi state_stream.API, streamConfig backend.Config) *Handler { + return &Handler{ + BaseHttpHandler: common.NewHttpHandler(logger, chain), + websocketConfig: config, + logger: logger, + streamApi: streamApi, + streamConfig: streamConfig, + } +} +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + //TODO: change to accept topic instead of URL + logger := h.BaseHttpHandler.Logger.With().Str("websocket_subscribe_url", r.URL.String()).Logger() + + err := h.BaseHttpHandler.VerifyRequest(w, r) + if err != nil { + // VerifyRequest sets the response error before returning + logger.Warn().Err(err).Msg("error validating websocket request") + return + } + + upgrader := websocket.Upgrader{ + // allow all origins by default, operators can override using a proxy + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.BaseHttpHandler.ErrorHandler(w, common.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) + return + } + + ctx := context.Background() + controller := NewWebSocketController(ctx, logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) + controller.HandleConnection() +} diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go new file mode 100644 index 00000000000..124fd2f8a00 --- /dev/null +++ b/engine/access/rest/websockets/handler_test.go @@ -0,0 +1,85 @@ +package websockets_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/websockets" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +var ( + chainID = flow.Testnet +) + +type WsHandlerSuite struct { + suite.Suite + + logger zerolog.Logger + handler *websockets.Handler +} + +func (s *WsHandlerSuite) SetupTest() { + s.logger = unittest.Logger() + wsConfig := websockets.NewDefaultWebsocketConfig() + streamApi := streammock.NewAPI(s.T()) + streamConfig := backend.Config{} + s.handler = websockets.NewWebSocketHandler(s.logger, wsConfig, chainID.Chain(), streamApi, streamConfig) +} + +func TestWsHandlerSuite(t *testing.T) { + suite.Run(t, new(WsHandlerSuite)) +} + +func ClientConnection(url string) (*websocket.Conn, *http.Response, error) { + wsURL := "ws" + strings.TrimPrefix(url, "http") + return websocket.DefaultDialer.Dial(wsURL, nil) +} + +func (s *WsHandlerSuite) TestSubscribeRequest() { + s.Run("Happy path", func() { + server := httptest.NewServer(s.handler) + defer server.Close() + + conn, _, err := ClientConnection(server.URL) + require.NoError(s.T(), err) + + args := map[string]interface{}{ + "start_block_height": 10, + } + body := websockets.SubscribeMessageRequest{ + BaseMessageRequest: websockets.BaseMessageRequest{Action: "subscribe"}, + Topic: "blocks", + Arguments: args, + } + bodyJSON, err := json.Marshal(body) + require.NoError(s.T(), err) + + err = conn.WriteMessage(websocket.TextMessage, bodyJSON) + require.NoError(s.T(), err) + + _, msg, err := conn.ReadMessage() + require.NoError(s.T(), err) + + actualMsg := strings.Trim(string(msg), "\n\"\\ ") + require.Equal(s.T(), "hello", actualMsg) + + _, msg, err = conn.ReadMessage() + require.NoError(s.T(), err) + + actualMsg = strings.Trim(string(msg), "\n\"\\ ") + require.Equal(s.T(), "world", actualMsg) + }) +} diff --git a/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go b/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go index c4353cecae2..a423bd4622f 100644 --- a/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go +++ b/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go @@ -252,7 +252,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { time.Sleep(1 * time.Second) respRecorder.Close() }() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireResponse(s.T(), respRecorder, expectedEventsResponses) }) } @@ -264,7 +264,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "can only provide either block ID or start height") }) @@ -289,7 +289,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "stream encountered an error: subscription error") }) @@ -298,7 +298,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, []string{"foo"}, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "invalid event type format") }) @@ -323,7 +323,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "subscription channel closed") }) } diff --git a/engine/access/rest/websockets/models.go b/engine/access/rest/websockets/models.go new file mode 100644 index 00000000000..42abb8b7241 --- /dev/null +++ b/engine/access/rest/websockets/models.go @@ -0,0 +1,59 @@ +package websockets + +// BaseMessageRequest represents a base structure for incoming messages. +type BaseMessageRequest struct { + Action string `json:"action"` // Action type of the request +} + +// BaseMessageResponse represents a base structure for outgoing messages. +type BaseMessageResponse struct { + Action string `json:"action,omitempty"` // Action type of the response + Success bool `json:"success"` // Indicates success or failure + ErrorMessage string `json:"error_message,omitempty"` // Error message, if any +} + +// SubscribeMessageRequest represents a request to subscribe to a topic. +type SubscribeMessageRequest struct { + BaseMessageRequest + Topic string `json:"topic"` // Topic to subscribe to + Arguments map[string]interface{} `json:"arguments"` // Additional arguments for subscription +} + +// SubscribeMessageResponse represents the response to a subscription request. +type SubscribeMessageResponse struct { + BaseMessageResponse + Topic string `json:"topic"` // Topic of the subscription + ID string `json:"id"` // Unique subscription ID +} + +// UnsubscribeMessageRequest represents a request to unsubscribe from a topic. +type UnsubscribeMessageRequest struct { + BaseMessageRequest + Topic string `json:"topic"` // Topic to unsubscribe from + ID string `json:"id"` // Unique subscription ID +} + +// UnsubscribeMessageResponse represents the response to an unsubscription request. +type UnsubscribeMessageResponse struct { + BaseMessageResponse + Topic string `json:"topic"` // Topic of the unsubscription + ID string `json:"id"` // Unique subscription ID +} + +// ListSubscriptionsMessageRequest represents a request to list active subscriptions. +type ListSubscriptionsMessageRequest struct { + BaseMessageRequest +} + +// SubscriptionEntry represents an active subscription entry. +type SubscriptionEntry struct { + Topic string `json:"topic,omitempty"` // Topic of the subscription + ID string `json:"id,omitempty"` // Unique subscription ID +} + +// ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. +// It contains a list of active subscriptions for the current WebSocket connection. +type ListSubscriptionsMessageResponse struct { + BaseMessageResponse + Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` +} diff --git a/engine/access/rest/websockets/threadsafe_map.go b/engine/access/rest/websockets/threadsafe_map.go new file mode 100644 index 00000000000..3ab265a40fb --- /dev/null +++ b/engine/access/rest/websockets/threadsafe_map.go @@ -0,0 +1,55 @@ +package websockets + +import ( + "sync" +) + +// ThreadSafeMap is a thread-safe map with read-write locking. +type ThreadSafeMap[K comparable, V any] struct { + mu sync.RWMutex + m map[K]V +} + +// NewThreadSafeMap initializes a new ThreadSafeMap. +func NewThreadSafeMap[K comparable, V any]() *ThreadSafeMap[K, V] { + return &ThreadSafeMap[K, V]{ + m: make(map[K]V), + } +} + +// Get retrieves a value for a key, returning the value and a boolean indicating if the key exists. +func (s *ThreadSafeMap[K, V]) Get(key K) (V, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + value, ok := s.m[key] + return value, ok +} + +// Insert inserts or updates a value for a key. +func (s *ThreadSafeMap[K, V]) Insert(key K, value V) { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value +} + +// Remove removes a key and its value from the map. +func (s *ThreadSafeMap[K, V]) Remove(key K) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.m, key) +} + +// Exists checks if a key exists in the map. +func (s *ThreadSafeMap[K, V]) Exists(key K) bool { + s.mu.RLock() + defer s.mu.RUnlock() + _, ok := s.m[key] + return ok +} + +// Len returns the number of elements in the map. +func (s *ThreadSafeMap[K, V]) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.m) +} diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 5d48f6091e4..6c68d3c0553 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -23,6 +23,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rest/common" "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -136,6 +137,7 @@ func (suite *RestAPITestSuite) SetupTest() { RestConfig: rest.Config{ ListenAddress: unittest.DefaultAddress, }, + WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/rpc/engine.go b/engine/access/rpc/engine.go index 145e3d62143..37b60b1a4d3 100644 --- a/engine/access/rpc/engine.go +++ b/engine/access/rpc/engine.go @@ -14,6 +14,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/consensus/hotstuff/model" "github.com/onflow/flow-go/engine/access/rest" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/access/state_stream" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -38,10 +39,11 @@ type Config struct { CollectionAddr string // the address of the upstream collection node HistoricalAccessAddrs string // the list of all access nodes from previous spork - BackendConfig backend.Config // configurable options for creating Backend - RestConfig rest.Config // the REST server configuration - MaxMsgSize uint // GRPC max message size - CompressorName string // GRPC compressor name + BackendConfig backend.Config // configurable options for creating Backend + RestConfig rest.Config // the REST server configuration + MaxMsgSize uint // GRPC max message size + CompressorName string // GRPC compressor name + WebSocketConfig websockets.Config } // Engine exposes the server with a simplified version of the Access API. @@ -75,7 +77,8 @@ type Engine struct { type Option func(*RPCEngineBuilder) // NewBuilder returns a new RPC engine builder. -func NewBuilder(log zerolog.Logger, +func NewBuilder( + log zerolog.Logger, state protocol.State, config Config, chainID flow.ChainID, @@ -240,8 +243,16 @@ func (e *Engine) serveREST(ctx irrecoverable.SignalerContext, ready component.Re e.log.Info().Str("rest_api_address", e.config.RestConfig.ListenAddress).Msg("starting REST server on address") - r, err := rest.NewServer(e.restHandler, e.config.RestConfig, e.log, e.chain, e.restCollector, e.stateStreamBackend, - e.stateStreamConfig) + r, err := rest.NewServer( + e.restHandler, + e.config.RestConfig, + e.log, + e.chain, + e.restCollector, + e.stateStreamBackend, + e.stateStreamConfig, + e.config.WebSocketConfig, + ) if err != nil { e.log.Err(err).Msg("failed to initialize the REST server") ctx.Throw(err) diff --git a/engine/access/rpc/rate_limit_test.go b/engine/access/rpc/rate_limit_test.go index 622b06e3f54..e4f923e98fb 100644 --- a/engine/access/rpc/rate_limit_test.go +++ b/engine/access/rpc/rate_limit_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc/status" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" @@ -115,6 +116,7 @@ func (suite *RateLimitTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, + WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/secure_grpcr_test.go b/engine/access/secure_grpcr_test.go index cc1d1a75cc8..aa92c5db052 100644 --- a/engine/access/secure_grpcr_test.go +++ b/engine/access/secure_grpcr_test.go @@ -19,6 +19,7 @@ import ( "github.com/onflow/crypto" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -110,6 +111,7 @@ func (suite *SecureGRPCTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, + WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server From b76c811c6def23c11700085b1bddb650ab1c4838 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 8 Nov 2024 10:39:33 +0200 Subject: [PATCH 02/10] fix issue after merge --- engine/access/rest/server.go | 2 +- engine/access/rest/websockets/config.go | 2 ++ engine/access/rest/websockets/handler_test.go | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 9fa07e63ff4..f23a683f39d 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -50,7 +50,7 @@ func NewServer(serverAPI access.API, builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } - builder.AddWebsocketsRoute(chain, &wsConfig, stateStreamApi, stateStreamConfig) + builder.AddWebsocketsRoute(chain, &wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize) c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go index 8354ca4c11d..13138e54539 100644 --- a/engine/access/rest/websockets/config.go +++ b/engine/access/rest/websockets/config.go @@ -8,6 +8,7 @@ type Config struct { MaxSubscriptionsPerConnection uint64 MaxResponsesPerSecond uint64 SendMessageTimeout time.Duration + MaxRequestSize int64 } func NewDefaultWebsocketConfig() *Config { @@ -15,5 +16,6 @@ func NewDefaultWebsocketConfig() *Config { MaxSubscriptionsPerConnection: 1000, MaxResponsesPerSecond: 1000, SendMessageTimeout: 10 * time.Second, + MaxRequestSize: 1024, } } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go index 41079159e40..396cfd00dac 100644 --- a/engine/access/rest/websockets/handler_test.go +++ b/engine/access/rest/websockets/handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream/backend" streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) From f88cf9bb827b2d8da8de3525294ab6f040be1ad9 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 8 Nov 2024 11:11:33 +0200 Subject: [PATCH 03/10] generate mocks. add graceful shutdown for controller --- Makefile | 1 + engine/access/rest/websockets/controller.go | 20 ++++- .../data_provider/mock/data_provider.go | 75 +++++++++++++++++++ engine/access/rest/websockets/handler_test.go | 15 ++-- .../access/rest/websockets/threadsafe_map.go | 15 ++++ 5 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 engine/access/rest/websockets/data_provider/mock/data_provider.go diff --git a/Makefile b/Makefile index d0557991462..bc14ed50f9f 100644 --- a/Makefile +++ b/Makefile @@ -203,6 +203,7 @@ generate-mocks: install-mock-generators mockery --name 'API' --dir="./engine/protocol" --case=underscore --output="./engine/protocol/mock" --outpkg="mock" mockery --name '.*' --dir="./engine/access/state_stream" --case=underscore --output="./engine/access/state_stream/mock" --outpkg="mock" mockery --name 'BlockTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" + mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_provider" --case=underscore --output="./engine/access/rest/websockets/data_provider/mock" --outpkg="mock" mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" mockery --name 'Communicator' --dir="./engine/access/rpc/backend" --case=underscore --output="./engine/access/rpc/backend/mock" --outpkg="mock" diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 6d1c6f6416b..751e4ec82a1 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -59,14 +59,17 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) { return case msg := <-c.communicationChannel: // TODO: handle 'response per second' limits - c.conn.WriteJSON(msg) + + err := c.conn.WriteJSON(msg) + if err != nil { + c.logger.Error().Err(err).Msg("error writing to connection") + } } } } func (c *Controller) readMessagesFromClient(ctx context.Context) { - defer close(c.communicationChannel) - defer c.conn.Close() + defer c.shutdownConnection() for { select { @@ -164,3 +167,14 @@ func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) { } func (c *Controller) handleListSubscriptions(msg ListSubscriptionsMessageRequest) {} + +func (c *Controller) shutdownConnection() { + defer c.conn.Close() + defer close(c.communicationChannel) + + c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) { + dp.Close() + }) + + c.dataProviders.Clear() +} diff --git a/engine/access/rest/websockets/data_provider/mock/data_provider.go b/engine/access/rest/websockets/data_provider/mock/data_provider.go new file mode 100644 index 00000000000..6a4aab6e130 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/mock/data_provider.go @@ -0,0 +1,75 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + uuid "github.com/google/uuid" + mock "github.com/stretchr/testify/mock" +) + +// DataProvider is an autogenerated mock type for the DataProvider type +type DataProvider struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *DataProvider) Close() { + _m.Called() +} + +// ID provides a mock function with given fields: +func (_m *DataProvider) ID() uuid.UUID { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ID") + } + + var r0 uuid.UUID + if rf, ok := ret.Get(0).(func() uuid.UUID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(uuid.UUID) + } + } + + return r0 +} + +// Run provides a mock function with given fields: +func (_m *DataProvider) Run() { + _m.Called() +} + +// Topic provides a mock function with given fields: +func (_m *DataProvider) Topic() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Topic") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// NewDataProvider creates a new instance of DataProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDataProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *DataProvider { + mock := &DataProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go index 396cfd00dac..80fd0cdb377 100644 --- a/engine/access/rest/websockets/handler_test.go +++ b/engine/access/rest/websockets/handler_test.go @@ -26,16 +26,19 @@ var ( type WsHandlerSuite struct { suite.Suite - logger zerolog.Logger - handler *websockets.Handler + logger zerolog.Logger + handler *websockets.Handler + wsConfig *websockets.Config + streamApi *streammock.API + streamConfig *backend.Config } func (s *WsHandlerSuite) SetupTest() { s.logger = unittest.Logger() - wsConfig := websockets.NewDefaultWebsocketConfig() - streamApi := streammock.NewAPI(s.T()) - streamConfig := backend.Config{} - s.handler = websockets.NewWebSocketHandler(s.logger, wsConfig, chainID.Chain(), streamApi, streamConfig, 1024) + s.wsConfig = websockets.NewDefaultWebsocketConfig() + s.streamApi = streammock.NewAPI(s.T()) + s.streamConfig = &backend.Config{} + s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, *s.streamConfig, 1024) } func TestWsHandlerSuite(t *testing.T) { diff --git a/engine/access/rest/websockets/threadsafe_map.go b/engine/access/rest/websockets/threadsafe_map.go index 3ab265a40fb..2c3f3438e40 100644 --- a/engine/access/rest/websockets/threadsafe_map.go +++ b/engine/access/rest/websockets/threadsafe_map.go @@ -53,3 +53,18 @@ func (s *ThreadSafeMap[K, V]) Len() int { defer s.mu.RUnlock() return len(s.m) } + +// ForEach applies a function to each key-value pair in the map. +func (s *ThreadSafeMap[K, V]) ForEach(f func(K, V)) { + s.mu.RLock() + defer s.mu.RUnlock() + for k, v := range s.m { + f(k, v) + } +} + +func (s *ThreadSafeMap[K, V]) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + s.m = make(map[K]V) +} From 29380d0ab1466620e0a7f56f149d9a0c7aff3135 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 8 Nov 2024 11:29:11 +0200 Subject: [PATCH 04/10] check err when closing conn --- Makefile | 1 - engine/access/handle_irrecoverable_state_test.go | 2 +- engine/access/rest/websockets/controller.go | 8 +++++--- engine/access/rest/websockets/handler_test.go | 5 ++++- engine/access/rest_api_test.go | 2 +- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index bc14ed50f9f..e79dfc571f6 100644 --- a/Makefile +++ b/Makefile @@ -207,7 +207,6 @@ generate-mocks: install-mock-generators mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" mockery --name 'Communicator' --dir="./engine/access/rpc/backend" --case=underscore --output="./engine/access/rpc/backend/mock" --outpkg="mock" - mockery --name '.*' --dir=model/fingerprint --case=underscore --output="./model/fingerprint/mock" --outpkg="mock" mockery --name 'ExecForkActor' --structname 'ExecForkActorMock' --dir=module/mempool/consensus/mock/ --case=underscore --output="./module/mempool/consensus/mock/" --outpkg="mock" mockery --name '.*' --dir=engine/verification/fetcher/ --case=underscore --output="./engine/verification/fetcher/mock" --outpkg="mockfetcher" diff --git a/engine/access/handle_irrecoverable_state_test.go b/engine/access/handle_irrecoverable_state_test.go index 78a5cd33931..9ab1e58870b 100644 --- a/engine/access/handle_irrecoverable_state_test.go +++ b/engine/access/handle_irrecoverable_state_test.go @@ -22,8 +22,8 @@ import ( accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rest" - "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 751e4ec82a1..a8c5e1275fc 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" - "github.com/rs/zerolog" dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" @@ -169,12 +168,15 @@ func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) { func (c *Controller) handleListSubscriptions(msg ListSubscriptionsMessageRequest) {} func (c *Controller) shutdownConnection() { - defer c.conn.Close() defer close(c.communicationChannel) + defer func(conn *websocket.Conn) { + if err := c.conn.Close(); err != nil { + c.logger.Error().Err(err).Msg("error closing connection") + } + }(c.conn) c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) { dp.Close() }) - c.dataProviders.Clear() } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go index 80fd0cdb377..a7fb4f109ff 100644 --- a/engine/access/rest/websockets/handler_test.go +++ b/engine/access/rest/websockets/handler_test.go @@ -56,7 +56,10 @@ func (s *WsHandlerSuite) TestSubscribeRequest() { defer server.Close() conn, _, err := ClientConnection(server.URL) - defer conn.Close() + defer func(conn *websocket.Conn) { + err := conn.Close() + require.NoError(s.T(), err) + }(conn) require.NoError(s.T(), err) args := map[string]interface{}{ diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 56a8dde971d..c0f50e9a2b1 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -23,8 +23,8 @@ import ( "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rest/common" "github.com/onflow/flow-go/engine/access/rest/http/request" - "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" From b08370d2f86adcbfdb79b6a2f19595a57773d4b5 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Mon, 11 Nov 2024 17:51:12 +0200 Subject: [PATCH 05/10] Fixed comments * Use contexts as function arguments * Move models to folder for consistency * Change parse msg function * Simplify mock block provider to remove dedlock --- cmd/observer/node_builder/observer_builder.go | 2 +- cmd/util/cmd/run-script/cmd.go | 2 +- .../access/handle_irrecoverable_state_test.go | 2 +- .../integration_unsecure_grpc_server_test.go | 2 +- engine/access/rest/router/router.go | 2 +- engine/access/rest/server.go | 2 +- engine/access/rest/websockets/config.go | 4 +- engine/access/rest/websockets/controller.go | 98 +++++++++++-------- .../rest/websockets/data_provider/blocks.go | 5 +- .../rest/websockets/data_provider/provider.go | 4 +- engine/access/rest/websockets/handler.go | 9 +- engine/access/rest/websockets/handler_test.go | 21 ++-- engine/access/rest/websockets/models.go | 59 ----------- .../rest/websockets/models/base_message.go | 13 +++ .../websockets/models/list_subscriptions.go | 13 +++ .../rest/websockets/models/subscribe.go | 15 +++ .../websockets/models/subscription_entry.go | 7 ++ .../rest/websockets/models/unsubscribe.go | 13 +++ .../access/rest/websockets/threadsafe_map.go | 70 ------------- engine/access/rest_api_test.go | 2 +- engine/access/rpc/rate_limit_test.go | 2 +- engine/access/secure_grpcr_test.go | 2 +- engine/common/worker/worker_builder_test.go | 3 +- .../test/gossipsub/scoring/ihave_spam_test.go | 7 +- .../p2p/connection/connection_gater_test.go | 9 +- network/p2p/node/libp2pNode_test.go | 5 +- network/test/cohort1/network_test.go | 3 +- .../concurrent_map.go} | 30 +++--- 28 files changed, 181 insertions(+), 225 deletions(-) delete mode 100644 engine/access/rest/websockets/models.go create mode 100644 engine/access/rest/websockets/models/base_message.go create mode 100644 engine/access/rest/websockets/models/list_subscriptions.go create mode 100644 engine/access/rest/websockets/models/subscribe.go create mode 100644 engine/access/rest/websockets/models/subscription_entry.go create mode 100644 engine/access/rest/websockets/models/unsubscribe.go delete mode 100644 engine/access/rest/websockets/threadsafe_map.go rename utils/{unittest/protected_map.go => concurrentmap/concurrent_map.go} (58%) diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 9ee3b1bd124..1bb6a8c04bb 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -254,7 +254,7 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig { registerCacheSize: 0, programCacheSize: 0, registerDBPruneThreshold: pruner.DefaultThreshold, - websocketConfig: *websockets.NewDefaultWebsocketConfig(), + websocketConfig: websockets.NewDefaultWebsocketConfig(), } } diff --git a/cmd/util/cmd/run-script/cmd.go b/cmd/util/cmd/run-script/cmd.go index dc4d6e381a0..171f97e76b7 100644 --- a/cmd/util/cmd/run-script/cmd.go +++ b/cmd/util/cmd/run-script/cmd.go @@ -170,7 +170,7 @@ func run(*cobra.Command, []string) { metrics.NewNoopCollector(), nil, backend.Config{}, - *websockets.NewDefaultWebsocketConfig(), + websockets.NewDefaultWebsocketConfig(), ) if err != nil { log.Fatal().Err(err).Msg("failed to create server") diff --git a/engine/access/handle_irrecoverable_state_test.go b/engine/access/handle_irrecoverable_state_test.go index 9ab1e58870b..456c5cd97fd 100644 --- a/engine/access/handle_irrecoverable_state_test.go +++ b/engine/access/handle_irrecoverable_state_test.go @@ -110,7 +110,7 @@ func (suite *IrrecoverableStateTestSuite) SetupTest() { RestConfig: rest.Config{ ListenAddress: unittest.DefaultAddress, }, - WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/integration_unsecure_grpc_server_test.go b/engine/access/integration_unsecure_grpc_server_test.go index 98de205ad66..3c4aeca97d4 100644 --- a/engine/access/integration_unsecure_grpc_server_test.go +++ b/engine/access/integration_unsecure_grpc_server_test.go @@ -139,7 +139,7 @@ func (suite *SameGRPCPortTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, - WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } blockCount := 5 diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index d5b37781f7c..a2d81cb0a58 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -88,7 +88,7 @@ func (b *RouterBuilder) AddLegacyWebsocketsRoutes( func (b *RouterBuilder) AddWebsocketsRoute( chain flow.Chain, - config *websockets.Config, + config websockets.Config, streamApi state_stream.API, streamConfig backend.Config, maxRequestSize int64, diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index f23a683f39d..0e582d0bee4 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -50,7 +50,7 @@ func NewServer(serverAPI access.API, builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } - builder.AddWebsocketsRoute(chain, &wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize) + builder.AddWebsocketsRoute(chain, wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize) c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go index 13138e54539..7f563ba94b9 100644 --- a/engine/access/rest/websockets/config.go +++ b/engine/access/rest/websockets/config.go @@ -11,8 +11,8 @@ type Config struct { MaxRequestSize int64 } -func NewDefaultWebsocketConfig() *Config { - return &Config{ +func NewDefaultWebsocketConfig() Config { + return Config{ MaxSubscriptionsPerConnection: 1000, MaxResponsesPerSecond: 1000, SendMessageTimeout: 10 * time.Second, diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index a8c5e1275fc..87ceae35b7a 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -10,45 +10,44 @@ import ( "github.com/rs/zerolog" dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/utils/concurrentmap" ) type Controller struct { - ctx context.Context logger zerolog.Logger - config *Config + config Config conn *websocket.Conn communicationChannel chan interface{} - dataProviders *ThreadSafeMap[uuid.UUID, dp.DataProvider] + dataProviders *concurrentmap.ConcurrentMap[uuid.UUID, dp.DataProvider] dataProvidersFactory *dp.Factory } func NewWebSocketController( - ctx context.Context, logger zerolog.Logger, - config *Config, + config Config, streamApi state_stream.API, streamConfig backend.Config, conn *websocket.Conn, ) *Controller { return &Controller{ - ctx: ctx, logger: logger.With().Str("component", "websocket-controller").Logger(), config: config, conn: conn, communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? - dataProviders: NewThreadSafeMap[uuid.UUID, dp.DataProvider](), + dataProviders: concurrentmap.NewConcurrentMap[uuid.UUID, dp.DataProvider](), dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), } } // HandleConnection manages the WebSocket connection, adding context and error handling. -func (c *Controller) HandleConnection() { +func (c *Controller) HandleConnection(ctx context.Context) { //TODO: configure the connection with ping-pong and deadlines - - go c.readMessagesFromClient(c.ctx) - go c.writeMessagesToClient(c.ctx) + //TODO: spin up a response limit tracker routine + go c.readMessagesFromClient(ctx) + go c.writeMessagesToClient(ctx) } func (c *Controller) writeMessagesToClient(ctx context.Context) { @@ -85,13 +84,13 @@ func (c *Controller) readMessagesFromClient(ctx context.Context) { return } - baseMsg, err := c.parseMessage(msg) + baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) if err != nil { - c.logger.Warn().Err(err).Msg("error parsing base message") + c.logger.Debug().Err(err).Msg("error parsing and validating client message") return } - if err := c.dispatchAction(baseMsg.Action, msg); err != nil { + if err := c.handleAction(ctx, baseMsg.Action, validatedMsg); err != nil { c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") } } @@ -106,55 +105,68 @@ func (c *Controller) readMessage() (json.RawMessage, error) { return message, nil } -func (c *Controller) parseMessage(message json.RawMessage) (BaseMessageRequest, error) { - var baseMsg BaseMessageRequest +func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.BaseMessageRequest, interface{}, error) { + var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { - return BaseMessageRequest{}, fmt.Errorf("error unmarshalling base message: %w", err) + return models.BaseMessageRequest{}, nil, fmt.Errorf("error unmarshalling base message: %w", err) } - return baseMsg, nil -} -// dispatchAction routes the action to the appropriate handler based on the action type. -func (c *Controller) dispatchAction(action string, message json.RawMessage) error { - switch action { + var validatedMsg interface{} + switch baseMsg.Action { case "subscribe": - var subscribeMsg SubscribeMessageRequest + var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { - return fmt.Errorf("error unmarshalling subscribe message: %w", err) + return baseMsg, nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) } - c.handleSubscribe(subscribeMsg) + //TODO: add validation logic for `topic` field + validatedMsg = subscribeMsg case "unsubscribe": - var unsubscribeMsg UnsubscribeMessageRequest + var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { - return fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + return baseMsg, nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } - c.handleUnsubscribe(unsubscribeMsg) + validatedMsg = unsubscribeMsg case "list_subscriptions": - var listMsg ListSubscriptionsMessageRequest + var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { - return fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + return baseMsg, nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) } - c.handleListSubscriptions(listMsg) + validatedMsg = listMsg + + default: + c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") + return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) + } + + return baseMsg, validatedMsg, nil +} +func (c *Controller) handleAction(ctx context.Context, action string, message interface{}) error { + switch action { + case "subscribe": + c.handleSubscribe(ctx, message.(models.SubscribeMessageRequest)) + case "unsubscribe": + c.handleUnsubscribe(ctx, message.(models.UnsubscribeMessageRequest)) + case "list_subscriptions": + c.handleListSubscriptions(ctx, message.(models.ListSubscriptionsMessageRequest)) default: - c.logger.Warn().Str("action", action).Msg("unknown action type") return fmt.Errorf("unknown action type: %s", action) } return nil } -func (c *Controller) handleSubscribe(msg SubscribeMessageRequest) { - dp := c.dataProvidersFactory.NewDataProvider(c.ctx, c.communicationChannel, msg.Topic) - c.dataProviders.Insert(dp.ID(), dp) - dp.Run() +func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { + dp := c.dataProvidersFactory.NewDataProvider(ctx, c.communicationChannel, msg.Topic) + c.dataProviders.Add(dp.ID(), dp) + dp.Run(ctx) } -func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) { +func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { id, err := uuid.Parse(msg.ID) if err != nil { - c.logger.Warn().Err(err).Str("topic", msg.Topic).Msg("error parsing message ID") + c.logger.Debug().Err(err).Msg("error parsing message ID") return } @@ -165,7 +177,8 @@ func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) { } } -func (c *Controller) handleListSubscriptions(msg ListSubscriptionsMessageRequest) {} +func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { +} func (c *Controller) shutdownConnection() { defer close(c.communicationChannel) @@ -175,8 +188,13 @@ func (c *Controller) shutdownConnection() { } }(c.conn) - c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) { + err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { dp.Close() + return nil }) + if err != nil { + c.logger.Error().Err(err).Msg("error closing data provider") + } + c.dataProviders.Clear() } diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go index 4c23bd4b587..7ec83e30fcd 100644 --- a/engine/access/rest/websockets/data_provider/blocks.go +++ b/engine/access/rest/websockets/data_provider/blocks.go @@ -38,13 +38,12 @@ func NewMockBlockProvider( } } -func (p *MockBlockProvider) Run() { +func (p *MockBlockProvider) Run(_ context.Context) { select { case <-p.ctx.Done(): return default: - p.ch <- "hello" - p.ch <- "world" + p.ch <- "hello world" } } diff --git a/engine/access/rest/websockets/data_provider/provider.go b/engine/access/rest/websockets/data_provider/provider.go index e919af590b6..ce2914140ba 100644 --- a/engine/access/rest/websockets/data_provider/provider.go +++ b/engine/access/rest/websockets/data_provider/provider.go @@ -1,11 +1,13 @@ package data_provider import ( + "context" + "github.com/google/uuid" ) type DataProvider interface { - Run() + Run(ctx context.Context) ID() uuid.UUID Topic() string Close() diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index 7bc381349f9..ff385f826ef 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -17,7 +17,7 @@ type Handler struct { *common.HttpHandler logger zerolog.Logger - websocketConfig *Config + websocketConfig Config streamApi state_stream.API streamConfig backend.Config } @@ -26,7 +26,7 @@ var _ http.Handler = (*Handler)(nil) func NewWebSocketHandler( logger zerolog.Logger, - config *Config, + config Config, chain flow.Chain, streamApi state_stream.API, streamConfig backend.Config, @@ -64,7 +64,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx := context.Background() - controller := NewWebSocketController(ctx, logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) - controller.HandleConnection() + controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) + controller.HandleConnection(context.TODO()) } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go index a7fb4f109ff..ebc83b00bdd 100644 --- a/engine/access/rest/websockets/handler_test.go +++ b/engine/access/rest/websockets/handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/onflow/flow-go/engine/access/rest/websockets" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" "github.com/onflow/flow-go/engine/access/state_stream/backend" streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" @@ -28,17 +29,17 @@ type WsHandlerSuite struct { logger zerolog.Logger handler *websockets.Handler - wsConfig *websockets.Config + wsConfig websockets.Config streamApi *streammock.API - streamConfig *backend.Config + streamConfig backend.Config } func (s *WsHandlerSuite) SetupTest() { s.logger = unittest.Logger() s.wsConfig = websockets.NewDefaultWebsocketConfig() s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = &backend.Config{} - s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, *s.streamConfig, 1024) + s.streamConfig = backend.Config{} + s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, s.streamConfig, 1024) } func TestWsHandlerSuite(t *testing.T) { @@ -65,8 +66,8 @@ func (s *WsHandlerSuite) TestSubscribeRequest() { args := map[string]interface{}{ "start_block_height": 10, } - body := websockets.SubscribeMessageRequest{ - BaseMessageRequest: websockets.BaseMessageRequest{Action: "subscribe"}, + body := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", Arguments: args, } @@ -80,12 +81,6 @@ func (s *WsHandlerSuite) TestSubscribeRequest() { require.NoError(s.T(), err) actualMsg := strings.Trim(string(msg), "\n\"\\ ") - require.Equal(s.T(), "hello", actualMsg) - - _, msg, err = conn.ReadMessage() - require.NoError(s.T(), err) - - actualMsg = strings.Trim(string(msg), "\n\"\\ ") - require.Equal(s.T(), "world", actualMsg) + require.Equal(s.T(), "hello world", actualMsg) }) } diff --git a/engine/access/rest/websockets/models.go b/engine/access/rest/websockets/models.go deleted file mode 100644 index 42abb8b7241..00000000000 --- a/engine/access/rest/websockets/models.go +++ /dev/null @@ -1,59 +0,0 @@ -package websockets - -// BaseMessageRequest represents a base structure for incoming messages. -type BaseMessageRequest struct { - Action string `json:"action"` // Action type of the request -} - -// BaseMessageResponse represents a base structure for outgoing messages. -type BaseMessageResponse struct { - Action string `json:"action,omitempty"` // Action type of the response - Success bool `json:"success"` // Indicates success or failure - ErrorMessage string `json:"error_message,omitempty"` // Error message, if any -} - -// SubscribeMessageRequest represents a request to subscribe to a topic. -type SubscribeMessageRequest struct { - BaseMessageRequest - Topic string `json:"topic"` // Topic to subscribe to - Arguments map[string]interface{} `json:"arguments"` // Additional arguments for subscription -} - -// SubscribeMessageResponse represents the response to a subscription request. -type SubscribeMessageResponse struct { - BaseMessageResponse - Topic string `json:"topic"` // Topic of the subscription - ID string `json:"id"` // Unique subscription ID -} - -// UnsubscribeMessageRequest represents a request to unsubscribe from a topic. -type UnsubscribeMessageRequest struct { - BaseMessageRequest - Topic string `json:"topic"` // Topic to unsubscribe from - ID string `json:"id"` // Unique subscription ID -} - -// UnsubscribeMessageResponse represents the response to an unsubscription request. -type UnsubscribeMessageResponse struct { - BaseMessageResponse - Topic string `json:"topic"` // Topic of the unsubscription - ID string `json:"id"` // Unique subscription ID -} - -// ListSubscriptionsMessageRequest represents a request to list active subscriptions. -type ListSubscriptionsMessageRequest struct { - BaseMessageRequest -} - -// SubscriptionEntry represents an active subscription entry. -type SubscriptionEntry struct { - Topic string `json:"topic,omitempty"` // Topic of the subscription - ID string `json:"id,omitempty"` // Unique subscription ID -} - -// ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. -// It contains a list of active subscriptions for the current WebSocket connection. -type ListSubscriptionsMessageResponse struct { - BaseMessageResponse - Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` -} diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go new file mode 100644 index 00000000000..f56d62fda8f --- /dev/null +++ b/engine/access/rest/websockets/models/base_message.go @@ -0,0 +1,13 @@ +package models + +// BaseMessageRequest represents a base structure for incoming messages. +type BaseMessageRequest struct { + Action string `json:"action"` // Action type of the request +} + +// BaseMessageResponse represents a base structure for outgoing messages. +type BaseMessageResponse struct { + Action string `json:"action,omitempty"` // Action type of the response + Success bool `json:"success"` // Indicates success or failure + ErrorMessage string `json:"error_message,omitempty"` // Error message, if any +} diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go new file mode 100644 index 00000000000..26174869585 --- /dev/null +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -0,0 +1,13 @@ +package models + +// ListSubscriptionsMessageRequest represents a request to list active subscriptions. +type ListSubscriptionsMessageRequest struct { + BaseMessageRequest +} + +// ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. +// It contains a list of active subscriptions for the current WebSocket connection. +type ListSubscriptionsMessageResponse struct { + BaseMessageResponse + Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` +} diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe.go new file mode 100644 index 00000000000..993bd63b811 --- /dev/null +++ b/engine/access/rest/websockets/models/subscribe.go @@ -0,0 +1,15 @@ +package models + +// SubscribeMessageRequest represents a request to subscribe to a topic. +type SubscribeMessageRequest struct { + BaseMessageRequest + Topic string `json:"topic"` // Topic to subscribe to + Arguments map[string]interface{} `json:"arguments"` // Additional arguments for subscription +} + +// SubscribeMessageResponse represents the response to a subscription request. +type SubscribeMessageResponse struct { + BaseMessageResponse + Topic string `json:"topic"` // Topic of the subscription + ID string `json:"id"` // Unique subscription ID +} diff --git a/engine/access/rest/websockets/models/subscription_entry.go b/engine/access/rest/websockets/models/subscription_entry.go new file mode 100644 index 00000000000..d3f2b352bb7 --- /dev/null +++ b/engine/access/rest/websockets/models/subscription_entry.go @@ -0,0 +1,7 @@ +package models + +// SubscriptionEntry represents an active subscription entry. +type SubscriptionEntry struct { + Topic string `json:"topic,omitempty"` // Topic of the subscription + ID string `json:"id,omitempty"` // Unique subscription ID +} diff --git a/engine/access/rest/websockets/models/unsubscribe.go b/engine/access/rest/websockets/models/unsubscribe.go new file mode 100644 index 00000000000..2024bb922e0 --- /dev/null +++ b/engine/access/rest/websockets/models/unsubscribe.go @@ -0,0 +1,13 @@ +package models + +// UnsubscribeMessageRequest represents a request to unsubscribe from a topic. +type UnsubscribeMessageRequest struct { + BaseMessageRequest + ID string `json:"id"` // Unique subscription ID +} + +// UnsubscribeMessageResponse represents the response to an unsubscription request. +type UnsubscribeMessageResponse struct { + BaseMessageResponse + ID string `json:"id"` // Unique subscription ID +} diff --git a/engine/access/rest/websockets/threadsafe_map.go b/engine/access/rest/websockets/threadsafe_map.go deleted file mode 100644 index 2c3f3438e40..00000000000 --- a/engine/access/rest/websockets/threadsafe_map.go +++ /dev/null @@ -1,70 +0,0 @@ -package websockets - -import ( - "sync" -) - -// ThreadSafeMap is a thread-safe map with read-write locking. -type ThreadSafeMap[K comparable, V any] struct { - mu sync.RWMutex - m map[K]V -} - -// NewThreadSafeMap initializes a new ThreadSafeMap. -func NewThreadSafeMap[K comparable, V any]() *ThreadSafeMap[K, V] { - return &ThreadSafeMap[K, V]{ - m: make(map[K]V), - } -} - -// Get retrieves a value for a key, returning the value and a boolean indicating if the key exists. -func (s *ThreadSafeMap[K, V]) Get(key K) (V, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - value, ok := s.m[key] - return value, ok -} - -// Insert inserts or updates a value for a key. -func (s *ThreadSafeMap[K, V]) Insert(key K, value V) { - s.mu.Lock() - defer s.mu.Unlock() - s.m[key] = value -} - -// Remove removes a key and its value from the map. -func (s *ThreadSafeMap[K, V]) Remove(key K) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.m, key) -} - -// Exists checks if a key exists in the map. -func (s *ThreadSafeMap[K, V]) Exists(key K) bool { - s.mu.RLock() - defer s.mu.RUnlock() - _, ok := s.m[key] - return ok -} - -// Len returns the number of elements in the map. -func (s *ThreadSafeMap[K, V]) Len() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.m) -} - -// ForEach applies a function to each key-value pair in the map. -func (s *ThreadSafeMap[K, V]) ForEach(f func(K, V)) { - s.mu.RLock() - defer s.mu.RUnlock() - for k, v := range s.m { - f(k, v) - } -} - -func (s *ThreadSafeMap[K, V]) Clear() { - s.mu.Lock() - defer s.mu.Unlock() - s.m = make(map[K]V) -} diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index c0f50e9a2b1..651adb41a63 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -138,7 +138,7 @@ func (suite *RestAPITestSuite) SetupTest() { RestConfig: rest.Config{ ListenAddress: unittest.DefaultAddress, }, - WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/rpc/rate_limit_test.go b/engine/access/rpc/rate_limit_test.go index e4f923e98fb..7148cdfefad 100644 --- a/engine/access/rpc/rate_limit_test.go +++ b/engine/access/rpc/rate_limit_test.go @@ -116,7 +116,7 @@ func (suite *RateLimitTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, - WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/secure_grpcr_test.go b/engine/access/secure_grpcr_test.go index aa92c5db052..6ffa8f8d324 100644 --- a/engine/access/secure_grpcr_test.go +++ b/engine/access/secure_grpcr_test.go @@ -111,7 +111,7 @@ func (suite *SecureGRPCTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, - WebSocketConfig: *websockets.NewDefaultWebsocketConfig(), + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/common/worker/worker_builder_test.go b/engine/common/worker/worker_builder_test.go index c08da0769c3..160f23844f5 100644 --- a/engine/common/worker/worker_builder_test.go +++ b/engine/common/worker/worker_builder_test.go @@ -14,6 +14,7 @@ import ( "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/mempool/queue" "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -115,7 +116,7 @@ func TestWorkerPool_TwoWorkers_ConcurrentEvents(t *testing.T) { } q := queue.NewHeroStore(uint32(size), unittest.Logger(), metrics.NewNoopCollector()) - distributedEvents := unittest.NewProtectedMap[string, struct{}]() + distributedEvents := concurrentmap.NewConcurrentMap[string, struct{}]() allEventsDistributed := sync.WaitGroup{} allEventsDistributed.Add(size) diff --git a/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go b/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go index c43b7435f55..5f2ff0f0e6d 100644 --- a/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go +++ b/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go @@ -19,6 +19,7 @@ import ( "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p" p2ptest "github.com/onflow/flow-go/network/p2p/test" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -36,7 +37,7 @@ func TestGossipSubIHaveBrokenPromises_Below_Threshold(t *testing.T) { sporkId := unittest.IdentifierFixture() blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId) - receivedIWants := unittest.NewProtectedMap[string, struct{}]() + receivedIWants := concurrentmap.NewConcurrentMap[string, struct{}]() idProvider := unittest.NewUpdatableIDProvider(flow.IdentityList{}) spammer := corruptlibp2p.NewGossipSubRouterSpammerWithRpcInspector(t, sporkId, role, idProvider, func(id peer.ID, rpc *corrupt.RPC) error { // override rpc inspector of the spammer node to keep track of the iwants it has received. @@ -188,7 +189,7 @@ func TestGossipSubIHaveBrokenPromises_Above_Threshold(t *testing.T) { sporkId := unittest.IdentifierFixture() blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId) - receivedIWants := unittest.NewProtectedMap[string, struct{}]() + receivedIWants := concurrentmap.NewConcurrentMap[string, struct{}]() idProvider := unittest.NewUpdatableIDProvider(flow.IdentityList{}) spammer := corruptlibp2p.NewGossipSubRouterSpammerWithRpcInspector(t, sporkId, role, idProvider, func(id peer.ID, rpc *corrupt.RPC) error { // override rpc inspector of the spammer node to keep track of the iwants it has received. @@ -437,7 +438,7 @@ func TestGossipSubIHaveBrokenPromises_Above_Threshold(t *testing.T) { func spamIHaveBrokenPromise(t *testing.T, spammer *corruptlibp2p.GossipSubRouterSpammer, topic string, - receivedIWants *unittest.ProtectedMap[string, struct{}], + receivedIWants *concurrentmap.ConcurrentMap[string, struct{}], victimNode p2p.LibP2PNode) { rpcCount := 10 // we can't send more than one iHave per RPC in this test, as each iHave should have a distinct topic, and we only have one subscribed topic. diff --git a/network/p2p/connection/connection_gater_test.go b/network/p2p/connection/connection_gater_test.go index ed8777d3f90..7794caa8110 100644 --- a/network/p2p/connection/connection_gater_test.go +++ b/network/p2p/connection/connection_gater_test.go @@ -24,6 +24,7 @@ import ( mockp2p "github.com/onflow/flow-go/network/p2p/mock" p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/network/p2p/unicast/stream" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -35,7 +36,7 @@ func TestConnectionGating(t *testing.T) { sporkID := unittest.IdentifierFixture() idProvider := mockmodule.NewIdentityProvider(t) // create 2 nodes - node1Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node1Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() node1, node1Id := p2ptest.NodeFixture( t, sporkID, @@ -49,7 +50,7 @@ func TestConnectionGating(t *testing.T) { }))) idProvider.On("ByPeerID", node1.ID()).Return(&node1Id, true).Maybe() - node2Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node2Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() node2, node2Id := p2ptest.NodeFixture( t, sporkID, @@ -246,7 +247,7 @@ func TestConnectionGater_InterceptUpgrade(t *testing.T) { inbounds := make([]chan string, 0, count) identities := make(flow.IdentityList, 0, count) - disallowedPeerIds := unittest.NewProtectedMap[peer.ID, struct{}]() + disallowedPeerIds := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() allPeerIds := make(peer.IDSlice, 0, count) idProvider := mockmodule.NewIdentityProvider(t) connectionGater := mockp2p.NewConnectionGater(t) @@ -331,7 +332,7 @@ func TestConnectionGater_Disallow_Integration(t *testing.T) { ids := flow.IdentityList{} inbounds := make([]chan string, 0, 5) - disallowedList := unittest.NewProtectedMap[*flow.Identity, struct{}]() + disallowedList := concurrentmap.NewConcurrentMap[*flow.Identity, struct{}]() for i := 0; i < count; i++ { handler, inbound := p2ptest.StreamHandlerFixture(t) diff --git a/network/p2p/node/libp2pNode_test.go b/network/p2p/node/libp2pNode_test.go index 9a538bd269b..b0c08560e43 100644 --- a/network/p2p/node/libp2pNode_test.go +++ b/network/p2p/node/libp2pNode_test.go @@ -24,6 +24,7 @@ import ( p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/network/p2p/utils" validator "github.com/onflow/flow-go/network/validator/pubsub" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -158,7 +159,7 @@ func TestConnGater(t *testing.T) { sporkID := unittest.IdentifierFixture() idProvider := mockmodule.NewIdentityProvider(t) - node1Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node1Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() node1, identity1 := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithConnectionGater(p2ptest.NewConnectionGater(idProvider, func(pid peer.ID) error { if !node1Peers.Has(pid) { return fmt.Errorf("peer id not found: %s", p2plogging.PeerId(pid)) @@ -173,7 +174,7 @@ func TestConnGater(t *testing.T) { node1Info, err := utils.PeerAddressInfo(identity1.IdentitySkeleton) assert.NoError(t, err) - node2Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node2Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() node2, identity2 := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithConnectionGater(p2ptest.NewConnectionGater(idProvider, func(pid peer.ID) error { if !node2Peers.Has(pid) { return fmt.Errorf("id not found: %s", p2plogging.PeerId(pid)) diff --git a/network/test/cohort1/network_test.go b/network/test/cohort1/network_test.go index bffd3ac52b7..f546dcfa54d 100644 --- a/network/test/cohort1/network_test.go +++ b/network/test/cohort1/network_test.go @@ -40,6 +40,7 @@ import ( "github.com/onflow/flow-go/network/p2p/unicast/ratelimit" "github.com/onflow/flow-go/network/p2p/utils/ratelimiter" "github.com/onflow/flow-go/network/underlay" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -617,7 +618,7 @@ func (suite *NetworkTestSuite) MultiPing(count int) { senderNodeIndex := 0 targetNodeIndex := suite.size - 1 - receivedPayloads := unittest.NewProtectedMap[string, struct{}]() // keep track of unique payloads received. + receivedPayloads := concurrentmap.NewConcurrentMap[string, struct{}]() // keep track of unique payloads received. // regex to extract the payload from the message regex := regexp.MustCompile(`^hello from: \d`) diff --git a/utils/unittest/protected_map.go b/utils/concurrentmap/concurrent_map.go similarity index 58% rename from utils/unittest/protected_map.go rename to utils/concurrentmap/concurrent_map.go index a2af2f5f513..fb946733a24 100644 --- a/utils/unittest/protected_map.go +++ b/utils/concurrentmap/concurrent_map.go @@ -1,36 +1,36 @@ -package unittest +package concurrentmap import "sync" -// ProtectedMap is a thread-safe map. -type ProtectedMap[K comparable, V any] struct { +// ConcurrentMap is a thread-safe map. +type ConcurrentMap[K comparable, V any] struct { mu sync.RWMutex m map[K]V } -// NewProtectedMap returns a new ProtectedMap with the given types -func NewProtectedMap[K comparable, V any]() *ProtectedMap[K, V] { - return &ProtectedMap[K, V]{ +// NewConcurrentMap returns a new ConcurrentMap with the given types +func NewConcurrentMap[K comparable, V any]() *ConcurrentMap[K, V] { + return &ConcurrentMap[K, V]{ m: make(map[K]V), } } // Add adds a key-value pair to the map -func (p *ProtectedMap[K, V]) Add(key K, value V) { +func (p *ConcurrentMap[K, V]) Add(key K, value V) { p.mu.Lock() defer p.mu.Unlock() p.m[key] = value } // Remove removes a key-value pair from the map -func (p *ProtectedMap[K, V]) Remove(key K) { +func (p *ConcurrentMap[K, V]) Remove(key K) { p.mu.Lock() defer p.mu.Unlock() delete(p.m, key) } // Has returns true if the map contains the given key -func (p *ProtectedMap[K, V]) Has(key K) bool { +func (p *ConcurrentMap[K, V]) Has(key K) bool { p.mu.RLock() defer p.mu.RUnlock() _, ok := p.m[key] @@ -38,7 +38,7 @@ func (p *ProtectedMap[K, V]) Has(key K) bool { } // Get returns the value for the given key and a boolean indicating if the key was found -func (p *ProtectedMap[K, V]) Get(key K) (V, bool) { +func (p *ConcurrentMap[K, V]) Get(key K) (V, bool) { p.mu.RLock() defer p.mu.RUnlock() value, ok := p.m[key] @@ -47,7 +47,7 @@ func (p *ProtectedMap[K, V]) Get(key K) (V, bool) { // ForEach iterates over the map and calls the given function for each key-value pair. // If the function returns an error, the iteration is stopped and the error is returned. -func (p *ProtectedMap[K, V]) ForEach(fn func(k K, v V) error) error { +func (p *ConcurrentMap[K, V]) ForEach(fn func(k K, v V) error) error { p.mu.RLock() defer p.mu.RUnlock() for k, v := range p.m { @@ -59,8 +59,14 @@ func (p *ProtectedMap[K, V]) ForEach(fn func(k K, v V) error) error { } // Size returns the size of the map. -func (p *ProtectedMap[K, V]) Size() int { +func (p *ConcurrentMap[K, V]) Size() int { p.mu.RLock() defer p.mu.RUnlock() return len(p.m) } + +func (p *ConcurrentMap[K, V]) Clear() { + p.mu.Lock() + defer p.mu.Unlock() + p.m = make(map[K]V) +} From dbaa54524d350ce74f890f397c6c20b8bfd06f09 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Mon, 11 Nov 2024 18:06:47 +0200 Subject: [PATCH 06/10] add additional space --- engine/access/rest/websockets/handler.go | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index ff385f826ef..911c8fc55b4 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -40,6 +40,7 @@ func NewWebSocketHandler( streamConfig: streamConfig, } } + func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { //TODO: change to accept topic instead of URL logger := h.HttpHandler.Logger.With().Str("websocket_subscribe_url", r.URL.String()).Logger() From b30d63d4494c341ccc0d38fa89ab7a7ebc40f68e Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Mon, 11 Nov 2024 18:41:41 +0200 Subject: [PATCH 07/10] regen data provider mock --- .../websockets/data_provider/mock/data_provider.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/engine/access/rest/websockets/data_provider/mock/data_provider.go b/engine/access/rest/websockets/data_provider/mock/data_provider.go index 6a4aab6e130..4a2a22a44a0 100644 --- a/engine/access/rest/websockets/data_provider/mock/data_provider.go +++ b/engine/access/rest/websockets/data_provider/mock/data_provider.go @@ -3,8 +3,11 @@ package mock import ( - uuid "github.com/google/uuid" + context "context" + mock "github.com/stretchr/testify/mock" + + uuid "github.com/google/uuid" ) // DataProvider is an autogenerated mock type for the DataProvider type @@ -37,9 +40,9 @@ func (_m *DataProvider) ID() uuid.UUID { return r0 } -// Run provides a mock function with given fields: -func (_m *DataProvider) Run() { - _m.Called() +// Run provides a mock function with given fields: ctx +func (_m *DataProvider) Run(ctx context.Context) { + _m.Called(ctx) } // Topic provides a mock function with given fields: From 839c35c26d462a22793d50d6ede39a52602bec10 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 14 Nov 2024 17:24:48 +0200 Subject: [PATCH 08/10] rename concurrent map. add more todos for error handling --- .../rest/websockets/data_provider/blocks.go | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go index 7ec83e30fcd..43d551539bc 100644 --- a/engine/access/rest/websockets/data_provider/blocks.go +++ b/engine/access/rest/websockets/data_provider/blocks.go @@ -11,39 +11,40 @@ import ( type MockBlockProvider struct { id uuid.UUID - ch chan<- interface{} + topicChan chan<- interface{} // provider is not the one who is responsible to close this channel topic string logger zerolog.Logger - ctx context.Context stopProviderFunc context.CancelFunc streamApi state_stream.API } func NewMockBlockProvider( - ctx context.Context, ch chan<- interface{}, topic string, logger zerolog.Logger, streamApi state_stream.API, ) *MockBlockProvider { - ctx, cancel := context.WithCancel(ctx) return &MockBlockProvider{ id: uuid.New(), - ch: ch, + topicChan: ch, topic: topic, logger: logger.With().Str("component", "block-provider").Logger(), - ctx: ctx, - stopProviderFunc: cancel, + stopProviderFunc: nil, streamApi: streamApi, } } -func (p *MockBlockProvider) Run(_ context.Context) { - select { - case <-p.ctx.Done(): - return - default: - p.ch <- "hello world" +func (p *MockBlockProvider) Run(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + p.stopProviderFunc = cancel + + for { + select { + case <-ctx.Done(): + return + case p.topicChan <- "hello world": + return + } } } From 48aaa566123d4759f7ea4d83ff01af4b2d3aece3 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 19 Nov 2024 15:15:19 +0200 Subject: [PATCH 09/10] Fix comments * make handle_connection blocking * rename concurrent_map * use type switch instead of switch * add todos for error handling --- engine/access/rest/websockets/controller.go | 42 ++++++++++++------- .../rest/websockets/data_provider/blocks.go | 2 +- .../rest/websockets/data_provider/factory.go | 6 +-- engine/access/rest/websockets/handler_test.go | 2 +- engine/common/worker/worker_builder_test.go | 2 +- .../test/gossipsub/scoring/ihave_spam_test.go | 6 +-- .../p2p/connection/connection_gater_test.go | 8 ++-- network/p2p/node/libp2pNode_test.go | 4 +- network/test/cohort1/network_test.go | 2 +- utils/concurrentmap/concurrent_map.go | 26 ++++++------ 10 files changed, 55 insertions(+), 45 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 87ceae35b7a..fe873f5f61c 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -21,7 +21,7 @@ type Controller struct { config Config conn *websocket.Conn communicationChannel chan interface{} - dataProviders *concurrentmap.ConcurrentMap[uuid.UUID, dp.DataProvider] + dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProvidersFactory *dp.Factory } @@ -37,7 +37,7 @@ func NewWebSocketController( config: config, conn: conn, communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? - dataProviders: concurrentmap.NewConcurrentMap[uuid.UUID, dp.DataProvider](), + dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), } } @@ -47,10 +47,14 @@ func (c *Controller) HandleConnection(ctx context.Context) { //TODO: configure the connection with ping-pong and deadlines //TODO: spin up a response limit tracker routine go c.readMessagesFromClient(ctx) - go c.writeMessagesToClient(ctx) + c.writeMessagesToClient(ctx) } +// writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. +// The communication channel is filled by data providers. Besides, the response limit tracker is involved in +// write message regulation func (c *Controller) writeMessagesToClient(ctx context.Context) { + //TODO: can it run forever? maybe we should cancel the ctx in the reader routine for { select { case <-ctx.Done(): @@ -66,6 +70,8 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) { } } +// readMessagesFromClient continuously reads messages from a client WebSocket connection, +// processes each message, and handles actions based on the message type. func (c *Controller) readMessagesFromClient(ctx context.Context) { defer c.shutdownConnection() @@ -90,7 +96,7 @@ func (c *Controller) readMessagesFromClient(ctx context.Context) { return } - if err := c.handleAction(ctx, baseMsg.Action, validatedMsg); err != nil { + if err := c.handleAction(ctx, validatedMsg); err != nil { c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") } } @@ -143,30 +149,35 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.Ba return baseMsg, validatedMsg, nil } -func (c *Controller) handleAction(ctx context.Context, action string, message interface{}) error { - switch action { - case "subscribe": - c.handleSubscribe(ctx, message.(models.SubscribeMessageRequest)) - case "unsubscribe": - c.handleUnsubscribe(ctx, message.(models.UnsubscribeMessageRequest)) - case "list_subscriptions": - c.handleListSubscriptions(ctx, message.(models.ListSubscriptionsMessageRequest)) +func (c *Controller) handleAction(ctx context.Context, message interface{}) error { + switch msg := message.(type) { + case models.SubscribeMessageRequest: + c.handleSubscribe(ctx, msg) + case models.UnsubscribeMessageRequest: + c.handleUnsubscribe(ctx, msg) + case models.ListSubscriptionsMessageRequest: + c.handleListSubscriptions(ctx, msg) default: - return fmt.Errorf("unknown action type: %s", action) + return fmt.Errorf("unknown message type: %T", msg) } return nil } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { - dp := c.dataProvidersFactory.NewDataProvider(ctx, c.communicationChannel, msg.Topic) + dp := c.dataProvidersFactory.NewDataProvider(c.communicationChannel, msg.Topic) c.dataProviders.Add(dp.ID(), dp) dp.Run(ctx) + + //TODO: return OK response to client + c.communicationChannel <- msg } -func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { +func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { id, err := uuid.Parse(msg.ID) if err != nil { c.logger.Debug().Err(err).Msg("error parsing message ID") + //TODO: return an error response to client + c.communicationChannel <- err return } @@ -178,6 +189,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri } func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { + //TODO: return a response to client } func (c *Controller) shutdownConnection() { diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go index 43d551539bc..01b4d07d2e7 100644 --- a/engine/access/rest/websockets/data_provider/blocks.go +++ b/engine/access/rest/websockets/data_provider/blocks.go @@ -42,7 +42,7 @@ func (p *MockBlockProvider) Run(ctx context.Context) { select { case <-ctx.Done(): return - case p.topicChan <- "hello world": + case p.topicChan <- "block{height: 42}": return } } diff --git a/engine/access/rest/websockets/data_provider/factory.go b/engine/access/rest/websockets/data_provider/factory.go index 86d69475377..6a2658b1b95 100644 --- a/engine/access/rest/websockets/data_provider/factory.go +++ b/engine/access/rest/websockets/data_provider/factory.go @@ -1,8 +1,6 @@ package data_provider import ( - "context" - "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/state_stream" @@ -23,10 +21,10 @@ func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, s } } -func (f *Factory) NewDataProvider(ctx context.Context, ch chan<- interface{}, topic string) DataProvider { +func (f *Factory) NewDataProvider(ch chan<- interface{}, topic string) DataProvider { switch topic { case "blocks": - return NewMockBlockProvider(ctx, ch, topic, f.logger, f.streamApi) + return NewMockBlockProvider(ch, topic, f.logger, f.streamApi) default: return nil } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go index ebc83b00bdd..6b9cce06572 100644 --- a/engine/access/rest/websockets/handler_test.go +++ b/engine/access/rest/websockets/handler_test.go @@ -81,6 +81,6 @@ func (s *WsHandlerSuite) TestSubscribeRequest() { require.NoError(s.T(), err) actualMsg := strings.Trim(string(msg), "\n\"\\ ") - require.Equal(s.T(), "hello world", actualMsg) + require.Equal(s.T(), "block{height: 42}", actualMsg) }) } diff --git a/engine/common/worker/worker_builder_test.go b/engine/common/worker/worker_builder_test.go index 160f23844f5..09aebe1cc41 100644 --- a/engine/common/worker/worker_builder_test.go +++ b/engine/common/worker/worker_builder_test.go @@ -116,7 +116,7 @@ func TestWorkerPool_TwoWorkers_ConcurrentEvents(t *testing.T) { } q := queue.NewHeroStore(uint32(size), unittest.Logger(), metrics.NewNoopCollector()) - distributedEvents := concurrentmap.NewConcurrentMap[string, struct{}]() + distributedEvents := concurrentmap.New[string, struct{}]() allEventsDistributed := sync.WaitGroup{} allEventsDistributed.Add(size) diff --git a/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go b/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go index 5f2ff0f0e6d..8debc74e7d7 100644 --- a/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go +++ b/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go @@ -37,7 +37,7 @@ func TestGossipSubIHaveBrokenPromises_Below_Threshold(t *testing.T) { sporkId := unittest.IdentifierFixture() blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId) - receivedIWants := concurrentmap.NewConcurrentMap[string, struct{}]() + receivedIWants := concurrentmap.New[string, struct{}]() idProvider := unittest.NewUpdatableIDProvider(flow.IdentityList{}) spammer := corruptlibp2p.NewGossipSubRouterSpammerWithRpcInspector(t, sporkId, role, idProvider, func(id peer.ID, rpc *corrupt.RPC) error { // override rpc inspector of the spammer node to keep track of the iwants it has received. @@ -189,7 +189,7 @@ func TestGossipSubIHaveBrokenPromises_Above_Threshold(t *testing.T) { sporkId := unittest.IdentifierFixture() blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId) - receivedIWants := concurrentmap.NewConcurrentMap[string, struct{}]() + receivedIWants := concurrentmap.New[string, struct{}]() idProvider := unittest.NewUpdatableIDProvider(flow.IdentityList{}) spammer := corruptlibp2p.NewGossipSubRouterSpammerWithRpcInspector(t, sporkId, role, idProvider, func(id peer.ID, rpc *corrupt.RPC) error { // override rpc inspector of the spammer node to keep track of the iwants it has received. @@ -438,7 +438,7 @@ func TestGossipSubIHaveBrokenPromises_Above_Threshold(t *testing.T) { func spamIHaveBrokenPromise(t *testing.T, spammer *corruptlibp2p.GossipSubRouterSpammer, topic string, - receivedIWants *concurrentmap.ConcurrentMap[string, struct{}], + receivedIWants *concurrentmap.Map[string, struct{}], victimNode p2p.LibP2PNode) { rpcCount := 10 // we can't send more than one iHave per RPC in this test, as each iHave should have a distinct topic, and we only have one subscribed topic. diff --git a/network/p2p/connection/connection_gater_test.go b/network/p2p/connection/connection_gater_test.go index 7794caa8110..e84bfe0042f 100644 --- a/network/p2p/connection/connection_gater_test.go +++ b/network/p2p/connection/connection_gater_test.go @@ -36,7 +36,7 @@ func TestConnectionGating(t *testing.T) { sporkID := unittest.IdentifierFixture() idProvider := mockmodule.NewIdentityProvider(t) // create 2 nodes - node1Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() + node1Peers := concurrentmap.New[peer.ID, struct{}]() node1, node1Id := p2ptest.NodeFixture( t, sporkID, @@ -50,7 +50,7 @@ func TestConnectionGating(t *testing.T) { }))) idProvider.On("ByPeerID", node1.ID()).Return(&node1Id, true).Maybe() - node2Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() + node2Peers := concurrentmap.New[peer.ID, struct{}]() node2, node2Id := p2ptest.NodeFixture( t, sporkID, @@ -247,7 +247,7 @@ func TestConnectionGater_InterceptUpgrade(t *testing.T) { inbounds := make([]chan string, 0, count) identities := make(flow.IdentityList, 0, count) - disallowedPeerIds := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() + disallowedPeerIds := concurrentmap.New[peer.ID, struct{}]() allPeerIds := make(peer.IDSlice, 0, count) idProvider := mockmodule.NewIdentityProvider(t) connectionGater := mockp2p.NewConnectionGater(t) @@ -332,7 +332,7 @@ func TestConnectionGater_Disallow_Integration(t *testing.T) { ids := flow.IdentityList{} inbounds := make([]chan string, 0, 5) - disallowedList := concurrentmap.NewConcurrentMap[*flow.Identity, struct{}]() + disallowedList := concurrentmap.New[*flow.Identity, struct{}]() for i := 0; i < count; i++ { handler, inbound := p2ptest.StreamHandlerFixture(t) diff --git a/network/p2p/node/libp2pNode_test.go b/network/p2p/node/libp2pNode_test.go index b0c08560e43..d53fabb0e17 100644 --- a/network/p2p/node/libp2pNode_test.go +++ b/network/p2p/node/libp2pNode_test.go @@ -159,7 +159,7 @@ func TestConnGater(t *testing.T) { sporkID := unittest.IdentifierFixture() idProvider := mockmodule.NewIdentityProvider(t) - node1Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() + node1Peers := concurrentmap.New[peer.ID, struct{}]() node1, identity1 := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithConnectionGater(p2ptest.NewConnectionGater(idProvider, func(pid peer.ID) error { if !node1Peers.Has(pid) { return fmt.Errorf("peer id not found: %s", p2plogging.PeerId(pid)) @@ -174,7 +174,7 @@ func TestConnGater(t *testing.T) { node1Info, err := utils.PeerAddressInfo(identity1.IdentitySkeleton) assert.NoError(t, err) - node2Peers := concurrentmap.NewConcurrentMap[peer.ID, struct{}]() + node2Peers := concurrentmap.New[peer.ID, struct{}]() node2, identity2 := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithConnectionGater(p2ptest.NewConnectionGater(idProvider, func(pid peer.ID) error { if !node2Peers.Has(pid) { return fmt.Errorf("id not found: %s", p2plogging.PeerId(pid)) diff --git a/network/test/cohort1/network_test.go b/network/test/cohort1/network_test.go index f546dcfa54d..723df438960 100644 --- a/network/test/cohort1/network_test.go +++ b/network/test/cohort1/network_test.go @@ -618,7 +618,7 @@ func (suite *NetworkTestSuite) MultiPing(count int) { senderNodeIndex := 0 targetNodeIndex := suite.size - 1 - receivedPayloads := concurrentmap.NewConcurrentMap[string, struct{}]() // keep track of unique payloads received. + receivedPayloads := concurrentmap.New[string, struct{}]() // keep track of unique payloads received. // regex to extract the payload from the message regex := regexp.MustCompile(`^hello from: \d`) diff --git a/utils/concurrentmap/concurrent_map.go b/utils/concurrentmap/concurrent_map.go index fb946733a24..148c3741428 100644 --- a/utils/concurrentmap/concurrent_map.go +++ b/utils/concurrentmap/concurrent_map.go @@ -2,35 +2,35 @@ package concurrentmap import "sync" -// ConcurrentMap is a thread-safe map. -type ConcurrentMap[K comparable, V any] struct { +// Map is a thread-safe map. +type Map[K comparable, V any] struct { mu sync.RWMutex m map[K]V } -// NewConcurrentMap returns a new ConcurrentMap with the given types -func NewConcurrentMap[K comparable, V any]() *ConcurrentMap[K, V] { - return &ConcurrentMap[K, V]{ +// New returns a new Map with the given types +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ m: make(map[K]V), } } // Add adds a key-value pair to the map -func (p *ConcurrentMap[K, V]) Add(key K, value V) { +func (p *Map[K, V]) Add(key K, value V) { p.mu.Lock() defer p.mu.Unlock() p.m[key] = value } // Remove removes a key-value pair from the map -func (p *ConcurrentMap[K, V]) Remove(key K) { +func (p *Map[K, V]) Remove(key K) { p.mu.Lock() defer p.mu.Unlock() delete(p.m, key) } // Has returns true if the map contains the given key -func (p *ConcurrentMap[K, V]) Has(key K) bool { +func (p *Map[K, V]) Has(key K) bool { p.mu.RLock() defer p.mu.RUnlock() _, ok := p.m[key] @@ -38,7 +38,7 @@ func (p *ConcurrentMap[K, V]) Has(key K) bool { } // Get returns the value for the given key and a boolean indicating if the key was found -func (p *ConcurrentMap[K, V]) Get(key K) (V, bool) { +func (p *Map[K, V]) Get(key K) (V, bool) { p.mu.RLock() defer p.mu.RUnlock() value, ok := p.m[key] @@ -47,7 +47,7 @@ func (p *ConcurrentMap[K, V]) Get(key K) (V, bool) { // ForEach iterates over the map and calls the given function for each key-value pair. // If the function returns an error, the iteration is stopped and the error is returned. -func (p *ConcurrentMap[K, V]) ForEach(fn func(k K, v V) error) error { +func (p *Map[K, V]) ForEach(fn func(k K, v V) error) error { p.mu.RLock() defer p.mu.RUnlock() for k, v := range p.m { @@ -59,14 +59,14 @@ func (p *ConcurrentMap[K, V]) ForEach(fn func(k K, v V) error) error { } // Size returns the size of the map. -func (p *ConcurrentMap[K, V]) Size() int { +func (p *Map[K, V]) Size() int { p.mu.RLock() defer p.mu.RUnlock() return len(p.m) } -func (p *ConcurrentMap[K, V]) Clear() { +func (p *Map[K, V]) Clear() { p.mu.Lock() defer p.mu.Unlock() - p.m = make(map[K]V) + clear(p.m) } From c38f6cef3acf3da2e98e9fc5c155088cf9de2128 Mon Sep 17 00:00:00 2001 From: Peter Argue <89119817+peterargue@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:35:38 -0800 Subject: [PATCH 10/10] Update engine/access/rest/websockets/handler.go --- engine/access/rest/websockets/handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index 911c8fc55b4..247890c2a62 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -48,7 +48,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { err := h.HttpHandler.VerifyRequest(w, r) if err != nil { // VerifyRequest sets the response error before returning - logger.Warn().Err(err).Msg("error validating websocket request") + logger.Debug().Err(err).Msg("error validating websocket request") return }