diff --git a/async_transport.go b/async_transport.go index 747fd65..b935b43 100644 --- a/async_transport.go +++ b/async_transport.go @@ -1,6 +1,7 @@ package rollbar import ( + "context" "fmt" "runtime" "sync" @@ -10,6 +11,7 @@ import ( // AsyncTransport is a concrete implementation of the Transport type which communicates with the // Rollbar API asynchronously using a buffered channel. type AsyncTransport struct { + ctx context.Context baseTransport // Buffer is the size of the channel used for queueing asynchronous payloads for sending to // Rollbar. @@ -36,7 +38,7 @@ func isClosed(ch chan payload) bool { // NewAsyncTransport builds an asynchronous transport which sends data to the Rollbar API at the // specified endpoint using the given access token. The channel is limited to the size of the input // buffer argument. -func NewAsyncTransport(token string, endpoint string, buffer int) *AsyncTransport { +func NewAsyncTransport(token string, endpoint string, buffer int, opts ...transportOption) *AsyncTransport { transport := &AsyncTransport{ baseTransport: baseTransport{ Token: token, @@ -48,7 +50,14 @@ func NewAsyncTransport(token string, endpoint string, buffer int) *AsyncTranspor bodyChannel: make(chan payload, buffer), Buffer: buffer, } - + for _, opt := range opts { + // Call the option giving the instantiated + // Transport as the argument + opt(transport) + } + if transport.ctx == nil { + transport.ctx = context.Background() + } go func() { defer func() { if r := recover(); r != nil { @@ -74,6 +83,10 @@ func NewAsyncTransport(token string, endpoint string, buffer int) *AsyncTranspor if canRetry && p.retriesLeft > 0 { p.retriesLeft -= 1 select { + case <-transport.ctx.Done(): // check for early termination + writePayloadToStderr(transport.Logger, p.body) + transport.waitGroup.Done() + return case transport.bodyChannel <- p: default: // This can happen if the bodyChannel had an item added to it from another @@ -125,7 +138,13 @@ func (t *AsyncTransport) Send(body map[string]interface{}) (err error) { body: body, retriesLeft: t.RetryAttempts, } - t.bodyChannel <- p + select { + case <-t.ctx.Done(): // check for early termination + writePayloadToStderr(t.Logger, body) + return t.ctx.Err() + case t.bodyChannel <- p: + default: + } } else { err = ErrBufferFull{} rollbarError(t.Logger, err.Error()) @@ -148,3 +167,11 @@ func (t *AsyncTransport) Close() error { t.Wait() return nil } + +func (t *AsyncTransport) setContext(ctx context.Context) { + t.ctx = ctx +} + +func (t *AsyncTransport) getContext() context.Context { + return t.ctx +} diff --git a/client.go b/client.go index 8ebd1fe..da6b524 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,7 @@ import ( // independent instances of a Client, then you can use the constructors provided for this // type. type Client struct { + ctx context.Context io.Closer // Transport used to send data to the Rollbar API. By default an asynchronous // implementation of the Transport interface is used. @@ -29,6 +30,14 @@ type Client struct { diagnostic diagnostic } +type clientOption func(*Client) + +func WithClientContext(ctx context.Context) clientOption { + return func(c *Client) { + c.ctx = ctx + } +} + // New returns the default implementation of a Client. // This uses the AsyncTransport. func New(token, environment, codeVersion, serverHost, serverRoot string) *Client { @@ -36,16 +45,26 @@ func New(token, environment, codeVersion, serverHost, serverRoot string) *Client } // NewAsync builds a Client with the asynchronous implementation of the transport interface. -func NewAsync(token, environment, codeVersion, serverHost, serverRoot string) *Client { +func NewAsync(token, environment, codeVersion, serverHost, serverRoot string, opts ...clientOption) *Client { configuration := createConfiguration(token, environment, codeVersion, serverHost, serverRoot) transport := NewTransport(token, configuration.endpoint) diagnostic := createDiagnostic() - return &Client{ + c := &Client{ Transport: transport, Telemetry: NewTelemetry(nil), configuration: configuration, diagnostic: diagnostic, } + for _, opt := range opts { + // Call the option giving the instantiated + // *Client as the argument + opt(c) + } + if c.ctx == nil { + c.ctx = context.Background() + } + c.Transport.setContext(c.ctx) + return c } // NewSync builds a Client with the synchronous implementation of the transport interface. @@ -77,6 +96,10 @@ func (c *Client) CaptureTelemetryEvent(eventType, eventlevel string, eventData m func (c *Client) SetTelemetry(options ...OptionFunc) { c.Telemetry = NewTelemetry(c.configuration.scrubHeaders, options...) } +func (c *Client) SetContext(ctx context.Context) { + c.ctx = ctx + c.Transport.setContext(ctx) +} // SetEnabled sets whether or not Rollbar is enabled. // If this is true then this library works as normal. diff --git a/client_test.go b/client_test.go index ee7eea4..fb0498d 100644 --- a/client_test.go +++ b/client_test.go @@ -23,6 +23,8 @@ func (t *TestTransport) Close() error { func (t *TestTransport) Wait() { t.WaitCalled = true } +func (t *TestTransport) setContext(ctx context.Context) { +} func (t *TestTransport) SetToken(_t string) {} func (t *TestTransport) SetEndpoint(_e string) {} diff --git a/rollbar.go b/rollbar.go index e68270e..7680c30 100644 --- a/rollbar.go +++ b/rollbar.go @@ -90,6 +90,10 @@ var DefaultStackTracer StackTracerFunc = func(err error) ([]runtime.Frame, bool) return nil, false } +func SetContext(ctx context.Context) { + std.SetContext(ctx) +} + // SetTelemetry sets the telemetry func SetTelemetry(options ...OptionFunc) { std.SetTelemetry(options...) diff --git a/rollbar_test.go b/rollbar_test.go index 3a2e367..84c0ed5 100644 --- a/rollbar_test.go +++ b/rollbar_test.go @@ -10,6 +10,7 @@ import ( "runtime" "strings" "testing" + "time" ) type CustomError struct { @@ -104,6 +105,26 @@ func TestEverything(t *testing.T) { type someNonstandardTypeForLogFailing struct{} +func TestSetContext(t *testing.T) { + SetToken(os.Getenv("TOKEN")) + SetEnvironment("test") + if std.ctx != context.Background() { + t.Error("Client ctx must be properly set") + } + tr := std.Transport.(*AsyncTransport) + if tr.getContext() != context.Background() { + t.Error("Transport ctx must be properly set") + } + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + SetContext(ctx) + if std.ctx != ctx { + t.Error("Client ctx must be properly set") + } + if tr.getContext() != ctx { + t.Error("Transport ctx must be properly set") + } +} + func TestEverythingGeneric(t *testing.T) { SetToken(os.Getenv("TOKEN")) SetEnvironment("test") @@ -115,7 +136,6 @@ func TestEverythingGeneric(t *testing.T) { if Environment() != "test" { t.Error("Token should be as set") } - Critical(errors.New("Normal generic critical error")) Error(&CustomError{"This is a generic custom error"}) @@ -645,6 +665,17 @@ func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return s(r) } +func TestNewAsyncWithContext(t *testing.T) { + ctx, _ := context.WithTimeout(context.Background(), 4*time.Second) + client := NewAsync("example", "test", "0.0.0", "", "", WithClientContext(ctx)) + if client.ctx != ctx { + t.Error("Client ctx must be properly set") + } + tr := client.Transport.(*AsyncTransport) + if tr.getContext() != ctx { + t.Error("Transport ctx must be properly set") + } +} func TestSetHttpClient(t *testing.T) { used := false c := &http.Client{ diff --git a/sync_transport.go b/sync_transport.go index db54ad8..bcd276b 100644 --- a/sync_transport.go +++ b/sync_transport.go @@ -1,6 +1,7 @@ package rollbar import ( + "context" "time" ) @@ -64,3 +65,5 @@ func (t *SyncTransport) Wait() {} func (t *SyncTransport) Close() error { return nil } +func (t *SyncTransport) setContext(ctx context.Context) { +} diff --git a/transport.go b/transport.go index ccbc986..6c3ba4a 100644 --- a/transport.go +++ b/transport.go @@ -1,6 +1,7 @@ package rollbar import ( + "context" "fmt" "io" "log" @@ -18,6 +19,14 @@ const ( DefaultRetryAttempts = 3 ) +type transportOption func(Transport) + +func WithTransportContext(ctx context.Context) transportOption { + return func(t Transport) { + t.setContext(ctx) + } +} + // Transport represents an object used for communicating with the Rollbar API. type Transport interface { io.Closer @@ -42,6 +51,8 @@ type Transport interface { SetHTTPClient(httpClient *http.Client) // SetItemsPerMinute sets the max number of items to send in a given minute SetItemsPerMinute(itemsPerMinute int) + + setContext(ctx context.Context) } // ClientLogger is the interface used by the rollbar Client/Transport to report problems. @@ -56,8 +67,8 @@ type SilentClientLogger struct{} func (s *SilentClientLogger) Printf(format string, args ...interface{}) {} // NewTransport creates a transport that sends items to the Rollbar API asynchronously. -func NewTransport(token, endpoint string) Transport { - return NewAsyncTransport(token, endpoint, DefaultBuffer) +func NewTransport(token, endpoint string, opts ...transportOption) Transport { + return NewAsyncTransport(token, endpoint, DefaultBuffer, opts...) } // -- rollbarError