diff --git a/access/grpc/client.go b/access/grpc/client.go index ed31ed45e..a7aa3d7ba 100644 --- a/access/grpc/client.go +++ b/access/grpc/client.go @@ -165,6 +165,14 @@ func (c *Client) GetCollection(ctx context.Context, colID flow.Identifier) (*flo return c.grpc.GetCollection(ctx, colID) } +func (c *Client) GetCollectionByID(ctx context.Context, id flow.Identifier) (*flow.Collection, error) { + return c.grpc.GetLightCollectionByID(ctx, id) +} + +func (c *Client) GetFullCollectionByID(ctx context.Context, id flow.Identifier) (*flow.FullCollection, error) { + return c.grpc.GetFullCollectionByID(ctx, id) +} + func (c *Client) SendTransaction(ctx context.Context, tx flow.Transaction) error { return c.grpc.SendTransaction(ctx, tx) } diff --git a/access/grpc/convert/convert.go b/access/grpc/convert/convert.go index 1415db5f3..5238a4aab 100644 --- a/access/grpc/convert/convert.go +++ b/access/grpc/convert/convert.go @@ -267,6 +267,21 @@ func CollectionToMessage(c flow.Collection) *entities.Collection { } } +func FullCollectionToTransactionsMessage(tx flow.FullCollection) ([]*entities.Transaction, error) { + var convertedTxs []*entities.Transaction + + for _, tx := range tx.Transactions { + convertedTx, err := TransactionToMessage(*tx) + if err != nil { + return nil, err + } + + convertedTxs = append(convertedTxs, convertedTx) + } + + return convertedTxs, nil +} + func MessageToCollection(m *entities.Collection) (flow.Collection, error) { if m == nil { return flow.Collection{}, ErrEmptyMessage @@ -284,6 +299,21 @@ func MessageToCollection(m *entities.Collection) (flow.Collection, error) { }, nil } +func MessageToFullCollection(m []*entities.Transaction) (flow.FullCollection, error) { + var collection flow.FullCollection + + for _, tx := range m { + convertedTx, err := MessageToTransaction(tx) + if err != nil { + return flow.FullCollection{}, err + } + + collection.Transactions = append(collection.Transactions, &convertedTx) + } + + return collection, nil +} + func CollectionGuaranteeToMessage(g flow.CollectionGuarantee) *entities.CollectionGuarantee { return &entities.CollectionGuarantee{ CollectionId: g.CollectionID.Bytes(), diff --git a/access/grpc/convert/convert_test.go b/access/grpc/convert/convert_test.go index 480dc4f56..2c255a301 100644 --- a/access/grpc/convert/convert_test.go +++ b/access/grpc/convert/convert_test.go @@ -139,7 +139,7 @@ func TestConvert_CadenceValue(t *testing.T) { } func TestConvert_Collection(t *testing.T) { - colA := test.CollectionGenerator().New() + colA := test.LightCollectionGenerator().New() msg := CollectionToMessage(*colA) diff --git a/access/grpc/grpc.go b/access/grpc/grpc.go index 853a16e78..ca6aff34f 100644 --- a/access/grpc/grpc.go +++ b/access/grpc/grpc.go @@ -318,6 +318,50 @@ func (c *BaseClient) GetCollection( return &result, nil } +func (c *BaseClient) GetLightCollectionByID( + ctx context.Context, + id flow.Identifier, + opts ...grpc.CallOption, +) (*flow.Collection, error) { + req := &access.GetCollectionByIDRequest{ + Id: id.Bytes(), + } + + res, err := c.rpcClient.GetCollectionByID(ctx, req, opts...) + if err != nil { + return nil, newRPCError(err) + } + + result, err := convert.MessageToCollection(res.GetCollection()) + if err != nil { + return nil, newMessageToEntityError(entityCollection, err) + } + + return &result, nil +} + +func (c *BaseClient) GetFullCollectionByID( + ctx context.Context, + id flow.Identifier, + opts ...grpc.CallOption, +) (*flow.FullCollection, error) { + req := &access.GetFullCollectionByIDRequest{ + Id: id.Bytes(), + } + + res, err := c.rpcClient.GetFullCollectionByID(ctx, req, opts...) + if err != nil { + return nil, newRPCError(err) + } + + result, err := convert.MessageToFullCollection(res.GetTransactions()) + if err != nil { + return nil, newMessageToEntityError(entityCollection, err) + } + + return &result, nil +} + func (c *BaseClient) SendTransaction( ctx context.Context, tx flow.Transaction, diff --git a/access/grpc/grpc_test.go b/access/grpc/grpc_test.go index 34a7f4a51..3ea6fb980 100644 --- a/access/grpc/grpc_test.go +++ b/access/grpc/grpc_test.go @@ -386,7 +386,7 @@ func TestClient_GetBlockByHeight(t *testing.T) { } func TestClient_GetCollection(t *testing.T) { - cols := test.CollectionGenerator() + cols := test.LightCollectionGenerator() ids := test.IdentifierGenerator() t.Run("Success", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { @@ -417,6 +417,44 @@ func TestClient_GetCollection(t *testing.T) { })) } +func TestClient_GetFullCollectionById(t *testing.T) { + collections := test.FullCollectionGenerator() + ids := test.IdentifierGenerator() + + t.Run("Success", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + expectedCollection := collections.New() + txs, err := convert.FullCollectionToTransactionsMessage(*expectedCollection) + require.NoError(t, err) + + response := &access.FullCollectionResponse{ + Transactions: txs, + } + + rpc. + On("GetFullCollectionByID", ctx, mock.Anything). + Return(response, nil) + + id := ids.New() + actualCollection, err := c.GetFullCollectionByID(ctx, id) + require.NoError(t, err) + + require.Equal(t, expectedCollection, actualCollection) + + })) + + t.Run("Not found error", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + rpc. + On("GetFullCollectionByID", ctx, mock.Anything). + Return(nil, errNotFound) + + id := ids.New() + col, err := c.GetFullCollectionByID(ctx, id) + assert.Error(t, err) + assert.Equal(t, codes.NotFound, status.Code(err)) + assert.Nil(t, col) + })) +} + func TestClient_SendTransaction(t *testing.T) { transactions := test.TransactionGenerator() diff --git a/access/http/internal/unittest/fixtures.go b/access/http/internal/unittest/fixtures.go index 81881c390..2c5695202 100644 --- a/access/http/internal/unittest/fixtures.go +++ b/access/http/internal/unittest/fixtures.go @@ -99,7 +99,7 @@ func BlockFlowFixture() models.Block { } func CollectionFlowFixture() models.Collection { - collection := test.CollectionGenerator().New() + collection := test.LightCollectionGenerator().New() return models.Collection{ Id: collection.ID().String(), diff --git a/collection.go b/collection.go index bc4e02645..3de5fff16 100644 --- a/collection.go +++ b/collection.go @@ -47,3 +47,20 @@ func (c Collection) Encode() []byte { type CollectionGuarantee struct { CollectionID Identifier } + +type FullCollection struct { + Transactions []*Transaction +} + +// Light returns the light, reference-only version of the collection. +func (c FullCollection) Light() Collection { + lc := Collection{TransactionIDs: make([]Identifier, 0, len(c.Transactions))} + for _, tx := range c.Transactions { + lc.TransactionIDs = append(lc.TransactionIDs, tx.ID()) + } + return lc +} + +func (c FullCollection) ID() Identifier { + return c.Light().ID() +} diff --git a/test/entities.go b/test/entities.go index 0f00be42b..689078d12 100644 --- a/test/entities.go +++ b/test/entities.go @@ -188,17 +188,17 @@ func (g *BlockHeaders) New() flow.BlockHeader { } } -type Collections struct { +type LightCollection struct { ids *Identifiers } -func CollectionGenerator() *Collections { - return &Collections{ +func LightCollectionGenerator() *LightCollection { + return &LightCollection{ ids: IdentifierGenerator(), } } -func (g *Collections) New() *flow.Collection { +func (g *LightCollection) New() *flow.Collection { return &flow.Collection{ TransactionIDs: []flow.Identifier{ g.ids.New(), @@ -207,6 +207,22 @@ func (g *Collections) New() *flow.Collection { } } +type FullCollection struct { + Transactions *Transactions +} + +func FullCollectionGenerator() *FullCollection { + return &FullCollection{ + Transactions: TransactionGenerator(), + } +} + +func (c *FullCollection) New() *flow.FullCollection { + return &flow.FullCollection{ + Transactions: []*flow.Transaction{c.Transactions.New(), c.Transactions.New()}, + } +} + type CollectionGuarantees struct { ids *Identifiers }