Skip to content

Commit

Permalink
Merge pull request #1 from dc0d/wip/after-batch
Browse files Browse the repository at this point in the history
add after batch func
  • Loading branch information
dc0d authored Aug 24, 2021
2 parents 606cb16 + 99e0523 commit bae2866
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 47 deletions.
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,19 @@
[![PkgGoDev](https://pkg.go.dev/badge/dc0d/sqstransport)](https://pkg.go.dev/github.com/dc0d/sqstransport)

# sqstransport
go-kit transport for sqs

This package contains a go-kit transport implementation for AWS SQS.

```go
sub := &Subscriber{
InputFactory: ..., // create a *sqs.ReceiveMessageInput instance,
DecodeRequest: ..., // decode the incoming message into an endpoint request object,
Handler: func(ctx context.Context, request interface{}) (response interface{}, err error) {
// handle the request,
},
ResponseHandler: ..., // handle the response,
}

go func() { _ = sub.Serve(client) }()
```

11 changes: 11 additions & 0 deletions subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/go-kit/kit/transport"
)

// Subscriber is a go-kit sqs transport.
type Subscriber struct {
// Before is optional. Can be used for starting a keep-in-flight hearbeat - an example.
// They run before DecodeRequest and can put additional data inside the context.
Expand All @@ -27,6 +28,9 @@ type Subscriber struct {
// Like deleting the message after being successfully processed.
ResponseHandler ResponseHandlerFunc

// AfterBatch is optional. It is called after a batch of messages passed to the Runner.
AfterBatch AfterBatchFunc

// InputFactory is required.
// It must return a non-nil params.
// It can return nil for optFns.
Expand All @@ -36,6 +40,8 @@ type Subscriber struct {
BaseContext context.Context

// Runner if not provided, the default runner will be used.
// All the Befor functions, decoding the message, handling the message
// and handling the response are executed by the Runner.
Runner Runner

// ErrorHandler is optional.
Expand Down Expand Up @@ -81,6 +87,10 @@ func (obj *Subscriber) Serve(l Client) error {
for _, msg := range output.Messages {
obj.runHandler(ctx, msg)
}

if obj.AfterBatch != nil {
obj.AfterBatch(ctx)
}
}
}

Expand Down Expand Up @@ -162,6 +172,7 @@ type (
RequestFunc func(context.Context, types.Message) context.Context
DecodeRequestFunc func(context.Context, types.Message) (request interface{}, err error)
ResponseHandlerFunc func(ctx context.Context, msg types.Message, response interface{}, err error)
AfterBatchFunc func(ctx context.Context)

Runner interface {
Run(func())
Expand Down
126 changes: 80 additions & 46 deletions subscriber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -21,14 +22,10 @@ import (
func Test_Subscriber_should_stop_when_shutdown_is_called(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

stopped := make(chan struct{})
sut := &Subscriber{
Handler: nopHandler,
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
onExit: func() { close(stopped) },
Expand All @@ -37,7 +34,7 @@ func Test_Subscriber_should_stop_when_shutdown_is_called(t *testing.T) {
serverStarted := make(chan struct{})
go func() {
close(serverStarted)
_ = sut.Serve(client)
_ = sut.Serve(mockClient10msec())
}()
<-serverStarted

Expand All @@ -57,16 +54,12 @@ func Test_Subscriber_should_stop_when_shutdown_is_called(t *testing.T) {
func Test_Subscriber_should_stop_when_base_context_is_canceled(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

ctx, cancel := context.WithCancel(context.Background())

stopped := make(chan struct{})
sut := &Subscriber{
Handler: nopHandler,
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
Expand All @@ -76,7 +69,7 @@ func Test_Subscriber_should_stop_when_base_context_is_canceled(t *testing.T) {
serverStarted := make(chan struct{})
go func() {
close(serverStarted)
_ = sut.Serve(client)
_ = sut.Serve(mockClient10msec())
}()
<-serverStarted

Expand All @@ -97,15 +90,13 @@ func Test_Subscriber_should_error_if_called_more_than_once(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sut := &Subscriber{
Handler: nopHandler,
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
Expand Down Expand Up @@ -171,10 +162,6 @@ func Test_Subscriber_should_call_the_handler_on_first_new_message(t *testing.T)
func Test_Subscriber_should_call_the_handler_on_each_new_message(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -188,13 +175,13 @@ func Test_Subscriber_should_call_the_handler_on_each_new_message(t *testing.T) {
incoming = append(incoming, fmt.Sprint(request))
return nil, nil
},
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
}

go func() { _ = sut.Serve(client) }()
go func() { _ = sut.Serve(mockClient10msec()) }()

assert.Eventually(t, func() bool {
incomingLock.Lock()
Expand All @@ -217,10 +204,6 @@ func Test_Subscriber_should_call_the_handler_on_each_new_message(t *testing.T) {
func Test_Subscriber_should_call_the_ResponseHandler_after_handler(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -233,7 +216,7 @@ func Test_Subscriber_should_call_the_ResponseHandler_after_handler(t *testing.T)
Handler: func(ctx context.Context, request interface{}) (response interface{}, err error) {
return expectedResponse, nil
},
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: func(ctx context.Context, msg types.Message, response interface{}, err error) {
actualResponseLock.Lock()
Expand All @@ -243,7 +226,7 @@ func Test_Subscriber_should_call_the_ResponseHandler_after_handler(t *testing.T)
BaseContext: ctx,
}

go func() { _ = sut.Serve(client) }()
go func() { _ = sut.Serve(mockClient10msec()) }()

assert.Eventually(t, func() bool {
actualResponseLock.Lock()
Expand All @@ -256,9 +239,6 @@ func Test_Subscriber_should_call_the_ResponseHandler_after_handler(t *testing.T)
func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_receive_message(t *testing.T) {
t.Parallel()

expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

expectedError := errors.New("an error")

client := &ClientSpy{
Expand All @@ -279,7 +259,7 @@ func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_receiv

sut := &Subscriber{
Handler: unreachableHandler,
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
Expand All @@ -297,12 +277,40 @@ func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_receiv
}, time.Millisecond*300, time.Millisecond*20)
}

func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_decode_request(t *testing.T) {
func Test_Subscriber_should_continue_if_error_handler_is_not_provided(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)
expectedError := errors.New("an error")

client := &ClientSpy{
ReceiveMessageFunc: func(
ctx context.Context,
params *sqs.ReceiveMessageInput,
optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
return nil, expectedError
},
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sut := &Subscriber{
Handler: unreachableHandler,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
}

go func() { _ = sut.Serve(client) }()

assert.Eventually(t, func() bool {
return len(client.ReceiveMessageCalls()) > 3
}, time.Millisecond*300, time.Millisecond*20)
}

func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_decode_request(t *testing.T) {
t.Parallel()

expectedError := errors.New("an error")

Expand All @@ -321,14 +329,14 @@ func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_decode
Handler: func(ctx context.Context, request interface{}) (response interface{}, err error) {
return nil, nil
},
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: decodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
ErrorHandler: errorHandler,
}

go func() { _ = sut.Serve(client) }()
go func() { _ = sut.Serve(mockClient10msec()) }()

assert.Eventually(t, func() bool {
if len(errorHandler.HandleCalls()) == 0 {
Expand All @@ -342,10 +350,6 @@ func Test_Subscriber_should_call_the_error_handler_on_returned_error_from_decode
func Test_Subscriber_should_call_the_before_functions(t *testing.T) {
t.Parallel()

client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

type contextKey string
const counterKey contextKey = "a_counter"

Expand All @@ -365,7 +369,7 @@ func Test_Subscriber_should_call_the_before_functions(t *testing.T) {

return nil, nil
},
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
Expand All @@ -380,7 +384,7 @@ func Test_Subscriber_should_call_the_before_functions(t *testing.T) {
},
}

go func() { _ = sut.Serve(client) }()
go func() { _ = sut.Serve(mockClient10msec()) }()

assert.Eventually(t, func() bool {
handlerCtxLock.Lock()
Expand All @@ -398,6 +402,34 @@ func Test_Subscriber_should_call_the_before_functions(t *testing.T) {
}, time.Millisecond*100, time.Millisecond*20)
}

func Test_Subscriber_should_call_AfterBatch_after_calling_the_handler_for_received_messages(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var afterBatchCalls int64

sut := &Subscriber{
Handler: func(ctx context.Context, request interface{}) (response interface{}, err error) {
return nil, nil
},
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
AfterBatch: func(ctx context.Context) {
atomic.AddInt64(&afterBatchCalls, 1)
},
}

go func() { _ = sut.Serve(mockClient10msec()) }()

assert.Eventually(t, func() bool {
return atomic.LoadInt64(&afterBatchCalls) > 3
}, time.Millisecond*100, time.Millisecond*20)
}

func Test_Subscriber_should_panic_if_any_before_function_returns_a_nil_context(t *testing.T) {
sut := &Subscriber{
Before: []RequestFunc{
Expand Down Expand Up @@ -500,8 +532,6 @@ func Test_Subscriber_init(t *testing.T) {

func ExampleSubscriber() {
client := mockClient10msec()
expectedInput := makeReceiveMessageInput()
inputFactory := makeInputFactory(expectedInput)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -522,7 +552,7 @@ func ExampleSubscriber() {

return nil, nil
},
InputFactory: inputFactory,
InputFactory: defaultInputFactory,
DecodeRequest: nopDecodeRequest,
ResponseHandler: nopResponseHandler,
BaseContext: ctx,
Expand Down Expand Up @@ -570,6 +600,10 @@ func mockClient10msec() *ClientSpy {
return client
}

func defaultInputFactory() (params *sqs.ReceiveMessageInput, optFns []func(*sqs.Options)) {
return makeReceiveMessageInput(), nil
}

func makeInputFactory(expectedInput *sqs.ReceiveMessageInput) func() (*sqs.ReceiveMessageInput, []func(*sqs.Options)) {
return func() (params *sqs.ReceiveMessageInput, optFns []func(*sqs.Options)) {
return expectedInput, nil
Expand Down
8 changes: 8 additions & 0 deletions test-watch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/sh

while true
do
watchman-wait -p "**/*.go" -- .
clear
make
done

0 comments on commit bae2866

Please sign in to comment.