diff --git a/Makefile b/Makefile index 2578fffe4b6..36075edd094 100644 --- a/Makefile +++ b/Makefile @@ -203,7 +203,8 @@ generate-mocks: install-mock-generators mockery --name 'API' --dir="./engine/protocol" --case=underscore --output="./engine/protocol/mock" --outpkg="mock" mockery --name '.*' --dir="./engine/access/state_stream" --case=underscore --output="./engine/access/state_stream/mock" --outpkg="mock" mockery --name 'BlockTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" - mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_provider" --case=underscore --output="./engine/access/rest/websockets/data_provider/mock" --outpkg="mock" + mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_providers" --case=underscore --output="./engine/access/rest/websockets/data_providers/mock" --outpkg="mock" + mockery --name 'DataProviderFactory' --dir="./engine/access/rest/websockets/data_providers" --case=underscore --output="./engine/access/rest/websockets/data_providers/mock" --outpkg="mock" mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" mockery --name 'Communicator' --dir="./engine/access/rpc/backend" --case=underscore --output="./engine/access/rpc/backend/mock" --outpkg="mock" diff --git a/access/handler.go b/access/handler.go index 25316e7f3dd..b974e7034fc 100644 --- a/access/handler.go +++ b/access/handler.go @@ -1066,7 +1066,7 @@ func (h *Handler) SubscribeBlocksFromStartBlockID(request *access.SubscribeBlock } sub := h.api.SubscribeBlocksFromStartBlockID(stream.Context(), startBlockID, blockStatus) - return subscription.HandleSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) + return subscription.HandleRPCSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) } // SubscribeBlocksFromStartHeight handles subscription requests for blocks started from block height. @@ -1093,7 +1093,7 @@ func (h *Handler) SubscribeBlocksFromStartHeight(request *access.SubscribeBlocks } sub := h.api.SubscribeBlocksFromStartHeight(stream.Context(), request.GetStartBlockHeight(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) + return subscription.HandleRPCSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) } // SubscribeBlocksFromLatest handles subscription requests for blocks started from latest sealed block. @@ -1120,7 +1120,7 @@ func (h *Handler) SubscribeBlocksFromLatest(request *access.SubscribeBlocksFromL } sub := h.api.SubscribeBlocksFromLatest(stream.Context(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) + return subscription.HandleRPCSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) } // handleBlocksResponse handles the subscription to block updates and sends @@ -1179,7 +1179,7 @@ func (h *Handler) SubscribeBlockHeadersFromStartBlockID(request *access.Subscrib } sub := h.api.SubscribeBlockHeadersFromStartBlockID(stream.Context(), startBlockID, blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) } // SubscribeBlockHeadersFromStartHeight handles subscription requests for block headers started from block height. @@ -1206,7 +1206,7 @@ func (h *Handler) SubscribeBlockHeadersFromStartHeight(request *access.Subscribe } sub := h.api.SubscribeBlockHeadersFromStartHeight(stream.Context(), request.GetStartBlockHeight(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) } // SubscribeBlockHeadersFromLatest handles subscription requests for block headers started from latest sealed block. @@ -1233,7 +1233,7 @@ func (h *Handler) SubscribeBlockHeadersFromLatest(request *access.SubscribeBlock } sub := h.api.SubscribeBlockHeadersFromLatest(stream.Context(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) } // handleBlockHeadersResponse handles the subscription to block updates and sends @@ -1293,7 +1293,7 @@ func (h *Handler) SubscribeBlockDigestsFromStartBlockID(request *access.Subscrib } sub := h.api.SubscribeBlockDigestsFromStartBlockID(stream.Context(), startBlockID, blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) } // SubscribeBlockDigestsFromStartHeight handles subscription requests for lightweight blocks started from block height. @@ -1320,7 +1320,7 @@ func (h *Handler) SubscribeBlockDigestsFromStartHeight(request *access.Subscribe } sub := h.api.SubscribeBlockDigestsFromStartHeight(stream.Context(), request.GetStartBlockHeight(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) } // SubscribeBlockDigestsFromLatest handles subscription requests for lightweight block started from latest sealed block. @@ -1347,7 +1347,7 @@ func (h *Handler) SubscribeBlockDigestsFromLatest(request *access.SubscribeBlock } sub := h.api.SubscribeBlockDigestsFromLatest(stream.Context(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) } // handleBlockDigestsResponse handles the subscription to block updates and sends @@ -1433,7 +1433,7 @@ func (h *Handler) SendAndSubscribeTransactionStatuses( sub := h.api.SubscribeTransactionStatuses(ctx, &tx, request.GetEventEncodingVersion()) messageIndex := counters.NewMonotonousCounter(0) - return subscription.HandleSubscription(sub, func(txResults []*TransactionResult) error { + return subscription.HandleRPCSubscription(sub, func(txResults []*TransactionResult) error { for i := range txResults { index := messageIndex.Value() if ok := messageIndex.Set(index + 1); !ok { diff --git a/engine/access/rest/common/parser/block_status.go b/engine/access/rest/common/parser/block_status.go new file mode 100644 index 00000000000..efb34519894 --- /dev/null +++ b/engine/access/rest/common/parser/block_status.go @@ -0,0 +1,24 @@ +package parser + +import ( + "fmt" + + "github.com/onflow/flow-go/model/flow" +) + +// Finalized and Sealed represents the status of a block. +// It is used in rest arguments to provide block status. +const ( + Finalized = "finalized" + Sealed = "sealed" +) + +func ParseBlockStatus(blockStatus string) (flow.BlockStatus, error) { + switch blockStatus { + case Finalized: + return flow.BlockStatusFinalized, nil + case Sealed: + return flow.BlockStatusSealed, nil + } + return flow.BlockStatusUnknown, fmt.Errorf("invalid 'block_status', must be '%s' or '%s'", Finalized, Sealed) +} diff --git a/engine/access/rest/common/parser/block_status_test.go b/engine/access/rest/common/parser/block_status_test.go new file mode 100644 index 00000000000..0bbaa30c56b --- /dev/null +++ b/engine/access/rest/common/parser/block_status_test.go @@ -0,0 +1,39 @@ +package parser + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/onflow/flow-go/model/flow" +) + +// TestParseBlockStatus_Invalid tests the ParseBlockStatus function with invalid inputs. +// It verifies that for each invalid block status string, the function returns an error +// matching the expected error message format. +func TestParseBlockStatus_Invalid(t *testing.T) { + tests := []string{"unknown", "pending", ""} + expectedErr := fmt.Sprintf("invalid 'block_status', must be '%s' or '%s'", Finalized, Sealed) + + for _, input := range tests { + _, err := ParseBlockStatus(input) + assert.EqualError(t, err, expectedErr) + } +} + +// TestParseBlockStatus_Valid tests the ParseBlockStatus function with valid inputs. +// It ensures that the function returns the correct flow.BlockStatus for valid status +// strings "finalized" and "sealed" without errors. +func TestParseBlockStatus_Valid(t *testing.T) { + tests := map[string]flow.BlockStatus{ + Finalized: flow.BlockStatusFinalized, + Sealed: flow.BlockStatusSealed, + } + + for input, expectedStatus := range tests { + status, err := ParseBlockStatus(input) + assert.NoError(t, err) + assert.Equal(t, expectedStatus, status) + } +} diff --git a/engine/access/rest/http/request/id.go b/engine/access/rest/common/parser/id.go similarity index 98% rename from engine/access/rest/http/request/id.go rename to engine/access/rest/common/parser/id.go index ba3c1200527..7b1436b4761 100644 --- a/engine/access/rest/http/request/id.go +++ b/engine/access/rest/common/parser/id.go @@ -1,4 +1,4 @@ -package request +package parser import ( "errors" diff --git a/engine/access/rest/http/request/id_test.go b/engine/access/rest/common/parser/id_test.go similarity index 98% rename from engine/access/rest/http/request/id_test.go rename to engine/access/rest/common/parser/id_test.go index 1096fdbe696..a663c915e7a 100644 --- a/engine/access/rest/http/request/id_test.go +++ b/engine/access/rest/common/parser/id_test.go @@ -1,4 +1,4 @@ -package request +package parser import ( "testing" diff --git a/engine/access/rest/http/request/get_block.go b/engine/access/rest/http/request/get_block.go index fd74b0e4be0..972cd2ee97b 100644 --- a/engine/access/rest/http/request/get_block.go +++ b/engine/access/rest/http/request/get_block.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/model/flow" ) @@ -122,7 +123,7 @@ func (g *GetBlockByIDs) Build(r *common.Request) error { } func (g *GetBlockByIDs) Parse(rawIds []string) error { - var ids IDs + var ids parser.IDs err := ids.Parse(rawIds) if err != nil { return err diff --git a/engine/access/rest/http/request/get_events.go b/engine/access/rest/http/request/get_events.go index 39f2ba9faef..c864cf24a47 100644 --- a/engine/access/rest/http/request/get_events.go +++ b/engine/access/rest/http/request/get_events.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/model/flow" ) @@ -50,7 +51,7 @@ func (g *GetEvents) Parse(rawType string, rawStart string, rawEnd string, rawBlo } g.EndHeight = height.Flow() - var blockIDs IDs + var blockIDs parser.IDs err = blockIDs.Parse(rawBlockIDs) if err != nil { return err diff --git a/engine/access/rest/http/request/get_execution_result.go b/engine/access/rest/http/request/get_execution_result.go index cdf216766c1..4947cd8f07f 100644 --- a/engine/access/rest/http/request/get_execution_result.go +++ b/engine/access/rest/http/request/get_execution_result.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/model/flow" ) @@ -30,7 +31,7 @@ func (g *GetExecutionResultByBlockIDs) Build(r *common.Request) error { } func (g *GetExecutionResultByBlockIDs) Parse(rawIDs []string) error { - var ids IDs + var ids parser.IDs err := ids.Parse(rawIDs) if err != nil { return err diff --git a/engine/access/rest/http/request/get_script.go b/engine/access/rest/http/request/get_script.go index de8da72cac1..a01a025465a 100644 --- a/engine/access/rest/http/request/get_script.go +++ b/engine/access/rest/http/request/get_script.go @@ -5,6 +5,7 @@ import ( "io" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/model/flow" ) @@ -42,7 +43,7 @@ func (g *GetScript) Parse(rawHeight string, rawID string, rawScript io.Reader) e } g.BlockHeight = height.Flow() - var id ID + var id parser.ID err = id.Parse(rawID) if err != nil { return err diff --git a/engine/access/rest/http/request/get_transaction.go b/engine/access/rest/http/request/get_transaction.go index 359570cd71d..0d5df1e541e 100644 --- a/engine/access/rest/http/request/get_transaction.go +++ b/engine/access/rest/http/request/get_transaction.go @@ -2,6 +2,7 @@ package request import ( "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/model/flow" ) @@ -15,14 +16,14 @@ type TransactionOptionals struct { } func (t *TransactionOptionals) Parse(r *common.Request) error { - var blockId ID + var blockId parser.ID err := blockId.Parse(r.GetQueryParam(blockIDQueryParam)) if err != nil { return err } t.BlockID = blockId.Flow() - var collectionId ID + var collectionId parser.ID err = collectionId.Parse(r.GetQueryParam(collectionIDQueryParam)) if err != nil { return err diff --git a/engine/access/rest/http/request/helpers.go b/engine/access/rest/http/request/helpers.go index 5591cc6df9b..38a669d0ad1 100644 --- a/engine/access/rest/http/request/helpers.go +++ b/engine/access/rest/http/request/helpers.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/model/flow" ) @@ -60,7 +61,7 @@ func (g *GetByIDRequest) Build(r *common.Request) error { } func (g *GetByIDRequest) Parse(rawID string) error { - var id ID + var id parser.ID err := id.Parse(rawID) if err != nil { return err diff --git a/engine/access/rest/http/request/transaction.go b/engine/access/rest/http/request/transaction.go index 614d78f1e07..68bad0009f2 100644 --- a/engine/access/rest/http/request/transaction.go +++ b/engine/access/rest/http/request/transaction.go @@ -4,6 +4,7 @@ import ( "fmt" "io" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/http/models" "github.com/onflow/flow-go/engine/access/rest/util" "github.com/onflow/flow-go/engine/common/rpc/convert" @@ -89,7 +90,7 @@ func (t *Transaction) Parse(raw io.Reader, chain flow.Chain) error { return fmt.Errorf("invalid transaction script encoding") } - var blockID ID + var blockID parser.ID err = blockID.Parse(tx.ReferenceBlockId) if err != nil { return fmt.Errorf("invalid reference block ID: %w", err) diff --git a/engine/access/rest/http/routes/events.go b/engine/access/rest/http/routes/events.go index 038a4a98aeb..fed682555d0 100644 --- a/engine/access/rest/http/routes/events.go +++ b/engine/access/rest/http/routes/events.go @@ -7,7 +7,6 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/common" - "github.com/onflow/flow-go/engine/access/rest/http/models" "github.com/onflow/flow-go/engine/access/rest/http/request" ) diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index a2d81cb0a58..93879da6aaa 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -14,6 +14,7 @@ import ( flowhttp "github.com/onflow/flow-go/engine/access/rest/http" "github.com/onflow/flow-go/engine/access/rest/http/models" "github.com/onflow/flow-go/engine/access/rest/websockets" + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" legacyws "github.com/onflow/flow-go/engine/access/rest/websockets/legacy" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -89,11 +90,10 @@ func (b *RouterBuilder) AddLegacyWebsocketsRoutes( func (b *RouterBuilder) AddWebsocketsRoute( chain flow.Chain, config websockets.Config, - streamApi state_stream.API, - streamConfig backend.Config, maxRequestSize int64, + dataProviderFactory dp.DataProviderFactory, ) *RouterBuilder { - handler := websockets.NewWebSocketHandler(b.logger, config, chain, streamApi, streamConfig, maxRequestSize) + handler := websockets.NewWebSocketHandler(b.logger, config, chain, maxRequestSize, dataProviderFactory) b.v1SubRouter. Methods(http.MethodGet). Path("/ws"). diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 0e582d0bee4..4f0e2260ae5 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -10,6 +10,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/router" "github.com/onflow/flow-go/engine/access/rest/websockets" + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" @@ -50,7 +51,8 @@ func NewServer(serverAPI access.API, builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } - builder.AddWebsocketsRoute(chain, wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize) + dataProviderFactory := dp.NewDataProviderFactory(logger, stateStreamApi, serverAPI) + builder.AddWebsocketsRoute(chain, wsConfig, config.MaxRequestSize, dataProviderFactory) c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fe873f5f61c..38bc7306b55 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -9,10 +9,8 @@ import ( "github.com/gorilla/websocket" "github.com/rs/zerolog" - dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/utils/concurrentmap" ) @@ -22,15 +20,14 @@ type Controller struct { conn *websocket.Conn communicationChannel chan interface{} dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] - dataProvidersFactory *dp.Factory + dataProviderFactory dp.DataProviderFactory } func NewWebSocketController( logger zerolog.Logger, config Config, - streamApi state_stream.API, - streamConfig backend.Config, conn *websocket.Conn, + dataProviderFactory dp.DataProviderFactory, ) *Controller { return &Controller{ logger: logger.With().Str("component", "websocket-controller").Logger(), @@ -38,7 +35,7 @@ func NewWebSocketController( conn: conn, communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), - dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), + dataProviderFactory: dataProviderFactory, } } @@ -164,12 +161,24 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { - dp := c.dataProvidersFactory.NewDataProvider(c.communicationChannel, msg.Topic) + dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.communicationChannel) + if err != nil { + // TODO: handle error here + c.logger.Error().Err(err).Msgf("error while creating data provider for topic: %s", msg.Topic) + } + c.dataProviders.Add(dp.ID(), dp) - dp.Run(ctx) //TODO: return OK response to client c.communicationChannel <- msg + + go func() { + err := dp.Run() + if err != nil { + //TODO: Log or handle the error from Run + c.logger.Error().Err(err).Msgf("error while running data provider for topic: %s", msg.Topic) + } + }() } func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go deleted file mode 100644 index 01b4d07d2e7..00000000000 --- a/engine/access/rest/websockets/data_provider/blocks.go +++ /dev/null @@ -1,61 +0,0 @@ -package data_provider - -import ( - "context" - - "github.com/google/uuid" - "github.com/rs/zerolog" - - "github.com/onflow/flow-go/engine/access/state_stream" -) - -type MockBlockProvider struct { - id uuid.UUID - topicChan chan<- interface{} // provider is not the one who is responsible to close this channel - topic string - logger zerolog.Logger - stopProviderFunc context.CancelFunc - streamApi state_stream.API -} - -func NewMockBlockProvider( - ch chan<- interface{}, - topic string, - logger zerolog.Logger, - streamApi state_stream.API, -) *MockBlockProvider { - return &MockBlockProvider{ - id: uuid.New(), - topicChan: ch, - topic: topic, - logger: logger.With().Str("component", "block-provider").Logger(), - stopProviderFunc: nil, - streamApi: streamApi, - } -} - -func (p *MockBlockProvider) Run(ctx context.Context) { - ctx, cancel := context.WithCancel(ctx) - p.stopProviderFunc = cancel - - for { - select { - case <-ctx.Done(): - return - case p.topicChan <- "block{height: 42}": - return - } - } -} - -func (p *MockBlockProvider) ID() uuid.UUID { - return p.id -} - -func (p *MockBlockProvider) Topic() string { - return p.topic -} - -func (p *MockBlockProvider) Close() { - p.stopProviderFunc() -} diff --git a/engine/access/rest/websockets/data_provider/factory.go b/engine/access/rest/websockets/data_provider/factory.go deleted file mode 100644 index 6a2658b1b95..00000000000 --- a/engine/access/rest/websockets/data_provider/factory.go +++ /dev/null @@ -1,31 +0,0 @@ -package data_provider - -import ( - "github.com/rs/zerolog" - - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" -) - -type Factory struct { - logger zerolog.Logger - streamApi state_stream.API - streamConfig backend.Config -} - -func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, streamConfig backend.Config) *Factory { - return &Factory{ - logger: logger, - streamApi: streamApi, - streamConfig: streamConfig, - } -} - -func (f *Factory) NewDataProvider(ch chan<- interface{}, topic string) DataProvider { - switch topic { - case "blocks": - return NewMockBlockProvider(ch, topic, f.logger, f.streamApi) - default: - return nil - } -} diff --git a/engine/access/rest/websockets/data_provider/provider.go b/engine/access/rest/websockets/data_provider/provider.go deleted file mode 100644 index ce2914140ba..00000000000 --- a/engine/access/rest/websockets/data_provider/provider.go +++ /dev/null @@ -1,14 +0,0 @@ -package data_provider - -import ( - "context" - - "github.com/google/uuid" -) - -type DataProvider interface { - Run(ctx context.Context) - ID() uuid.UUID - Topic() string - Close() -} diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go new file mode 100644 index 00000000000..cf1ee1313d9 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -0,0 +1,52 @@ +package data_providers + +import ( + "context" + + "github.com/google/uuid" + + "github.com/onflow/flow-go/engine/access/subscription" +) + +// baseDataProvider holds common objects for the provider +type baseDataProvider struct { + id uuid.UUID + topic string + cancel context.CancelFunc + send chan<- interface{} + subscription subscription.Subscription +} + +// newBaseDataProvider creates a new instance of baseDataProvider. +func newBaseDataProvider( + topic string, + cancel context.CancelFunc, + send chan<- interface{}, + subscription subscription.Subscription, +) *baseDataProvider { + return &baseDataProvider{ + id: uuid.New(), + topic: topic, + cancel: cancel, + send: send, + subscription: subscription, + } +} + +// ID returns the unique identifier of the data provider. +func (b *baseDataProvider) ID() uuid.UUID { + return b.id +} + +// Topic returns the topic associated with the data provider. +func (b *baseDataProvider) Topic() string { + return b.topic +} + +// Close terminates the data provider. +// +// No errors are expected during normal operations. +func (b *baseDataProvider) Close() error { + b.cancel() + return nil +} diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider.go b/engine/access/rest/websockets/data_providers/block_digests_provider.go new file mode 100644 index 00000000000..1fa3f7a6dc7 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/block_digests_provider.go @@ -0,0 +1,82 @@ +package data_providers + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" +) + +// BlockDigestsDataProvider is responsible for providing block digests +type BlockDigestsDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + api access.API +} + +var _ DataProvider = (*BlockDigestsDataProvider)(nil) + +// NewBlockDigestsDataProvider creates a new instance of BlockDigestsDataProvider. +func NewBlockDigestsDataProvider( + ctx context.Context, + logger zerolog.Logger, + api access.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, +) (*BlockDigestsDataProvider, error) { + p := &BlockDigestsDataProvider{ + logger: logger.With().Str("component", "block-digests-data-provider").Logger(), + api: api, + } + + // Parse arguments passed to the provider. + blockArgs, err := ParseBlocksArguments(arguments) + if err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, blockArgs), // Set up a subscription to block digests based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for block digests and handles responses. +// +// No errors are expected during normal operations. +func (p *BlockDigestsDataProvider) Run() error { + return subscription.HandleSubscription( + p.subscription, + subscription.HandleResponse(p.send, func(block *flow.BlockDigest) (interface{}, error) { + return &models.BlockDigestMessageResponse{ + Block: block, + }, nil + }), + ) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *BlockDigestsDataProvider) createSubscription(ctx context.Context, args BlocksArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.api.SubscribeBlockDigestsFromStartBlockID(ctx, args.StartBlockID, args.BlockStatus) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.api.SubscribeBlockDigestsFromStartHeight(ctx, args.StartBlockHeight, args.BlockStatus) + } + + return p.api.SubscribeBlockDigestsFromLatest(ctx, args.BlockStatus) +} diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go new file mode 100644 index 00000000000..476edf77111 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go @@ -0,0 +1,129 @@ +package data_providers + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + statestreamsmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/model/flow" +) + +type BlockDigestsProviderSuite struct { + BlocksProviderSuite +} + +func TestBlockDigestsProviderSuite(t *testing.T) { + suite.Run(t, new(BlockDigestsProviderSuite)) +} + +// SetupTest initializes the test suite with required dependencies. +func (s *BlockDigestsProviderSuite) SetupTest() { + s.BlocksProviderSuite.SetupTest() +} + +// TestBlockDigestsDataProvider_InvalidArguments tests the behavior of the block digests data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Missing 'block_status' argument. +// 2. Invalid 'block_status' argument. +// 3. Providing both 'start_block_id' and 'start_block_height' simultaneously. +func (s *BlockDigestsProviderSuite) TestBlockDigestsDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := BlockDigestsTopic + + for _, test := range s.invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, topic, test.arguments, send) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// validBlockDigestsArgumentsTestCases defines test happy cases for block digests data providers. +// Each test case specifies input arguments, and setup functions for the mock API used in the test. +func (s *BlockDigestsProviderSuite) validBlockDigestsArgumentsTestCases() []testType { + return []testType{ + { + name: "happy path with start_block_id argument", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlockDigestsFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + { + name: "happy path with start_block_height argument", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlockDigestsFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + { + name: "happy path without any start argument", + arguments: models.Arguments{ + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlockDigestsFromLatest", + mock.Anything, + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + } +} + +// TestBlockDigestsDataProvider_HappyPath tests the behavior of the block digests data provider +// when it is configured correctly and operating under normal conditions. It +// validates that block digests are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *BlockDigestsProviderSuite) TestBlockDigestsDataProvider_HappyPath() { + s.testHappyPath( + BlockDigestsTopic, + s.validBlockDigestsArgumentsTestCases(), + func(dataChan chan interface{}, blocks []*flow.Block) { + for _, block := range blocks { + dataChan <- flow.NewBlockDigest(block.Header.ID(), block.Header.Height, block.Header.Timestamp) + } + }, + s.requireBlockDigests, + ) +} + +// requireBlockHeaders ensures that the received block header information matches the expected data. +func (s *BlocksProviderSuite) requireBlockDigests(v interface{}, expectedBlock *flow.Block) { + actualResponse, ok := v.(*models.BlockDigestMessageResponse) + require.True(s.T(), ok, "unexpected response type: %T", v) + + s.Require().Equal(expectedBlock.Header.ID(), actualResponse.Block.ID()) + s.Require().Equal(expectedBlock.Header.Height, actualResponse.Block.Height) + s.Require().Equal(expectedBlock.Header.Timestamp, actualResponse.Block.Timestamp) +} diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider.go b/engine/access/rest/websockets/data_providers/block_headers_provider.go new file mode 100644 index 00000000000..4f9e29e2428 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/block_headers_provider.go @@ -0,0 +1,82 @@ +package data_providers + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" +) + +// BlockHeadersDataProvider is responsible for providing block headers +type BlockHeadersDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + api access.API +} + +var _ DataProvider = (*BlockHeadersDataProvider)(nil) + +// NewBlockHeadersDataProvider creates a new instance of BlockHeadersDataProvider. +func NewBlockHeadersDataProvider( + ctx context.Context, + logger zerolog.Logger, + api access.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, +) (*BlockHeadersDataProvider, error) { + p := &BlockHeadersDataProvider{ + logger: logger.With().Str("component", "block-headers-data-provider").Logger(), + api: api, + } + + // Parse arguments passed to the provider. + blockArgs, err := ParseBlocksArguments(arguments) + if err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, blockArgs), // Set up a subscription to block headers based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for block headers and handles responses. +// +// No errors are expected during normal operations. +func (p *BlockHeadersDataProvider) Run() error { + return subscription.HandleSubscription( + p.subscription, + subscription.HandleResponse(p.send, func(header *flow.Header) (interface{}, error) { + return &models.BlockHeaderMessageResponse{ + Header: header, + }, nil + }), + ) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *BlockHeadersDataProvider) createSubscription(ctx context.Context, args BlocksArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.api.SubscribeBlockHeadersFromStartBlockID(ctx, args.StartBlockID, args.BlockStatus) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.api.SubscribeBlockHeadersFromStartHeight(ctx, args.StartBlockHeight, args.BlockStatus) + } + + return p.api.SubscribeBlockHeadersFromLatest(ctx, args.BlockStatus) +} diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go new file mode 100644 index 00000000000..57c262d8795 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go @@ -0,0 +1,127 @@ +package data_providers + +import ( + "context" + "strconv" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + statestreamsmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/model/flow" +) + +type BlockHeadersProviderSuite struct { + BlocksProviderSuite +} + +func TestBlockHeadersProviderSuite(t *testing.T) { + suite.Run(t, new(BlockHeadersProviderSuite)) +} + +// SetupTest initializes the test suite with required dependencies. +func (s *BlockHeadersProviderSuite) SetupTest() { + s.BlocksProviderSuite.SetupTest() +} + +// TestBlockHeadersDataProvider_InvalidArguments tests the behavior of the block headers data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Missing 'block_status' argument. +// 2. Invalid 'block_status' argument. +// 3. Providing both 'start_block_id' and 'start_block_height' simultaneously. +func (s *BlockHeadersProviderSuite) TestBlockHeadersDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := BlockHeadersTopic + + for _, test := range s.invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, topic, test.arguments, send) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// validBlockHeadersArgumentsTestCases defines test happy cases for block headers data providers. +// Each test case specifies input arguments, and setup functions for the mock API used in the test. +func (s *BlockHeadersProviderSuite) validBlockHeadersArgumentsTestCases() []testType { + return []testType{ + { + name: "happy path with start_block_id argument", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlockHeadersFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + { + name: "happy path with start_block_height argument", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlockHeadersFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + { + name: "happy path without any start argument", + arguments: models.Arguments{ + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlockHeadersFromLatest", + mock.Anything, + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + } +} + +// TestBlockHeadersDataProvider_HappyPath tests the behavior of the block headers data provider +// when it is configured correctly and operating under normal conditions. It +// validates that block headers are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *BlockHeadersProviderSuite) TestBlockHeadersDataProvider_HappyPath() { + s.testHappyPath( + BlockHeadersTopic, + s.validBlockHeadersArgumentsTestCases(), + func(dataChan chan interface{}, blocks []*flow.Block) { + for _, block := range blocks { + dataChan <- block.Header + } + }, + s.requireBlockHeaders, + ) +} + +// requireBlockHeaders ensures that the received block header information matches the expected data. +func (s *BlockHeadersProviderSuite) requireBlockHeaders(v interface{}, expectedBlock *flow.Block) { + actualResponse, ok := v.(*models.BlockHeaderMessageResponse) + require.True(s.T(), ok, "unexpected response type: %T", v) + + s.Require().Equal(expectedBlock.Header, actualResponse.Header) +} diff --git a/engine/access/rest/websockets/data_providers/blocks_provider.go b/engine/access/rest/websockets/data_providers/blocks_provider.go new file mode 100644 index 00000000000..72cfaa6f554 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/blocks_provider.go @@ -0,0 +1,138 @@ +package data_providers + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/util" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" +) + +// BlocksArguments contains the arguments required for subscribing to blocks / block headers / block digests +type BlocksArguments struct { + StartBlockID flow.Identifier // ID of the block to start subscription from + StartBlockHeight uint64 // Height of the block to start subscription from + BlockStatus flow.BlockStatus // Status of blocks to subscribe to +} + +// BlocksDataProvider is responsible for providing blocks +type BlocksDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + api access.API +} + +var _ DataProvider = (*BlocksDataProvider)(nil) + +// NewBlocksDataProvider creates a new instance of BlocksDataProvider. +func NewBlocksDataProvider( + ctx context.Context, + logger zerolog.Logger, + api access.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, +) (*BlocksDataProvider, error) { + p := &BlocksDataProvider{ + logger: logger.With().Str("component", "blocks-data-provider").Logger(), + api: api, + } + + // Parse arguments passed to the provider. + blockArgs, err := ParseBlocksArguments(arguments) + if err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, blockArgs), // Set up a subscription to blocks based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for blocks and handles responses. +// +// No errors are expected during normal operations. +func (p *BlocksDataProvider) Run() error { + return subscription.HandleSubscription( + p.subscription, + subscription.HandleResponse(p.send, func(block *flow.Block) (interface{}, error) { + return &models.BlockMessageResponse{ + Block: block, + }, nil + }), + ) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *BlocksDataProvider) createSubscription(ctx context.Context, args BlocksArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.api.SubscribeBlocksFromStartBlockID(ctx, args.StartBlockID, args.BlockStatus) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.api.SubscribeBlocksFromStartHeight(ctx, args.StartBlockHeight, args.BlockStatus) + } + + return p.api.SubscribeBlocksFromLatest(ctx, args.BlockStatus) +} + +// ParseBlocksArguments validates and initializes the blocks arguments. +func ParseBlocksArguments(arguments models.Arguments) (BlocksArguments, error) { + var args BlocksArguments + + // Parse 'block_status' + if blockStatusIn, ok := arguments["block_status"]; ok { + blockStatus, err := parser.ParseBlockStatus(blockStatusIn) + if err != nil { + return args, err + } + args.BlockStatus = blockStatus + } else { + return args, fmt.Errorf("'block_status' must be provided") + } + + startBlockIDIn, hasStartBlockID := arguments["start_block_id"] + startBlockHeightIn, hasStartBlockHeight := arguments["start_block_height"] + + // Ensure only one of start_block_id or start_block_height is provided + if hasStartBlockID && hasStartBlockHeight { + return args, fmt.Errorf("can only provide either 'start_block_id' or 'start_block_height'") + } + + // Parse 'start_block_id' if provided + if hasStartBlockID { + var startBlockID parser.ID + err := startBlockID.Parse(startBlockIDIn) + if err != nil { + return args, err + } + args.StartBlockID = startBlockID.Flow() + } + + // Parse 'start_block_height' if provided + if hasStartBlockHeight { + var err error + args.StartBlockHeight, err = util.ToUint64(startBlockHeightIn) + if err != nil { + return args, fmt.Errorf("invalid 'start_block_height': %w", err) + } + } else { + args.StartBlockHeight = request.EmptyHeight + } + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/blocks_provider_test.go b/engine/access/rest/websockets/data_providers/blocks_provider_test.go new file mode 100644 index 00000000000..9e07f9459e9 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/blocks_provider_test.go @@ -0,0 +1,273 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + accessmock "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + statestreamsmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +const unknownBlockStatus = "unknown_block_status" + +type testErrType struct { + name string + arguments models.Arguments + expectedErrorMsg string +} + +// testType represents a valid test scenario for subscribing +type testType struct { + name string + arguments models.Arguments + setupBackend func(sub *statestreamsmock.Subscription) +} + +// BlocksProviderSuite is a test suite for testing the block providers functionality. +type BlocksProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *accessmock.API + + blocks []*flow.Block + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestBlocksProviderSuite(t *testing.T) { + suite.Run(t, new(BlocksProviderSuite)) +} + +func (s *BlocksProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = accessmock.NewAPI(s.T()) + + blockCount := 5 + s.blocks = make([]*flow.Block, 0, blockCount) + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + parent := s.rootBlock.Header + + for i := 0; i < blockCount; i++ { + block := unittest.BlockWithParentFixture(parent) + // update for next iteration + parent = block.Header + s.blocks = append(s.blocks, block) + + } + s.finalizedBlock = parent + + s.factory = NewDataProviderFactory(s.log, nil, s.api) + s.Require().NotNil(s.factory) +} + +// invalidArgumentsTestCases returns a list of test cases with invalid argument combinations +// for testing the behavior of block, block headers, block digests data providers. Each test case includes a name, +// a set of input arguments, and the expected error message that should be returned. +// +// The test cases cover scenarios such as: +// 1. Missing the required 'block_status' argument. +// 2. Providing an unknown or invalid 'block_status' value. +// 3. Supplying both 'start_block_id' and 'start_block_height' simultaneously, which is not allowed. +func (s *BlocksProviderSuite) invalidArgumentsTestCases() []testErrType { + return []testErrType{ + { + name: "missing 'block_status' argument", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + }, + expectedErrorMsg: "'block_status' must be provided", + }, + { + name: "unknown 'block_status' argument", + arguments: models.Arguments{ + "block_status": unknownBlockStatus, + }, + expectedErrorMsg: fmt.Sprintf("invalid 'block_status', must be '%s' or '%s'", parser.Finalized, parser.Sealed), + }, + { + name: "provide both 'start_block_id' and 'start_block_height' arguments", + arguments: models.Arguments{ + "block_status": parser.Finalized, + "start_block_id": s.rootBlock.ID().String(), + "start_block_height": fmt.Sprintf("%d", s.rootBlock.Header.Height), + }, + expectedErrorMsg: "can only provide either 'start_block_id' or 'start_block_height'", + }, + } +} + +// TestBlocksDataProvider_InvalidArguments tests the behavior of the block data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Missing 'block_status' argument. +// 2. Invalid 'block_status' argument. +// 3. Providing both 'start_block_id' and 'start_block_height' simultaneously. +func (s *BlocksProviderSuite) TestBlocksDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + for _, test := range s.invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewBlocksDataProvider(ctx, s.log, s.api, BlocksTopic, test.arguments, send) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// validBlockArgumentsTestCases defines test happy cases for block data providers. +// Each test case specifies input arguments, and setup functions for the mock API used in the test. +func (s *BlocksProviderSuite) validBlockArgumentsTestCases() []testType { + return []testType{ + { + name: "happy path with start_block_id argument", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlocksFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + { + name: "happy path with start_block_height argument", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlocksFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + { + name: "happy path without any start argument", + arguments: models.Arguments{ + "block_status": parser.Finalized, + }, + setupBackend: func(sub *statestreamsmock.Subscription) { + s.api.On( + "SubscribeBlocksFromLatest", + mock.Anything, + flow.BlockStatusFinalized, + ).Return(sub).Once() + }, + }, + } +} + +// TestBlocksDataProvider_HappyPath tests the behavior of the block data provider +// when it is configured correctly and operating under normal conditions. It +// validates that blocks are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *BlocksProviderSuite) TestBlocksDataProvider_HappyPath() { + s.testHappyPath( + BlocksTopic, + s.validBlockArgumentsTestCases(), + func(dataChan chan interface{}, blocks []*flow.Block) { + for _, block := range blocks { + dataChan <- block + } + }, + s.requireBlock, + ) +} + +// requireBlocks ensures that the received block information matches the expected data. +func (s *BlocksProviderSuite) requireBlock(v interface{}, expectedBlock *flow.Block) { + actualResponse, ok := v.(*models.BlockMessageResponse) + require.True(s.T(), ok, "unexpected response type: %T", v) + + s.Require().Equal(expectedBlock, actualResponse.Block) +} + +// testHappyPath tests a variety of scenarios for data providers in +// happy path scenarios. This function runs parameterized test cases that +// simulate various configurations and verifies that the data provider operates +// as expected without encountering errors. +// +// Arguments: +// - topic: The topic associated with the data provider. +// - tests: A slice of test cases to run, each specifying setup and validation logic. +// - sendData: A function to simulate emitting data into the subscription's data channel. +// - requireFn: A function to validate the output received in the send channel. +func (s *BlocksProviderSuite) testHappyPath( + topic string, + tests []testType, + sendData func(chan interface{}, []*flow.Block), + requireFn func(interface{}, *flow.Block), +) { + for _, test := range tests { + s.Run(test.name, func() { + ctx := context.Background() + send := make(chan interface{}, 10) + + // Create a channel to simulate the subscription's data channel + dataChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := statestreamsmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(dataChan)) + sub.On("Err").Return(nil) + test.setupBackend(sub) + + // Create the data provider instance + provider, err := s.factory.NewDataProvider(ctx, topic, test.arguments, send) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting data to the data channel + go func() { + defer close(dataChan) + sendData(dataChan, s.blocks) + }() + + // Collect responses + for _, b := range s.blocks { + unittest.RequireReturnsBefore(s.T(), func() { + v, ok := <-send + s.Require().True(ok, "channel closed while waiting for block %x %v: err: %v", b.Header.Height, b.ID(), sub.Err()) + + requireFn(v, b) + }, time.Second, fmt.Sprintf("timed out waiting for block %d %v", b.Header.Height, b.ID())) + } + + // Ensure the provider is properly closed after the test + provider.Close() + }) + } +} diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go new file mode 100644 index 00000000000..08dc497808b --- /dev/null +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -0,0 +1,33 @@ +package data_providers + +import ( + "github.com/google/uuid" +) + +// The DataProvider is the interface abstracts of the actual data provider used by the WebSocketCollector. +// It provides methods for retrieving the provider's unique ID, topic, and a methods to close and run the provider. +type DataProvider interface { + // ID returns the unique identifier of the data provider. + ID() uuid.UUID + // Topic returns the topic associated with the data provider. + Topic() string + // Close terminates the data provider. + // + // No errors are expected during normal operations. + Close() error + // Run starts processing the subscription and handles responses. + // + // The separation of the data provider's creation and its Run() method + // allows for better control over the subscription lifecycle. By doing so, + // a confirmation message can be sent to the client immediately upon + // successful subscription creation or failure. This ensures any required + // setup or preparation steps can be handled prior to initiating the + // subscription and data streaming process. + // + // Run() begins the actual processing of the subscription. At this point, + // the context used for provider creation is no longer needed, as all + // necessary preparation steps should have been completed. + // + // No errors are expected during normal operations. + Run() error +} diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go new file mode 100644 index 00000000000..72f4a6b7633 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -0,0 +1,98 @@ +package data_providers + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" +) + +// Constants defining various topic names used to specify different types of +// data providers. +const ( + EventsTopic = "events" + AccountStatusesTopic = "account_statuses" + BlocksTopic = "blocks" + BlockHeadersTopic = "block_headers" + BlockDigestsTopic = "block_digests" + TransactionStatusesTopic = "transaction_statuses" +) + +// DataProviderFactory defines an interface for creating data providers +// based on specified topics. The factory abstracts the creation process +// and ensures consistent access to required APIs. +type DataProviderFactory interface { + // NewDataProvider creates a new data provider based on the specified topic + // and configuration parameters. + // + // No errors are expected during normal operations. + NewDataProvider(ctx context.Context, topic string, arguments models.Arguments, ch chan<- interface{}) (DataProvider, error) +} + +var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) + +// DataProviderFactoryImpl is an implementation of the DataProviderFactory interface. +// It is responsible for creating data providers based on the +// requested topic. It manages access to logging and relevant APIs needed to retrieve data. +type DataProviderFactoryImpl struct { + logger zerolog.Logger + + stateStreamApi state_stream.API + accessApi access.API +} + +// NewDataProviderFactory creates a new DataProviderFactory +// +// Parameters: +// - logger: Used for logging within the data providers. +// - eventFilterConfig: Configuration for filtering events from state streams. +// - stateStreamApi: API for accessing data from the Flow state stream API. +// - accessApi: API for accessing data from the Flow Access API. +func NewDataProviderFactory( + logger zerolog.Logger, + stateStreamApi state_stream.API, + accessApi access.API, +) *DataProviderFactoryImpl { + return &DataProviderFactoryImpl{ + logger: logger, + stateStreamApi: stateStreamApi, + accessApi: accessApi, + } +} + +// NewDataProvider creates a new data provider based on the specified topic +// and configuration parameters. +// +// Parameters: +// - ctx: Context for managing request lifetime and cancellation. +// - topic: The topic for which a data provider is to be created. +// - arguments: Configuration arguments for the data provider. +// - ch: Channel to which the data provider sends data. +// +// No errors are expected during normal operations. +func (s *DataProviderFactoryImpl) NewDataProvider( + ctx context.Context, + topic string, + arguments models.Arguments, + ch chan<- interface{}, +) (DataProvider, error) { + switch topic { + case BlocksTopic: + return NewBlocksDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + case BlockHeadersTopic: + return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + case BlockDigestsTopic: + return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + // TODO: Implemented handlers for each topic should be added in respective case + case EventsTopic, + AccountStatusesTopic, + TransactionStatusesTopic: + return nil, fmt.Errorf(`topic "%s" not implemented yet`, topic) + default: + return nil, fmt.Errorf("unsupported topic \"%s\"", topic) + } +} diff --git a/engine/access/rest/websockets/data_providers/factory_test.go b/engine/access/rest/websockets/data_providers/factory_test.go new file mode 100644 index 00000000000..2ed2b075d0c --- /dev/null +++ b/engine/access/rest/websockets/data_providers/factory_test.go @@ -0,0 +1,136 @@ +package data_providers + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + accessmock "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + statestreammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +// DataProviderFactorySuite is a test suite for testing the DataProviderFactory functionality. +type DataProviderFactorySuite struct { + suite.Suite + + ctx context.Context + ch chan interface{} + + accessApi *accessmock.API + stateStreamApi *statestreammock.API + + factory *DataProviderFactoryImpl +} + +func TestDataProviderFactorySuite(t *testing.T) { + suite.Run(t, new(DataProviderFactorySuite)) +} + +// SetupTest sets up the initial context and dependencies for each test case. +// It initializes the factory with mock instances and validates that it is created successfully. +func (s *DataProviderFactorySuite) SetupTest() { + log := unittest.Logger() + s.stateStreamApi = statestreammock.NewAPI(s.T()) + s.accessApi = accessmock.NewAPI(s.T()) + + s.ctx = context.Background() + s.ch = make(chan interface{}) + + s.factory = NewDataProviderFactory(log, s.stateStreamApi, s.accessApi) + s.Require().NotNil(s.factory) +} + +// setupSubscription creates a mock subscription instance for testing purposes. +// It configures the return value of the specified API call to the mock subscription. +func (s *DataProviderFactorySuite) setupSubscription(apiCall *mock.Call) { + subscription := statestreammock.NewSubscription(s.T()) + apiCall.Return(subscription).Once() +} + +// TODO: add others topic to check when they will be implemented +// TestSupportedTopics verifies that supported topics return a valid provider and no errors. +// Each test case includes a topic and arguments for which a data provider should be created. +func (s *DataProviderFactorySuite) TestSupportedTopics() { + // Define supported topics and check if each returns the correct provider without errors + testCases := []struct { + name string + topic string + arguments models.Arguments + setupSubscription func() + assertExpectations func() + }{ + { + name: "block topic", + topic: BlocksTopic, + arguments: models.Arguments{"block_status": parser.Finalized}, + setupSubscription: func() { + s.setupSubscription(s.accessApi.On("SubscribeBlocksFromLatest", mock.Anything, flow.BlockStatusFinalized)) + }, + assertExpectations: func() { + s.accessApi.AssertExpectations(s.T()) + }, + }, + { + name: "block headers topic", + topic: BlockHeadersTopic, + arguments: models.Arguments{"block_status": parser.Finalized}, + setupSubscription: func() { + s.setupSubscription(s.accessApi.On("SubscribeBlockHeadersFromLatest", mock.Anything, flow.BlockStatusFinalized)) + }, + assertExpectations: func() { + s.accessApi.AssertExpectations(s.T()) + }, + }, + { + name: "block digests topic", + topic: BlockDigestsTopic, + arguments: models.Arguments{"block_status": parser.Finalized}, + setupSubscription: func() { + s.setupSubscription(s.accessApi.On("SubscribeBlockDigestsFromLatest", mock.Anything, flow.BlockStatusFinalized)) + }, + assertExpectations: func() { + s.accessApi.AssertExpectations(s.T()) + }, + }, + } + + for _, test := range testCases { + s.Run(test.name, func() { + s.T().Parallel() + test.setupSubscription() + + provider, err := s.factory.NewDataProvider(s.ctx, test.topic, test.arguments, s.ch) + s.Require().NotNil(provider, "Expected provider for topic %s", test.topic) + s.Require().NoError(err, "Expected no error for topic %s", test.topic) + s.Require().Equal(test.topic, provider.Topic()) + + test.assertExpectations() + }) + } +} + +// TestUnsupportedTopics verifies that unsupported topics do not return a provider +// and instead return an error indicating the topic is unsupported. +func (s *DataProviderFactorySuite) TestUnsupportedTopics() { + s.T().Parallel() + + // Define unsupported topics + unsupportedTopics := []string{ + "unknown_topic", + "", + } + + for _, topic := range unsupportedTopics { + provider, err := s.factory.NewDataProvider(s.ctx, topic, nil, s.ch) + s.Require().Nil(provider, "Expected no provider for unsupported topic %s", topic) + s.Require().Error(err, "Expected error for unsupported topic %s", topic) + s.Require().EqualError(err, fmt.Sprintf("unsupported topic \"%s\"", topic)) + } +} diff --git a/engine/access/rest/websockets/data_provider/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go similarity index 72% rename from engine/access/rest/websockets/data_provider/mock/data_provider.go rename to engine/access/rest/websockets/data_providers/mock/data_provider.go index 4a2a22a44a0..3fe8bc5d15b 100644 --- a/engine/access/rest/websockets/data_provider/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -3,11 +3,8 @@ package mock import ( - context "context" - - mock "github.com/stretchr/testify/mock" - uuid "github.com/google/uuid" + mock "github.com/stretchr/testify/mock" ) // DataProvider is an autogenerated mock type for the DataProvider type @@ -16,8 +13,21 @@ type DataProvider struct { } // Close provides a mock function with given fields: -func (_m *DataProvider) Close() { - _m.Called() +func (_m *DataProvider) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 } // ID provides a mock function with given fields: @@ -40,9 +50,22 @@ func (_m *DataProvider) ID() uuid.UUID { return r0 } -// Run provides a mock function with given fields: ctx -func (_m *DataProvider) Run(ctx context.Context) { - _m.Called(ctx) +// Run provides a mock function with given fields: +func (_m *DataProvider) Run() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 } // Topic provides a mock function with given fields: diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go new file mode 100644 index 00000000000..c2e46e58d1d --- /dev/null +++ b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go @@ -0,0 +1,61 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + context "context" + + data_providers "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" + mock "github.com/stretchr/testify/mock" + + models "github.com/onflow/flow-go/engine/access/rest/websockets/models" +) + +// DataProviderFactory is an autogenerated mock type for the DataProviderFactory type +type DataProviderFactory struct { + mock.Mock +} + +// NewDataProvider provides a mock function with given fields: ctx, topic, arguments, ch +func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, arguments models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { + ret := _m.Called(ctx, topic, arguments, ch) + + if len(ret) == 0 { + panic("no return value specified for NewDataProvider") + } + + var r0 data_providers.DataProvider + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { + return rf(ctx, topic, arguments, ch) + } + if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { + r0 = rf(ctx, topic, arguments, ch) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(data_providers.DataProvider) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, models.Arguments, chan<- interface{}) error); ok { + r1 = rf(ctx, topic, arguments, ch) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewDataProviderFactory creates a new instance of DataProviderFactory. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDataProviderFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *DataProviderFactory { + mock := &DataProviderFactory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index 247890c2a62..c93548d5f9e 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -8,18 +8,16 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/common" - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" "github.com/onflow/flow-go/model/flow" ) type Handler struct { *common.HttpHandler - logger zerolog.Logger - websocketConfig Config - streamApi state_stream.API - streamConfig backend.Config + logger zerolog.Logger + websocketConfig Config + dataProviderFactory dp.DataProviderFactory } var _ http.Handler = (*Handler)(nil) @@ -28,16 +26,14 @@ func NewWebSocketHandler( logger zerolog.Logger, config Config, chain flow.Chain, - streamApi state_stream.API, - streamConfig backend.Config, maxRequestSize int64, + dataProviderFactory dp.DataProviderFactory, ) *Handler { return &Handler{ - HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), - websocketConfig: config, - logger: logger, - streamApi: streamApi, - streamConfig: streamConfig, + HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), + websocketConfig: config, + logger: logger, + dataProviderFactory: dataProviderFactory, } } @@ -65,6 +61,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) + controller := NewWebSocketController(logger, h.websocketConfig, conn, h.dataProviderFactory) controller.HandleConnection(context.TODO()) } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go deleted file mode 100644 index 6b9cce06572..00000000000 --- a/engine/access/rest/websockets/handler_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package websockets_test - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/gorilla/websocket" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - - "github.com/onflow/flow-go/engine/access/rest/websockets" - "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream/backend" - streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/onflow/flow-go/model/flow" - "github.com/onflow/flow-go/utils/unittest" -) - -var ( - chainID = flow.Testnet -) - -type WsHandlerSuite struct { - suite.Suite - - logger zerolog.Logger - handler *websockets.Handler - wsConfig websockets.Config - streamApi *streammock.API - streamConfig backend.Config -} - -func (s *WsHandlerSuite) SetupTest() { - s.logger = unittest.Logger() - s.wsConfig = websockets.NewDefaultWebsocketConfig() - s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = backend.Config{} - s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, s.streamConfig, 1024) -} - -func TestWsHandlerSuite(t *testing.T) { - suite.Run(t, new(WsHandlerSuite)) -} - -func ClientConnection(url string) (*websocket.Conn, *http.Response, error) { - wsURL := "ws" + strings.TrimPrefix(url, "http") - return websocket.DefaultDialer.Dial(wsURL, nil) -} - -func (s *WsHandlerSuite) TestSubscribeRequest() { - s.Run("Happy path", func() { - server := httptest.NewServer(s.handler) - defer server.Close() - - conn, _, err := ClientConnection(server.URL) - defer func(conn *websocket.Conn) { - err := conn.Close() - require.NoError(s.T(), err) - }(conn) - require.NoError(s.T(), err) - - args := map[string]interface{}{ - "start_block_height": 10, - } - body := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", - Arguments: args, - } - bodyJSON, err := json.Marshal(body) - require.NoError(s.T(), err) - - err = conn.WriteMessage(websocket.TextMessage, bodyJSON) - require.NoError(s.T(), err) - - _, msg, err := conn.ReadMessage() - require.NoError(s.T(), err) - - actualMsg := strings.Trim(string(msg), "\n\"\\ ") - require.Equal(s.T(), "block{height: 42}", actualMsg) - }) -} diff --git a/engine/access/rest/websockets/legacy/request/subscribe_events.go b/engine/access/rest/websockets/legacy/request/subscribe_events.go index 5b2574ccc82..1110d3582d4 100644 --- a/engine/access/rest/websockets/legacy/request/subscribe_events.go +++ b/engine/access/rest/websockets/legacy/request/subscribe_events.go @@ -5,6 +5,7 @@ import ( "strconv" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/http/request" "github.com/onflow/flow-go/model/flow" ) @@ -56,7 +57,7 @@ func (g *SubscribeEvents) Parse( rawContracts []string, rawHeartbeatInterval string, ) error { - var startBlockID request.ID + var startBlockID parser.ID err := startBlockID.Parse(rawStartBlockID) if err != nil { return err diff --git a/engine/access/rest/websockets/models/block_models.go b/engine/access/rest/websockets/models/block_models.go new file mode 100644 index 00000000000..fa7af987236 --- /dev/null +++ b/engine/access/rest/websockets/models/block_models.go @@ -0,0 +1,26 @@ +package models + +import ( + "github.com/onflow/flow-go/model/flow" +) + +// BlockMessageResponse is the response message for 'blocks' topic. +type BlockMessageResponse struct { + // The sealed or finalized blocks according to the block status + // in the request. + Block *flow.Block `json:"block"` +} + +// BlockHeaderMessageResponse is the response message for 'block_headers' topic. +type BlockHeaderMessageResponse struct { + // The sealed or finalized block headers according to the block status + // in the request. + Header *flow.Header `json:"header"` +} + +// BlockDigestMessageResponse is the response message for 'block_digests' topic. +type BlockDigestMessageResponse struct { + // The sealed or finalized block digest according to the block status + // in the request. + Block *flow.BlockDigest `json:"block_digest"` +} diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe.go index 993bd63b811..95ad17e3708 100644 --- a/engine/access/rest/websockets/models/subscribe.go +++ b/engine/access/rest/websockets/models/subscribe.go @@ -1,10 +1,12 @@ package models +type Arguments map[string]string + // SubscribeMessageRequest represents a request to subscribe to a topic. type SubscribeMessageRequest struct { BaseMessageRequest - Topic string `json:"topic"` // Topic to subscribe to - Arguments map[string]interface{} `json:"arguments"` // Additional arguments for subscription + Topic string `json:"topic"` // Topic to subscribe to + Arguments Arguments `json:"arguments"` // Additional arguments for subscription } // SubscribeMessageResponse represents the response to a subscription request. diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 651adb41a63..ca14d4fca72 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -22,7 +22,7 @@ import ( accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rest/common" - "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/router" "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" @@ -232,8 +232,8 @@ func TestRestAPI(t *testing.T) { func (suite *RestAPITestSuite) TestGetBlock() { - testBlockIDs := make([]string, request.MaxIDsLength) - testBlocks := make([]*flow.Block, request.MaxIDsLength) + testBlockIDs := make([]string, parser.MaxIDsLength) + testBlocks := make([]*flow.Block, parser.MaxIDsLength) for i := range testBlockIDs { collections := unittest.CollectionListFixture(1) block := unittest.BlockWithGuaranteesFixture( @@ -283,7 +283,7 @@ func (suite *RestAPITestSuite) TestGetBlock() { actualBlocks, resp, err := client.BlocksApi.BlocksIdGet(ctx, blockIDSlice, optionsForBlockByID()) require.NoError(suite.T(), err) assert.Equal(suite.T(), http.StatusOK, resp.StatusCode) - assert.Len(suite.T(), actualBlocks, request.MaxIDsLength) + assert.Len(suite.T(), actualBlocks, parser.MaxIDsLength) for i, b := range testBlocks { assert.Equal(suite.T(), b.ID().String(), actualBlocks[i].Header.Id) } @@ -381,13 +381,13 @@ func (suite *RestAPITestSuite) TestGetBlock() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - blockIDs := make([]string, request.MaxIDsLength+1) + blockIDs := make([]string, parser.MaxIDsLength+1) copy(blockIDs, testBlockIDs) - blockIDs[request.MaxIDsLength] = unittest.IdentifierFixture().String() + blockIDs[parser.MaxIDsLength] = unittest.IdentifierFixture().String() blockIDSlice := []string{strings.Join(blockIDs, ",")} _, resp, err := client.BlocksApi.BlocksIdGet(ctx, blockIDSlice, optionsForBlockByID()) - assertError(suite.T(), resp, err, http.StatusBadRequest, fmt.Sprintf("at most %d IDs can be requested at a time", request.MaxIDsLength)) + assertError(suite.T(), resp, err, http.StatusBadRequest, fmt.Sprintf("at most %d IDs can be requested at a time", parser.MaxIDsLength)) }) suite.Run("GetBlockByID with one non-existing block ID", func() { diff --git a/engine/access/state_stream/backend/handler.go b/engine/access/state_stream/backend/handler.go index b2066440bb8..3acf1bad6ca 100644 --- a/engine/access/state_stream/backend/handler.go +++ b/engine/access/state_stream/backend/handler.go @@ -102,7 +102,7 @@ func (h *Handler) SubscribeExecutionData(request *executiondata.SubscribeExecuti sub := h.api.SubscribeExecutionData(stream.Context(), startBlockID, request.GetStartBlockHeight()) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeExecutionDataFromStartBlockID handles subscription requests for @@ -129,7 +129,7 @@ func (h *Handler) SubscribeExecutionDataFromStartBlockID(request *executiondata. sub := h.api.SubscribeExecutionDataFromStartBlockID(stream.Context(), startBlockID) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeExecutionDataFromStartBlockHeight handles subscription requests for @@ -150,7 +150,7 @@ func (h *Handler) SubscribeExecutionDataFromStartBlockHeight(request *executiond sub := h.api.SubscribeExecutionDataFromStartBlockHeight(stream.Context(), request.GetStartBlockHeight()) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeExecutionDataFromLatest handles subscription requests for @@ -171,7 +171,7 @@ func (h *Handler) SubscribeExecutionDataFromLatest(request *executiondata.Subscr sub := h.api.SubscribeExecutionDataFromLatest(stream.Context()) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeEvents is deprecated and will be removed in a future version. @@ -213,7 +213,7 @@ func (h *Handler) SubscribeEvents(request *executiondata.SubscribeEventsRequest, sub := h.api.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // SubscribeEventsFromStartBlockID handles subscription requests for events starting at the specified block ID. @@ -248,7 +248,7 @@ func (h *Handler) SubscribeEventsFromStartBlockID(request *executiondata.Subscri sub := h.api.SubscribeEventsFromStartBlockID(stream.Context(), startBlockID, filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // SubscribeEventsFromStartHeight handles subscription requests for events starting at the specified block height. @@ -278,7 +278,7 @@ func (h *Handler) SubscribeEventsFromStartHeight(request *executiondata.Subscrib sub := h.api.SubscribeEventsFromStartHeight(stream.Context(), request.GetStartBlockHeight(), filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // SubscribeEventsFromLatest handles subscription requests for events started from latest sealed block.. @@ -308,7 +308,7 @@ func (h *Handler) SubscribeEventsFromLatest(request *executiondata.SubscribeEven sub := h.api.SubscribeEventsFromLatest(stream.Context(), filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // handleSubscribeExecutionData handles the subscription to execution data and sends it to the client via the provided stream. @@ -546,7 +546,7 @@ func (h *Handler) SubscribeAccountStatusesFromStartBlockID( sub := h.api.SubscribeAccountStatusesFromStartBlockID(stream.Context(), startBlockID, filter) - return subscription.HandleSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) } // SubscribeAccountStatusesFromStartHeight streams account statuses for all blocks starting at the requested @@ -573,7 +573,7 @@ func (h *Handler) SubscribeAccountStatusesFromStartHeight( sub := h.api.SubscribeAccountStatusesFromStartHeight(stream.Context(), request.GetStartBlockHeight(), filter) - return subscription.HandleSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) } // SubscribeAccountStatusesFromLatestBlock streams account statuses for all blocks starting @@ -600,5 +600,5 @@ func (h *Handler) SubscribeAccountStatusesFromLatestBlock( sub := h.api.SubscribeAccountStatusesFromLatestBlock(stream.Context(), filter) - return subscription.HandleSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) } diff --git a/engine/access/subscription/util.go b/engine/access/subscription/util.go index 593f3d78499..9ef98044bb8 100644 --- a/engine/access/subscription/util.go +++ b/engine/access/subscription/util.go @@ -1,8 +1,9 @@ package subscription import ( + "fmt" + "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/onflow/flow-go/engine/common/rpc" ) @@ -14,21 +15,20 @@ import ( // - sub: The subscription. // - handleResponse: The function responsible for handling the response of the subscribed type. // -// Expected errors during normal operation: -// - codes.Internal: If the subscription encounters an error or gets an unexpected response. +// No errors are expected during normal operations. func HandleSubscription[T any](sub Subscription, handleResponse func(resp T) error) error { for { v, ok := <-sub.Channel() if !ok { if sub.Err() != nil { - return rpc.ConvertError(sub.Err(), "stream encountered an error", codes.Internal) + return fmt.Errorf("stream encountered an error: %w", sub.Err()) } return nil } resp, ok := v.(T) if !ok { - return status.Errorf(codes.Internal, "unexpected response type: %T", v) + return fmt.Errorf("unexpected response type: %T", v) } err := handleResponse(resp) @@ -37,3 +37,42 @@ func HandleSubscription[T any](sub Subscription, handleResponse func(resp T) err } } } + +// HandleRPCSubscription is a generic handler for subscriptions to a specific type for rpc calls. +// +// Parameters: +// - sub: The subscription. +// - handleResponse: The function responsible for handling the response of the subscribed type. +// +// Expected errors during normal operation: +// - codes.Internal: If the subscription encounters an error or gets an unexpected response. +func HandleRPCSubscription[T any](sub Subscription, handleResponse func(resp T) error) error { + err := HandleSubscription(sub, handleResponse) + if err != nil { + return rpc.ConvertError(err, "handle subscription error", codes.Internal) + } + + return nil +} + +// HandleResponse processes a generic response of type and sends it to the provided channel. +// +// Parameters: +// - send: The channel to which the processed response is sent. +// - transform: A function to transform the response into the expected interface{} type. +// +// No errors are expected during normal operations. +func HandleResponse[T any](send chan<- interface{}, transform func(resp T) (interface{}, error)) func(resp T) error { + return func(response T) error { + // Transform the response + resp, err := transform(response) + if err != nil { + return fmt.Errorf("failed to transform response: %w", err) + } + + // send to the channel + send <- resp + + return nil + } +}