Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add failure test for kinesis-consumer #69

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pkg/providers/kinesis/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@ package consumer

import (
"context"
"github.com/aws/aws-sdk-go/aws/request"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
"github.com/doublecloud/transfer/internal/logger"
"github.com/doublecloud/transfer/library/go/core/xerrors"
"github.com/doublecloud/transfer/library/go/slices"
"go.ytsaurus.tech/library/go/core/log"
)

// KinesisReader is a lightweight interface that narrow down usage to just what really needed by this code
type KinesisReader interface {
ListShards(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error)
GetRecords(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
GetShardIteratorWithContext(aws.Context, *kinesis.GetShardIteratorInput, ...request.Option) (*kinesis.GetShardIteratorOutput, error)
}

// Record wraps the record returned from the Kinesis library and
// extends to include the shard id.
type Record struct {
Expand Down Expand Up @@ -65,7 +72,7 @@ type Consumer struct {
streamName string
initialShardIteratorType string
initialTimestamp *time.Time
client kinesisiface.KinesisAPI
client KinesisReader
group Group
logger log.Logger
store Store
Expand Down
7 changes: 3 additions & 4 deletions pkg/providers/kinesis/consumer/group_all.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
"github.com/doublecloud/transfer/library/go/core/xerrors"
"go.ytsaurus.tech/library/go/core/log"
)

// NewAllGroup returns an intitialized AllGroup for consuming
// all shards on a stream
func NewAllGroup(ksis kinesisiface.KinesisAPI, store Store, streamName string, logger log.Logger) *AllGroup {
func NewAllGroup(ksis KinesisReader, store Store, streamName string, logger log.Logger) *AllGroup {
return &AllGroup{
Store: store,
ksis: ksis,
Expand All @@ -31,7 +30,7 @@ func NewAllGroup(ksis kinesisiface.KinesisAPI, store Store, streamName string, l
type AllGroup struct {
Store

ksis kinesisiface.KinesisAPI
ksis KinesisReader
streamName string
logger log.Logger

Expand Down Expand Up @@ -88,7 +87,7 @@ func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) {
}

// listShards pulls a list of shard IDs from the kinesis api
func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) {
func listShards(ksis KinesisReader, streamName string) ([]*kinesis.Shard, error) {
var ss []*kinesis.Shard
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(streamName),
Expand Down
4 changes: 1 addition & 3 deletions pkg/providers/kinesis/consumer/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package consumer

import (
"time"

"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)

// Option is used to override defaults when creating a new Consumer
Expand All @@ -24,7 +22,7 @@ func WithStore(store Store) Option {
}

// WithClient overrides the default client
func WithClient(client kinesisiface.KinesisAPI) Option {
func WithClient(client KinesisReader) Option {
return func(c *Consumer) {
c.client = client
}
Expand Down
85 changes: 85 additions & 0 deletions pkg/providers/kinesis/source_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package kinesis

import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/doublecloud/transfer/internal/logger"
"github.com/doublecloud/transfer/library/go/core/metrics/solomon"
"github.com/doublecloud/transfer/library/go/core/xerrors"
"github.com/doublecloud/transfer/pkg/abstract"
"github.com/doublecloud/transfer/pkg/abstract/coordinator"
"github.com/doublecloud/transfer/pkg/parsequeue"
"github.com/doublecloud/transfer/pkg/providers/kinesis/consumer"
"github.com/doublecloud/transfer/pkg/stats"
"github.com/stretchr/testify/require"
"testing"
"time"
)

type fakeClient struct {
cntr int
}

func (f *fakeClient) ListShards(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{Shards: []*kinesis.Shard{
{ShardId: aws.String("s-1")},
{ShardId: aws.String("s-2")},
{ShardId: aws.String("s-3")},
}}, nil
}

func (f *fakeClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
f.cntr++
if f.cntr < 3 {
return &kinesis.GetRecordsOutput{
Records: []*kinesis.Record{
{
ApproximateArrivalTimestamp: aws.Time(time.Now()),
Data: []byte("test"),
EncryptionType: nil,
PartitionKey: nil,
SequenceNumber: aws.String(fmt.Sprintf("s1-%v", f.cntr)),
},
},
NextShardIterator: aws.String("next-1"),
}, nil
}
return nil, awserr.New("non-retryable-code", "asd", xerrors.New("demo error"))
}

func (f *fakeClient) GetShardIteratorWithContext(a aws.Context, input *kinesis.GetShardIteratorInput, option ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("s1"),
}, nil
}

type mockSync struct {
}

func (m mockSync) Close() error {
return nil
}

func (m mockSync) AsyncPush(items []abstract.ChangeItem) chan error {
resCh := make(chan error)
return resCh
}

func TestFailure(t *testing.T) {
var err error
s := new(Source)
s.cp = coordinator.NewFakeClient()
s.logger = logger.Log
s.ctx = context.Background()
s.config = new(KinesisSource)
s.config.WithDefaults()
s.metrics = stats.NewSourceStats(solomon.NewRegistry(solomon.NewRegistryOpts()))
s.consumer, err = consumer.New("abc", consumer.WithClient(&fakeClient{}))
require.NoError(t, err)
parseQ := parsequeue.NewWaitable(s.logger, 10, &mockSync{}, s.parse, s.ack)
require.Error(t, s.run(parseQ))
}
Loading