Skip to content

Commit

Permalink
Merge pull request #99 from rollbar/pawel/add_context_support
Browse files Browse the repository at this point in the history
add context support
  • Loading branch information
pawelsz-rb authored Sep 27, 2022
2 parents 204202a + d66f9f5 commit 54913f7
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 8 deletions.
33 changes: 30 additions & 3 deletions async_transport.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rollbar

import (
"context"
"fmt"
"runtime"
"sync"
Expand All @@ -10,6 +11,7 @@ import (
// AsyncTransport is a concrete implementation of the Transport type which communicates with the
// Rollbar API asynchronously using a buffered channel.
type AsyncTransport struct {
ctx context.Context
baseTransport
// Buffer is the size of the channel used for queueing asynchronous payloads for sending to
// Rollbar.
Expand All @@ -36,7 +38,7 @@ func isClosed(ch chan payload) bool {
// NewAsyncTransport builds an asynchronous transport which sends data to the Rollbar API at the
// specified endpoint using the given access token. The channel is limited to the size of the input
// buffer argument.
func NewAsyncTransport(token string, endpoint string, buffer int) *AsyncTransport {
func NewAsyncTransport(token string, endpoint string, buffer int, opts ...transportOption) *AsyncTransport {
transport := &AsyncTransport{
baseTransport: baseTransport{
Token: token,
Expand All @@ -48,7 +50,14 @@ func NewAsyncTransport(token string, endpoint string, buffer int) *AsyncTranspor
bodyChannel: make(chan payload, buffer),
Buffer: buffer,
}

for _, opt := range opts {
// Call the option giving the instantiated
// Transport as the argument
opt(transport)
}
if transport.ctx == nil {
transport.ctx = context.Background()
}
go func() {
defer func() {
if r := recover(); r != nil {
Expand All @@ -74,6 +83,10 @@ func NewAsyncTransport(token string, endpoint string, buffer int) *AsyncTranspor
if canRetry && p.retriesLeft > 0 {
p.retriesLeft -= 1
select {
case <-transport.ctx.Done(): // check for early termination
writePayloadToStderr(transport.Logger, p.body)
transport.waitGroup.Done()
return
case transport.bodyChannel <- p:
default:
// This can happen if the bodyChannel had an item added to it from another
Expand Down Expand Up @@ -125,7 +138,13 @@ func (t *AsyncTransport) Send(body map[string]interface{}) (err error) {
body: body,
retriesLeft: t.RetryAttempts,
}
t.bodyChannel <- p
select {
case <-t.ctx.Done(): // check for early termination
writePayloadToStderr(t.Logger, body)
return t.ctx.Err()
case t.bodyChannel <- p:
default:
}
} else {
err = ErrBufferFull{}
rollbarError(t.Logger, err.Error())
Expand All @@ -148,3 +167,11 @@ func (t *AsyncTransport) Close() error {
t.Wait()
return nil
}

func (t *AsyncTransport) setContext(ctx context.Context) {
t.ctx = ctx
}

func (t *AsyncTransport) getContext() context.Context {
return t.ctx
}
27 changes: 25 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
// independent instances of a Client, then you can use the constructors provided for this
// type.
type Client struct {
ctx context.Context
io.Closer
// Transport used to send data to the Rollbar API. By default an asynchronous
// implementation of the Transport interface is used.
Expand All @@ -29,23 +30,41 @@ type Client struct {
diagnostic diagnostic
}

type clientOption func(*Client)

func WithClientContext(ctx context.Context) clientOption {
return func(c *Client) {
c.ctx = ctx
}
}

// New returns the default implementation of a Client.
// This uses the AsyncTransport.
func New(token, environment, codeVersion, serverHost, serverRoot string) *Client {
return NewAsync(token, environment, codeVersion, serverHost, serverRoot)
}

// NewAsync builds a Client with the asynchronous implementation of the transport interface.
func NewAsync(token, environment, codeVersion, serverHost, serverRoot string) *Client {
func NewAsync(token, environment, codeVersion, serverHost, serverRoot string, opts ...clientOption) *Client {
configuration := createConfiguration(token, environment, codeVersion, serverHost, serverRoot)
transport := NewTransport(token, configuration.endpoint)
diagnostic := createDiagnostic()
return &Client{
c := &Client{
Transport: transport,
Telemetry: NewTelemetry(nil),
configuration: configuration,
diagnostic: diagnostic,
}
for _, opt := range opts {
// Call the option giving the instantiated
// *Client as the argument
opt(c)
}
if c.ctx == nil {
c.ctx = context.Background()
}
c.Transport.setContext(c.ctx)
return c
}

// NewSync builds a Client with the synchronous implementation of the transport interface.
Expand Down Expand Up @@ -77,6 +96,10 @@ func (c *Client) CaptureTelemetryEvent(eventType, eventlevel string, eventData m
func (c *Client) SetTelemetry(options ...OptionFunc) {
c.Telemetry = NewTelemetry(c.configuration.scrubHeaders, options...)
}
func (c *Client) SetContext(ctx context.Context) {
c.ctx = ctx
c.Transport.setContext(ctx)
}

// SetEnabled sets whether or not Rollbar is enabled.
// If this is true then this library works as normal.
Expand Down
2 changes: 2 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func (t *TestTransport) Close() error {
func (t *TestTransport) Wait() {
t.WaitCalled = true
}
func (t *TestTransport) setContext(ctx context.Context) {
}

func (t *TestTransport) SetToken(_t string) {}
func (t *TestTransport) SetEndpoint(_e string) {}
Expand Down
4 changes: 4 additions & 0 deletions rollbar.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ var DefaultStackTracer StackTracerFunc = func(err error) ([]runtime.Frame, bool)
return nil, false
}

func SetContext(ctx context.Context) {
std.SetContext(ctx)
}

// SetTelemetry sets the telemetry
func SetTelemetry(options ...OptionFunc) {
std.SetTelemetry(options...)
Expand Down
33 changes: 32 additions & 1 deletion rollbar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"runtime"
"strings"
"testing"
"time"
)

type CustomError struct {
Expand Down Expand Up @@ -104,6 +105,26 @@ func TestEverything(t *testing.T) {

type someNonstandardTypeForLogFailing struct{}

func TestSetContext(t *testing.T) {
SetToken(os.Getenv("TOKEN"))
SetEnvironment("test")
if std.ctx != context.Background() {
t.Error("Client ctx must be properly set")
}
tr := std.Transport.(*AsyncTransport)
if tr.getContext() != context.Background() {
t.Error("Transport ctx must be properly set")
}
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
SetContext(ctx)
if std.ctx != ctx {
t.Error("Client ctx must be properly set")
}
if tr.getContext() != ctx {
t.Error("Transport ctx must be properly set")
}
}

func TestEverythingGeneric(t *testing.T) {
SetToken(os.Getenv("TOKEN"))
SetEnvironment("test")
Expand All @@ -115,7 +136,6 @@ func TestEverythingGeneric(t *testing.T) {
if Environment() != "test" {
t.Error("Token should be as set")
}

Critical(errors.New("Normal generic critical error"))
Error(&CustomError{"This is a generic custom error"})

Expand Down Expand Up @@ -645,6 +665,17 @@ func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return s(r)
}

func TestNewAsyncWithContext(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), 4*time.Second)
client := NewAsync("example", "test", "0.0.0", "", "", WithClientContext(ctx))
if client.ctx != ctx {
t.Error("Client ctx must be properly set")
}
tr := client.Transport.(*AsyncTransport)
if tr.getContext() != ctx {
t.Error("Transport ctx must be properly set")
}
}
func TestSetHttpClient(t *testing.T) {
used := false
c := &http.Client{
Expand Down
3 changes: 3 additions & 0 deletions sync_transport.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rollbar

import (
"context"
"time"
)

Expand Down Expand Up @@ -64,3 +65,5 @@ func (t *SyncTransport) Wait() {}
func (t *SyncTransport) Close() error {
return nil
}
func (t *SyncTransport) setContext(ctx context.Context) {
}
15 changes: 13 additions & 2 deletions transport.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rollbar

import (
"context"
"fmt"
"io"
"log"
Expand All @@ -18,6 +19,14 @@ const (
DefaultRetryAttempts = 3
)

type transportOption func(Transport)

func WithTransportContext(ctx context.Context) transportOption {
return func(t Transport) {
t.setContext(ctx)
}
}

// Transport represents an object used for communicating with the Rollbar API.
type Transport interface {
io.Closer
Expand All @@ -42,6 +51,8 @@ type Transport interface {
SetHTTPClient(httpClient *http.Client)
// SetItemsPerMinute sets the max number of items to send in a given minute
SetItemsPerMinute(itemsPerMinute int)

setContext(ctx context.Context)
}

// ClientLogger is the interface used by the rollbar Client/Transport to report problems.
Expand All @@ -56,8 +67,8 @@ type SilentClientLogger struct{}
func (s *SilentClientLogger) Printf(format string, args ...interface{}) {}

// NewTransport creates a transport that sends items to the Rollbar API asynchronously.
func NewTransport(token, endpoint string) Transport {
return NewAsyncTransport(token, endpoint, DefaultBuffer)
func NewTransport(token, endpoint string, opts ...transportOption) Transport {
return NewAsyncTransport(token, endpoint, DefaultBuffer, opts...)
}

// -- rollbarError
Expand Down

0 comments on commit 54913f7

Please sign in to comment.